diff --git a/.flake/pkgs/ffdb/ffdb.py b/.flake/pkgs/ffdb/ffdb.py index b5fc3956bf..befdab8849 100644 --- a/.flake/pkgs/ffdb/ffdb.py +++ b/.flake/pkgs/ffdb/ffdb.py @@ -1,8 +1,8 @@ -from proj.config_file import get_config_root +from proj import get_repo_root from pathlib import Path import gdb -gdb.execute(f'directory {get_config_root(Path.cwd())}') +gdb.execute(f'directory {get_repo_root(Path.cwd())}') gdb.prompt_hook = lambda x: '(ffdb) ' gdb.execute('set history save on') gdb.execute('catch throw') diff --git a/.proj.toml b/.proj.toml index b14d763339..463aa0bb07 100644 --- a/.proj.toml +++ b/.proj.toml @@ -6,6 +6,21 @@ cuda_launch_cmd = [ "nixGL", "--", ] +layout_ignore_paths = [ + "lib/runtime", + "lib/kernels", + "lib/utils/test/common", + "lib/compiler/ffi", + "lib/op-attrs/ffi", + "lib/pcg/ffi", + "lib/substitutions/ffi", + "lib/utils/ffi", + "lib/ffi", + "lib/utils/include/utils/graph/docs", + "lib/compiler/test/src/internal", + "lib/local-execution/test/src/internal", + "bin/protobuf-to-json", +] [targets.utils] type = "lib" @@ -63,12 +78,19 @@ has-cpu-only-benchmarks = false has-cuda-tests = false has-cuda-benchmarks = false -[targets.local-execution] -type = "lib" -has-cpu-only-tests = true -has-cpu-only-benchmarks = false -has-cuda-tests = true -has-cuda-benchmarks = false +# [targets.local-execution] +# type = "lib" +# has-cpu-only-tests = true +# has-cpu-only-benchmarks = false +# has-cuda-tests = true +# has-cuda-benchmarks = false + +# [targets.local-pcg-execution] +# type = "lib" +# has-cpu-only-tests = true +# has-cpu-only-benchmarks = false +# has-cuda-tests = false +# has-cuda-benchmarks = false [targets.models] type = "lib" diff --git a/CMakeLists.txt b/CMakeLists.txt index a5f5a6fa11..8b313f5d4f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -102,7 +102,6 @@ include(spdlog) include(doctestlib) # named doctestlib to avoid a name collision with doctest.cmake in rapidcheck include(gbenchmark) include(libassert) -include(visit_struct) include(CTest) include(fmt) include(legion) diff --git a/bin/export-model-arch/include/export-model-arch/json_sp_model_export.dtg.toml b/bin/export-model-arch/include/export-model-arch/json_sp_model_export.dtg.toml new file mode 100644 index 0000000000..9b6c1718d4 --- /dev/null +++ b/bin/export-model-arch/include/export-model-arch/json_sp_model_export.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "JsonSPModelExport" +type = "struct" +features = [ + "eq", + "hash", + "fmt", + "json", +] + +includes = [ + "pcg/file_format/v1/v1_computation_graph.dtg.h", + "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h", +] + +src_includes = [ + "pcg/file_format/v1/v1_binary_sp_decomposition/json.h", +] + +[[fields]] +name = "sp_decomposition" +type = "::FlexFlow::V1BinarySPDecomposition" + +[[fields]] +name = "computation_graph" +type = "::FlexFlow::V1ComputationGraph" diff --git a/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml b/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml deleted file mode 100644 index efaf10c255..0000000000 --- a/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml +++ /dev/null @@ -1,25 +0,0 @@ -namespace = "FlexFlow" -name = "JsonSPModelExport" -features = [ - "eq", - "hash", - "fmt", - "json", -] - -includes = [ - "pcg/file_format/v1/v1_computation_graph.dtg.h", - "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h", -] - -src_includes = [ - "pcg/file_format/v1/v1_binary_sp_decomposition/json.h", -] - -[[fields]] -name = "sp_decomposition" -type = "::FlexFlow::V1BinarySPDecomposition" - -[[fields]] -name = "computation_graph" -type = "::FlexFlow::V1ComputationGraph" diff --git a/bin/export-model-arch/src/export_model_arch.cc b/bin/export-model-arch/src/export-model-arch/main.cc similarity index 100% rename from bin/export-model-arch/src/export_model_arch.cc rename to bin/export-model-arch/src/export-model-arch/main.cc diff --git a/bin/substitution-to-dot/CMakeLists.txt b/bin/substitution-to-dot/CMakeLists.txt index ed9b017d52..6ea98e1fce 100644 --- a/bin/substitution-to-dot/CMakeLists.txt +++ b/bin/substitution-to-dot/CMakeLists.txt @@ -2,7 +2,9 @@ ff_add_executable( NAME substitution-to-dot SRC_PATTERNS - *.cc + src/*.cc + PRIVATE_INCLUDE + include/ DEPS substitution-generator ) diff --git a/bin/substitution-to-dot/substitution_to_dot.cc b/bin/substitution-to-dot/src/substitution-to-dot/main.cc similarity index 100% rename from bin/substitution-to-dot/substitution_to_dot.cc rename to bin/substitution-to-dot/src/substitution-to-dot/main.cc diff --git a/cmake/visit_struct.cmake b/cmake/visit_struct.cmake deleted file mode 100644 index 108745fc14..0000000000 --- a/cmake/visit_struct.cmake +++ /dev/null @@ -1,16 +0,0 @@ -add_library( - visit_struct - INTERFACE -) -target_include_directories( - visit_struct - INTERFACE - ${CMAKE_CURRENT_SOURCE_DIR}/deps/visit_struct/include/ -) -set_target_properties( - visit_struct - PROPERTIES - CXX_STANDARD 11 - CXX_STANDARD_REQUIRED YES - CXX_EXTENSIONS NO -) diff --git a/flake.lock b/flake.lock index f016c47f45..fdde6ce02c 100644 --- a/flake.lock +++ b/flake.lock @@ -66,11 +66,11 @@ ] }, "locked": { - "lastModified": 1752259929, - "narHash": "sha256-GkMRIi6Xk3qswrbekWtO1sQYz61mw25+62boDk1Gd7s=", + "lastModified": 1763685681, + "narHash": "sha256-VFtDhrXx49yQS2r5Oxz2mvw/60uIAZhy0Y0rDBMvEno=", "ref": "refs/heads/master", - "rev": "669773600c781ab8b29ac2379d0c070721417f9d", - "revCount": 117, + "rev": "72f7bd4008671613237681e29c9c90403a421ce0", + "revCount": 138, "type": "git", "url": "https://git.sr.ht/~lockshaw/proj" }, diff --git a/flake.nix b/flake.nix index 474a22f385..6ccd5616cd 100644 --- a/flake.nix +++ b/flake.nix @@ -92,10 +92,7 @@ -DFF_USE_EXTERNAL_RAPIDCHECK=ON \ -DFF_USE_EXTERNAL_EXPECTED=ON \ -DFF_USE_EXTERNAL_GBENCHMARK=ON \ - -DFF_USE_EXTERNAL_LIBASSERT=ON \ - -DFF_USE_EXTERNAL_RANGEV3=ON \ - -DFF_USE_EXTERNAL_BOOST_PREPROCESSOR=ON \ - -DFF_USE_EXTERNAL_TYPE_INDEX=ON" + -DFF_USE_EXTERNAL_LIBASSERT=ON" ''; buildInputs = builtins.concatLists [ diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index e2e561c384..2e71e577c0 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -4,6 +4,7 @@ add_subdirectory(runtime) add_subdirectory(op-attrs) add_subdirectory(kernels) add_subdirectory(local-execution) +add_subdirectory(local-pcg-execution) add_subdirectory(task-spec) add_subdirectory(utils) add_subdirectory(ffi) diff --git a/lib/compiler/include/compiler/allowed_machine_views.h b/lib/compiler/include/compiler/allowed_machine_views.h index 9bb73fd1a9..2a3de47b0d 100644 --- a/lib/compiler/include/compiler/allowed_machine_views.h +++ b/lib/compiler/include/compiler/allowed_machine_views.h @@ -1,18 +1,18 @@ #ifndef _FLEXFLOW_COMPILER_ALLOWED_MACHINE_VIEWS_H #define _FLEXFLOW_COMPILER_ALLOWED_MACHINE_VIEWS_H -#include "pcg/machine_specification.dtg.h" -#include "pcg/machine_view.dtg.h" -#include "pcg/operator_task_space.dtg.h" +#include "compiler/machine_mapping/machine_view.dtg.h" +#include "op-attrs/operator_task_space.dtg.h" +#include "pcg/machine_compute_specification.dtg.h" namespace FlexFlow { bool is_valid_machine_view(MachineView const &mv, OperatorTaskSpace const &task, - MachineSpecification const &ms); + MachineComputeSpecification const &ms); std::unordered_set - get_allowed_machine_views(MachineSpecification const &machine_spec, + get_allowed_machine_views(MachineComputeSpecification const &machine_spec, OperatorTaskSpace const &task, DeviceType device_type); diff --git a/lib/compiler/include/compiler/compiler.h b/lib/compiler/include/compiler/compiler.h deleted file mode 100644 index 178ab19a53..0000000000 --- a/lib/compiler/include/compiler/compiler.h +++ /dev/null @@ -1,43 +0,0 @@ -#ifndef _FLEXFLOW_COMPILER_COMPILER_H -#define _FLEXFLOW_COMPILER_COMPILER_H - -#include "pcg/cost_values.h" -#include "pcg/machine_view.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "pcg/tensor_mapping.h" - -namespace FlexFlow { - -enum class SearchAlgorithm { - DATA_PARALLEL, -}; - -using SearchAlgorithmConfig = std::variant<>; -using SearchSolution = std::variant<>; - -struct SearchResult { - ParallelComputationGraph pcg; - TensorMapping tensor_mapping; - SearchSolution solution; - CostValues cost_values; -}; - -SearchResult optimize(ComputationGraph const &, - MachineSpecification const &, - CostEstimator const &, - SearchAlgorithm, - optional const &); - -// struct SearchSolution { -// LabelledMultiDiGraph optimized_pcg; -// std::unordered_map device_assignments; -// /* std::unordered_map> tensor_mappings; */ -// }; -// -// SearchSolution run_data_parallelize(ComputationGraph const &, -// MachineSpecification const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/cost_estimator/communication_edge.h b/lib/compiler/include/compiler/cost_estimator/communication_edge.h new file mode 100644 index 0000000000..14ecc266f0 --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/communication_edge.h @@ -0,0 +1,49 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_COMMUNICATION_EDGE_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_COMMUNICATION_EDGE_H + +#include "pcg/machine_space_coordinate.dtg.h" + +namespace FlexFlow { + +struct CommunicationEdge { + CommunicationEdge() = delete; + + CommunicationEdge(MachineSpaceCoordinate const &src, + MachineSpaceCoordinate const &dst); + + bool operator==(CommunicationEdge const &) const; + bool operator!=(CommunicationEdge const &) const; + + bool operator<(CommunicationEdge const &) const; + bool operator>(CommunicationEdge const &) const; + bool operator<=(CommunicationEdge const &) const; + bool operator>=(CommunicationEdge const &) const; + + MachineSpaceCoordinate const &get_src() const; + MachineSpaceCoordinate const &get_dst() const; + +private: + MachineSpaceCoordinate src; + MachineSpaceCoordinate dst; + +private: + std::tuple tie() const; + + friend struct ::std::hash; +}; + +std::string format_as(CommunicationEdge const &); +std::ostream &operator<<(std::ostream &, CommunicationEdge const &); + +} // namespace FlexFlow + +namespace std { + +template <> +struct hash<::FlexFlow::CommunicationEdge> { + size_t operator()(::FlexFlow::CommunicationEdge const &) const; +}; + +} // namespace std + +#endif diff --git a/lib/compiler/include/compiler/cost_estimator/cost_estimator.h b/lib/compiler/include/compiler/cost_estimator/cost_estimator.h index 7b7255a89d..bd423d8956 100644 --- a/lib/compiler/include/compiler/cost_estimator/cost_estimator.h +++ b/lib/compiler/include/compiler/cost_estimator/cost_estimator.h @@ -4,9 +4,9 @@ #include "compiler/cost_estimator/op_cost_estimate_key.dtg.h" #include "compiler/cost_estimator/op_cost_metrics.dtg.h" #include "compiler/cost_estimator/tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/machine_view.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/pcg_operator_attrs.dtg.h" -#include "pcg/machine_view.dtg.h" #include namespace FlexFlow { diff --git a/lib/compiler/include/compiler/cost_estimator/network_cost_model.h b/lib/compiler/include/compiler/cost_estimator/network_cost_model.h new file mode 100644 index 0000000000..3772acd54a --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/network_cost_model.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_NETWORK_COST_MODEL_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_NETWORK_COST_MODEL_H + +#include "compiler/cost_estimator/tensor_set_movement.dtg.h" +#include "pcg/machine_specification.dtg.h" + +namespace FlexFlow { + +float estimate_communication_cost(MachineSpecification const &, + TensorSetMovement const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/cost_estimator/op_cost_estimate_key.dtg.toml b/lib/compiler/include/compiler/cost_estimator/op_cost_estimate_key.dtg.toml new file mode 100644 index 0000000000..6a3d4987ac --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/op_cost_estimate_key.dtg.toml @@ -0,0 +1,48 @@ +namespace = "FlexFlow" +name = "OpCostEstimateKey" +type = "struct" +features = [ + "eq", + "ord", + "fmt", + "hash", +] + +includes = [ + "op-attrs/pcg_operator_attrs.dtg.h", + "op-attrs/parallel_tensor_shape.dtg.h", + "", + "compiler/machine_mapping/machine_view.dtg.h", + "pcg/optimizer_attrs.dtg.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/fmt/unordered_map.h", + "utils/ord/unordered_map.h", +] + +[[fields]] +name = "op_attrs" +type = "::FlexFlow::PCGOperatorAttrs" + +[[fields]] +name = "input_shapes" +type = "std::unordered_map<::FlexFlow::TensorSlotName, ::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "weight_shapes" +type = "std::unordered_map<::FlexFlow::TensorSlotName, ::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "output_shapes" +type = "std::unordered_map<::FlexFlow::TensorSlotName, ::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "optimizer_attrs" +type = "::FlexFlow::OptimizerAttrs" + +[[fields]] +name = "machine_view" +type = "::FlexFlow::MachineView" diff --git a/lib/compiler/include/compiler/cost_estimator/op_cost_estimate_key.struct.toml b/lib/compiler/include/compiler/cost_estimator/op_cost_estimate_key.struct.toml deleted file mode 100644 index b153bd0072..0000000000 --- a/lib/compiler/include/compiler/cost_estimator/op_cost_estimate_key.struct.toml +++ /dev/null @@ -1,45 +0,0 @@ -namespace = "FlexFlow" -name = "OpCostEstimateKey" -features = [ - "eq", - "ord", - "fmt", - "hash", -] - -includes = [ - "op-attrs/pcg_operator_attrs.dtg.h", - "op-attrs/parallel_tensor_shape.dtg.h", - "", - "pcg/machine_view.dtg.h", - "pcg/optimizer_attrs.dtg.h", -] - -src_includes = [ - "utils/hash/vector.h", - "utils/fmt/vector.h", -] - -[[fields]] -name = "op_attrs" -type = "::FlexFlow::PCGOperatorAttrs" - -[[fields]] -name = "input_shapes" -type = "std::vector<::FlexFlow::ParallelTensorShape>" - -[[fields]] -name = "weight_shapes" -type = "std::vector<::FlexFlow::ParallelTensorShape>" - -[[fields]] -name = "output_shapes" -type = "std::vector<::FlexFlow::ParallelTensorShape>" - -[[fields]] -name = "optimizer_attrs" -type = "::FlexFlow::OptimizerAttrs" - -[[fields]] -name = "machine_view" -type = "::FlexFlow::MachineView" diff --git a/lib/compiler/include/compiler/cost_estimator/op_cost_metrics.dtg.toml b/lib/compiler/include/compiler/cost_estimator/op_cost_metrics.dtg.toml new file mode 100644 index 0000000000..7a673c83b2 --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/op_cost_metrics.dtg.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "OpCostMetrics" +type = "struct" +features = [ + "eq", + "fmt", + "hash", +] + +includes = [ + "utils/units/milliseconds_t.h", + "utils/units/num_bytes_t.h", +] + +[[fields]] +name = "forward_runtime" +type = "::FlexFlow::milliseconds_t" + +[[fields]] +name = "backward_runtime" +type = "::FlexFlow::milliseconds_t" + +[[fields]] +name = "memory_usage" +type = "::FlexFlow::num_bytes_t" diff --git a/lib/compiler/include/compiler/cost_estimator/op_cost_metrics.h b/lib/compiler/include/compiler/cost_estimator/op_cost_metrics.h index f2d12aff71..aa638f7287 100644 --- a/lib/compiler/include/compiler/cost_estimator/op_cost_metrics.h +++ b/lib/compiler/include/compiler/cost_estimator/op_cost_metrics.h @@ -6,6 +6,9 @@ namespace FlexFlow { +bool is_pareto_optimal_in(OpCostMetrics const &, + std::unordered_set const &); + OpCostMetrics make_op_cost_metrics_from_runtime_only( RuntimeOnlyOpCostMetrics const &runtime_only, num_bytes_t const &memory_usage); diff --git a/lib/compiler/include/compiler/cost_estimator/op_cost_metrics.struct.toml b/lib/compiler/include/compiler/cost_estimator/op_cost_metrics.struct.toml deleted file mode 100644 index 7d0c7684a9..0000000000 --- a/lib/compiler/include/compiler/cost_estimator/op_cost_metrics.struct.toml +++ /dev/null @@ -1,24 +0,0 @@ -namespace = "FlexFlow" -name = "OpCostMetrics" -features = [ - "eq", - "fmt", - "hash", -] - -includes = [ - "utils/units/milliseconds_t.h", - "utils/units/num_bytes_t.h", -] - -[[fields]] -name = "forward_runtime" -type = "::FlexFlow::milliseconds_t" - -[[fields]] -name = "backward_runtime" -type = "::FlexFlow::milliseconds_t" - -[[fields]] -name = "memory_usage" -type = "::FlexFlow::num_bytes_t" diff --git a/lib/compiler/include/compiler/cost_estimator/parallel_tensor_space_to_machine_space_mapping.dtg.toml b/lib/compiler/include/compiler/cost_estimator/parallel_tensor_space_to_machine_space_mapping.dtg.toml new file mode 100644 index 0000000000..b62e8ad611 --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/parallel_tensor_space_to_machine_space_mapping.dtg.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "ParallelTensorSpaceToMachineSpaceMapping" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_space_coordinate.dtg.h", + "pcg/machine_space_coordinate.dtg.h", + "op-attrs/parallel_tensor_dim_degrees.dtg.h", + "utils/bidict/bidict.h", +] + +[[fields]] +name = "raw_mapping" +type = "::FlexFlow::bidict<::FlexFlow::ParallelTensorSpaceCoordinate, ::FlexFlow::MachineSpaceCoordinate>" + +[[fields]] +name = "parallel_tensor_space" +type = "::FlexFlow::ParallelTensorDimDegrees" diff --git a/lib/compiler/include/compiler/cost_estimator/parallel_tensor_space_to_machine_space_mapping.h b/lib/compiler/include/compiler/cost_estimator/parallel_tensor_space_to_machine_space_mapping.h new file mode 100644 index 0000000000..ec42e11e42 --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/parallel_tensor_space_to_machine_space_mapping.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_PARALLEL_TENSOR_SPACE_TO_MACHINE_SPACE_MAPPING_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_PARALLEL_TENSOR_SPACE_TO_MACHINE_SPACE_MAPPING_H + +#include "compiler/cost_estimator/parallel_tensor_space_to_machine_space_mapping.dtg.h" +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.dtg.h" +#include "pcg/operator_space_to_machine_space_mapping.dtg.h" + +namespace FlexFlow { + +ParallelTensorSpaceToMachineSpaceMapping ptensor_machine_map_from_composition( + OperatorSpaceToMachineSpaceMapping const &op_task_to_machine_space_mapping, + OperatorSpaceToParallelTensorSpaceMapping const &op_task_to_parallel); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/cost_estimator/runtime_only_cost_estimator.h b/lib/compiler/include/compiler/cost_estimator/runtime_only_cost_estimator.h index aa1c2d70b6..1d5c914afc 100644 --- a/lib/compiler/include/compiler/cost_estimator/runtime_only_cost_estimator.h +++ b/lib/compiler/include/compiler/cost_estimator/runtime_only_cost_estimator.h @@ -4,9 +4,9 @@ #include "compiler/cost_estimator/runtime_only_op_cost_estimate_key.dtg.h" #include "compiler/cost_estimator/runtime_only_op_cost_metrics.dtg.h" #include "compiler/cost_estimator/tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/machine_view.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/pcg_operator_attrs.dtg.h" -#include "pcg/machine_view.dtg.h" #include namespace FlexFlow { diff --git a/lib/compiler/include/compiler/cost_estimator/runtime_only_op_cost_estimate_key.dtg.toml b/lib/compiler/include/compiler/cost_estimator/runtime_only_op_cost_estimate_key.dtg.toml new file mode 100644 index 0000000000..99501645ce --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/runtime_only_op_cost_estimate_key.dtg.toml @@ -0,0 +1,43 @@ +namespace = "FlexFlow" +name = "RuntimeOnlyOpCostEstimateKey" +type = "struct" +features = [ + "eq", + "ord", + "fmt", + "hash", +] + +includes = [ + "op-attrs/pcg_operator_attrs.dtg.h", + "op-attrs/parallel_tensor_shape.dtg.h", + "", + "compiler/machine_mapping/machine_view.dtg.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/ord/unordered_map.h", + "utils/fmt/unordered_map.h", +] + +[[fields]] +name = "op_attrs" +type = "::FlexFlow::PCGOperatorAttrs" + +[[fields]] +name = "input_shapes" +type = "std::unordered_map<::FlexFlow::TensorSlotName, ::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "weight_shapes" +type = "std::unordered_map<::FlexFlow::TensorSlotName, ::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "output_shapes" +type = "std::unordered_map<::FlexFlow::TensorSlotName, ::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "machine_view" +type = "::FlexFlow::MachineView" diff --git a/lib/compiler/include/compiler/cost_estimator/runtime_only_op_cost_estimate_key.struct.toml b/lib/compiler/include/compiler/cost_estimator/runtime_only_op_cost_estimate_key.struct.toml deleted file mode 100644 index 94be6f6e69..0000000000 --- a/lib/compiler/include/compiler/cost_estimator/runtime_only_op_cost_estimate_key.struct.toml +++ /dev/null @@ -1,40 +0,0 @@ -namespace = "FlexFlow" -name = "RuntimeOnlyOpCostEstimateKey" -features = [ - "eq", - "ord", - "fmt", - "hash", -] - -includes = [ - "op-attrs/pcg_operator_attrs.dtg.h", - "op-attrs/parallel_tensor_shape.dtg.h", - "", - "pcg/machine_view.dtg.h", -] - -src_includes = [ - "utils/hash/vector.h", - "utils/fmt/vector.h", -] - -[[fields]] -name = "op_attrs" -type = "::FlexFlow::PCGOperatorAttrs" - -[[fields]] -name = "input_shapes" -type = "std::vector<::FlexFlow::ParallelTensorShape>" - -[[fields]] -name = "weight_shapes" -type = "std::vector<::FlexFlow::ParallelTensorShape>" - -[[fields]] -name = "output_shapes" -type = "std::vector<::FlexFlow::ParallelTensorShape>" - -[[fields]] -name = "machine_view" -type = "::FlexFlow::MachineView" diff --git a/lib/compiler/include/compiler/cost_estimator/runtime_only_op_cost_metrics.dtg.toml b/lib/compiler/include/compiler/cost_estimator/runtime_only_op_cost_metrics.dtg.toml new file mode 100644 index 0000000000..77c291ceb6 --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/runtime_only_op_cost_metrics.dtg.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "RuntimeOnlyOpCostMetrics" +type = "struct" +features = [ + "eq", + "fmt", + "hash", +] + +includes = [ + "utils/units/milliseconds_t.h", +] + +[[fields]] +name = "forward_runtime" +type = "::FlexFlow::milliseconds_t" + +[[fields]] +name = "backward_runtime" +type = "::FlexFlow::milliseconds_t" diff --git a/lib/compiler/include/compiler/cost_estimator/runtime_only_op_cost_metrics.struct.toml b/lib/compiler/include/compiler/cost_estimator/runtime_only_op_cost_metrics.struct.toml deleted file mode 100644 index 65ac318f0e..0000000000 --- a/lib/compiler/include/compiler/cost_estimator/runtime_only_op_cost_metrics.struct.toml +++ /dev/null @@ -1,19 +0,0 @@ -namespace = "FlexFlow" -name = "RuntimeOnlyOpCostMetrics" -features = [ - "eq", - "fmt", - "hash", -] - -includes = [ - "utils/units/milliseconds_t.h", -] - -[[fields]] -name = "forward_runtime" -type = "::FlexFlow::milliseconds_t" - -[[fields]] -name = "backward_runtime" -type = "::FlexFlow::milliseconds_t" diff --git a/lib/compiler/include/compiler/cost_estimator/single_communication.dtg.toml b/lib/compiler/include/compiler/cost_estimator/single_communication.dtg.toml new file mode 100644 index 0000000000..8f09321d0b --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/single_communication.dtg.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "SingleCommunication" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", +] + +includes = [ + "pcg/machine_space_coordinate.dtg.h", +] + +[[fields]] +name = "src_machine_coord" +type = "::FlexFlow::MachineSpaceCoordinate" + +[[fields]] +name = "dst_machine_coord" +type = "::FlexFlow::MachineSpaceCoordinate" diff --git a/lib/compiler/include/compiler/cost_estimator/single_tensor_movement.dtg.toml b/lib/compiler/include/compiler/cost_estimator/single_tensor_movement.dtg.toml new file mode 100644 index 0000000000..9dd77f4fe2 --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/single_tensor_movement.dtg.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "SingleTensorMovement" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/cost_estimator/communication_edge.h", + "utils/units/num_bytes_t.h", + "", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "edge_to_size" +type = "std::unordered_map<::FlexFlow::CommunicationEdge, ::FlexFlow::num_bytes_t>" diff --git a/lib/compiler/include/compiler/cost_estimator/single_tensor_movement.struct.toml b/lib/compiler/include/compiler/cost_estimator/single_tensor_movement.struct.toml deleted file mode 100644 index 70f73ebe51..0000000000 --- a/lib/compiler/include/compiler/cost_estimator/single_tensor_movement.struct.toml +++ /dev/null @@ -1,30 +0,0 @@ -namespace = "FlexFlow" -name = "SingleTensorMovement" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "op-attrs/parallel_tensor_shape.dtg.h", - "pcg/machine_view.dtg.h", - "", -] - -src_includes = [ - "utils/hash/unordered_set.h", - "utils/fmt/unordered_set.h", -] - -[[fields]] -name = "parallel_tensor_shape" -type = "::FlexFlow::ParallelTensorShape" - -[[fields]] -name = "src_machine_views" -type = "std::unordered_set<::FlexFlow::MachineView>" - -[[fields]] -name = "dst_machine_views" -type = "std::unordered_set<::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/cost_estimator/tensor_set_movement.dtg.toml b/lib/compiler/include/compiler/cost_estimator/tensor_set_movement.dtg.toml new file mode 100644 index 0000000000..2660b0a3c3 --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/tensor_set_movement.dtg.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "TensorSetMovement" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/cost_estimator/communication_edge.h", + "utils/units/num_bytes_t.h", + "", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "edge_to_size" +type = "std::unordered_map<::FlexFlow::CommunicationEdge, ::FlexFlow::num_bytes_t>" diff --git a/lib/compiler/include/compiler/cost_estimator/tensor_set_movement.h b/lib/compiler/include/compiler/cost_estimator/tensor_set_movement.h index 34188ff97c..fc1aa4777c 100644 --- a/lib/compiler/include/compiler/cost_estimator/tensor_set_movement.h +++ b/lib/compiler/include/compiler/cost_estimator/tensor_set_movement.h @@ -2,12 +2,14 @@ #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_TENSOR_SET_MOVEMENT_H #include "compiler/cost_estimator/tensor_set_movement.dtg.h" -#include "pcg/machine_view.dtg.h" +#include "compiler/machine_mapping/machine_view.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" namespace FlexFlow { +TensorSetMovement empty_tensor_set_movement(); + TensorSetMovement get_tensor_set_movement_from_pcg_edge( ParallelComputationGraphEdge const &edge, ParallelComputationGraph const &pcg, diff --git a/lib/compiler/include/compiler/cost_estimator/tensor_set_movement.struct.toml b/lib/compiler/include/compiler/cost_estimator/tensor_set_movement.struct.toml deleted file mode 100644 index 3625605239..0000000000 --- a/lib/compiler/include/compiler/cost_estimator/tensor_set_movement.struct.toml +++ /dev/null @@ -1,21 +0,0 @@ -namespace = "FlexFlow" -name = "TensorSetMovement" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "compiler/cost_estimator/single_tensor_movement.dtg.h", - "", -] - -src_includes = [ - "utils/fmt/unordered_multiset.h", - "utils/hash/unordered_multiset.h", -] - -[[fields]] -name = "single_tensor_movements" -type = "std::unordered_multiset<::FlexFlow::SingleTensorMovement>" diff --git a/lib/compiler/include/compiler/graph_optimize_result.dtg.toml b/lib/compiler/include/compiler/graph_optimize_result.dtg.toml new file mode 100644 index 0000000000..56fbb9ff5b --- /dev/null +++ b/lib/compiler/include/compiler/graph_optimize_result.dtg.toml @@ -0,0 +1,12 @@ +namespace = "FlexFlow" +name = "GraphOptimizeResult" +type = "struct" +features = [] + +includes = [ + "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.h", +] + +[[fields]] +name = "mapped_pcg" +type = "::FlexFlow::MappedParallelComputationGraph" diff --git a/lib/compiler/include/compiler/graph_optimize_result.h b/lib/compiler/include/compiler/graph_optimize_result.h deleted file mode 100644 index f3843e2a93..0000000000 --- a/lib/compiler/include/compiler/graph_optimize_result.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_GRAPH_OPTIMIZE_RESULT_H -#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_GRAPH_OPTIMIZE_RESULT_H - -#include "compiler/graph_optimize_result.dtg.h" - -namespace FlexFlow { - -std::string format_as(GraphOptimizeResult const &); -std::ostream &operator<<(std::ostream &, GraphOptimizeResult const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/graph_optimize_result.struct.toml b/lib/compiler/include/compiler/graph_optimize_result.struct.toml deleted file mode 100644 index 22f29cbd59..0000000000 --- a/lib/compiler/include/compiler/graph_optimize_result.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "GraphOptimizeResult" -features = [ ] - -includes = [ - "compiler/machine_mapping/machine_mapping.dtg.h", - "pcg/parallel_computation_graph/parallel_computation_graph.h" -] - -[[fields]] -name = "pcg" -type = "::FlexFlow::ParallelComputationGraph" - -[[fields]] -name = "machine_mapping" -type = "::FlexFlow::MachineMapping" diff --git a/lib/compiler/include/compiler/graph_optimize_state.h b/lib/compiler/include/compiler/graph_optimize_state.h index 404111ff8b..62c9f97331 100644 --- a/lib/compiler/include/compiler/graph_optimize_state.h +++ b/lib/compiler/include/compiler/graph_optimize_state.h @@ -17,9 +17,6 @@ struct GraphOptimizeState { bool operator<(GraphOptimizeState const &other) const; }; -std::string format_as(GraphOptimizeState const &); -std::ostream &operator<<(std::ostream &, GraphOptimizeState const &); - } // namespace FlexFlow namespace std { diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_communication_edge.dtg.toml b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_communication_edge.dtg.toml new file mode 100644 index 0000000000..2bca9b9992 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_communication_edge.dtg.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "AbstractedCommunicationEdge" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_device.dtg.h", +] + +[[fields]] +name = "src" +type = "::FlexFlow::AbstractedDevice" + +[[fields]] +name = "dst" +type = "::FlexFlow::AbstractedDevice" diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_device.dtg.toml b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_device.dtg.toml new file mode 100644 index 0000000000..1685fca931 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_device.dtg.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "AbstractedDevice" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "op-attrs/task_space_coordinate.dtg.h", + "utils/full_binary_tree/binary_tree_path.dtg.h", +] + +[[fields]] +name = "operator_tree_path" +type = "::FlexFlow::BinaryTreePath" + +[[fields]] +name = "task_space_coordinate" +type = "::FlexFlow::TaskSpaceCoordinate" diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_device.h b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_device.h new file mode 100644 index 0000000000..b0a17309e9 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_device.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ABSTRACTED_TENSOR_SET_MOVEMENT_ABSTRACTED_DEVICE_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ABSTRACTED_TENSOR_SET_MOVEMENT_ABSTRACTED_DEVICE_H + +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_device.dtg.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/machine_space_stencil.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" +#include "op-attrs/operator_task_space.dtg.h" +#include "pcg/machine_compute_specification.dtg.h" +#include "pcg/machine_space_coordinate.dtg.h" + +namespace FlexFlow { + +MachineSpaceCoordinate concretize_abstracted_device( + AbstractedDevice const &abstracted_device, + std::unordered_map const + &machine_space_stencils); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_communication.dtg.toml b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_communication.dtg.toml new file mode 100644 index 0000000000..2b631ab2af --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_communication.dtg.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "AbstractedSingleCommunication" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_communication_edge.dtg.h", + "utils/units/num_bytes_t.h", +] + +[[fields]] +name = "edge" +type = "::FlexFlow::AbstractedCommunicationEdge" + +[[fields]] +name = "size" +type = "::FlexFlow::num_bytes_t" diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_communication.dtg.toml b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_communication.dtg.toml new file mode 100644 index 0000000000..0ca953d9ab --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_communication.dtg.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "AbstractedSingleTensorCommunication" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_communication_edge.dtg.h", + "utils/units/num_bytes_t.h", +] + +[[fields]] +name = "edge" +type = "::FlexFlow::AbstractedSingleTensorCommunicationEdge" + +[[fields]] +name = "size" +type = "::FlexFlow::num_bytes_t" diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_communication_edge.dtg.toml b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_communication_edge.dtg.toml new file mode 100644 index 0000000000..e36c8de82d --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_communication_edge.dtg.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "AbstractedSingleTensorCommunicationEdge" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_device.dtg.h", +] + +[[fields]] +name = "src_coord" +type = "::FlexFlow::TaskSpaceCoordinate" + +[[fields]] +name = "dst" +type = "::FlexFlow::AbstractedDevice" diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_communication_edge.h b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_communication_edge.h new file mode 100644 index 0000000000..0b4f7a3a43 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_communication_edge.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ABSTRACTED_TENSOR_SET_MOVEMENT_ABSTRACTED_SINGLE_TENSOR_COMMUNICATION_EDGE_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ABSTRACTED_TENSOR_SET_MOVEMENT_ABSTRACTED_SINGLE_TENSOR_COMMUNICATION_EDGE_H + +#include "compiler/cost_estimator/communication_edge.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_communication_edge.dtg.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/machine_space_stencil.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" +#include "op-attrs/operator_task_space.dtg.h" +#include "pcg/machine_compute_specification.dtg.h" +namespace FlexFlow { + +std::optional + concretize_abstracted_single_tensor_communication_edge( + AbstractedSingleTensorCommunicationEdge const &edge, + MachineSpaceStencil const &src_machine_stencil, + std::unordered_map const + &dst_machine_stencils); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.dtg.toml b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.dtg.toml new file mode 100644 index 0000000000..3658afe154 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "AbstractedSingleTensorMovement" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_communication_edge.dtg.h", + "utils/units/num_bytes_t.h", + "", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "src_op_tree_path" +type = "::FlexFlow::BinaryTreePath" + +[[fields]] +name = "edge_to_size" +type = "std::unordered_map<::FlexFlow::AbstractedSingleTensorCommunicationEdge, ::FlexFlow::num_bytes_t>" diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.h b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.h new file mode 100644 index 0000000000..1a4062cc4c --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ABSTRACTED_TENSOR_SET_MOVEMENT_ABSTRACTED_SINGLE_TENSOR_MOVEMENT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ABSTRACTED_TENSOR_SET_MOVEMENT_ABSTRACTED_SINGLE_TENSOR_MOVEMENT_H + +#include "compiler/cost_estimator/tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_communication.dtg.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.dtg.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/machine_space_stencil.dtg.h" + +namespace FlexFlow { + +std::unordered_set + abstracted_single_tensor_movement_get_dst_layers( + AbstractedSingleTensorMovement const &); + +AbstractedSingleTensorMovement merge_abstracted_single_tensor_movements( + std::unordered_multiset const &); + +AbstractedSingleTensorMovement + abstracted_single_tensor_movement_from_communications( + BinaryTreePath const &src_op_tree_path, + std::unordered_set const + &communications); + +TensorSetMovement concretize_abstracted_single_tensor_movement( + AbstractedSingleTensorMovement const &, + std::unordered_map const + &pre_machine_stencils, + std::unordered_map const + &post_machine_stencils); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.struct.toml b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.struct.toml deleted file mode 100644 index 449a448706..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.struct.toml +++ /dev/null @@ -1,30 +0,0 @@ -namespace = "FlexFlow" -name = "AbstractedSingleTensorMovement" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "op-attrs/parallel_tensor_shape.dtg.h", - "utils/full_binary_tree/binary_tree_path.dtg.h", - "", -] - -src_includes = [ - "utils/hash/unordered_set.h", - "utils/fmt/unordered_set.h", -] - -[[fields]] -name = "parallel_tensor_shape" -type = "::FlexFlow::ParallelTensorShape" - -[[fields]] -name = "src_machine_views" -type = "std::unordered_set<::FlexFlow::BinaryTreePath>" - -[[fields]] -name = "dst_machine_views" -type = "std::unordered_set<::FlexFlow::BinaryTreePath>" diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.toml b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.toml new file mode 100644 index 0000000000..b030692260 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "AbstractedTensorSetMovement" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.dtg.h", +] + +src_includes = [ + "utils/fmt/unordered_set.h", + "utils/hash/unordered_set.h", +] + +[[fields]] +name = "single_tensor_movements" +type = "std::unordered_set<::FlexFlow::AbstractedSingleTensorMovement>" diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h index 5b7e2f3613..d925df2762 100644 --- a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h @@ -2,14 +2,23 @@ #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ABSTRACTED_TENSOR_SET_MOVEMENT_H #include "compiler/cost_estimator/tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_communication.dtg.h" #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/machine_space_stencil.dtg.h" #include "compiler/machine_mapping/machine_mapping.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" #include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" +#include "op-attrs/operator_task_space.dtg.h" +#include "pcg/machine_compute_specification.dtg.h" namespace FlexFlow { AbstractedTensorSetMovement empty_abstracted_tensor_set_movement(); +AbstractedTensorSetMovement + abstracted_tensor_set_movement_from_single_tensor_movement( + AbstractedSingleTensorMovement const &); + std::unordered_set get_src_layers(AbstractedTensorSetMovement const &); std::unordered_set @@ -17,8 +26,10 @@ std::unordered_set TensorSetMovement concretize_abstracted_tensor_set_movement( AbstractedTensorSetMovement const &, - ParallelLayerGuidObliviousMachineMapping const &pre, - ParallelLayerGuidObliviousMachineMapping const &post); + std::unordered_map const + &pre_machine_stencils, + std::unordered_map const + &post_machine_stencils); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.struct.toml b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.struct.toml deleted file mode 100644 index 4cf184706b..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.struct.toml +++ /dev/null @@ -1,21 +0,0 @@ -namespace = "FlexFlow" -name = "AbstractedTensorSetMovement" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.dtg.h", - "", -] - -src_includes = [ - "utils/fmt/unordered_multiset.h", - "utils/hash/unordered_multiset.h", -] - -[[fields]] -name = "single_tensor_movements" -type = "std::unordered_multiset<::FlexFlow::AbstractedSingleTensorMovement>" diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h index 8567a7a3e6..02912f7938 100644 --- a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h @@ -1,12 +1,20 @@ #ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_ABSTRACTED_TENSOR_SET_MOVEMENT_ACROSS_SPLIT_H #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_ABSTRACTED_TENSOR_SET_MOVEMENT_ACROSS_SPLIT_H +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_communication.dtg.h" #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h" #include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h" #include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" namespace FlexFlow { +AbstractedSingleTensorMovement get_abstracted_single_tensor_movement_along_edge( + ParallelComputationGraph const &pcg, + ParallelComputationGraphEdge const &edge, + BinaryTreePath const &src_path, + BinaryTreePath const &dst_path); + AbstractedTensorSetMovement get_abstracted_tensor_set_movement_across_split( TransitiveReducedPCG const &transitive_reduced_pcg, PCGBinarySeriesSplit const &split); diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/machine_space_stencil.dtg.toml b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/machine_space_stencil.dtg.toml new file mode 100644 index 0000000000..e86bb6a6af --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/machine_space_stencil.dtg.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "MachineSpaceStencil" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", +] + +includes = [ + "op-attrs/operator_task_space.dtg.h", + "compiler/machine_mapping/machine_view.dtg.h", +] + +[[fields]] +name = "operator_task_space" +type = "::FlexFlow::OperatorTaskSpace" + +[[fields]] +name = "machine_view" +type = "::FlexFlow::MachineView" diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/machine_space_stencil.h b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/machine_space_stencil.h new file mode 100644 index 0000000000..54dfc28cbf --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/machine_space_stencil.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ABSTRACTED_TENSOR_SET_MOVEMENT_MACHINE_SPACE_STENCIL_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ABSTRACTED_TENSOR_SET_MOVEMENT_MACHINE_SPACE_STENCIL_H + +#include "compiler/machine_mapping/abstracted_tensor_set_movement/machine_space_stencil.dtg.h" +#include "op-attrs/task_space_coordinate.dtg.h" +#include "pcg/machine_space_coordinate.dtg.h" + +namespace FlexFlow { + +MachineSpaceCoordinate + machine_space_stencil_compute_machine_coord(MachineSpaceStencil const &, + TaskSpaceCoordinate const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/feasible_machine_mapping_result.dtg.toml b/lib/compiler/include/compiler/machine_mapping/feasible_machine_mapping_result.dtg.toml new file mode 100644 index 0000000000..ef2d7899ed --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/feasible_machine_mapping_result.dtg.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "FeasibleMachineMappingResult" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h", + "utils/units/milliseconds_t.h", +] + +[[fields]] +name = "runtime" +type = "::FlexFlow::milliseconds_t" + +[[fields]] +name = "machine_mapping" +type = "::FlexFlow::ParallelLayerGuidObliviousMachineMapping" diff --git a/lib/compiler/include/compiler/machine_mapping/feasible_machine_mapping_result.struct.toml b/lib/compiler/include/compiler/machine_mapping/feasible_machine_mapping_result.struct.toml deleted file mode 100644 index 8dda2d15ba..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/feasible_machine_mapping_result.struct.toml +++ /dev/null @@ -1,20 +0,0 @@ -namespace = "FlexFlow" -name = "FeasibleMachineMappingResult" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h", - "utils/units/milliseconds_t.h", -] - -[[fields]] -name = "runtime" -type = "::FlexFlow::milliseconds_t" - -[[fields]] -name = "machine_mapping" -type = "::FlexFlow::ParallelLayerGuidObliviousMachineMapping" diff --git a/lib/compiler/include/compiler/machine_mapping/get_machine_resource_splits.h b/lib/compiler/include/compiler/machine_mapping/get_machine_resource_splits.h deleted file mode 100644 index 990c1c8205..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/get_machine_resource_splits.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_MACHINE_RESOURCE_SPLITS_H -#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_MACHINE_RESOURCE_SPLITS_H - -#include "pcg/machine_specification.dtg.h" -#include -#include - -namespace FlexFlow { - -std::unordered_set> - get_machine_resource_splits(MachineSpecification const &resource); - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h index 2cd3f3e289..3e49899003 100644 --- a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_GET_OPTIMAL_MACHINE_MAPPING_H #define _FLEXFLOW_COMPILER_MACHINE_MAPPING_GET_OPTIMAL_MACHINE_MAPPING_H +#include "compiler/machine_mapping/machine_compute_resource_slice.dtg.h" #include "compiler/machine_mapping/machine_mapping_cache.dtg.h" #include "compiler/machine_mapping/machine_mapping_constraints.dtg.h" #include "compiler/machine_mapping/machine_mapping_context.dtg.h" @@ -8,7 +9,6 @@ #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" #include "compiler/machine_mapping/parallel_split_transformation.dtg.h" -#include "pcg/machine_specification.dtg.h" namespace FlexFlow { @@ -16,14 +16,14 @@ MachineMappingResult get_optimal_machine_mapping(MachineMappingCache &result_cache, MachineMappingContext const &context, MachineMappingProblemTree const &problem_tree, - MachineSpecification const &resources, + MachineComputeResourceSlice const &resources, MachineMappingConstraints const &constraints); MachineMappingResult get_optimal_machine_mapping(MachineMappingCache &result_cache, MachineMappingContext const &context, MMProblemTreeSeriesSplit const &series_split, - MachineSpecification const &resources, + MachineComputeResourceSlice const &resources, MachineMappingConstraints const &constraints, std::optional const ¶llel_split_transformation); @@ -32,14 +32,14 @@ MachineMappingResult get_optimal_machine_mapping( MachineMappingCache &result_cache, MachineMappingContext const &context, MMProblemTreeParallelSplit const ¶llel_split, - MachineSpecification const &resources, + MachineComputeResourceSlice const &resources, MachineMappingConstraints const &constraints); MachineMappingResult get_optimal_machine_mapping( MachineMappingCache &result_cache, MachineMappingContext const &, UnmappedRuntimeOnlyOpCostEstimateKey const &leaf, - MachineSpecification const &resources, + MachineComputeResourceSlice const &resources, MachineMappingConstraints const &constraints); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h b/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h index 2aed9a20e4..48eaf59592 100644 --- a/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h +++ b/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h @@ -5,6 +5,7 @@ #include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" #include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h" #include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h" +#include "pcg/machine_compute_specification.dtg.h" namespace FlexFlow { diff --git a/lib/compiler/include/compiler/machine_mapping/include_unconstrained.dtg.toml b/lib/compiler/include/compiler/machine_mapping/include_unconstrained.dtg.toml new file mode 100644 index 0000000000..9c99bc2c4e --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/include_unconstrained.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "IncludeUnconstrained" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", + "rapidcheck", + "json", +] + +includes = [] + +[[fields]] +name = "raw_bool" +type = "bool" diff --git a/lib/compiler/include/compiler/machine_mapping/include_unconstrained.struct.toml b/lib/compiler/include/compiler/machine_mapping/include_unconstrained.struct.toml deleted file mode 100644 index b9a7f9ac59..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/include_unconstrained.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "IncludeUnconstrained" -features = [ - "eq", - "ord", - "hash", - "fmt", - "rapidcheck", - "json", -] - -includes = [] - -[[fields]] -name = "raw_bool" -type = "bool" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_compute_resource_slice.dtg.toml b/lib/compiler/include/compiler/machine_mapping/machine_compute_resource_slice.dtg.toml new file mode 100644 index 0000000000..77d6b7558c --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_compute_resource_slice.dtg.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "MachineComputeResourceSlice" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/positive_int/positive_int.h", +] + +[[fields]] +name = "num_nodes" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "num_gpus_per_node" +type = "::FlexFlow::positive_int" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_compute_resource_slice.h b/lib/compiler/include/compiler/machine_mapping/machine_compute_resource_slice.h new file mode 100644 index 0000000000..99187999ec --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_compute_resource_slice.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_COMPUTE_RESOURCE_SLICE_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_COMPUTE_RESOURCE_SLICE_H + +#include "compiler/machine_mapping/machine_compute_resource_slice.dtg.h" +#include "pcg/machine_compute_specification.dtg.h" + +namespace FlexFlow { + +MachineComputeResourceSlice + compute_slice_from_specification(MachineComputeSpecification const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping.dtg.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping.dtg.toml new file mode 100644 index 0000000000..18b61840fd --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "MachineMapping" +type = "struct" +features = [ + "eq", + "ord", + "hash", + # "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "compiler/machine_mapping/machine_view.dtg.h", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/fmt/unordered_map.h", + "utils/ord/unordered_map.h", +] + +[[fields]] +name = "machine_views" +type = "std::unordered_map<::FlexFlow::parallel_layer_guid_t, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping.h index 7375cde985..3e5b9238dd 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping.h @@ -1,10 +1,10 @@ -#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_H -#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_H +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_H #include "compiler/machine_mapping/machine_mapping.dtg.h" #include "pcg/device_id_t.dtg.h" #include "pcg/machine_specification.dtg.h" -#include "pcg/operator_task_space.dtg.h" +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" namespace FlexFlow { @@ -14,6 +14,10 @@ MachineMapping combine_disjoint_mappings(MachineMapping const &, bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2); +MappedParallelComputationGraph + mapped_pcg_from_pcg_and_mapping(ParallelComputationGraph const &, + MachineMapping const &); + } // namespace FlexFlow #endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping.struct.toml deleted file mode 100644 index 92517c1110..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping.struct.toml +++ /dev/null @@ -1,24 +0,0 @@ -namespace = "FlexFlow" -name = "MachineMapping" -features = [ - "eq", - # "ord", - "hash", - # "json", - # "rapidcheck", - "fmt", -] - -includes = [ - "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", - "pcg/machine_view.dtg.h", -] - -src_includes = [ - "utils/hash/unordered_map.h", - "utils/fmt/unordered_map.h", -] - -[[fields]] -name = "machine_views" -type = "std::unordered_map<::FlexFlow::parallel_layer_guid_t, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.dtg.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.dtg.toml new file mode 100644 index 0000000000..5683206177 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.dtg.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "MachineMappingCache" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "compiler/machine_mapping/machine_mapping_state.dtg.h", + "compiler/machine_mapping/machine_mapping_result.dtg.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "raw_map" +type = "std::unordered_map<::FlexFlow::MachineMappingState, ::FlexFlow::MachineMappingResult>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.struct.toml deleted file mode 100644 index a76ff26eb9..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.struct.toml +++ /dev/null @@ -1,22 +0,0 @@ -namespace = "FlexFlow" -name = "MachineMappingCache" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "", - "compiler/machine_mapping/machine_mapping_state.dtg.h", - "compiler/machine_mapping/machine_mapping_result.dtg.h", -] - -src_includes = [ - "utils/fmt/unordered_map.h", - "utils/hash/unordered_map.h", -] - -[[fields]] -name = "raw_map" -type = "std::unordered_map<::FlexFlow::MachineMappingState, ::FlexFlow::MachineMappingResult>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.dtg.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.dtg.toml new file mode 100644 index 0000000000..a83a7caa02 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "MachineMappingConstraints" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/machine_view.dtg.h", + "utils/full_binary_tree/binary_tree_path.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/fmt/unordered_map.h", + "utils/fmt/optional.h", +] + +[[fields]] +name = "machine_views" +type = "std::unordered_map<::FlexFlow::BinaryTreePath, std::optional<::FlexFlow::MachineView>>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.struct.toml deleted file mode 100644 index 8e13abedb9..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.struct.toml +++ /dev/null @@ -1,23 +0,0 @@ -namespace = "FlexFlow" -name = "MachineMappingConstraints" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "pcg/machine_view.dtg.h", - "utils/full_binary_tree/binary_tree_path.dtg.h", - "", -] - -src_includes = [ - "utils/hash/unordered_map.h", - "utils/fmt/unordered_map.h", - "utils/fmt/optional.h", -] - -[[fields]] -name = "machine_views" -type = "std::unordered_map<::FlexFlow::BinaryTreePath, std::optional<::FlexFlow::MachineView>>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.dtg.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.dtg.toml new file mode 100644 index 0000000000..ae5299ecdd --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "MachineMappingContext" +type = "struct" +features = [] + +includes = [ + "compiler/cost_estimator/runtime_only_cost_estimator.h", + "compiler/machine_mapping/machine_view.dtg.h", + "compiler/machine_mapping/machine_compute_resource_slice.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.dtg.h", +] + +[[fields]] +name = "cost_estimator" +type = "::FlexFlow::RuntimeOnlyCostEstimator" + +[[fields]] +name = "allowed_machine_views" +type = "std::function(::FlexFlow::UnmappedRuntimeOnlyOpCostEstimateKey const &, ::FlexFlow::MachineComputeResourceSlice const &)>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml deleted file mode 100644 index dd49aaa98a..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "MachineMappingContext" -features = [] - -includes = [ - "compiler/cost_estimator/runtime_only_cost_estimator.h", - "pcg/machine_view.dtg.h", - "pcg/machine_specification.dtg.h", - "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.dtg.h", -] - -[[fields]] -name = "cost_estimator" -type = "::FlexFlow::RuntimeOnlyCostEstimator" - -[[fields]] -name = "allowed_machine_views" -type = "std::function(::FlexFlow::UnmappedRuntimeOnlyOpCostEstimateKey const &, ::FlexFlow::MachineSpecification const &)>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h index 68d02aaa54..be8a4b9afa 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h @@ -2,9 +2,9 @@ #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_MACHINE_MAPPING_PROBLEM_TREE_H #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" +#include "compiler/machine_mapping/machine_view.dtg.h" #include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h" #include "pcg/machine_specification.dtg.h" -#include "pcg/machine_view.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" namespace FlexFlow { diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.toml new file mode 100644 index 0000000000..2456dca145 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "MachineMappingProblemTree" +type = "variant" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.dtg.h", +] + +[[values]] +type = "::FlexFlow::MMProblemTreeSeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::MMProblemTreeParallelSplit" +key = "parallel" + +[[values]] +type = "::FlexFlow::UnmappedRuntimeOnlyOpCostEstimateKey" +key = "leaf" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h index 65f7006b21..abd77bfa7b 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_H #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_H +#include "compiler/machine_mapping/abstracted_tensor_set_movement/machine_space_stencil.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" @@ -28,6 +29,9 @@ std::optional mm_problem_tree_get_subtree_at_path(MachineMappingProblemTree const &, BinaryTreePath const &); +std::unordered_map + mm_problem_tree_get_path_to_leaf_map(MachineMappingProblemTree const &); + } // namespace FlexFlow #endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.variant.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.variant.toml deleted file mode 100644 index 808853994a..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.variant.toml +++ /dev/null @@ -1,25 +0,0 @@ -namespace = "FlexFlow" -name = "MachineMappingProblemTree" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h", - "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h", - "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.dtg.h", -] - -[[values]] -type = "::FlexFlow::MMProblemTreeSeriesSplit" -key = "series" - -[[values]] -type = "::FlexFlow::MMProblemTreeParallelSplit" -key = "parallel" - -[[values]] -type = "::FlexFlow::UnmappedRuntimeOnlyOpCostEstimateKey" -key = "leaf" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.toml new file mode 100644 index 0000000000..b0dad05430 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "MMProblemTreeParallelSplit" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct MachineMappingProblemTree", +] + +post_includes = [ + "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h", +] + +includes = [] + +[[fields]] +name = "left_child" +type = "::FlexFlow::MachineMappingProblemTree" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::MachineMappingProblemTree" +indirect = true diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.struct.toml deleted file mode 100644 index 5247b2006a..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.struct.toml +++ /dev/null @@ -1,27 +0,0 @@ -namespace = "FlexFlow" -name = "MMProblemTreeParallelSplit" -features = [ - "eq", - "hash", - "fmt", -] - -fwd_decls = [ - "struct MachineMappingProblemTree", -] - -post_includes = [ - "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h", -] - -includes = [] - -[[fields]] -name = "left_child" -type = "::FlexFlow::MachineMappingProblemTree" -indirect = true - -[[fields]] -name = "right_child" -type = "::FlexFlow::MachineMappingProblemTree" -indirect = true diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.toml new file mode 100644 index 0000000000..64b05b0101 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.toml @@ -0,0 +1,34 @@ +namespace = "FlexFlow" +name = "MMProblemTreeSeriesSplit" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct MachineMappingProblemTree", +] + +post_includes = [ + "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h", +] + +includes = [ + "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h", +] + +[[fields]] +name = "tensor_set_movement" +type = "::FlexFlow::AbstractedTensorSetMovement" + +[[fields]] +name = "left_child" +type = "::FlexFlow::MachineMappingProblemTree" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::MachineMappingProblemTree" +indirect = true diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.struct.toml deleted file mode 100644 index d4f61bb3f5..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.struct.toml +++ /dev/null @@ -1,33 +0,0 @@ -namespace = "FlexFlow" -name = "MMProblemTreeSeriesSplit" -features = [ - "eq", - "hash", - "fmt", -] - -fwd_decls = [ - "struct MachineMappingProblemTree", -] - -post_includes = [ - "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h", -] - -includes = [ - "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h", -] - -[[fields]] -name = "tensor_set_movement" -type = "::FlexFlow::AbstractedTensorSetMovement" - -[[fields]] -name = "left_child" -type = "::FlexFlow::MachineMappingProblemTree" -indirect = true - -[[fields]] -name = "right_child" -type = "::FlexFlow::MachineMappingProblemTree" -indirect = true diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.toml new file mode 100644 index 0000000000..4bad66f7ee --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.toml @@ -0,0 +1,41 @@ +namespace = "FlexFlow" +name = "UnmappedOpCostEstimateKey" +type = "struct" +features = [ + "eq", + "fmt", + "hash", +] + +includes = [ + "op-attrs/pcg_operator_attrs.dtg.h", + "op-attrs/parallel_tensor_shape.dtg.h", + "", + "pcg/optimizer_attrs.dtg.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/fmt/unordered_map.h", +] + +[[fields]] +name = "op_attrs" +type = "::FlexFlow::PCGOperatorAttrs" + +[[fields]] +name = "input_shapes" +type = "std::unordered_map<::FlexFlow::TensorSlotName, ::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "weight_shapes" +type = "std::unordered_map<::FlexFlow::TensorSlotName, ::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "output_shapes" +type = "std::unordered_map<::FlexFlow::TensorSlotName, ::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "optimizer_attrs" +type = "::FlexFlow::OptimizerAttrs" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.struct.toml deleted file mode 100644 index 5dcfd33859..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.struct.toml +++ /dev/null @@ -1,39 +0,0 @@ -namespace = "FlexFlow" -name = "UnmappedOpCostEstimateKey" -features = [ - "eq", - "fmt", - "hash", -] - -includes = [ - "op-attrs/pcg_operator_attrs.dtg.h", - "op-attrs/parallel_tensor_shape.dtg.h", - "", - "pcg/optimizer_attrs.dtg.h", -] - -src_includes = [ - "utils/hash/vector.h", - "utils/fmt/vector.h", -] - -[[fields]] -name = "op_attrs" -type = "::FlexFlow::PCGOperatorAttrs" - -[[fields]] -name = "input_shapes" -type = "std::vector<::FlexFlow::ParallelTensorShape>" - -[[fields]] -name = "weight_shapes" -type = "std::vector<::FlexFlow::ParallelTensorShape>" - -[[fields]] -name = "output_shapes" -type = "std::vector<::FlexFlow::ParallelTensorShape>" - -[[fields]] -name = "optimizer_attrs" -type = "::FlexFlow::OptimizerAttrs" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.dtg.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.dtg.toml new file mode 100644 index 0000000000..8db92162a1 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.dtg.toml @@ -0,0 +1,36 @@ +namespace = "FlexFlow" +name = "UnmappedRuntimeOnlyOpCostEstimateKey" +type = "struct" +features = [ + "eq", + "fmt", + "hash", +] + +includes = [ + "op-attrs/pcg_operator_attrs.dtg.h", + "op-attrs/parallel_tensor_shape.dtg.h", + "", + "op-attrs/tensor_slot_name.dtg.h", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/fmt/unordered_map.h", +] + +[[fields]] +name = "op_attrs" +type = "::FlexFlow::PCGOperatorAttrs" + +[[fields]] +name = "input_shapes" +type = "std::unordered_map<::FlexFlow::TensorSlotName, ::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "weight_shapes" +type = "std::unordered_map<::FlexFlow::TensorSlotName, ::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "output_shapes" +type = "std::unordered_map<::FlexFlow::TensorSlotName, ::FlexFlow::ParallelTensorShape>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.struct.toml deleted file mode 100644 index e38ce06f03..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.struct.toml +++ /dev/null @@ -1,34 +0,0 @@ -namespace = "FlexFlow" -name = "UnmappedRuntimeOnlyOpCostEstimateKey" -features = [ - "eq", - "fmt", - "hash", -] - -includes = [ - "op-attrs/pcg_operator_attrs.dtg.h", - "op-attrs/parallel_tensor_shape.dtg.h", - "", -] - -src_includes = [ - "utils/hash/vector.h", - "utils/fmt/vector.h", -] - -[[fields]] -name = "op_attrs" -type = "::FlexFlow::PCGOperatorAttrs" - -[[fields]] -name = "input_shapes" -type = "std::vector<::FlexFlow::ParallelTensorShape>" - -[[fields]] -name = "weight_shapes" -type = "std::vector<::FlexFlow::ParallelTensorShape>" - -[[fields]] -name = "output_shapes" -type = "std::vector<::FlexFlow::ParallelTensorShape>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.dtg.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.dtg.toml new file mode 100644 index 0000000000..1c9b664246 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.dtg.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "MachineMappingResult" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/feasible_machine_mapping_result.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/optional.h", +] + +[[fields]] +name = "raw_result" +type = "std::optional<::FlexFlow::FeasibleMachineMappingResult>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h index 8924b1c110..f7b52ec574 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_RESULT_H #include "compiler/machine_mapping/machine_mapping_result.dtg.h" +#include "compiler/machine_mapping/machine_resource_split.dtg.h" #include "compiler/machine_mapping/parallel_split_transformation.dtg.h" #include "utils/units/milliseconds_t.h" @@ -20,8 +21,10 @@ FeasibleMachineMappingResult require_feasible(MachineMappingResult const &); MachineMappingResult const &post_result, std::optional const ¶llel_split_transformation); + [[nodiscard]] MachineMappingResult - parallel_combine(MachineMappingResult const &lhs_result, + parallel_combine(MachineResourceSplit const &split, + MachineMappingResult const &lhs_result, MachineMappingResult const &rhs_result); [[nodiscard]] MachineMappingResult diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml deleted file mode 100644 index 92a2873af5..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml +++ /dev/null @@ -1,20 +0,0 @@ -namespace = "FlexFlow" -name = "MachineMappingResult" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "compiler/machine_mapping/feasible_machine_mapping_result.dtg.h", - "", -] - -src_includes = [ - "utils/fmt/optional.h", -] - -[[fields]] -name = "raw_result" -type = "std::optional<::FlexFlow::FeasibleMachineMappingResult>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.dtg.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.dtg.toml new file mode 100644 index 0000000000..369cbfd851 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "MachineMappingState" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/machine_mapping_constraints.dtg.h", + "compiler/machine_mapping/machine_compute_resource_slice.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h", +] + +[[fields]] +name = "problem_tree" +type = "::FlexFlow::MachineMappingProblemTree" + +[[fields]] +name = "resources" +type = "::FlexFlow::MachineComputeResourceSlice" + +[[fields]] +name = "constraints" +type = "::FlexFlow::MachineMappingConstraints" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml deleted file mode 100644 index 1346f6ebe7..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml +++ /dev/null @@ -1,25 +0,0 @@ -namespace = "FlexFlow" -name = "MachineMappingState" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "pcg/machine_specification.dtg.h", - "compiler/machine_mapping/machine_mapping_constraints.dtg.h", - "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h", -] - -[[fields]] -name = "problem_tree" -type = "::FlexFlow::MachineMappingProblemTree" - -[[fields]] -name = "resources" -type = "::FlexFlow::MachineSpecification" - -[[fields]] -name = "constraints" -type = "::FlexFlow::MachineMappingConstraints" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_resource_split.dtg.toml b/lib/compiler/include/compiler/machine_mapping/machine_resource_split.dtg.toml new file mode 100644 index 0000000000..36aff3952e --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_resource_split.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "MachineResourceSplit" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", + "rapidcheck", + "json", +] + +includes = [ + "utils/positive_int/positive_int.h", + "pcg/machine_specification_dimension.dtg.h", +] + +[[fields]] +name = "offset" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "dimension" +type = "::FlexFlow::MachineSpecificationDimension" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_resource_split.h b/lib/compiler/include/compiler/machine_mapping/machine_resource_split.h new file mode 100644 index 0000000000..7573276b82 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_resource_split.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_RESOURCE_SPLIT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_RESOURCE_SPLIT_H + +#include "compiler/machine_mapping/machine_compute_resource_slice.dtg.h" +#include "compiler/machine_mapping/machine_resource_split.dtg.h" +#include "compiler/machine_mapping/machine_view.dtg.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" + +namespace FlexFlow { + +std::pair + apply_resource_split(MachineResourceSplit const &split, + MachineComputeResourceSlice const &resources); + +std::unordered_set + get_machine_resource_splits(MachineComputeResourceSlice const &); + +MachineSpaceCoordinate + offset_machine_space_coordinate_by(MachineSpaceCoordinate const &, + MachineResourceSplit const &); + +MachineView offset_machine_view_by(MachineView const &, + MachineResourceSplit const &); + +ParallelLayerGuidObliviousMachineMapping offset_layer_oblivious_mapping_by( + ParallelLayerGuidObliviousMachineMapping const &mapping, + MachineResourceSplit const &split); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_space_dim_subgrid.dtg.toml b/lib/compiler/include/compiler/machine_mapping/machine_space_dim_subgrid.dtg.toml new file mode 100644 index 0000000000..f8b030be74 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_space_dim_subgrid.dtg.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "MachineSpaceDimSubgrid" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "utils/nonnegative_int/nonnegative_int.h", + "compiler/machine_mapping/stride_t.dtg.h", + "utils/int_ge_two/int_ge_two.h", +] + +[[fields]] +name = "offset" +type = "::FlexFlow::nonnegative_int" + +[[fields]] +name = "stride" +type = "::FlexFlow::stride_t" + +[[fields]] +name = "num_points" +type = "::FlexFlow::int_ge_two" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_space_subgrid.dtg.toml b/lib/compiler/include/compiler/machine_mapping/machine_space_subgrid.dtg.toml new file mode 100644 index 0000000000..0b5a2151c7 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_space_subgrid.dtg.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "MachineSpaceSubgrid" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "compiler/machine_mapping/machine_space_dim_subgrid.dtg.h", + "", +] + +src_includes = [ + "utils/hash/vector.h", + "utils/fmt/vector.h", + "utils/ord/vector.h", +] + +[[fields]] +name = "inter_node_strides" +type = "std::vector<::FlexFlow::MachineSpaceDimSubgrid>" + +[[fields]] +name = "intra_node_strides" +type = "std::vector<::FlexFlow::MachineSpaceDimSubgrid>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_view.dtg.toml b/lib/compiler/include/compiler/machine_mapping/machine_view.dtg.toml new file mode 100644 index 0000000000..ce3bca0d4e --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_view.dtg.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "MachineView" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "compiler/machine_mapping/machine_view_dimension.dtg.h", + "pcg/machine_space_coordinate.dtg.h" +] + +src_includes = [ + "utils/fmt/vector.h", + "utils/hash/vector.h" +] + + +[[fields]] +name = "start" +type = "::FlexFlow::MachineSpaceCoordinate" + +[[fields]] +name = "dimensions" +type = "std::vector<::FlexFlow::MachineViewDimension>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_view.h b/lib/compiler/include/compiler/machine_mapping/machine_view.h new file mode 100644 index 0000000000..6888fa6b94 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_view.h @@ -0,0 +1,82 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_VIEW_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_VIEW_H + +#include "compiler/machine_mapping/machine_view.dtg.h" +#include "op-attrs/computation_graph_op_attrs.dtg.h" +#include "op-attrs/operator_task_space.dtg.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" +#include "op-attrs/task_space_coordinate.dtg.h" +#include "pcg/device_id_t.dtg.h" +#include "pcg/machine_compute_specification.dtg.h" +#include "pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h" +#include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h" +#include "pcg/operator_space_to_machine_space_mapping.dtg.h" +#include "utils/bidict/bidict.h" +#include +#include +#include + +namespace FlexFlow { + +nonnegative_int mv_get_expected_task_space_num_dims(MachineView const &mv); + +DeviceType get_device_type(MachineView const &mv); + +std::vector get_strides(MachineView const &mv); + +std::vector + get_dimensions(MachineView const &mv); + +MachineView machine_view_from_strides_and_machine_spec_dimensions( + MachineSpaceCoordinate const &start, + std::vector const &strides, + std::vector const &dims); + +MachineSpaceCoordinate get_machine_space_coordinate( + OperatorTaskSpace const &operator_task_space, + MachineView const &machine_view, + TaskSpaceCoordinate const &task_space_coordinate); + +TaskSpaceCoordinate + mv_task_space_coord_for_machine_space_coord(MachineView const &, + OperatorTaskSpace const &, + MachineSpaceCoordinate const &); + +OperatorSpaceToMachineSpaceMapping get_coordinate_mapping_for_machine_view( + OperatorTaskSpace const &operator_task_space, + MachineView const &machine_view); + +std::unordered_set + get_machine_space_coordinates(OperatorTaskSpace const &task, + MachineView const &mv); + +std::unordered_set + get_device_ids(OperatorTaskSpace const &task, + MachineView const &mv, + MachineComputeSpecification const &ms); + +MachineView make_1d_machine_view(MachineSpaceCoordinate const &start, + MachineSpecificationDimension const &dim, + stride_t stride); + +MachineView make_single_device_machine_view(MachineSpaceCoordinate const &); + +OperatorAtomicTaskShardBinding + operator_atomic_task_shard_binding_from_machine_view( + ComputationGraphOpAttrs const &, + std::vector const &, + MachineView const &, + MachineSpaceCoordinate const &); + +MappedOperatorTaskGroup mapped_operator_task_group_from_machine_view( + ComputationGraphOpAttrs const &, + std::unordered_map const &, + MachineView const &); + +bidict + get_tensor_shard_to_device_coord_mapping(ComputationGraphOpAttrs const &, + MachineView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_view_dimension.dtg.toml b/lib/compiler/include/compiler/machine_mapping/machine_view_dimension.dtg.toml new file mode 100644 index 0000000000..81931359c0 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_view_dimension.dtg.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "MachineViewDimension" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "pcg/machine_specification_dimension.dtg.h", + "compiler/machine_mapping/stride_t.dtg.h", +] + + +[[fields]] +name = "stride" +type = "::FlexFlow::stride_t" + +[[fields]] +name = "projection" +type = "::FlexFlow::MachineSpecificationDimension" diff --git a/lib/compiler/include/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.h b/lib/compiler/include/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.h index 74c6aee851..3c1dc5f9fb 100644 --- a/lib/compiler/include/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.h +++ b/lib/compiler/include/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.h @@ -10,7 +10,7 @@ #include "compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.dtg.h" #include "compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_context.dtg.h" #include "compiler/machine_mapping/parallel_split_transformation.dtg.h" -#include "pcg/machine_specification.dtg.h" +#include "pcg/machine_compute_specification.dtg.h" namespace FlexFlow { @@ -18,14 +18,14 @@ MachineMappingWithMemoryResult get_optimal_machine_mapping_with_memory( MachineMappingWithMemoryCache &result_cache, MachineMappingWithMemoryContext const &context, MachineMappingProblemTree const &problem_tree, - MachineSpecification const &resources, + MachineComputeResourceSlice const &resources, MachineMappingConstraints const &constraints); MachineMappingWithMemoryResult get_optimal_machine_mapping_with_memory( MachineMappingWithMemoryCache &result_cache, MachineMappingWithMemoryContext const &context, MMProblemTreeSeriesSplit const &series_split, - MachineSpecification const &resources, + MachineComputeResourceSlice const &resources, MachineMappingConstraints const &constraints, std::optional const ¶llel_split_transformation); @@ -34,14 +34,14 @@ MachineMappingWithMemoryResult get_optimal_machine_mapping_with_memory( MachineMappingWithMemoryCache &result_cache, MachineMappingWithMemoryContext const &context, MMProblemTreeParallelSplit const ¶llel_split, - MachineSpecification const &resources, + MachineComputeResourceSlice const &resources, MachineMappingConstraints const &constraints); MachineMappingWithMemoryResult get_optimal_machine_mapping_with_memory( MachineMappingWithMemoryCache &result_cache, MachineMappingWithMemoryContext const &context, UnmappedRuntimeOnlyOpCostEstimateKey const &leaf, - MachineSpecification const &resources, + MachineComputeResourceSlice const &resources, MachineMappingConstraints const &constraints); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_for_single_layer.struct.toml b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_for_single_layer.struct.toml deleted file mode 100644 index b61dd134c0..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_for_single_layer.struct.toml +++ /dev/null @@ -1,20 +0,0 @@ -namespace = "FlexFlow" -name = "MachineMappingForSingleLayer" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h", - "compiler/cost_estimator/op_cost_metrics.dtg.h", -] - -[[fields]] -name = "cost" -type = "::FlexFlow::OpCostMetrics" - -[[fields]] -name = "machine_mapping" -type = "::FlexFlow::ParallelLayerGuidObliviousMachineMapping" diff --git a/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.dtg.toml b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.dtg.toml new file mode 100644 index 0000000000..bfe5981466 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.dtg.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "MachineMappingWithMemoryCache" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "compiler/machine_mapping/machine_mapping_state.dtg.h", + "compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "raw_map" +type = "std::unordered_map<::FlexFlow::MachineMappingState, ::FlexFlow::MachineMappingWithMemoryResult>" diff --git a/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.struct.toml b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.struct.toml deleted file mode 100644 index c2fe393e99..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.struct.toml +++ /dev/null @@ -1,22 +0,0 @@ -namespace = "FlexFlow" -name = "MachineMappingWithMemoryCache" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "", - "compiler/machine_mapping/machine_mapping_state.dtg.h", - "compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.dtg.h", -] - -src_includes = [ - "utils/fmt/unordered_map.h", - "utils/hash/unordered_map.h", -] - -[[fields]] -name = "raw_map" -type = "std::unordered_map<::FlexFlow::MachineMappingState, ::FlexFlow::MachineMappingWithMemoryResult>" diff --git a/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_context.dtg.toml b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_context.dtg.toml new file mode 100644 index 0000000000..fc47dff0ba --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_context.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "MachineMappingWithMemoryContext" +type = "struct" +features = [] + +includes = [ + "compiler/cost_estimator/cost_estimator.h", + "compiler/machine_mapping/machine_view.dtg.h", + "compiler/machine_mapping/machine_compute_resource_slice.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.dtg.h", + "pcg/optimizer_attrs.dtg.h", +] + +[[fields]] +name = "cost_estimator" +type = "::FlexFlow::CostEstimator" + +[[fields]] +name = "optimizer_attrs" +type = "::FlexFlow::OptimizerAttrs" + +[[fields]] +name = "allowed_machine_views" +type = "std::function(::FlexFlow::UnmappedRuntimeOnlyOpCostEstimateKey const &, ::FlexFlow::MachineComputeResourceSlice const &)>" diff --git a/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_context.struct.toml b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_context.struct.toml deleted file mode 100644 index 9530697632..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_context.struct.toml +++ /dev/null @@ -1,23 +0,0 @@ -namespace = "FlexFlow" -name = "MachineMappingWithMemoryContext" -features = [] - -includes = [ - "compiler/cost_estimator/cost_estimator.h", - "pcg/machine_view.dtg.h", - "pcg/machine_specification.dtg.h", - "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.dtg.h", - "pcg/optimizer_attrs.dtg.h", -] - -[[fields]] -name = "cost_estimator" -type = "::FlexFlow::CostEstimator" - -[[fields]] -name = "optimizer_attrs" -type = "::FlexFlow::OptimizerAttrs" - -[[fields]] -name = "allowed_machine_views" -type = "std::function(::FlexFlow::UnmappedRuntimeOnlyOpCostEstimateKey const &, ::FlexFlow::MachineSpecification const &)>" diff --git a/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.h b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.h index 4cb865dece..ab648f48f3 100644 --- a/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.h +++ b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.h @@ -1,12 +1,38 @@ #ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_MEMORY_OPTIMIZATION_MACHINE_MAPPING_RESULT_WITH_MEMORY_H #define _FLEXFLOW_COMPILER_MACHINE_MAPPING_MEMORY_OPTIMIZATION_MACHINE_MAPPING_RESULT_WITH_MEMORY_H -#include "compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.dtg.h" +#include "compiler/machine_mapping/machine_resource_split.dtg.h" +#include "compiler/machine_mapping/memory_optimization/pareto_optimal_machine_mapping.dtg.h" #include "compiler/machine_mapping/parallel_split_transformation.dtg.h" #include namespace FlexFlow { +struct MachineMappingWithMemoryResult { + MachineMappingWithMemoryResult() = delete; + + explicit MachineMappingWithMemoryResult( + std::unordered_set const &); + + bool operator==(MachineMappingWithMemoryResult const &) const; + bool operator!=(MachineMappingWithMemoryResult const &) const; + + std::unordered_set const & + get_pareto_frontier() const; + +private: + std::unordered_set m_pareto_frontier; + +private: + std::tuple tie() const; + + friend struct ::std::hash; +}; + +std::string format_as(MachineMappingWithMemoryResult const &); +std::ostream &operator<<(std::ostream &, + MachineMappingWithMemoryResult const &); + [[nodiscard]] MachineMappingWithMemoryResult empty_machine_mapping_with_memory_result(); [[nodiscard]] bool is_empty(MachineMappingWithMemoryResult const &); @@ -14,10 +40,6 @@ namespace FlexFlow { [[nodiscard]] MachineMappingWithMemoryResult get_mapping_with_minimal_runtime( std::unordered_set const &); -[[nodiscard]] MachineMappingWithMemoryResult - remove_non_pareto_optimal_machine_mapping_result( - MachineMappingWithMemoryResult const &); - [[nodiscard]] MachineMappingWithMemoryResult series_combine(milliseconds_t comm_cost, MachineMappingWithMemoryResult const &pre_result, @@ -25,7 +47,8 @@ namespace FlexFlow { std::optional const ¶llel_split_transformation); [[nodiscard]] MachineMappingWithMemoryResult - parallel_combine(MachineMappingWithMemoryResult const &lhs_result, + parallel_combine(MachineResourceSplit const &split, + MachineMappingWithMemoryResult const &lhs_result, MachineMappingWithMemoryResult const &rhs_result); [[nodiscard]] MachineMappingWithMemoryResult @@ -38,4 +61,13 @@ namespace FlexFlow { } // namespace FlexFlow +namespace std { + +template <> +struct hash<::FlexFlow::MachineMappingWithMemoryResult> { + size_t operator()(::FlexFlow::MachineMappingWithMemoryResult const &) const; +}; + +} // namespace std + #endif diff --git a/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.struct.toml b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.struct.toml deleted file mode 100644 index c1e1ee1cac..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.struct.toml +++ /dev/null @@ -1,20 +0,0 @@ -namespace = "FlexFlow" -name = "MachineMappingWithMemoryResult" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "compiler/machine_mapping/memory_optimization/machine_mapping_for_single_layer.dtg.h", -] - -src_includes = [ - "utils/hash/unordered_set.h", - "utils/fmt/unordered_set.h", -] - -[[fields]] -name = "machine_mappings" -type = "std::unordered_set<::FlexFlow::MachineMappingForSingleLayer>" diff --git a/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_state.dtg.toml b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_state.dtg.toml new file mode 100644 index 0000000000..bfadc51ce4 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_state.dtg.toml @@ -0,0 +1,31 @@ +namespace = "FlexFlow" +name = "MachineMappingWithMemoryState" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/machine_specification.dtg.h", + "compiler/machine_mapping/machine_mapping_constraints.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h", + "pcg/optimizer_attrs.dtg.h", +] + +[[fields]] +name = "problem_tree" +type = "::FlexFlow::MachineMappingProblemTree" + +[[fields]] +name = "resources" +type = "::FlexFlow::MachineSpecification" + +[[fields]] +name = "constraints" +type = "::FlexFlow::MachineMappingConstraints" + +[[fields]] +name = "optimizer_attrs" +type = "::FlexFlow::OptimizerAttrs" diff --git a/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_state.struct.toml b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_state.struct.toml deleted file mode 100644 index 77af129094..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_state.struct.toml +++ /dev/null @@ -1,30 +0,0 @@ -namespace = "FlexFlow" -name = "MachineMappingWithMemoryState" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "pcg/machine_specification.dtg.h", - "compiler/machine_mapping/machine_mapping_constraints.dtg.h", - "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h", - "pcg/optimizer_attrs.dtg.h", -] - -[[fields]] -name = "problem_tree" -type = "::FlexFlow::MachineMappingProblemTree" - -[[fields]] -name = "resources" -type = "::FlexFlow::MachineSpecification" - -[[fields]] -name = "constraints" -type = "::FlexFlow::MachineMappingConstraints" - -[[fields]] -name = "optimizer_attrs" -type = "::FlexFlow::OptimizerAttrs" diff --git a/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_memory_constraints.dtg.toml b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_memory_constraints.dtg.toml new file mode 100644 index 0000000000..171cfb71e9 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_memory_constraints.dtg.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "MachineMemoryConstraints" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [] + +[[fields]] +name = "memory_limit" +type = "size_t" diff --git a/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_memory_constraints.struct.toml b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_memory_constraints.struct.toml deleted file mode 100644 index 0d2572c783..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_memory_constraints.struct.toml +++ /dev/null @@ -1,13 +0,0 @@ -namespace = "FlexFlow" -name = "MachineMemoryConstraints" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [] - -[[fields]] -name = "memory_limit" -type = "size_t" diff --git a/lib/compiler/include/compiler/machine_mapping/memory_optimization/pareto_optimal_machine_mapping.dtg.toml b/lib/compiler/include/compiler/machine_mapping/memory_optimization/pareto_optimal_machine_mapping.dtg.toml new file mode 100644 index 0000000000..fc33be6aae --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/memory_optimization/pareto_optimal_machine_mapping.dtg.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "ParetoOptimalMachineMapping" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h", + "compiler/cost_estimator/op_cost_metrics.dtg.h", +] + +[[fields]] +name = "cost" +type = "::FlexFlow::OpCostMetrics" + +[[fields]] +name = "machine_mapping" +type = "::FlexFlow::ParallelLayerGuidObliviousMachineMapping" diff --git a/lib/compiler/include/compiler/machine_mapping/memory_optimization/pareto_optimal_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/memory_optimization/pareto_optimal_machine_mapping.h new file mode 100644 index 0000000000..6e263fc412 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/memory_optimization/pareto_optimal_machine_mapping.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MEMORY_OPTIMIZATION_PARETO_OPTIMAL_MACHINE_MAPPING_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MEMORY_OPTIMIZATION_PARETO_OPTIMAL_MACHINE_MAPPING_H + +#include "compiler/machine_mapping/memory_optimization/pareto_optimal_machine_mapping.dtg.h" + +namespace FlexFlow { + +bool is_pareto_optimal_in( + ParetoOptimalMachineMapping const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/multi_dimensional_stride.dtg.toml b/lib/compiler/include/compiler/machine_mapping/multi_dimensional_stride.dtg.toml new file mode 100644 index 0000000000..94afa7cd2a --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/multi_dimensional_stride.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "MultiDimensionalStride" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "", + "compiler/machine_mapping/stride_t.dtg.h", +] + +src_includes = [ + "utils/hash/vector.h", + "utils/fmt/vector.h" + +] + +[[fields]] +name = "raw_strides" +type = "std::vector<::FlexFlow::stride_t>" diff --git a/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.toml b/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.toml new file mode 100644 index 0000000000..344817bffc --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "ParallelLayerGuidObliviousMachineMapping" +type = "struct" +features = [ + "eq", + "hash", + "fmt", + "rapidcheck", +] + +includes = [ + "compiler/machine_mapping/machine_view.dtg.h", + "utils/full_binary_tree/binary_tree_path.dtg.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "raw_mapping" +type = "std::unordered_map<::FlexFlow::BinaryTreePath, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h index cb3af9c689..9f2871239d 100644 --- a/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h @@ -1,7 +1,11 @@ #ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_PARALLEL_LAYER_GUID_OBLIVIOUS_MACHINE_MAPPING_H #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_PARALLEL_LAYER_GUID_OBLIVIOUS_MACHINE_MAPPING_H +#include "compiler/machine_mapping/abstracted_tensor_set_movement/machine_space_stencil.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" #include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" #include namespace FlexFlow { @@ -19,6 +23,22 @@ std::optional get_machine_view_for_path(ParallelLayerGuidObliviousMachineMapping const &, BinaryTreePath const &); +std::unordered_map + get_machine_stencils_for_decomposition( + ParallelComputationGraph const &pcg, + PCGBinarySPDecomposition const &decomposition, + ParallelLayerGuidObliviousMachineMapping const &mapping); + +std::unordered_map> + get_machine_stencils_for_mm_problem_tree( + MachineMappingProblemTree const &, + ParallelLayerGuidObliviousMachineMapping const &mapping); + +std::unordered_map + get_machine_stencils_for_partially_mapped_mm_problem_tree( + MachineMappingProblemTree const &, + ParallelLayerGuidObliviousMachineMapping const &); + } // namespace FlexFlow #endif diff --git a/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.struct.toml b/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.struct.toml deleted file mode 100644 index f00fcc8490..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.struct.toml +++ /dev/null @@ -1,21 +0,0 @@ -namespace = "FlexFlow" -name = "ParallelLayerGuidObliviousMachineMapping" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "pcg/machine_view.dtg.h", - "utils/full_binary_tree/binary_tree_path.dtg.h", -] - -src_includes = [ - "utils/fmt/unordered_map.h", - "utils/hash/unordered_map.h", -] - -[[fields]] -name = "raw_mapping" -type = "std::unordered_map<::FlexFlow::BinaryTreePath, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/machine_mapping/parallel_split_transformation.dtg.toml b/lib/compiler/include/compiler/machine_mapping/parallel_split_transformation.dtg.toml new file mode 100644 index 0000000000..a8a02f8ec1 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/parallel_split_transformation.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "ParallelSplitTransformation" +type = "enum" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "LthenR" + +[[values]] +name = "RthenL" diff --git a/lib/compiler/include/compiler/machine_mapping/parallel_split_transformation.enum.toml b/lib/compiler/include/compiler/machine_mapping/parallel_split_transformation.enum.toml deleted file mode 100644 index 8247c0cbdc..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/parallel_split_transformation.enum.toml +++ /dev/null @@ -1,14 +0,0 @@ -namespace = "FlexFlow" -name = "ParallelSplitTransformation" -features = [ - "hash", - "fmt", - "rapidcheck", - "json", -] - -[[values]] -name = "LthenR" - -[[values]] -name = "RthenL" diff --git a/lib/compiler/include/compiler/machine_mapping/pcg_split_boundary_layers.dtg.toml b/lib/compiler/include/compiler/machine_mapping/pcg_split_boundary_layers.dtg.toml new file mode 100644 index 0000000000..cbdedd47e7 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/pcg_split_boundary_layers.dtg.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "PCGSplitBoundaryLayers" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_set.h", "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "pre_split_boundary" +type = "std::unordered_set<::FlexFlow::parallel_layer_guid_t>" + +[[fields]] +name = "post_split_boundary" +type = "std::unordered_set<::FlexFlow::parallel_layer_guid_t>" diff --git a/lib/compiler/include/compiler/machine_mapping/pcg_split_boundary_layers.struct.toml b/lib/compiler/include/compiler/machine_mapping/pcg_split_boundary_layers.struct.toml deleted file mode 100644 index 155e526672..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/pcg_split_boundary_layers.struct.toml +++ /dev/null @@ -1,24 +0,0 @@ -namespace = "FlexFlow" -name = "PCGSplitBoundaryLayers" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", - "", -] - -src_includes = [ - "utils/hash/unordered_set.h", "utils/fmt/unordered_set.h", -] - -[[fields]] -name = "pre_split_boundary" -type = "std::unordered_set<::FlexFlow::parallel_layer_guid_t>" - -[[fields]] -name = "post_split_boundary" -type = "std::unordered_set<::FlexFlow::parallel_layer_guid_t>" diff --git a/lib/compiler/include/compiler/machine_mapping/start_invariant_machine_view.dtg.toml b/lib/compiler/include/compiler/machine_mapping/start_invariant_machine_view.dtg.toml new file mode 100644 index 0000000000..b271ed1959 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/start_invariant_machine_view.dtg.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "StartInvariantMachineView" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "compiler/machine_mapping/machine_view_dimension.dtg.h", + "pcg/device_type.dtg.h" +] + +src_includes = [ + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "dimensions" +type = "std::vector<::FlexFlow::MachineViewDimension>" + + +[[fields]] +name = "device_type" +type = "::FlexFlow::DeviceType" diff --git a/lib/compiler/include/compiler/machine_mapping/start_invariant_machine_view.h b/lib/compiler/include/compiler/machine_mapping/start_invariant_machine_view.h new file mode 100644 index 0000000000..631d23b07c --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/start_invariant_machine_view.h @@ -0,0 +1,45 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_START_INVARIANT_MACHINE_VIEW_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_START_INVARIANT_MACHINE_VIEW_H + +#include "compiler/machine_mapping/machine_view.dtg.h" +#include "compiler/machine_mapping/start_invariant_machine_view.dtg.h" +#include "op-attrs/operator_task_space.dtg.h" +#include "op-attrs/task_space_coordinate.dtg.h" +#include "pcg/machine_compute_specification.dtg.h" +#include "pcg/machine_space_offset.h" +#include + +namespace FlexFlow { + +MachineView + machine_view_from_start_invariant(StartInvariantMachineView const &mv, + MachineSpaceCoordinate const &start); +StartInvariantMachineView + start_invariant_from_machine_view(MachineView const &mv); + +nonnegative_int num_dims(StartInvariantMachineView const &mv); + +DeviceType get_device_type(StartInvariantMachineView const &mv); + +std::vector get_strides(StartInvariantMachineView const &mv); + +std::vector + get_dimensions(StartInvariantMachineView const &mv); + +StartInvariantMachineView + start_invariant_machine_view_from_strides_and_machine_spec_dimensions( + std::vector const &strides, + std::vector const &dims); + +MachineSpaceOffset + get_machine_space_offset(OperatorTaskSpace const &task, + StartInvariantMachineView const &mv, + TaskSpaceCoordinate const &coordinates); + +std::unordered_set + get_machine_space_offsets(OperatorTaskSpace const &task, + StartInvariantMachineView const &mv); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/stride_t.dtg.toml b/lib/compiler/include/compiler/machine_mapping/stride_t.dtg.toml new file mode 100644 index 0000000000..5ff035656c --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/stride_t.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "stride_t" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "utils/positive_int/positive_int.h", +] + +[[fields]] +name = "unwrapped" +type = "::FlexFlow::positive_int" diff --git a/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.dtg.toml b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.dtg.toml new file mode 100644 index 0000000000..b495e43fee --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "TransitiveReducedPCG" +type = "struct" +features = [] + +includes = [ + "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h", +] + +[[fields]] +name = "full_pcg" +type = "::FlexFlow::ParallelComputationGraph" + +[[fields]] +name = "transitive_reduction" +type = "::FlexFlow::DiGraphView" + diff --git a/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h index 2b2bc9bf84..8055d15b4e 100644 --- a/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h +++ b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h @@ -7,11 +7,11 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" -#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/transitive_reduced_kwarg_dataflow_graph/transitive_reduced_kwarg_dataflow_graph_view.dtg.h" namespace FlexFlow { -TransitiveReducedDataflowGraphView +TransitiveReducedKwargDataflowGraphView get_underlying_transitive_reduced_dataflow_graph( TransitiveReducedPCG const &); diff --git a/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.struct.toml b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.struct.toml deleted file mode 100644 index bb76ec2ff7..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "TransitiveReducedPCG" -features = [] - -includes = [ - "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h", -] - -[[fields]] -name = "full_pcg" -type = "::FlexFlow::ParallelComputationGraph" - -[[fields]] -name = "transitive_reduction" -type = "::FlexFlow::DiGraphView" - diff --git a/lib/compiler/include/compiler/machine_mapping/unstructured_device_mapping.dtg.toml b/lib/compiler/include/compiler/machine_mapping/unstructured_device_mapping.dtg.toml new file mode 100644 index 0000000000..28391eddc0 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/unstructured_device_mapping.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "UnstructuredDeviceMapping" +type = "struct" +features = [ + "eq", + # "ord", + "hash", + # "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "pcg/device_id_t.dtg.h" +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/fmt/unordered_map.h", + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h" +] + +[[fields]] +name = "raw_device_map" +type = "std::unordered_map<::FlexFlow::parallel_layer_guid_t, std::unordered_set<::FlexFlow::device_id_t>>" diff --git a/lib/compiler/include/compiler/machine_mapping/unstructured_device_mapping.h b/lib/compiler/include/compiler/machine_mapping/unstructured_device_mapping.h index 0fb31210fd..8c1333fabc 100644 --- a/lib/compiler/include/compiler/machine_mapping/unstructured_device_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping/unstructured_device_mapping.h @@ -3,15 +3,15 @@ #include "compiler/machine_mapping/machine_mapping.dtg.h" #include "compiler/machine_mapping/unstructured_device_mapping.dtg.h" -#include "pcg/machine_specification.dtg.h" +#include "pcg/machine_compute_specification.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" namespace FlexFlow { -UnstructuredDeviceMapping - get_unstructured_device_mapping(MachineMapping const &machine_mapping, - MachineSpecification const &machine_spec, - ParallelComputationGraph const &pcg); +UnstructuredDeviceMapping get_unstructured_device_mapping( + MachineMapping const &machine_mapping, + MachineComputeSpecification const &machine_spec, + ParallelComputationGraph const &pcg); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/unstructured_device_mapping.struct.toml b/lib/compiler/include/compiler/machine_mapping/unstructured_device_mapping.struct.toml deleted file mode 100644 index ae38a37292..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/unstructured_device_mapping.struct.toml +++ /dev/null @@ -1,26 +0,0 @@ -namespace = "FlexFlow" -name = "UnstructuredDeviceMapping" -features = [ - "eq", - # "ord", - "hash", - # "json", - # "rapidcheck", - "fmt", -] - -includes = [ - "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", - "pcg/device_id_t.dtg.h" -] - -src_includes = [ - "utils/hash/unordered_map.h", - "utils/fmt/unordered_map.h", - "utils/hash/unordered_set.h", - "utils/fmt/unordered_set.h" -] - -[[fields]] -name = "raw_device_map" -type = "std::unordered_map<::FlexFlow::parallel_layer_guid_t, std::unordered_set<::FlexFlow::device_id_t>>" diff --git a/lib/compiler/include/compiler/mapped_task_signature_tensor_key.dtg.toml b/lib/compiler/include/compiler/mapped_task_signature_tensor_key.dtg.toml new file mode 100644 index 0000000000..b504e2e713 --- /dev/null +++ b/lib/compiler/include/compiler/mapped_task_signature_tensor_key.dtg.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "MappedTaskSignatureTensorKey" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "utils/nonnegative_int/nonnegative_int.h", + "op-attrs/tensor_role.dtg.h", + "pcg/gpu_id_t.dtg.h", +] + +[[fields]] +name = "gpu_id" +type = "::FlexFlow::gpu_id_t" + +[[fields]] +name = "tensor_role" +type = "::FlexFlow::TensorRole" + +[[fields]] +name = "idx" +type = "::FlexFlow::nonnegative_int" diff --git a/lib/compiler/include/compiler/optimizer_config.dtg.toml b/lib/compiler/include/compiler/optimizer_config.dtg.toml new file mode 100644 index 0000000000..395b22f46b --- /dev/null +++ b/lib/compiler/include/compiler/optimizer_config.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "OptimizerConfig" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ +] + +[[fields]] +name = "alpha" +type = "float" + +[[fields]] +name = "budget" +type = "int" + +[[fields]] +name = "threshold" +type = "float" + +[[fields]] +name = "max_num_ops" +type = "int" diff --git a/lib/compiler/include/compiler/optimizer_config.struct.toml b/lib/compiler/include/compiler/optimizer_config.struct.toml deleted file mode 100644 index b7f4f71e9c..0000000000 --- a/lib/compiler/include/compiler/optimizer_config.struct.toml +++ /dev/null @@ -1,26 +0,0 @@ -namespace = "FlexFlow" -name = "OptimizerConfig" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ -] - -[[fields]] -name = "alpha" -type = "float" - -[[fields]] -name = "budget" -type = "int" - -[[fields]] -name = "threshold" -type = "float" - -[[fields]] -name = "max_num_ops" -type = "int" diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_parallel_split.dtg.toml b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_parallel_split.dtg.toml new file mode 100644 index 0000000000..70c5129d2e --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_parallel_split.dtg.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "ComputationGraphBinaryParallelSplit" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct ComputationGraphBinarySPDecomposition", +] + +post_includes = [ + "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.dtg.h", +] + +includes = [] + +[[fields]] +name = "left_child" +type = "::FlexFlow::ComputationGraphBinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::ComputationGraphBinarySPDecomposition" +indirect = true diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_parallel_split.struct.toml b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_parallel_split.struct.toml deleted file mode 100644 index 9654a2546e..0000000000 --- a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_parallel_split.struct.toml +++ /dev/null @@ -1,27 +0,0 @@ -namespace = "FlexFlow" -name = "ComputationGraphBinaryParallelSplit" -features = [ - "eq", - "hash", - "fmt", -] - -fwd_decls = [ - "struct ComputationGraphBinarySPDecomposition", -] - -post_includes = [ - "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.dtg.h", -] - -includes = [] - -[[fields]] -name = "left_child" -type = "::FlexFlow::ComputationGraphBinarySPDecomposition" -indirect = true - -[[fields]] -name = "right_child" -type = "::FlexFlow::ComputationGraphBinarySPDecomposition" -indirect = true diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_series_split.dtg.toml b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_series_split.dtg.toml new file mode 100644 index 0000000000..b4f8f2fd64 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_series_split.dtg.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "ComputationGraphBinarySeriesSplit" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct ComputationGraphBinarySPDecomposition", +] + +post_includes = [ + "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.dtg.h", +] + +includes = [] + +[[fields]] +name = "left_child" +type = "::FlexFlow::ComputationGraphBinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::ComputationGraphBinarySPDecomposition" +indirect = true diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_series_split.struct.toml b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_series_split.struct.toml deleted file mode 100644 index aa66c80b43..0000000000 --- a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_series_split.struct.toml +++ /dev/null @@ -1,27 +0,0 @@ -namespace = "FlexFlow" -name = "ComputationGraphBinarySeriesSplit" -features = [ - "eq", - "hash", - "fmt", -] - -fwd_decls = [ - "struct ComputationGraphBinarySPDecomposition", -] - -post_includes = [ - "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.dtg.h", -] - -includes = [] - -[[fields]] -name = "left_child" -type = "::FlexFlow::ComputationGraphBinarySPDecomposition" -indirect = true - -[[fields]] -name = "right_child" -type = "::FlexFlow::ComputationGraphBinarySPDecomposition" -indirect = true diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.dtg.toml b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.dtg.toml new file mode 100644 index 0000000000..390e808d0e --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "ComputationGraphBinarySPDecomposition" +type = "variant" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/layer_guid_t.dtg.h", + "compiler/series_parallel/computation_graph/computation_graph_binary_series_split.dtg.h", + "compiler/series_parallel/computation_graph/computation_graph_binary_parallel_split.dtg.h", +] + +[[values]] +type = "::FlexFlow::ComputationGraphBinarySeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::ComputationGraphBinaryParallelSplit" +key = "parallel" + +[[values]] +type = "::FlexFlow::layer_guid_t" +key = "leaf" diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.variant.toml b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.variant.toml deleted file mode 100644 index 452470620b..0000000000 --- a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.variant.toml +++ /dev/null @@ -1,25 +0,0 @@ -namespace = "FlexFlow" -name = "ComputationGraphBinarySPDecomposition" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "pcg/layer_guid_t.dtg.h", - "compiler/series_parallel/computation_graph/computation_graph_binary_series_split.dtg.h", - "compiler/series_parallel/computation_graph/computation_graph_binary_parallel_split.dtg.h", -] - -[[values]] -type = "::FlexFlow::ComputationGraphBinarySeriesSplit" -key = "series" - -[[values]] -type = "::FlexFlow::ComputationGraphBinaryParallelSplit" -key = "parallel" - -[[values]] -type = "::FlexFlow::layer_guid_t" -key = "leaf" diff --git a/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h deleted file mode 100644 index d43edaa79d..0000000000 --- a/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_BALANCED_BINARY_SP_DECOMPOSITION_H -#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_BALANCED_BINARY_SP_DECOMPOSITION_H - -namespace FlexFlow { - -std::optional - get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.dtg.toml b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.dtg.toml new file mode 100644 index 0000000000..35a4b7609e --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.dtg.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "PCGBinaryParallelSplit" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct PCGBinarySPDecomposition", +] + +post_includes = [ + "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h", +] + +includes = [] + +[[fields]] +name = "left_child" +type = "::FlexFlow::PCGBinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::PCGBinarySPDecomposition" +indirect = true diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.struct.toml b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.struct.toml deleted file mode 100644 index f7f7026716..0000000000 --- a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.struct.toml +++ /dev/null @@ -1,27 +0,0 @@ -namespace = "FlexFlow" -name = "PCGBinaryParallelSplit" -features = [ - "eq", - "hash", - "fmt", -] - -fwd_decls = [ - "struct PCGBinarySPDecomposition", -] - -post_includes = [ - "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h", -] - -includes = [] - -[[fields]] -name = "left_child" -type = "::FlexFlow::PCGBinarySPDecomposition" -indirect = true - -[[fields]] -name = "right_child" -type = "::FlexFlow::PCGBinarySPDecomposition" -indirect = true diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.dtg.toml b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.dtg.toml new file mode 100644 index 0000000000..e58f40a231 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.dtg.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "PCGBinarySeriesSplit" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct PCGBinarySPDecomposition", +] + +post_includes = [ + "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h", +] + +includes = [] + +[[fields]] +name = "left_child" +type = "::FlexFlow::PCGBinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::PCGBinarySPDecomposition" +indirect = true diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.struct.toml b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.struct.toml deleted file mode 100644 index af2c8c4dae..0000000000 --- a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.struct.toml +++ /dev/null @@ -1,27 +0,0 @@ -namespace = "FlexFlow" -name = "PCGBinarySeriesSplit" -features = [ - "eq", - "hash", - "fmt", -] - -fwd_decls = [ - "struct PCGBinarySPDecomposition", -] - -post_includes = [ - "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h", -] - -includes = [] - -[[fields]] -name = "left_child" -type = "::FlexFlow::PCGBinarySPDecomposition" -indirect = true - -[[fields]] -name = "right_child" -type = "::FlexFlow::PCGBinarySPDecomposition" -indirect = true diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.toml b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.toml new file mode 100644 index 0000000000..a1a92c952a --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "PCGBinarySPDecomposition" +type = "variant" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h", + "compiler/series_parallel/pcg/pcg_binary_parallel_split.dtg.h", +] + +[[values]] +type = "::FlexFlow::PCGBinarySeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::PCGBinaryParallelSplit" +key = "parallel" + +[[values]] +type = "::FlexFlow::parallel_layer_guid_t" +key = "leaf" diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h index 86fa1a59aa..74ad521884 100644 --- a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h @@ -29,10 +29,16 @@ std::unordered_multiset SPDecompositionTreeNodeType get_node_type(PCGBinarySPDecomposition const &); +std::unordered_set + pcg_sp_tree_get_all_leaf_paths(PCGBinarySPDecomposition const &); + std::unordered_set find_paths_to_leaf(PCGBinarySPDecomposition const &, parallel_layer_guid_t const &); +std::unordered_map + pcg_sp_tree_get_path_to_leaf_map(PCGBinarySPDecomposition const &); + } // namespace FlexFlow #endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.variant.toml b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.variant.toml deleted file mode 100644 index 52372fb270..0000000000 --- a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.variant.toml +++ /dev/null @@ -1,25 +0,0 @@ -namespace = "FlexFlow" -name = "PCGBinarySPDecomposition" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", - "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h", - "compiler/series_parallel/pcg/pcg_binary_parallel_split.dtg.h", -] - -[[values]] -type = "::FlexFlow::PCGBinarySeriesSplit" -key = "series" - -[[values]] -type = "::FlexFlow::PCGBinaryParallelSplit" -key = "parallel" - -[[values]] -type = "::FlexFlow::parallel_layer_guid_t" -key = "leaf" diff --git a/lib/compiler/include/compiler/task_graph_simulator/in_progress_task.dtg.toml b/lib/compiler/include/compiler/task_graph_simulator/in_progress_task.dtg.toml new file mode 100644 index 0000000000..0788cb196e --- /dev/null +++ b/lib/compiler/include/compiler/task_graph_simulator/in_progress_task.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "InProgressTask" +type = "struct" + +features = [ + "eq", + "hash", + "fmt", + "ord" +] + +includes = [ + "utils/graph/node/node.dtg.h" +] + + +[[fields]] +name = "start_time" +type = "float" + +[[fields]] +name = "end_time" +type = "float" + +[[fields]] +name = "node" +type = "::FlexFlow::Node" diff --git a/lib/compiler/include/compiler/task_graph_simulator/in_progress_task.struct.toml b/lib/compiler/include/compiler/task_graph_simulator/in_progress_task.struct.toml deleted file mode 100644 index 71e0e17f5e..0000000000 --- a/lib/compiler/include/compiler/task_graph_simulator/in_progress_task.struct.toml +++ /dev/null @@ -1,26 +0,0 @@ -namespace = "FlexFlow" -name = "InProgressTask" - -features = [ - "eq", - "hash", - "fmt", - "ord" -] - -includes = [ - "utils/graph/node/node.dtg.h" -] - - -[[fields]] -name = "start_time" -type = "float" - -[[fields]] -name = "end_time" -type = "float" - -[[fields]] -name = "node" -type = "::FlexFlow::Node" diff --git a/lib/compiler/include/compiler/task_graph_simulator/pcg_task.dtg.toml b/lib/compiler/include/compiler/task_graph_simulator/pcg_task.dtg.toml new file mode 100644 index 0000000000..48eb99e9c6 --- /dev/null +++ b/lib/compiler/include/compiler/task_graph_simulator/pcg_task.dtg.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "PCGTask" +type = "variant" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/cost_estimator/runtime_only_op_cost_estimate_key.dtg.h", + "compiler/cost_estimator/tensor_set_movement.dtg.h", +] + +[[values]] +type = "::FlexFlow::RuntimeOnlyOpCostEstimateKey" +key = "operator" + +[[values]] +type = "::FlexFlow::TensorSetMovement" +key = "tensor_movement" diff --git a/lib/compiler/include/compiler/task_graph_simulator/pcg_task.variant.toml b/lib/compiler/include/compiler/task_graph_simulator/pcg_task.variant.toml deleted file mode 100644 index cb8490c861..0000000000 --- a/lib/compiler/include/compiler/task_graph_simulator/pcg_task.variant.toml +++ /dev/null @@ -1,20 +0,0 @@ -namespace = "FlexFlow" -name = "PCGTask" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "compiler/cost_estimator/runtime_only_op_cost_estimate_key.dtg.h", - "compiler/cost_estimator/tensor_set_movement.dtg.h", -] - -[[values]] -type = "::FlexFlow::RuntimeOnlyOpCostEstimateKey" -key = "operator" - -[[values]] -type = "::FlexFlow::TensorSetMovement" -key = "tensor_movement" diff --git a/lib/compiler/include/compiler/task_graph_simulator/pcg_task_graph.dtg.toml b/lib/compiler/include/compiler/task_graph_simulator/pcg_task_graph.dtg.toml new file mode 100644 index 0000000000..2c5b5f56fc --- /dev/null +++ b/lib/compiler/include/compiler/task_graph_simulator/pcg_task_graph.dtg.toml @@ -0,0 +1,35 @@ +namespace = "FlexFlow" +name = "PCGTaskGraph" +type = "struct" + +features = [ +] + +includes = [ + "utils/graph/digraph/digraph_view.h", + "utils/bidict/bidict.h", + "compiler/task_graph_simulator/pcg_task.dtg.h", + "pcg/device_id_t.dtg.h", + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "", + "" +] + +src_includes = [ + "utils/fmt/unordered_set.h", + "utils/hash/unordered_set.h", + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h" +] + +[[fields]] +name = "graph" +type = "::FlexFlow::DiGraphView" + +[[fields]] +name = "node_to_task" +type = "::FlexFlow::bidict<::FlexFlow::Node, ::FlexFlow::PCGTask>" + +[[fields]] +name = "node_to_devices" +type = "std::unordered_map<::FlexFlow::Node, std::unordered_set<::FlexFlow::device_id_t>>" diff --git a/lib/compiler/include/compiler/task_graph_simulator/pcg_task_graph.h b/lib/compiler/include/compiler/task_graph_simulator/pcg_task_graph.h index 2c6d6514e8..1af1f15dd0 100644 --- a/lib/compiler/include/compiler/task_graph_simulator/pcg_task_graph.h +++ b/lib/compiler/include/compiler/task_graph_simulator/pcg_task_graph.h @@ -3,14 +3,15 @@ #include "compiler/machine_mapping/machine_mapping.dtg.h" #include "compiler/task_graph_simulator/pcg_task_graph.dtg.h" -#include "pcg/machine_specification.dtg.h" +#include "pcg/machine_compute_specification.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" namespace FlexFlow { -PCGTaskGraph get_pcg_task_graph(ParallelComputationGraph const &pcg, - MachineMapping const &machine_mapping, - MachineSpecification const &machine_spec); +PCGTaskGraph + get_pcg_task_graph(ParallelComputationGraph const &pcg, + MachineMapping const &machine_mapping, + MachineComputeSpecification const &machine_spec); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/task_graph_simulator/pcg_task_graph.struct.toml b/lib/compiler/include/compiler/task_graph_simulator/pcg_task_graph.struct.toml deleted file mode 100644 index 099f44c564..0000000000 --- a/lib/compiler/include/compiler/task_graph_simulator/pcg_task_graph.struct.toml +++ /dev/null @@ -1,34 +0,0 @@ -namespace = "FlexFlow" -name = "PCGTaskGraph" - -features = [ -] - -includes = [ - "utils/graph/digraph/digraph_view.h", - "utils/bidict/bidict.h", - "compiler/task_graph_simulator/pcg_task.dtg.h", - "pcg/device_id_t.dtg.h", - "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", - "", - "" -] - -src_includes = [ - "utils/fmt/unordered_set.h", - "utils/hash/unordered_set.h", - "utils/fmt/unordered_map.h", - "utils/hash/unordered_map.h" -] - -[[fields]] -name = "graph" -type = "::FlexFlow::DiGraphView" - -[[fields]] -name = "node_to_task" -type = "::FlexFlow::bidict<::FlexFlow::Node, ::FlexFlow::PCGTask>" - -[[fields]] -name = "node_to_devices" -type = "std::unordered_map<::FlexFlow::Node, std::unordered_set<::FlexFlow::device_id_t>>" diff --git a/lib/compiler/include/compiler/task_graph_simulator/task_execution_constraint.dtg.toml b/lib/compiler/include/compiler/task_graph_simulator/task_execution_constraint.dtg.toml new file mode 100644 index 0000000000..a39b072fb3 --- /dev/null +++ b/lib/compiler/include/compiler/task_graph_simulator/task_execution_constraint.dtg.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "TaskExecutionConstraint" +type = "struct" +features = [ +] + +includes = [ + "utils/graph/node/node.dtg.h", + "", + "" +] + + +[[fields]] +name = "is_satisfied" +type = "std::function const &, std::unordered_set const &)>" diff --git a/lib/compiler/include/compiler/task_graph_simulator/task_execution_constraint.struct.toml b/lib/compiler/include/compiler/task_graph_simulator/task_execution_constraint.struct.toml deleted file mode 100644 index 004655b5ec..0000000000 --- a/lib/compiler/include/compiler/task_graph_simulator/task_execution_constraint.struct.toml +++ /dev/null @@ -1,15 +0,0 @@ -namespace = "FlexFlow" -name = "TaskExecutionConstraint" -features = [ -] - -includes = [ - "utils/graph/node/node.dtg.h", - "", - "" -] - - -[[fields]] -name = "is_satisfied" -type = "std::function const &, std::unordered_set const &)>" diff --git a/lib/compiler/include/compiler/task_graph_simulator/task_graph_execution_state.dtg.toml b/lib/compiler/include/compiler/task_graph_simulator/task_graph_execution_state.dtg.toml new file mode 100644 index 0000000000..bc93b7b8bc --- /dev/null +++ b/lib/compiler/include/compiler/task_graph_simulator/task_graph_execution_state.dtg.toml @@ -0,0 +1,41 @@ +namespace = "FlexFlow" +name = "TaskGraphExecutionState" +type = "struct" + +features = [ +] + +includes = [ + "utils/deduplicated_priority_queue.h", + "utils/graph/node/node.dtg.h", + "compiler/task_graph_simulator/in_progress_task.dtg.h", + "compiler/task_graph_simulator/in_progress_task_comparator.h", + "", + "", + "" +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", + "utils/hash/set.h", + "utils/fmt/set.h", + "utils/fmt/vector.h", + "utils/hash/vector.h" +] + +[[fields]] +name = "ready_tasks" +type = "std::set<::FlexFlow::Node>" + +[[fields]] +name = "in_progress_tasks" +type = "::FlexFlow::DeduplicatedPriorityQueue<::FlexFlow::InProgressTask, std::vector<::FlexFlow::InProgressTask>, ::FlexFlow::InProgressTaskComparator>" + +[[fields]] +name = "finished_tasks" +type = "std::unordered_set<::FlexFlow::Node>" + +[[fields]] +name = "current_time" +type = "float" diff --git a/lib/compiler/include/compiler/task_graph_simulator/task_graph_execution_state.struct.toml b/lib/compiler/include/compiler/task_graph_simulator/task_graph_execution_state.struct.toml deleted file mode 100644 index b96d7264b9..0000000000 --- a/lib/compiler/include/compiler/task_graph_simulator/task_graph_execution_state.struct.toml +++ /dev/null @@ -1,40 +0,0 @@ -namespace = "FlexFlow" -name = "TaskGraphExecutionState" - -features = [ -] - -includes = [ - "utils/deduplicated_priority_queue.h", - "utils/graph/node/node.dtg.h", - "compiler/task_graph_simulator/in_progress_task.dtg.h", - "compiler/task_graph_simulator/in_progress_task_comparator.h", - "", - "", - "" -] - -src_includes = [ - "utils/hash/unordered_set.h", - "utils/fmt/unordered_set.h", - "utils/hash/set.h", - "utils/fmt/set.h", - "utils/fmt/vector.h", - "utils/hash/vector.h" -] - -[[fields]] -name = "ready_tasks" -type = "std::set<::FlexFlow::Node>" - -[[fields]] -name = "in_progress_tasks" -type = "::FlexFlow::DeduplicatedPriorityQueue<::FlexFlow::InProgressTask, std::vector<::FlexFlow::InProgressTask>, ::FlexFlow::InProgressTaskComparator>" - -[[fields]] -name = "finished_tasks" -type = "std::unordered_set<::FlexFlow::Node>" - -[[fields]] -name = "current_time" -type = "float" diff --git a/lib/compiler/include/compiler/task_graph_simulator/task_graph_execution_trace.dtg.toml b/lib/compiler/include/compiler/task_graph_simulator/task_graph_execution_trace.dtg.toml new file mode 100644 index 0000000000..629e222920 --- /dev/null +++ b/lib/compiler/include/compiler/task_graph_simulator/task_graph_execution_trace.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "TaskGraphExecutionTrace" +type = "struct" + +features = [ + "hash", + "fmt", + "eq" +] + +includes = [ + "compiler/task_graph_simulator/task_profile.dtg.h", + "" +] + +src_includes = [ + "utils/fmt/unordered_set.h", + "utils/hash/unordered_set.h" +] + + +[[fields]] +name = "task_profiles" +type = "std::unordered_set<::FlexFlow::TaskProfile>" diff --git a/lib/compiler/include/compiler/task_graph_simulator/task_graph_execution_trace.struct.toml b/lib/compiler/include/compiler/task_graph_simulator/task_graph_execution_trace.struct.toml deleted file mode 100644 index 3003e5a157..0000000000 --- a/lib/compiler/include/compiler/task_graph_simulator/task_graph_execution_trace.struct.toml +++ /dev/null @@ -1,23 +0,0 @@ -namespace = "FlexFlow" -name = "TaskGraphExecutionTrace" - -features = [ - "hash", - "fmt", - "eq" -] - -includes = [ - "compiler/task_graph_simulator/task_profile.dtg.h", - "" -] - -src_includes = [ - "utils/fmt/unordered_set.h", - "utils/hash/unordered_set.h" -] - - -[[fields]] -name = "task_profiles" -type = "std::unordered_set<::FlexFlow::TaskProfile>" diff --git a/lib/compiler/include/compiler/task_graph_simulator/task_profile.dtg.toml b/lib/compiler/include/compiler/task_graph_simulator/task_profile.dtg.toml new file mode 100644 index 0000000000..67c7506b10 --- /dev/null +++ b/lib/compiler/include/compiler/task_graph_simulator/task_profile.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "TaskProfile" +type = "struct" + +features = [ + "eq", + "hash", + "fmt", + "ord" +] + +includes = [ + "utils/graph/node/node.dtg.h" +] + + +[[fields]] +name = "node" +type = "::FlexFlow::Node" + +[[fields]] +name = "start_time" +type = "float" + +[[fields]] +name = "end_time" +type = "float" diff --git a/lib/compiler/include/compiler/task_graph_simulator/task_profile.struct.toml b/lib/compiler/include/compiler/task_graph_simulator/task_profile.struct.toml deleted file mode 100644 index 1a47acfa0e..0000000000 --- a/lib/compiler/include/compiler/task_graph_simulator/task_profile.struct.toml +++ /dev/null @@ -1,26 +0,0 @@ -namespace = "FlexFlow" -name = "TaskProfile" - -features = [ - "eq", - "hash", - "fmt", - "ord" -] - -includes = [ - "utils/graph/node/node.dtg.h" -] - - -[[fields]] -name = "node" -type = "::FlexFlow::Node" - -[[fields]] -name = "start_time" -type = "float" - -[[fields]] -name = "end_time" -type = "float" diff --git a/lib/compiler/include/compiler/unity_algorithm.h b/lib/compiler/include/compiler/unity_algorithm.h index 232f2b9563..d8ba9158a6 100644 --- a/lib/compiler/include/compiler/unity_algorithm.h +++ b/lib/compiler/include/compiler/unity_algorithm.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_COMPILER_UNITY_ALGORITHM_H -#define _FLEXFLOW_COMPILER_UNITY_ALGORITHM_H +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_UNITY_ALGORITHM_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_UNITY_ALGORITHM_H #include "compiler/cost_estimator/cost_estimator.h" #include "compiler/graph_optimize_result.dtg.h" diff --git a/lib/compiler/src/compiler/allowed_machine_views.cc b/lib/compiler/src/compiler/allowed_machine_views.cc index 370cb5a4ec..558f383adc 100644 --- a/lib/compiler/src/compiler/allowed_machine_views.cc +++ b/lib/compiler/src/compiler/allowed_machine_views.cc @@ -1,8 +1,8 @@ #include "compiler/allowed_machine_views.h" -#include "pcg/machine_specification.h" -#include "pcg/machine_view.h" -#include "pcg/multi_dimensional_stride.dtg.h" -#include "pcg/operator_task_space.h" +#include "compiler/machine_mapping/machine_view.h" +#include "compiler/machine_mapping/multi_dimensional_stride.dtg.h" +#include "op-attrs/operator_task_space.h" +#include "pcg/machine_compute_specification.h" #include "utils/containers/all_of.h" #include "utils/containers/cartesian_product.h" #include "utils/containers/extend.h" @@ -25,16 +25,17 @@ namespace FlexFlow { bool is_valid_machine_view(MachineView const &mv, - OperatorTaskSpace const &task, - MachineSpecification const &ms) { - if (num_dims(mv) != num_dims(task)) { + OperatorTaskSpace const &task_space, + MachineComputeSpecification const &ms) { + if (mv_get_expected_task_space_num_dims(mv) != + op_task_space_num_dims(task_space)) { return false; } - std::optional maximum_device_coord = - get_machine_space_coordinate( - task, mv, get_task_space_maximum_coordinate(task), ms); - return maximum_device_coord.has_value(); + MachineSpaceCoordinate maximum_device_coord = get_machine_space_coordinate( + task_space, mv, get_task_space_maximum_coordinate(task_space)); + + return is_valid_machine_space_coordinate(ms, maximum_device_coord); } /* @@ -46,8 +47,8 @@ bool is_valid_machine_view(MachineView const &mv, * the returned `MachineView`s to be invalid) */ static std::unordered_set - get_candidate_machine_views(MachineSpecification const &machine_spec, - OperatorTaskSpace const &task, + get_candidate_machine_views(MachineComputeSpecification const &machine_spec, + OperatorTaskSpace const &task_space, DeviceType const &device_type) { auto get_max_stride_upper_bound = @@ -83,14 +84,12 @@ static std::unordered_set return strides; }; - auto candidate_starts = [](MachineSpecification const &ms, + auto candidate_starts = [](MachineComputeSpecification const &ms, DeviceType const &device_type) { std::unordered_set result; - for (nonnegative_int node_idx : - nonnegative_range(ms.num_nodes.nonnegative_int_from_positive_int())) { + for (nonnegative_int node_idx : nonnegative_range(ms.num_nodes)) { for (nonnegative_int device_idx : - nonnegative_range(get_num_devices_per_node(ms, device_type) - .nonnegative_int_from_positive_int())) { + nonnegative_range(get_num_devices_per_node(ms, device_type))) { result.insert( MachineSpaceCoordinate{node_idx, device_idx, device_type}); } @@ -98,14 +97,19 @@ static std::unordered_set return result; }; - auto candidate_dimensions = [](OperatorTaskSpace const &task) { + auto candidate_dimensions = [](OperatorTaskSpace const &task_space) { std::unordered_set options = { MachineSpecificationDimension::INTER_NODE, MachineSpecificationDimension::INTRA_NODE}; - return get_all_permutations_with_repetition(options, num_dims(task)); + return get_all_permutations_with_repetition( + options, op_task_space_num_dims(task_space)); }; - std::vector tensor_dims = task.degrees; + std::vector tensor_dims = + transform(task_space.degrees.dims, [](int_ge_two dim) { + return dim.positive_int_from_int_ge_two(); + }); + positive_int total_devices = get_num_devices(machine_spec, device_type); std::unordered_set machine_views; @@ -115,7 +119,7 @@ static std::unordered_set for (MachineSpaceCoordinate start : candidate_starts(machine_spec, device_type)) { for (std::vector const &dims : - candidate_dimensions(task)) { + candidate_dimensions(task_space)) { machine_views.insert( machine_view_from_strides_and_machine_spec_dimensions( start, strides.raw_strides, dims)); @@ -126,14 +130,14 @@ static std::unordered_set } std::unordered_set - get_allowed_machine_views(MachineSpecification const &machine_spec, - OperatorTaskSpace const &task, + get_allowed_machine_views(MachineComputeSpecification const &machine_spec, + OperatorTaskSpace const &task_space, DeviceType device_type) { std::unordered_set views = - get_candidate_machine_views(machine_spec, task, device_type); + get_candidate_machine_views(machine_spec, task_space, device_type); return filter(views, [&](MachineView const &mv) { - return is_valid_machine_view(mv, task, machine_spec); + return is_valid_machine_view(mv, task_space, machine_spec); }); } diff --git a/lib/compiler/src/compiler/cost_estimator/communication_edge.cc b/lib/compiler/src/compiler/cost_estimator/communication_edge.cc new file mode 100644 index 0000000000..f86fdeaab1 --- /dev/null +++ b/lib/compiler/src/compiler/cost_estimator/communication_edge.cc @@ -0,0 +1,69 @@ +#include "compiler/cost_estimator/communication_edge.h" +#include "utils/hash-utils.h" +#include "utils/hash/tuple.h" +#include + +namespace FlexFlow { + +CommunicationEdge::CommunicationEdge(MachineSpaceCoordinate const &src, + MachineSpaceCoordinate const &dst) + : src(src), dst(dst) { + ASSERT(src != dst); +} + +bool CommunicationEdge::operator==(CommunicationEdge const &other) const { + return this->tie() == other.tie(); +} + +bool CommunicationEdge::operator!=(CommunicationEdge const &other) const { + return this->tie() != other.tie(); +} + +bool CommunicationEdge::operator<(CommunicationEdge const &other) const { + return this->tie() < other.tie(); +} + +bool CommunicationEdge::operator>(CommunicationEdge const &other) const { + return this->tie() > other.tie(); +} + +bool CommunicationEdge::operator<=(CommunicationEdge const &other) const { + return this->tie() <= other.tie(); +} + +bool CommunicationEdge::operator>=(CommunicationEdge const &other) const { + return this->tie() >= other.tie(); +} + +MachineSpaceCoordinate const &CommunicationEdge::get_src() const { + return this->src; +} + +MachineSpaceCoordinate const &CommunicationEdge::get_dst() const { + return this->dst; +} + +std::tuple + CommunicationEdge::tie() const { + return std::tie(this->src, this->dst); +} + +std::string format_as(CommunicationEdge const &e) { + return fmt::format( + "", e.get_src(), e.get_dst()); +} + +std::ostream &operator<<(std::ostream &s, CommunicationEdge const &e) { + return (s << fmt::to_string(e)); +} + +} // namespace FlexFlow + +namespace std { + +size_t hash<::FlexFlow::CommunicationEdge>::operator()( + ::FlexFlow::CommunicationEdge const &e) const { + return get_std_hash(e.tie()); +} + +} // namespace std diff --git a/lib/compiler/src/compiler/cost_estimator/network_cost_model.cc b/lib/compiler/src/compiler/cost_estimator/network_cost_model.cc new file mode 100644 index 0000000000..8002cfa526 --- /dev/null +++ b/lib/compiler/src/compiler/cost_estimator/network_cost_model.cc @@ -0,0 +1,16 @@ +#include "compiler/cost_estimator/network_cost_model.h" +#include "utils/exception.h" + +namespace FlexFlow { + +float estimate_communication_cost( + MachineSpecification const &machine_spec, + TensorSetMovement const &tensor_set_movement) { + NOT_IMPLEMENTED(); // TODO @lockshaw + // for (SingleTensorMovement const &single_tensor_movement : + // tensor_set_movement.single_tensor_movements) { + // for + // } +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/cost_estimator/op_cost_estimate_key.cc b/lib/compiler/src/compiler/cost_estimator/op_cost_estimate_key.cc index 92b07bbe23..f3edd6a69a 100644 --- a/lib/compiler/src/compiler/cost_estimator/op_cost_estimate_key.cc +++ b/lib/compiler/src/compiler/cost_estimator/op_cost_estimate_key.cc @@ -1,12 +1,11 @@ #include "compiler/cost_estimator/op_cost_estimate_key.h" #include "compiler/cost_estimator/op_cost_estimate_key.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" +#include "compiler/machine_mapping/machine_view.dtg.h" +#include "compiler/machine_mapping/machine_view.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "pcg/device_id_t.dtg.h" #include "pcg/machine_specification.dtg.h" -#include "pcg/machine_view.dtg.h" -#include "pcg/machine_view.h" -#include "pcg/operator_task_space.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" #include diff --git a/lib/compiler/src/compiler/cost_estimator/op_cost_metrics.cc b/lib/compiler/src/compiler/cost_estimator/op_cost_metrics.cc index 2bca184419..7eab1f0a2a 100644 --- a/lib/compiler/src/compiler/cost_estimator/op_cost_metrics.cc +++ b/lib/compiler/src/compiler/cost_estimator/op_cost_metrics.cc @@ -1,7 +1,17 @@ #include "compiler/cost_estimator/op_cost_metrics.h" +#include "utils/containers/all_of.h" namespace FlexFlow { +bool is_pareto_optimal_in(OpCostMetrics const &m, + std::unordered_set const &others) { + return all_of(others, [&](OpCostMetrics const &other) { + return m.forward_runtime <= other.forward_runtime || + m.backward_runtime <= other.backward_runtime || + m.memory_usage <= other.memory_usage; + }); +} + OpCostMetrics make_op_cost_metrics_from_runtime_only( RuntimeOnlyOpCostMetrics const &runtime_only, num_bytes_t const &memory_usage) { diff --git a/lib/compiler/src/compiler/cost_estimator/parallel_tensor_space_to_machine_space_mapping.cc b/lib/compiler/src/compiler/cost_estimator/parallel_tensor_space_to_machine_space_mapping.cc new file mode 100644 index 0000000000..270e615ff1 --- /dev/null +++ b/lib/compiler/src/compiler/cost_estimator/parallel_tensor_space_to_machine_space_mapping.cc @@ -0,0 +1,40 @@ +#include "compiler/cost_estimator/parallel_tensor_space_to_machine_space_mapping.h" +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.h" +#include "op-attrs/parallel_tensor_dim_degrees.h" +#include "op-attrs/parallel_tensor_space_coordinate.h" +#include "op-attrs/task_space_coordinate.h" +#include "utils/bidict/algorithms/exhaustive_relational_join.h" +#include "utils/bidict/algorithms/transform_keys.h" +#include "utils/bidict/algorithms/transform_values.h" +#include + +namespace FlexFlow { + +ParallelTensorSpaceToMachineSpaceMapping ptensor_machine_map_from_composition( + OperatorSpaceToMachineSpaceMapping const &op_task_to_machine_space_mapping, + OperatorSpaceToParallelTensorSpaceMapping const + &op_task_to_parallel_tensor_space_mapping) { + ASSERT(op_task_to_machine_space_mapping.operator_task_space == + get_operator_task_space_for_mapping( + op_task_to_parallel_tensor_space_mapping)); + + bidict + pt_to_op_coord_map = transform_keys( + transform_values(op_task_to_parallel_tensor_space_mapping.raw_mapping + .coord_mapping.reversed(), + task_space_coordinate_from_dim_coord), + parallel_tensor_space_coord_from_dim_coord); + + bidict op_to_ms_coord_map = + op_task_to_machine_space_mapping.raw_mapping; + + return ParallelTensorSpaceToMachineSpaceMapping{ + /*raw_mapping=*/exhaustive_relational_join(pt_to_op_coord_map, + op_to_ms_coord_map), + /*parallel_tensor_space=*/ + parallel_tensor_dim_degrees_from_dim_domain( + op_task_to_parallel_tensor_space_mapping.raw_mapping.r_domain), + }; +}; + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/cost_estimator/tensor_set_movement.cc b/lib/compiler/src/compiler/cost_estimator/tensor_set_movement.cc index 8f2ab84b84..d5b9d8a7f5 100644 --- a/lib/compiler/src/compiler/cost_estimator/tensor_set_movement.cc +++ b/lib/compiler/src/compiler/cost_estimator/tensor_set_movement.cc @@ -1,16 +1,61 @@ #include "compiler/cost_estimator/tensor_set_movement.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" +#include "utils/containers/unordered_multiset_of.h" +#include "utils/full_binary_tree/binary_tree_path.dtg.h" + namespace FlexFlow { +TensorSetMovement empty_tensor_set_movement() { + + return TensorSetMovement{{}}; +} + TensorSetMovement get_tensor_set_movement_from_pcg_edge( ParallelComputationGraphEdge const &edge, ParallelComputationGraph const &pcg, MachineView const &src_mv, MachineView const &dst_mv) { - ParallelTensorShape tensor_shape = - get_parallel_tensor_shape(pcg, parallel_tensor_guid_t{edge.raw_edge.src}); - return TensorSetMovement{ - {SingleTensorMovement{tensor_shape, {src_mv}, {dst_mv}}}}; + + parallel_layer_guid_t src = get_src_layer(edge); + parallel_layer_guid_t dst = get_dst_layer(edge); + + BinaryTreePath src_path = BinaryTreePath{{BinaryTreePathEntry::LEFT_CHILD}}; + BinaryTreePath dst_path = BinaryTreePath{{BinaryTreePathEntry::RIGHT_CHILD}}; + + AbstractedSingleTensorMovement abstracted_single_tensor_movement = + get_abstracted_single_tensor_movement_along_edge( + /*pcg=*/pcg, + /*edge=*/edge, + /*src_path=*/src_path, + /*dst_path=*/dst_path); + + AbstractedTensorSetMovement abstracted_tensor_set_movement = + abstracted_tensor_set_movement_from_single_tensor_movement( + abstracted_single_tensor_movement); + + MachineSpaceStencil src_machine_stencil = MachineSpaceStencil{ + /*operator_task_space=*/get_operator_task_space(pcg, src), + /*machine_view=*/src_mv, + }; + + MachineSpaceStencil dst_machine_stencil = MachineSpaceStencil{ + /*operator_task_space=*/get_operator_task_space(pcg, dst), + /*machine_view=*/dst_mv, + }; + + return concretize_abstracted_tensor_set_movement( + abstracted_tensor_set_movement, + /*pre_machine_stencils=*/ + std::unordered_map{ + {src_path, src_machine_stencil}, + }, + /*post_machine_stencils=*/ + std::unordered_map{ + {dst_path, dst_machine_stencil}, + }); } } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/graph_optimize_result.cc b/lib/compiler/src/compiler/graph_optimize_result.cc deleted file mode 100644 index f48c119603..0000000000 --- a/lib/compiler/src/compiler/graph_optimize_result.cc +++ /dev/null @@ -1,15 +0,0 @@ -#include "compiler/graph_optimize_result.h" - -namespace FlexFlow { - -std::string format_as(GraphOptimizeResult const &r) { - return fmt::format("", - as_dot(r.pcg), - r.machine_mapping); -} - -std::ostream &operator<<(std::ostream &s, GraphOptimizeResult const &r) { - return (s << fmt::to_string(r)); -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/graph_optimize_state.cc b/lib/compiler/src/compiler/graph_optimize_state.cc index 1091b92866..bf40df2f11 100644 --- a/lib/compiler/src/compiler/graph_optimize_state.cc +++ b/lib/compiler/src/compiler/graph_optimize_state.cc @@ -1,6 +1,9 @@ #include "compiler/graph_optimize_state.h" -#include "compiler/graph_optimize_result.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" +#include "utils/hash/tuple.h" +#include "utils/hash/unordered_map.h" +#include "utils/hash/unordered_multiset.h" namespace FlexFlow { @@ -8,43 +11,65 @@ GraphOptimizeState::GraphOptimizeState( GraphOptimizeResult const &graph_optimize_result, float runtime) : graph_optimize_result(graph_optimize_result), runtime(runtime) {} +static std::unordered_multiset>, + std::unordered_map>> + get_layer_signature_set(MappedParallelComputationGraph const &mapped_pcg) { + + auto get_layer_signature = [&](parallel_layer_guid_t l) + -> std::tuple>, + std::unordered_map> { + ParallelLayerAttrs layer_attrs = + get_parallel_layer_attrs(mapped_pcg.pcg, l); + + std::unordered_map< + TensorSlotName, + std::tuple> + inputs = + map_values(get_incoming_tensors(mapped_pcg.pcg, l), + [&](parallel_tensor_guid_t const &i) { + parallel_layer_guid_t src = get_source_layer(i); + TensorSlotName src_slot = i.raw_graph_output.slot_name; + ParallelTensorAttrs tensor_attrs = + get_parallel_tensor_attrs(mapped_pcg.pcg, i); + + return std::tuple{ + get_parallel_layer_attrs(mapped_pcg.pcg, src), + src_slot, + tensor_attrs, + }; + }); + + std::unordered_map outputs = + map_values(get_layer_outputs(mapped_pcg.pcg, l), + [&](parallel_tensor_guid_t const &o) { + return get_parallel_tensor_attrs(mapped_pcg.pcg, o); + }); + + return { + layer_attrs, + mapped_pcg.mapped_tasks.at(l), + inputs, + outputs, + }; + }; + + return transform(unordered_multiset_of(get_parallel_layers(mapped_pcg.pcg)), + get_layer_signature); +} + bool GraphOptimizeState::operator==(GraphOptimizeState const &other) const { - // Note(@wmdi): This is a hack to implement a partially correct homomorphism - // check. Switch to the homomorphism check used in substitutions right after - // https://github.com/flexflow/FlexFlow/pull/1471 is merged. - auto layers1 = topological_ordering(graph_optimize_result.pcg); - auto layers2 = topological_ordering(other.graph_optimize_result.pcg); - if (layers1.size() != layers2.size()) { - return false; - } - std::unordered_map mapping; - for (size_t i = 0; i < layers1.size(); ++i) { - if (get_parallel_layer_attrs(graph_optimize_result.pcg, layers1[i]) != - get_parallel_layer_attrs(other.graph_optimize_result.pcg, layers2[i])) { - return false; - } - auto inputs1 = get_incoming_tensors(graph_optimize_result.pcg, layers1[i]); - auto inputs2 = - get_incoming_tensors(other.graph_optimize_result.pcg, layers2[i]); - if (inputs1.size() != inputs2.size()) { - return false; - } - for (size_t j = 0; j < inputs1.size(); ++j) { - if (inputs1[j] != mapping.at(inputs2[j])) { - return false; - } - } - auto outputs1 = get_layer_outputs(graph_optimize_result.pcg, layers1[i]); - auto outputs2 = - get_layer_outputs(other.graph_optimize_result.pcg, layers2[i]); - if (outputs1.size() != outputs2.size()) { - return false; - } - for (size_t j = 0; j < outputs1.size(); ++j) { - mapping.emplace(outputs2[j], outputs1[j]); - } - } - return true; + return get_layer_signature_set(this->graph_optimize_result.mapped_pcg) == + get_layer_signature_set(other.graph_optimize_result.mapped_pcg); } bool GraphOptimizeState::operator!=(GraphOptimizeState const &other) const { @@ -55,16 +80,6 @@ bool GraphOptimizeState::operator<(GraphOptimizeState const &other) const { return runtime < other.runtime; } -std::string format_as(GraphOptimizeState const &st) { - return fmt::format("", - st.graph_optimize_result, - st.runtime); -} - -std::ostream &operator<<(std::ostream &s, GraphOptimizeState const &st) { - return (s << fmt::to_string(st)); -} - } // namespace FlexFlow namespace std { @@ -73,24 +88,11 @@ size_t hash<::FlexFlow::GraphOptimizeState>::operator()( ::FlexFlow::GraphOptimizeState const &state) const { // TODO(@wmdi): Eventually it might be good to use a proper graph hash like // https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash.html#networkx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash - size_t seed = 0; - auto layers = topological_ordering(state.graph_optimize_result.pcg); - ::FlexFlow::hash_combine(seed, layers.size()); - for (auto layer : layers) { - ::FlexFlow::hash_combine( - seed, get_parallel_layer_attrs(state.graph_optimize_result.pcg, layer)); - auto inputs = get_incoming_tensors(state.graph_optimize_result.pcg, layer); - ::FlexFlow::hash_combine(seed, inputs.size()); - for (auto input : inputs) { - for (size_t i = 0; i < layers.size(); ++i) { - if (get_source_layer(input) == layers[i]) { - ::FlexFlow::hash_combine(seed, i); - break; - } - } - } - } - return seed; + using namespace ::FlexFlow; + + auto layers = get_layer_signature_set(state.graph_optimize_result.mapped_pcg); + + return get_std_hash(layers); } } // namespace std diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_device.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_device.cc new file mode 100644 index 0000000000..b5bdb42ece --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_device.cc @@ -0,0 +1,25 @@ +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_device.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/machine_space_stencil.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.dtg.h" +#include "compiler/machine_mapping/machine_view.h" +#include "op-attrs/computation_graph_op_attrs.h" +#include "op-attrs/get_operator_task_space.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "pcg/machine_compute_specification.dtg.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +MachineSpaceCoordinate concretize_abstracted_device( + AbstractedDevice const &abstracted_device, + std::unordered_map const &stencils) { + + return machine_space_stencil_compute_machine_coord( + stencils.at(abstracted_device.operator_tree_path), + abstracted_device.task_space_coordinate); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_communication_edge.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_communication_edge.cc new file mode 100644 index 0000000000..2a35b76849 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_communication_edge.cc @@ -0,0 +1,30 @@ +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_communication_edge.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_device.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/machine_space_stencil.h" +#include "pcg/machine_compute_specification.dtg.h" + +namespace FlexFlow { + +std::optional + concretize_abstracted_single_tensor_communication_edge( + AbstractedSingleTensorCommunicationEdge const &edge, + MachineSpaceStencil const &src_machine_stencil, + std::unordered_map const + &dst_machine_stencils) { + + MachineSpaceCoordinate src = machine_space_stencil_compute_machine_coord( + src_machine_stencil, edge.src_coord); + MachineSpaceCoordinate dst = + concretize_abstracted_device(edge.dst, dst_machine_stencils); + + if (src == dst) { + return std::nullopt; + } else { + return CommunicationEdge{ + /*src=*/src, + /*dst=*/dst, + }; + } +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.cc new file mode 100644 index 0000000000..1088d02adb --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.cc @@ -0,0 +1,95 @@ +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_communication_edge.h" +#include "utils/containers/filtermap_keys.h" +#include "utils/containers/map_from_pairs.h" +#include "utils/containers/map_keys_with_value_merging.h" +#include "utils/containers/merge_maps_with.h" +#include "utils/containers/require_all_same1.h" +#include "utils/containers/require_same.h" +#include "utils/containers/transform.h" +#include "utils/containers/values.h" + +namespace FlexFlow { + +std::unordered_set + abstracted_single_tensor_movement_get_dst_layers( + AbstractedSingleTensorMovement const &m) { + return transform( + keys(m.edge_to_size), + [](AbstractedSingleTensorCommunicationEdge const &e) -> BinaryTreePath { + return e.dst.operator_tree_path; + }); +} + +AbstractedSingleTensorMovement merge_abstracted_single_tensor_movements( + std::unordered_multiset const &movements) { + + std::unordered_multiset src_paths = + transform(movements, [](AbstractedSingleTensorMovement const &m) { + return m.src_op_tree_path; + }); + + BinaryTreePath src_op_tree_path = require_all_same1(src_paths); + + return AbstractedSingleTensorMovement{ + /*src_op_tree_path=*/require_all_same1(src_paths), + /*edge_to_size=*/ + merge_maps_with(transform(vector_of(movements), + [](AbstractedSingleTensorMovement const &m) { + return m.edge_to_size; + }), + [](num_bytes_t l, num_bytes_t r) { return l + r; }), + }; +} + +AbstractedSingleTensorMovement + abstracted_single_tensor_movement_from_communications( + BinaryTreePath const &src_op_tree_path, + std::unordered_set const + &communications) { + + return AbstractedSingleTensorMovement{ + /*src_op_tree_path=*/src_op_tree_path, + /*edge_to_size=*/ + map_from_pairs( + transform(communications, + [](AbstractedSingleTensorCommunication const &c) { + return std::pair{c.edge, c.size}; + })), + }; +} + +TensorSetMovement concretize_abstracted_single_tensor_movement( + AbstractedSingleTensorMovement const &abstracted, + std::unordered_map const + &pre_machine_stencils, + std::unordered_map const + &post_machine_stencils) { + + MachineSpaceStencil pre_machine_stencil = + pre_machine_stencils.at(abstracted.src_op_tree_path); + + std::unordered_map, num_bytes_t> + communication_edges = map_keys_with_value_merging( + abstracted.edge_to_size, + /*key_func=*/ + [&](AbstractedSingleTensorCommunicationEdge const &k) { + return concretize_abstracted_single_tensor_communication_edge( + /*edge=*/k, + /*src_machine_stencils=*/pre_machine_stencil, + /*dst_machine_stencils=*/post_machine_stencils); + }, + /*merge_values=*/ + [](num_bytes_t lhs, num_bytes_t rhs) { + return require_same(lhs, rhs); + }); + + return TensorSetMovement{ + /*edge_to_size=*/ + filtermap_keys( + communication_edges, + [](std::optional const &e) { return e; }), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc index 6f3deca138..98a7d9b0b2 100644 --- a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc @@ -1,8 +1,15 @@ #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" +#include "compiler/cost_estimator/tensor_set_movement.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.dtg.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.h" #include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h" +#include "utils/containers/binary_merge_maps_with.h" #include "utils/containers/flatmap.h" +#include "utils/containers/map_keys_with_value_merging.h" +#include "utils/containers/merge_maps_with.h" #include "utils/containers/transform.h" #include "utils/containers/unordered_set_of.h" +#include "utils/hash/unordered_map.h" namespace FlexFlow { @@ -10,53 +17,62 @@ AbstractedTensorSetMovement empty_abstracted_tensor_set_movement() { return AbstractedTensorSetMovement{{}}; } +AbstractedTensorSetMovement + abstracted_tensor_set_movement_from_single_tensor_movement( + AbstractedSingleTensorMovement const &m) { + return AbstractedTensorSetMovement{ + /*single_tensor_movements=*/{m}, + }; +} + std::unordered_set get_src_layers(AbstractedTensorSetMovement const &m) { - return flatmap(unordered_set_of(m.single_tensor_movements), - [](AbstractedSingleTensorMovement const &s) { - return s.src_machine_views; - }); + return transform( + m.single_tensor_movements, + [](AbstractedSingleTensorMovement const &e) -> BinaryTreePath { + return e.src_op_tree_path; + }); } std::unordered_set get_dst_layers(AbstractedTensorSetMovement const &m) { - return flatmap(unordered_set_of(m.single_tensor_movements), - [](AbstractedSingleTensorMovement const &s) { - return s.dst_machine_views; + return flatmap(m.single_tensor_movements, + [](AbstractedSingleTensorMovement const &m) + -> std::unordered_set { + return abstracted_single_tensor_movement_get_dst_layers(m); }); } TensorSetMovement concretize_abstracted_tensor_set_movement( AbstractedTensorSetMovement const &abstracted, - ParallelLayerGuidObliviousMachineMapping const &pre_mapping, - ParallelLayerGuidObliviousMachineMapping const &post_mapping) { - ParallelLayerGuidObliviousMachineMapping mapping = - binary_combine_mappings(/*lhs=*/pre_mapping, - /*rhs=*/post_mapping); - - auto concretize_tensor_movement = - [&](AbstractedSingleTensorMovement const &a) { - return SingleTensorMovement{ - /*parallel_tensor_shape=*/a.parallel_tensor_shape, - /*src_machine_views=*/ - transform( - a.src_machine_views, - [&](BinaryTreePath const &path) { - return get_machine_view_for_path(pre_mapping, path).value(); - }), - /*dst_machine_views=*/ - transform( - a.dst_machine_views, - [&](BinaryTreePath const &path) { - return get_machine_view_for_path(post_mapping, path).value(); - }), - }; - }; - - return TensorSetMovement{ - /*single_tensor_movements=*/transform(abstracted.single_tensor_movements, - concretize_tensor_movement), + std::unordered_map const + &pre_machine_stencils, + std::unordered_map const + &post_machine_stencils) { + + std::vector single_tensor_movements = + transform(vector_of(abstracted.single_tensor_movements), + [&](AbstractedSingleTensorMovement const &m) { + return concretize_abstracted_single_tensor_movement( + m, + /*pre_machine_stencils=*/pre_machine_stencils, + /*post_machine_stencils=*/post_machine_stencils); + }); + + auto merge_tensor_set_movements = + [](TensorSetMovement const &lhs, + TensorSetMovement const &rhs) -> TensorSetMovement { + return TensorSetMovement{ + binary_merge_maps_with( + lhs.edge_to_size, + rhs.edge_to_size, + [](num_bytes_t l, num_bytes_t r) { return l + r; }), + }; }; + + return foldl(single_tensor_movements, + empty_tensor_set_movement(), + merge_tensor_set_movements); } } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc index 0e0f60c891..df02655ccc 100644 --- a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc @@ -1,62 +1,113 @@ #include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_communication.dtg.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" #include "compiler/machine_mapping/transitive_reduced_pcg.h" #include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "op-attrs/operator_task_space_to_operator_task_space_mapping.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_shape.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" +#include "utils/bidict/algorithms/unordered_set_of.h" +#include "utils/containers/binary_cartesian_product.h" +#include "utils/containers/flatmap.h" #include "utils/containers/generate_map.h" #include "utils/containers/get_only.h" +#include "utils/containers/group_by.h" +#include "utils/containers/map_from_pairs.h" +#include "utils/containers/merge_maps_with.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_multiset_of.h" #include "utils/containers/values.h" +#include "utils/containers/vector_of.h" namespace FlexFlow { +AbstractedSingleTensorMovement get_abstracted_single_tensor_movement_along_edge( + ParallelComputationGraph const &pcg, + ParallelComputationGraphEdge const &edge, + BinaryTreePath const &src_path, + BinaryTreePath const &dst_path) { + + parallel_layer_guid_t pcg_src = get_src_layer(edge); + parallel_layer_guid_t pcg_dst = get_dst_layer(edge); + + parallel_tensor_guid_t parallel_tensor = get_parallel_tensor(edge); + TensorShape tensor_piece = + get_piece_shape(get_parallel_tensor_shape(pcg, parallel_tensor)); + + OperatorTaskSpaceToOperatorTaskSpaceMapping mapping = + pcg_get_mapping_along_edge(pcg, edge); + + bidict coord_mapping = + op_to_op_get_coord_mapping(mapping); + + std::unordered_map + single_comms = map_from_pairs(transform( + unordered_set_of(coord_mapping), + [&](std::pair const & + src_dst) -> std::pair { + auto [src_task_coord, dst_task_coord] = src_dst; + + return std::pair{ + AbstractedSingleTensorCommunicationEdge{ + /*src_coord=*/src_task_coord, + /*dst=*/AbstractedDevice{dst_path, dst_task_coord}, + }, + get_size_in_bytes(tensor_piece), + }; + })); + + return AbstractedSingleTensorMovement{ + /*src_op_tree_path=*/src_path, + /*edge_to_size=*/single_comms, + }; +} + AbstractedTensorSetMovement get_abstracted_tensor_set_movement_across_split( TransitiveReducedPCG const &tr_pcg, PCGBinarySeriesSplit const &split) { std::unordered_set edges_across_split = pcg_get_transitive_reduced_edges_across_split(tr_pcg, split); - auto get_movement_for_tensor = - [&](parallel_tensor_guid_t const &t) -> AbstractedSingleTensorMovement { - std::unordered_set tensor_edges = - filter(edges_across_split, [&](ParallelComputationGraphEdge const &e) { - return get_parallel_tensor(e) == t; - }); - - std::unordered_set src_layers = - transform(tensor_edges, [&](ParallelComputationGraphEdge const &e) { - return get_src_layer(e); - }); - - std::unordered_set dst_layers = - transform(tensor_edges, [&](ParallelComputationGraphEdge const &e) { - return get_dst_layer(e); - }); - - return AbstractedSingleTensorMovement{ - /*parallel_tensor_shape=*/get_parallel_tensor_shape(tr_pcg.full_pcg, t), - /*src_machine_views=*/ - transform(src_layers, - [&](parallel_layer_guid_t const &l) { - return get_only( - find_paths_to_leaf(split.get_left_child(), l)); - }), - /*dst_machine_views=*/ - transform(dst_layers, - [&](parallel_layer_guid_t const &l) { - return get_only( - find_paths_to_leaf(split.get_right_child(), l)); - }), - }; + OneToMany + edges_by_tensor = group_by(edges_across_split, + [](ParallelComputationGraphEdge const &e) { + return get_parallel_tensor(e); + }); + + auto get_src_layer_path = [&](parallel_layer_guid_t layer) -> BinaryTreePath { + return get_only(find_paths_to_leaf(split.get_left_child(), layer)); }; - std::unordered_map - single_tensor_movements = generate_map( - pcg_get_transitive_reduced_tensors_across_split(tr_pcg, split), - get_movement_for_tensor); + auto get_dst_layer_path = [&](parallel_layer_guid_t layer) -> BinaryTreePath { + return get_only(find_paths_to_leaf(split.get_right_child(), layer)); + }; + + auto to_abstracted_single_tensor_movement = + [&](ParallelComputationGraphEdge const &pcg_edge) + -> AbstractedSingleTensorMovement { + parallel_layer_guid_t pcg_src = get_src_layer(pcg_edge); + parallel_layer_guid_t pcg_dst = get_dst_layer(pcg_edge); + + return get_abstracted_single_tensor_movement_along_edge( + /*pcg=*/tr_pcg.full_pcg, + /*edge=*/pcg_edge, + /*src_path=*/get_src_layer_path(pcg_src), + /*dst_path=*/get_dst_layer_path(pcg_dst)); + }; return AbstractedTensorSetMovement{ - values(single_tensor_movements), + transform( + edges_by_tensor.right_groups(), + [&](std::unordered_set const &edges) { + return merge_abstracted_single_tensor_movements( + transform(unordered_multiset_of(edges), + to_abstracted_single_tensor_movement)); + }), }; } diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/machine_space_stencil.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/machine_space_stencil.cc new file mode 100644 index 0000000000..4571718ed4 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/machine_space_stencil.cc @@ -0,0 +1,16 @@ +#include "compiler/machine_mapping/abstracted_tensor_set_movement/machine_space_stencil.h" +#include "compiler/machine_mapping/machine_view.h" + +namespace FlexFlow { + +MachineSpaceCoordinate machine_space_stencil_compute_machine_coord( + MachineSpaceStencil const &machine_space_stencil, + TaskSpaceCoordinate const &task_space_coordinate) { + + return get_machine_space_coordinate( + /*operator_task_space=*/machine_space_stencil.operator_task_space, + /*machine_view=*/machine_space_stencil.machine_view, + /*task_space_coordinate=*/task_space_coordinate); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/get_machine_resource_splits.cc b/lib/compiler/src/compiler/machine_mapping/get_machine_resource_splits.cc deleted file mode 100644 index e921a0c465..0000000000 --- a/lib/compiler/src/compiler/machine_mapping/get_machine_resource_splits.cc +++ /dev/null @@ -1,34 +0,0 @@ -#include "compiler/machine_mapping/get_machine_resource_splits.h" -#include "utils/hash/pair.h" - -namespace FlexFlow { - -std::unordered_set> - get_machine_resource_splits(MachineSpecification const &resource) { - std::unordered_set> - result; - - for (int i = 1; i < resource.num_nodes; i *= 2) { - MachineSpecification sub_resource1 = resource; - MachineSpecification sub_resource2 = resource; - sub_resource1.num_nodes = positive_int{i}; - sub_resource2.num_nodes = - positive_int{resource.num_nodes.int_from_positive_int() - i}; - result.insert(std::make_pair(sub_resource1, sub_resource2)); - result.insert(std::make_pair(sub_resource2, sub_resource1)); - } - - for (int i = 1; i < resource.num_gpus_per_node; i *= 2) { - MachineSpecification sub_resource1 = resource; - MachineSpecification sub_resource2 = resource; - sub_resource1.num_gpus_per_node = positive_int{i}; - sub_resource2.num_gpus_per_node = - positive_int{resource.num_gpus_per_node.int_from_positive_int() - i}; - result.insert(std::make_pair(sub_resource1, sub_resource2)); - result.insert(std::make_pair(sub_resource2, sub_resource1)); - } - - return result; -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index 8ca033d0d6..2407297322 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -1,20 +1,24 @@ #include "compiler/machine_mapping/get_optimal_machine_mapping.h" #include "compiler/cost_estimator/op_cost_metrics.dtg.h" #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" -#include "compiler/machine_mapping/get_machine_resource_splits.h" #include "compiler/machine_mapping/machine_mapping_cache.h" #include "compiler/machine_mapping/machine_mapping_constraints.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.h" #include "compiler/machine_mapping/machine_mapping_result.h" +#include "compiler/machine_mapping/machine_resource_split.dtg.h" +#include "compiler/machine_mapping/machine_resource_split.h" +#include "compiler/machine_mapping/machine_view.dtg.h" +#include "compiler/machine_mapping/machine_view.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h" #include "compiler/machine_mapping/transitive_reduced_pcg.h" #include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h" #include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "op-attrs/computation_graph_op_attrs.h" +#include "op-attrs/get_operator_task_space.h" +#include "op-attrs/parallel_tensor_shape.h" #include "pcg/machine_specification.dtg.h" -#include "pcg/machine_specification.h" -#include "pcg/machine_view.dtg.h" -#include "pcg/machine_view.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/containers/contains.h" #include "utils/containers/flatmap.h" @@ -30,7 +34,7 @@ MachineMappingResult get_optimal_machine_mapping(MachineMappingCache &result_cache, MachineMappingContext const &context, MachineMappingProblemTree const &problem_tree, - MachineSpecification const &resources, + MachineComputeResourceSlice const &resources, MachineMappingConstraints const &constraints) { MachineMappingState state = MachineMappingState{ @@ -75,21 +79,21 @@ MachineMappingResult get_optimal_machine_mapping(MachineMappingCache &result_cache, MachineMappingContext const &context, MMProblemTreeSeriesSplit const &series_split, - MachineSpecification const &resources, + MachineComputeResourceSlice const &resources, MachineMappingConstraints const &constraints, std::optional const ¶llel_split_transformation) { auto get_boundary_machine_view_assignments = - [&](std::unordered_set const &boundary_layers) + [&](MachineMappingProblemTree const &root, + std::unordered_set const &boundary_layers) -> std::unordered_set { std::unordered_map> allowed = generate_map( boundary_layers, [&](BinaryTreePath const &l) -> std::unordered_set { UnmappedRuntimeOnlyOpCostEstimateKey leaf = - mm_problem_tree_get_subtree_at_path( - MachineMappingProblemTree{series_split}, l) + mm_problem_tree_get_subtree_at_path(root, l) .value() .get(); return context.allowed_machine_views(leaf, resources); @@ -139,7 +143,8 @@ MachineMappingResult for (ParallelLayerGuidObliviousMachineMapping const &assigned_pre_machine_views : - get_boundary_machine_view_assignments(get_src_layers(tensor_movement))) { + get_boundary_machine_view_assignments(series_split.get_left_child(), + get_src_layers(tensor_movement))) { MachineMappingResult pre_result = eval_pre_boundary_mapping(assigned_pre_machine_views); @@ -147,7 +152,7 @@ MachineMappingResult for (ParallelLayerGuidObliviousMachineMapping const &assigned_post_machine_views : get_boundary_machine_view_assignments( - get_dst_layers(tensor_movement))) { + series_split.get_right_child(), get_dst_layers(tensor_movement))) { MachineMappingResult post_result = eval_post_boundary_mapping(assigned_post_machine_views); @@ -155,8 +160,13 @@ MachineMappingResult TensorSetMovement comm_across_split = concretize_abstracted_tensor_set_movement( tensor_movement, - /*pre_mapping=*/assigned_pre_machine_views, - /*post_mapping=*/assigned_post_machine_views); + /*pre_machine_stencils=*/ + get_machine_stencils_for_partially_mapped_mm_problem_tree( + series_split.get_left_child(), assigned_pre_machine_views), + /*post_machine_stencils=*/ + get_machine_stencils_for_partially_mapped_mm_problem_tree( + series_split.get_right_child(), assigned_post_machine_views)); + milliseconds_t cost_across_split = context.cost_estimator.estimate_cost(comm_across_split); @@ -175,7 +185,7 @@ MachineMappingResult get_optimal_machine_mapping( MachineMappingCache &result_cache, MachineMappingContext const &context, MMProblemTreeParallelSplit const ¶llel_split, - MachineSpecification const &resources, + MachineComputeResourceSlice const &resources, MachineMappingConstraints const &constraints) { MachineMappingProblemTree lhs = parallel_split.get_left_child(); @@ -202,18 +212,16 @@ MachineMappingResult get_optimal_machine_mapping( restrict_to_right_child(constraints); auto evaluate_resource_split = - [&](std::pair const - &resource_split) { + [&](MachineResourceSplit const &resource_split) { + auto [lhs_resources, rhs_resources] = + apply_resource_split(resource_split, resources); + MachineMappingResult left_result = get_optimal_machine_mapping( - result_cache, context, lhs, resource_split.first, left_constraints); - MachineMappingResult right_result = - get_optimal_machine_mapping(result_cache, - context, - rhs, - resource_split.second, - right_constraints); + result_cache, context, lhs, lhs_resources, left_constraints); + MachineMappingResult right_result = get_optimal_machine_mapping( + result_cache, context, rhs, rhs_resources, right_constraints); - return parallel_combine(left_result, right_result); + return parallel_combine(resource_split, left_result, right_result); }; std::unordered_set parallel_results = transform( @@ -227,7 +235,7 @@ MachineMappingResult get_optimal_machine_mapping( MachineMappingCache &result_cache, MachineMappingContext const &context, UnmappedRuntimeOnlyOpCostEstimateKey const &leaf, - MachineSpecification const &resource, + MachineComputeResourceSlice const &resource, MachineMappingConstraints const &constraints) { std::unordered_set candidates = [&] { diff --git a/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc index 6cc3f4329c..7d1f28337c 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc @@ -1,12 +1,15 @@ #include "compiler/machine_mapping/get_tensor_set_movement_across_split.h" #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" #include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h" #include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" #include "utils/containers/generate_map.h" #include "utils/containers/keys.h" +#include "utils/containers/map_values.h" #include "utils/containers/sum.h" #include "utils/containers/values.h" @@ -17,10 +20,31 @@ TensorSetMovement get_tensor_set_movement_across_split( PCGBinarySeriesSplit const &split, ParallelLayerGuidObliviousMachineMapping const &pre_mapping, ParallelLayerGuidObliviousMachineMapping const &post_mapping) { + AbstractedTensorSetMovement abstracted = get_abstracted_tensor_set_movement_across_split(tr_pcg, split); + + auto get_task_spaces = [&](PCGBinarySPDecomposition const &t) + -> std::unordered_map { + return map_values(pcg_sp_tree_get_path_to_leaf_map(t), + [&](parallel_layer_guid_t parallel_layer_guid) { + return get_operator_task_space(tr_pcg.full_pcg, + parallel_layer_guid); + }); + }; + + std::unordered_map pre_stencils = + get_machine_stencils_for_decomposition( + tr_pcg.full_pcg, split.get_left_child(), pre_mapping); + + std::unordered_map post_stencils = + get_machine_stencils_for_decomposition( + tr_pcg.full_pcg, split.get_right_child(), post_mapping); + return concretize_abstracted_tensor_set_movement( - abstracted, pre_mapping, post_mapping); + abstracted, + /*pre_machine_stencils=*/pre_stencils, + /*post_machine_stencils=*/post_stencils); } } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_compute_resource_slice.cc b/lib/compiler/src/compiler/machine_mapping/machine_compute_resource_slice.cc new file mode 100644 index 0000000000..46614269fc --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_compute_resource_slice.cc @@ -0,0 +1,14 @@ +#include "compiler/machine_mapping/machine_compute_resource_slice.h" + +namespace FlexFlow { + +MachineComputeResourceSlice + compute_slice_from_specification(MachineComputeSpecification const &spec) { + + return MachineComputeResourceSlice{ + /*num_nodes=*/spec.num_nodes, + /*num_gpus_per_node=*/spec.num_gpus_per_node, + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc index 82c8274808..1676a6929d 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc @@ -1,14 +1,49 @@ #include "compiler/machine_mapping/machine_mapping.h" +#include "compiler/machine_mapping/machine_view.h" +#include "op-attrs/computation_graph_op_attrs.h" #include "utils/containers/are_disjoint.h" +#include "utils/containers/binary_merge_disjoint_maps.h" #include "utils/containers/keys.h" -#include "utils/containers/merge_maps.h" namespace FlexFlow { +MappedParallelComputationGraph + mapped_pcg_from_pcg_and_mapping(ParallelComputationGraph const &pcg, + MachineMapping const &mapping) { + + std::unordered_set pcg_layers = + get_parallel_layers(pcg); + std::unordered_set mapped_layers = + keys(mapping.machine_views); + ASSERT(pcg_layers == mapped_layers); + + return MappedParallelComputationGraph{ + /*pcg=*/pcg, + /*mapped_tasks=*/ + generate_map( + get_parallel_layers(pcg), + [&](parallel_layer_guid_t l) -> MappedOperatorTaskGroup { + ComputationGraphOpAttrs op_attrs = + compgraph_op_attrs_from_pcg_op_attrs(pcg_get_op_attrs(pcg, l)) + .value(); + + std::unordered_map + inputs_dim_degrees = get_incoming_input_degrees(pcg, l); + + ASSERT(contains_key(mapping.machine_views, l)); + MachineView machine_view = mapping.machine_views.at(l); + + return mapped_operator_task_group_from_machine_view( + op_attrs, inputs_dim_degrees, machine_view); + }), + }; +} + MachineMapping combine_disjoint_mappings(MachineMapping const &m1, MachineMapping const &m2) { return MachineMapping{ - merge_disjoint_maps(m1.machine_views, m2.machine_views)}; + binary_merge_disjoint_maps(m1.machine_views, m2.machine_views), + }; } bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2) { diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc index fbfccf737f..a430eed7a5 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc @@ -1,6 +1,7 @@ #include "compiler/machine_mapping/machine_mapping_cache.h" #include "utils/containers/contains_key.h" #include "utils/containers/try_at.h" +#include namespace FlexFlow { @@ -17,12 +18,9 @@ std::optional void machine_mapping_cache_save(MachineMappingCache &cache, MachineMappingState const &k, MachineMappingResult const &v) { - if (contains_key(cache.raw_map, k)) { - throw mk_runtime_error( - fmt::format("machine_mapping_cache_save expected key to not already " - "exist, but received existing key {}", - k)); - } + ASSERT(!contains_key(cache.raw_map, k), + "machine_mapping_cache_save expected key to not already exist", + k); cache.raw_map.emplace(k, v); } diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc index 2cee866a01..20683777d5 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc @@ -80,16 +80,14 @@ MachineMappingConstraints with_additional_constraints( if (!current_machine_view.has_value()) { result.machine_views.at(layer) = machine_view; } else { - if (current_machine_view.value() != machine_view) { - throw mk_runtime_error( - fmt::format("with_additional_layer_machine_views received machine " - "view assignment for layer {} " - "to machine view {}, but that layer is already " - "assigned to machine view {}.", - layer, - machine_view, - current_machine_view.value())); - } + ASSERT(current_machine_view.value() == machine_view, + fmt::format("with_additional_layer_machine_views received machine " + "view assignment for layer {} " + "to machine view {}, but that layer is already " + "assigned to machine view {}.", + layer, + machine_view, + current_machine_view.value())); } } @@ -98,13 +96,11 @@ MachineMappingConstraints with_additional_constraints( std::optional require_only_root(MachineMappingConstraints const &constraints) { - if (keys(constraints.machine_views) != - std::unordered_set{binary_tree_root_path()}) { - throw mk_runtime_error( - fmt::format("require_only_root expected constraints to have only a " - "single key (the root path), but received {}", - constraints)); - } + ASSERT(keys(constraints.machine_views) == + std::unordered_set{binary_tree_root_path()}, + fmt::format("require_only_root expected constraints to have only a " + "single key (the root path), but received {}", + constraints)); return constraints.machine_views.at(binary_tree_root_path()); } diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc index 09323b1800..340c448275 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc @@ -1,6 +1,7 @@ #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_path_to_leaf_map.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h" namespace FlexFlow { @@ -89,4 +90,11 @@ std::optional tree, generic_binary_sp_impl_for_mm_problem_tree(), path); } +std::unordered_map + mm_problem_tree_get_path_to_leaf_map( + MachineMappingProblemTree const &tree) { + return get_path_to_leaf_map(tree, + generic_binary_sp_impl_for_mm_problem_tree()); +} + } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.cc index 53155a9a9b..9d84f2ca81 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.cc @@ -14,13 +14,13 @@ UnmappedRuntimeOnlyOpCostEstimateKey return UnmappedRuntimeOnlyOpCostEstimateKey{ /*op_attrs=*/pcg_get_op_attrs(pcg, parallel_layer_guid), /*input_shapes=*/ - transform(get_incoming_inputs(pcg, parallel_layer_guid), - get_tensor_shape), + map_values(get_incoming_inputs(pcg, parallel_layer_guid), + get_tensor_shape), /*weight_shapes=*/ - transform(get_incoming_weights(pcg, parallel_layer_guid), - get_tensor_shape), + map_values(get_incoming_weights(pcg, parallel_layer_guid), + get_tensor_shape), /*output_shapes=*/ - transform(get_layer_outputs(pcg, parallel_layer_guid), get_tensor_shape), + map_values(get_layer_outputs(pcg, parallel_layer_guid), get_tensor_shape), }; } diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc index a370a6803d..7c9c7951eb 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc @@ -1,8 +1,8 @@ #include "compiler/machine_mapping/machine_mapping_result.h" #include "compiler/machine_mapping/machine_mapping.h" +#include "compiler/machine_mapping/machine_resource_split.h" #include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h" #include "utils/containers/map_keys.h" -#include "utils/containers/merge_maps.h" #include "utils/full_binary_tree/binary_tree_path.h" namespace FlexFlow { @@ -72,7 +72,8 @@ MachineMappingResult } MachineMappingResult - parallel_combine(MachineMappingResult const &maybe_lhs_result, + parallel_combine(MachineResourceSplit const &split, + MachineMappingResult const &maybe_lhs_result, MachineMappingResult const &maybe_rhs_result) { FeasibleMachineMappingResult lhs_result = ({ if (is_infeasible(maybe_lhs_result)) { @@ -92,8 +93,10 @@ MachineMappingResult FeasibleMachineMappingResult{ /*runtime=*/std::max(lhs_result.runtime, rhs_result.runtime), /*machine_mapping=*/ - binary_combine_mappings(/*lhs=*/lhs_result.machine_mapping, - /*rhs=*/rhs_result.machine_mapping), + binary_combine_mappings( + /*lhs=*/lhs_result.machine_mapping, + /*rhs=*/offset_layer_oblivious_mapping_by( + rhs_result.machine_mapping, split)), }, }; } diff --git a/lib/compiler/src/compiler/machine_mapping/machine_resource_split.cc b/lib/compiler/src/compiler/machine_mapping/machine_resource_split.cc new file mode 100644 index 0000000000..875f44a0c9 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_resource_split.cc @@ -0,0 +1,125 @@ +#include "compiler/machine_mapping/machine_resource_split.h" +#include "utils/containers/map_values.h" +#include + +namespace FlexFlow { + +std::pair + apply_resource_split(MachineResourceSplit const &split, + MachineComputeResourceSlice const &resources) { + if (split.dimension == MachineSpecificationDimension::INTER_NODE) { + ASSERT(split.offset < resources.num_nodes); + + return { + MachineComputeResourceSlice{ + /*num_nodes=*/split.offset, + /*num_gpus_per_node=*/resources.num_gpus_per_node, + }, + MachineComputeResourceSlice{ + /*num_nodes=*/positive_int{ + resources.num_nodes.int_from_positive_int() - + split.offset.int_from_positive_int()}, + /*num_gpus_per_node=*/resources.num_gpus_per_node, + }, + }; + } else { + ASSERT(split.dimension == MachineSpecificationDimension::INTRA_NODE); + + ASSERT(split.offset < resources.num_gpus_per_node); + + return { + MachineComputeResourceSlice{ + /*num_nodes=*/resources.num_nodes, + /*num_gpus_per_node=*/split.offset, + }, + MachineComputeResourceSlice{ + /*num_nodes=*/resources.num_nodes, + /*num_gpus_per_node=*/ + positive_int{ + resources.num_gpus_per_node.int_from_positive_int() - + split.offset.int_from_positive_int(), + }, + }, + }; + } +} + +std::unordered_set + get_machine_resource_splits(MachineComputeResourceSlice const &resources) { + + std::unordered_set result; + + for (positive_int i = 1_p; i < resources.num_nodes; i *= 2_p) { + result.insert(MachineResourceSplit{ + /*offset=*/i, + /*dimension=*/MachineSpecificationDimension::INTER_NODE, + }); + result.insert(MachineResourceSplit{ + /*offset=*/positive_int{ + resources.num_nodes.int_from_positive_int() - + i.int_from_positive_int(), + }, + /*dimension=*/MachineSpecificationDimension::INTER_NODE, + }); + } + + for (positive_int i = 1_p; i < resources.num_gpus_per_node; i *= 2_p) { + result.insert(MachineResourceSplit{ + /*offset=*/i, + /*dimension=*/MachineSpecificationDimension::INTRA_NODE, + }); + result.insert(MachineResourceSplit{ + /*offset=*/positive_int{ + resources.num_gpus_per_node.int_from_positive_int() - + i.int_from_positive_int(), + }, + /*dimension=*/MachineSpecificationDimension::INTRA_NODE, + }); + } + + return result; +} + +MachineSpaceCoordinate + offset_machine_space_coordinate_by(MachineSpaceCoordinate const &coord, + MachineResourceSplit const &split) { + if (split.dimension == MachineSpecificationDimension::INTER_NODE) { + return MachineSpaceCoordinate{ + /*node_idx=*/(coord.node_idx + split.offset) + .nonnegative_int_from_positive_int(), + /*device_idx=*/coord.device_idx, + /*device_type=*/coord.device_type, + }; + } else { + ASSERT(split.dimension == MachineSpecificationDimension::INTRA_NODE); + + return MachineSpaceCoordinate{ + /*node_idx=*/coord.node_idx, + /*device_idx=*/ + (coord.device_idx + split.offset).nonnegative_int_from_positive_int(), + /*device_type=*/coord.device_type, + }; + } +} + +MachineView offset_machine_view_by(MachineView const &machine_view, + MachineResourceSplit const &split) { + return MachineView{ + /*start=*/offset_machine_space_coordinate_by(machine_view.start, split), + /*dimensions=*/machine_view.dimensions, + }; +} + +ParallelLayerGuidObliviousMachineMapping offset_layer_oblivious_mapping_by( + ParallelLayerGuidObliviousMachineMapping const &mapping, + MachineResourceSplit const &split) { + + return ParallelLayerGuidObliviousMachineMapping{ + map_values(mapping.raw_mapping, + [&](MachineView const &mv) { + return offset_machine_view_by(mv, split); + }), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_view.cc b/lib/compiler/src/compiler/machine_mapping/machine_view.cc new file mode 100644 index 0000000000..090dec5845 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_view.cc @@ -0,0 +1,269 @@ +#include "compiler/machine_mapping/machine_view.h" +#include "compiler/machine_mapping/machine_view_dimension.dtg.h" +#include "compiler/machine_mapping/stride_t.dtg.h" +#include "op-attrs/get_operator_space_to_parallel_tensor_space_mappings.h" +#include "op-attrs/get_operator_task_space.h" +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.h" +#include "op-attrs/operator_task_space.dtg.h" +#include "op-attrs/operator_task_space.h" +#include "op-attrs/parallel_tensor_dim_degrees.h" +#include "op-attrs/task_space_coordinate.h" +#include "op-attrs/tensor_role.dtg.h" +#include "pcg/machine_compute_specification.h" +#include "pcg/machine_space_coordinate.dtg.h" +#include "pcg/machine_specification.dtg.h" +#include "pcg/machine_specification_dimension.dtg.h" +#include "utils/bidict/generate_bidict.h" +#include "utils/containers/contains.h" +#include "utils/containers/count.h" +#include "utils/containers/filter.h" +#include "utils/containers/get_only.h" +#include "utils/containers/scanl.h" +#include "utils/containers/sum.h" +#include "utils/containers/transform.h" +#include "utils/containers/zip3_strict.h" +#include "utils/containers/zip_with_strict.h" +#include "utils/exception.h" +#include "utils/nonnegative_int/nonnegative_range.h" +#include "utils/nonnegative_int/num_elements.h" + +namespace FlexFlow { + +nonnegative_int mv_get_expected_task_space_num_dims(MachineView const &mv) { + return num_elements(get_strides(mv)); +} + +DeviceType get_device_type(MachineView const &mv) { + return mv.start.device_type; +} + +std::vector get_strides(MachineView const &mv) { + return transform(mv.dimensions, + [](MachineViewDimension const &dim) { return dim.stride; }); +} + +std::vector + get_dimensions(MachineView const &mv) { + return transform(mv.dimensions, [](MachineViewDimension const &dim) { + return dim.projection; + }); +} + +MachineView machine_view_from_strides_and_machine_spec_dimensions( + MachineSpaceCoordinate const &start, + std::vector const &strides, + std::vector const &dims) { + ASSERT(strides.size() == dims.size()); + std::vector dimensions = zip_with_strict( + strides, dims, [](stride_t s, MachineSpecificationDimension d) { + return MachineViewDimension{s, d}; + }); + return MachineView{start, dimensions}; +} + +MachineSpaceCoordinate + get_machine_space_coordinate(OperatorTaskSpace const &task_space, + MachineView const &machine_view, + TaskSpaceCoordinate const &coord) { + + ASSERT(mv_get_expected_task_space_num_dims(machine_view) == + op_task_space_num_dims(task_space), + "Dimension of MachineView must match dimension of OperatorTaskSpace", + machine_view, + task_space); + ASSERT(op_task_space_num_dims(task_space) == + task_space_coord_num_dims(coord)); + ASSERT(operator_task_space_contains_coord(task_space, coord)); + + auto get_dimension_indices_for_dimension = + [&](MachineSpecificationDimension dimension) + -> std::vector { + std::vector mv_dimensions = + get_dimensions(machine_view); + return filter(nonnegative_range(num_elements(mv_dimensions)), + [&](nonnegative_int idx) { + return mv_dimensions.at(idx.unwrap_nonnegative()) == + dimension; + }); + }; + + auto compute_index = + [&](nonnegative_int start_idx, + std::vector const &dimension_indices) { + std::vector mv_strides = get_strides(machine_view); + + std::vector sizes = + transform(dimension_indices, [&](nonnegative_int i) { + return (task_space.degrees.dims.at(i.unwrap_nonnegative()) * + mv_strides.at(i.unwrap_nonnegative()).unwrapped) + .positive_int_from_int_ge_two(); + }); + std::vector coord_points = + transform(dimension_indices, [&](nonnegative_int i) { + return coord.orthotope_coord.raw.at(i.unwrap_nonnegative()); + }); + std::vector strides = + transform(dimension_indices, [&](nonnegative_int i) { + return mv_strides.at(i.unwrap_nonnegative()).unwrapped; + }); + + std::vector coeffs = + scanl(sizes, 1_p, std::multiplies()); + + nonnegative_int index = start_idx; + for (auto [coeff, coord_point, stride] : + zip3(coeffs, coord_points, strides)) { + index += coeff * coord_point * stride; + } + return index; + }; + + std::vector inter_dimension_indices = + get_dimension_indices_for_dimension( + MachineSpecificationDimension::INTER_NODE); + std::vector intra_dimension_indices = + get_dimension_indices_for_dimension( + MachineSpecificationDimension::INTRA_NODE); + + nonnegative_int node_idx = + compute_index(machine_view.start.node_idx, inter_dimension_indices); + nonnegative_int device_idx = + compute_index(machine_view.start.device_idx, intra_dimension_indices); + MachineSpaceCoordinate ms_coord = MachineSpaceCoordinate{ + node_idx, device_idx, get_device_type(machine_view)}; + + return ms_coord; +} + +TaskSpaceCoordinate mv_task_space_coord_for_machine_space_coord( + MachineView const &machine_view, + OperatorTaskSpace const &operator_task_space, + MachineSpaceCoordinate const &machine_space_coord) { + OperatorSpaceToMachineSpaceMapping mapping = + get_coordinate_mapping_for_machine_view(operator_task_space, + machine_view); + + return mapping.raw_mapping.at_r(machine_space_coord); +} + +OperatorSpaceToMachineSpaceMapping get_coordinate_mapping_for_machine_view( + OperatorTaskSpace const &operator_task_space, + MachineView const &machine_view) { + + return OperatorSpaceToMachineSpaceMapping{ + /*raw_mapping=*/generate_bidict( + get_task_space_coordinates(operator_task_space), + [&](TaskSpaceCoordinate const &task_space_coord) { + return get_machine_space_coordinate( + /*operator_task_space=*/operator_task_space, + /*machine_view=*/machine_view, + /*task_space_coordinate=*/task_space_coord); + }), + /*operator_task_space=*/operator_task_space, + }; +} + +std::unordered_set + get_machine_space_coordinates(OperatorTaskSpace const &task_space, + MachineView const &machine_view) { + + ASSERT(op_task_space_num_dims(task_space) == + mv_get_expected_task_space_num_dims(machine_view)); + + return transform(get_task_space_coordinates(task_space), + [&](TaskSpaceCoordinate const &coord) { + return get_machine_space_coordinate( + task_space, machine_view, coord); + }); +} + +std::unordered_set + get_device_ids(OperatorTaskSpace const &task_space, + MachineView const &mv, + MachineComputeSpecification const &ms) { + ASSERT(op_task_space_num_dims(task_space) == + mv_get_expected_task_space_num_dims(mv)); + + return transform(get_machine_space_coordinates(task_space, mv), + [&](MachineSpaceCoordinate const &coord) { + return get_device_id(ms, coord); + }); +} + +MachineView make_1d_machine_view(MachineSpaceCoordinate const &start, + MachineSpecificationDimension const &dim, + stride_t stride) { + + return machine_view_from_strides_and_machine_spec_dimensions( + start, {stride}, {dim}); +} + +MachineView + make_single_device_machine_view(MachineSpaceCoordinate const &coord) { + return machine_view_from_strides_and_machine_spec_dimensions(coord, {}, {}); +} + +static OperatorAtomicTaskShardBinding + operator_atomic_task_shard_binding_from_machine_view( + ComputationGraphOpAttrs const &op_attrs, + std::unordered_map const + &inputs_dim_degrees, + MachineView const &machine_view, + MachineSpaceCoordinate const &machine_space_coord) { + OperatorTaskSpace op_task_space = + get_operator_task_space(op_attrs, inputs_dim_degrees); + + TaskSpaceCoordinate task_space_coord = + mv_task_space_coord_for_machine_space_coord( + machine_view, op_task_space, machine_space_coord); + + std::unordered_map + mappings = get_operator_to_ptensor_mappings(op_attrs, inputs_dim_degrees); + + std::unordered_map + ptensor_coords = generate_map( + keys(inputs_dim_degrees), + [&](TensorSlotName const &slot_name) + -> ParallelTensorSpaceCoordinate { + num_ptensor_shard_dims_t num_shard_dims = + get_ptensor_dim_degrees_num_shard_dims( + inputs_dim_degrees.at(slot_name)); + + return ptensor_coord_for_task_space_coord( + mappings.at(slot_name), task_space_coord, num_shard_dims); + }); + + return OperatorAtomicTaskShardBinding{ + /*tensor_coords=*/ptensor_coords, + }; +} + +MappedOperatorTaskGroup mapped_operator_task_group_from_machine_view( + ComputationGraphOpAttrs const &op_attrs, + std::unordered_map const + &inputs_dim_degrees, + MachineView const &machine_view) { + + OperatorTaskSpace op_task_space = + get_operator_task_space(op_attrs, inputs_dim_degrees); + + return MappedOperatorTaskGroup{ + generate_bidict( + get_machine_space_coordinates(op_task_space, machine_view), + [&](MachineSpaceCoordinate const &machine_space_coord) { + return operator_atomic_task_shard_binding_from_machine_view( + op_attrs, + inputs_dim_degrees, + machine_view, + machine_space_coord); + }), + }; +} + +bidict + get_tensor_shard_to_device_coord_mapping(ComputationGraphOpAttrs const &, + MachineView const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc b/lib/compiler/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc index 74e8db6304..a3f2009a60 100644 --- a/lib/compiler/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc +++ b/lib/compiler/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc @@ -1,18 +1,19 @@ #include "compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.h" #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" -#include "compiler/machine_mapping/get_machine_resource_splits.h" #include "compiler/machine_mapping/machine_mapping_constraints.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" +#include "compiler/machine_mapping/machine_resource_split.dtg.h" +#include "compiler/machine_mapping/machine_resource_split.h" +#include "compiler/machine_mapping/machine_view.dtg.h" +#include "compiler/machine_mapping/machine_view.h" #include "compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.h" #include "compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h" #include "compiler/machine_mapping/transitive_reduced_pcg.h" #include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h" #include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" #include "pcg/machine_specification.dtg.h" -#include "pcg/machine_specification.h" -#include "pcg/machine_view.dtg.h" -#include "pcg/machine_view.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/containers/contains.h" #include "utils/containers/flatmap.h" @@ -28,7 +29,7 @@ MachineMappingWithMemoryResult get_optimal_machine_mapping_with_memory( MachineMappingWithMemoryCache &result_cache, MachineMappingWithMemoryContext const &context, MachineMappingProblemTree const &problem_tree, - MachineSpecification const &resources, + MachineComputeResourceSlice const &resources, MachineMappingConstraints const &constraints) { MachineMappingState state = MachineMappingState{ @@ -73,25 +74,26 @@ MachineMappingWithMemoryResult get_optimal_machine_mapping_with_memory( MachineMappingWithMemoryCache &result_cache, MachineMappingWithMemoryContext const &context, MMProblemTreeSeriesSplit const &series_split, - MachineSpecification const &resources, + MachineComputeResourceSlice const &resources, MachineMappingConstraints const &constraints, std::optional const ¶llel_split_transformation) { auto get_boundary_machine_view_assignments = - [&](std::unordered_set const &boundary_layers) + [&](MachineMappingProblemTree const &root, + std::unordered_set const &boundary_layers) -> std::unordered_set { std::unordered_map> allowed = generate_map( boundary_layers, [&](BinaryTreePath const &l) -> std::unordered_set { UnmappedRuntimeOnlyOpCostEstimateKey leaf = - mm_problem_tree_get_subtree_at_path( - MachineMappingProblemTree{series_split}, l) + mm_problem_tree_get_subtree_at_path(root, l) .value() .get(); return context.allowed_machine_views(leaf, resources); }); + return transform( get_all_assignments(allowed), [](std::unordered_map const &m) { @@ -140,7 +142,8 @@ MachineMappingWithMemoryResult get_optimal_machine_mapping_with_memory( for (ParallelLayerGuidObliviousMachineMapping const &assigned_pre_machine_views : - get_boundary_machine_view_assignments(get_src_layers(tensor_movement))) { + get_boundary_machine_view_assignments(series_split.get_left_child(), + get_src_layers(tensor_movement))) { MachineMappingWithMemoryResult pre_result = eval_pre_boundary_mapping(assigned_pre_machine_views); @@ -148,7 +151,7 @@ MachineMappingWithMemoryResult get_optimal_machine_mapping_with_memory( for (ParallelLayerGuidObliviousMachineMapping const &assigned_post_machine_views : get_boundary_machine_view_assignments( - get_dst_layers(tensor_movement))) { + series_split.get_right_child(), get_dst_layers(tensor_movement))) { MachineMappingWithMemoryResult post_result = eval_post_boundary_mapping(assigned_post_machine_views); @@ -156,8 +159,13 @@ MachineMappingWithMemoryResult get_optimal_machine_mapping_with_memory( TensorSetMovement comm_across_split = concretize_abstracted_tensor_set_movement( tensor_movement, - /*pre_mapping=*/assigned_pre_machine_views, - /*post_mapping=*/assigned_post_machine_views); + /*pre_machine_stencils=*/ + get_machine_stencils_for_partially_mapped_mm_problem_tree( + series_split.get_left_child(), assigned_pre_machine_views), + /*post_machine_stencils=*/ + get_machine_stencils_for_partially_mapped_mm_problem_tree( + series_split.get_right_child(), assigned_post_machine_views)); + milliseconds_t cost_across_split = context.cost_estimator.estimate_cost(comm_across_split); @@ -176,7 +184,7 @@ MachineMappingWithMemoryResult get_optimal_machine_mapping_with_memory( MachineMappingWithMemoryCache &result_cache, MachineMappingWithMemoryContext const &context, MMProblemTreeParallelSplit const ¶llel_split, - MachineSpecification const &resources, + MachineComputeResourceSlice const &resources, MachineMappingConstraints const &constraints) { MachineMappingProblemTree lhs = parallel_split.get_left_child(); @@ -204,22 +212,18 @@ MachineMappingWithMemoryResult get_optimal_machine_mapping_with_memory( restrict_to_right_child(constraints); auto evaluate_resource_split = - [&](std::pair const - &resource_split) { + [&](MachineResourceSplit const &resource_split) { + auto [lhs_resources, rhs_resources] = + apply_resource_split(resource_split, resources); + MachineMappingWithMemoryResult left_result = - get_optimal_machine_mapping_with_memory(result_cache, - context, - lhs, - resource_split.first, - left_constraints); + get_optimal_machine_mapping_with_memory( + result_cache, context, lhs, lhs_resources, left_constraints); MachineMappingWithMemoryResult right_result = - get_optimal_machine_mapping_with_memory(result_cache, - context, - rhs, - resource_split.second, - right_constraints); + get_optimal_machine_mapping_with_memory( + result_cache, context, rhs, rhs_resources, right_constraints); - return parallel_combine(left_result, right_result); + return parallel_combine(resource_split, left_result, right_result); }; std::unordered_set parallel_results = @@ -234,7 +238,7 @@ MachineMappingWithMemoryResult get_optimal_machine_mapping_with_memory( MachineMappingWithMemoryCache &result_cache, MachineMappingWithMemoryContext const &context, UnmappedRuntimeOnlyOpCostEstimateKey const &leaf, - MachineSpecification const &resource, + MachineComputeResourceSlice const &resource, MachineMappingConstraints const &constraints) { std::unordered_set candidates = [&] { diff --git a/lib/compiler/src/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.cc b/lib/compiler/src/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.cc index 617ba682be..c17e306aa8 100644 --- a/lib/compiler/src/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.cc +++ b/lib/compiler/src/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.cc @@ -1,6 +1,7 @@ #include "compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.h" #include "utils/containers/contains_key.h" #include "utils/containers/try_at.h" +#include namespace FlexFlow { @@ -19,12 +20,10 @@ void machine_mapping_with_memory_cache_save( MachineMappingWithMemoryCache &cache, MachineMappingState const &k, MachineMappingWithMemoryResult const &v) { - if (contains_key(cache.raw_map, k)) { - throw mk_runtime_error(fmt::format( - "machine_mapping_with_memory_cache_save expected key to not already " - "exist, but received existing key {}", - k)); - } + ASSERT(!contains_key(cache.raw_map, k), + "machine_mapping_with_memory_cache_save expected key to not already " + "exist", + k); cache.raw_map.emplace(k, v); } diff --git a/lib/compiler/src/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.cc b/lib/compiler/src/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.cc index cff7984897..9021e0d382 100644 --- a/lib/compiler/src/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.cc +++ b/lib/compiler/src/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.cc @@ -1,10 +1,54 @@ #include "compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.h" +#include "compiler/machine_mapping/machine_resource_split.h" +#include "compiler/machine_mapping/memory_optimization/pareto_optimal_machine_mapping.h" #include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h" +#include "utils/containers/all_of.h" #include "utils/containers/set_union.h" +#include "utils/containers/transform.h" #include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/hash/tuple.h" +#include "utils/hash/unordered_set.h" namespace FlexFlow { +MachineMappingWithMemoryResult::MachineMappingWithMemoryResult( + std::unordered_set const &pareto_frontier) + : m_pareto_frontier(pareto_frontier) { + ASSERT(all_of(pareto_frontier, [&](ParetoOptimalMachineMapping const &m) { + return is_pareto_optimal_in(m, pareto_frontier); + })); +} + +bool MachineMappingWithMemoryResult::operator==( + MachineMappingWithMemoryResult const &other) const { + return this->tie() == other.tie(); +} + +bool MachineMappingWithMemoryResult::operator!=( + MachineMappingWithMemoryResult const &other) const { + return this->tie() != other.tie(); +} + +std::unordered_set const & + MachineMappingWithMemoryResult::get_pareto_frontier() const { + return this->m_pareto_frontier; +} + +std::string format_as(MachineMappingWithMemoryResult const &r) { + return fmt::format("", + r.get_pareto_frontier()); +} + +std::ostream &operator<<(std::ostream &s, + MachineMappingWithMemoryResult const &r) { + return (s << fmt::to_string(r)); +} + +std::tuple const &> + MachineMappingWithMemoryResult::tie() const { + return std::tie(this->m_pareto_frontier); +} + MachineMappingWithMemoryResult empty_machine_mapping_with_memory_result() { return MachineMappingWithMemoryResult{ {}, @@ -23,29 +67,6 @@ MachineMappingWithMemoryResult get_mapping_with_minimal_runtime( return result; } -MachineMappingWithMemoryResult remove_non_pareto_optimal_machine_mapping_result( - MachineMappingWithMemoryResult const &result) { - std::unordered_set non_pareto_optimal_mappings; - for (MachineMappingForSingleLayer const &mapping : result.machine_mappings) { - bool is_pareto_optimal = true; - for (MachineMappingForSingleLayer const &other_mapping : - result.machine_mappings) { - if (mapping.cost.forward_runtime >= other_mapping.cost.forward_runtime && - mapping.cost.backward_runtime >= - other_mapping.cost.backward_runtime && - mapping.cost.memory_usage >= other_mapping.cost.memory_usage && - mapping != other_mapping) { - is_pareto_optimal = false; - break; - } - } - if (is_pareto_optimal) { - non_pareto_optimal_mappings.insert(mapping); - } - } - return MachineMappingWithMemoryResult{std::move(non_pareto_optimal_mappings)}; -} - MachineMappingWithMemoryResult series_combine(milliseconds_t comm_cost, MachineMappingWithMemoryResult const &pre_result, @@ -53,8 +74,8 @@ MachineMappingWithMemoryResult std::optional const ¶llel_split_transformation) { auto combine_machine_mapping = - [&](MachineMappingForSingleLayer const &pre_mm, - MachineMappingForSingleLayer const &post_mm) { + [&](ParetoOptimalMachineMapping const &pre_mm, + ParetoOptimalMachineMapping const &post_mm) { OpCostMetrics cost = OpCostMetrics{ /*forward_runtime=*/pre_mm.cost.forward_runtime + comm_cost + post_mm.cost.forward_runtime, @@ -76,28 +97,34 @@ MachineMappingWithMemoryResult } }(); - return MachineMappingForSingleLayer{cost, mapping}; + return ParetoOptimalMachineMapping{cost, mapping}; }; - MachineMappingWithMemoryResult result = - empty_machine_mapping_with_memory_result(); - for (MachineMappingForSingleLayer const &pre_mm : - pre_result.machine_mappings) { - for (MachineMappingForSingleLayer const &post_mm : - post_result.machine_mappings) { - result.machine_mappings.insert(combine_machine_mapping(pre_mm, post_mm)); + std::unordered_set result; + + for (ParetoOptimalMachineMapping const &pre_mm : + pre_result.get_pareto_frontier()) { + for (ParetoOptimalMachineMapping const &post_mm : + post_result.get_pareto_frontier()) { + result.insert(combine_machine_mapping(pre_mm, post_mm)); } } - return remove_non_pareto_optimal_machine_mapping_result(result); + return MachineMappingWithMemoryResult{ + /*pareto_frontier=*/filter(result, + [&](ParetoOptimalMachineMapping const &m) { + return is_pareto_optimal_in(m, result); + }), + }; } MachineMappingWithMemoryResult - parallel_combine(MachineMappingWithMemoryResult const &lhs_result, + parallel_combine(MachineResourceSplit const &split, + MachineMappingWithMemoryResult const &lhs_result, MachineMappingWithMemoryResult const &rhs_result) { auto combine_machine_mapping = - [&](MachineMappingForSingleLayer const &lhs_mm, - MachineMappingForSingleLayer const &rhs_mm) { + [&](ParetoOptimalMachineMapping const &lhs_mm, + ParetoOptimalMachineMapping const &rhs_mm) { OpCostMetrics cost = OpCostMetrics{ /*forward_runtime=*/ std::max(lhs_mm.cost.forward_runtime, rhs_mm.cost.forward_runtime), @@ -110,38 +137,50 @@ MachineMappingWithMemoryResult ParallelLayerGuidObliviousMachineMapping mapping = binary_combine_mappings(lhs_mm.machine_mapping, - rhs_mm.machine_mapping); + offset_layer_oblivious_mapping_by( + rhs_mm.machine_mapping, split)); - return MachineMappingForSingleLayer{cost, mapping}; + return ParetoOptimalMachineMapping{cost, mapping}; }; - MachineMappingWithMemoryResult result = - empty_machine_mapping_with_memory_result(); - for (MachineMappingForSingleLayer const &lhs_mm : - lhs_result.machine_mappings) { - for (MachineMappingForSingleLayer const &rhs_mm : - rhs_result.machine_mappings) { - result.machine_mappings.insert(combine_machine_mapping(lhs_mm, rhs_mm)); + std::unordered_set result; + for (ParetoOptimalMachineMapping const &lhs_mm : + lhs_result.get_pareto_frontier()) { + + for (ParetoOptimalMachineMapping const &rhs_mm : + rhs_result.get_pareto_frontier()) { + + result.insert(combine_machine_mapping(lhs_mm, rhs_mm)); } } - return remove_non_pareto_optimal_machine_mapping_result(result); + return MachineMappingWithMemoryResult{ + /*pareto_frontier=*/filter(result, + [&](ParetoOptimalMachineMapping const &m) { + return is_pareto_optimal_in(m, result); + }), + }; } MachineMappingWithMemoryResult minimize_runtime(MachineMappingWithMemoryResult const &m1, MachineMappingWithMemoryResult const &m2) { - MachineMappingWithMemoryResult result = MachineMappingWithMemoryResult{ - set_union(m1.machine_mappings, m2.machine_mappings), + std::unordered_set result = + set_union(m1.get_pareto_frontier(), m2.get_pareto_frontier()); + + return MachineMappingWithMemoryResult{ + /*pareto_frontier=*/filter(result, + [&](ParetoOptimalMachineMapping const &m) { + return is_pareto_optimal_in(m, result); + }), }; - return remove_non_pareto_optimal_machine_mapping_result(result); } MachineMappingWithMemoryResult make_singleton_machine_mapping_with_memory_result( OpCostMetrics cost, MachineView const &machine_view) { return MachineMappingWithMemoryResult{{ - MachineMappingForSingleLayer{ + ParetoOptimalMachineMapping{ cost, ParallelLayerGuidObliviousMachineMapping{{ {binary_tree_root_path(), machine_view}, @@ -151,3 +190,12 @@ MachineMappingWithMemoryResult } } // namespace FlexFlow + +namespace std { + +size_t hash<::FlexFlow::MachineMappingWithMemoryResult>::operator()( + ::FlexFlow::MachineMappingWithMemoryResult const &r) const { + return get_std_hash(r.tie()); +} + +} // namespace std diff --git a/lib/compiler/src/compiler/machine_mapping/memory_optimization/pareto_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/memory_optimization/pareto_optimal_machine_mapping.cc new file mode 100644 index 0000000000..ca6d762eed --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/memory_optimization/pareto_optimal_machine_mapping.cc @@ -0,0 +1,16 @@ +#include "compiler/machine_mapping/memory_optimization/pareto_optimal_machine_mapping.h" +#include "compiler/cost_estimator/op_cost_metrics.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +bool is_pareto_optimal_in( + ParetoOptimalMachineMapping const &m, + std::unordered_set const &others) { + return is_pareto_optimal_in( + m.cost, transform(others, [](ParetoOptimalMachineMapping const &m) { + return m.cost; + })); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc index ed60004bf4..6e2096afcc 100644 --- a/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc @@ -1,6 +1,13 @@ #include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "op-attrs/computation_graph_op_attrs.h" +#include "op-attrs/get_operator_task_space.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/containers/binary_merge_disjoint_maps.h" #include "utils/containers/map_keys.h" -#include "utils/containers/merge_maps.h" +#include "utils/containers/require_same.h" #include "utils/containers/try_at.h" #include "utils/full_binary_tree/binary_tree_path.h" @@ -10,15 +17,99 @@ ParallelLayerGuidObliviousMachineMapping binary_combine_mappings( ParallelLayerGuidObliviousMachineMapping const &lhs, ParallelLayerGuidObliviousMachineMapping const &rhs) { return ParallelLayerGuidObliviousMachineMapping{ - merge_disjoint_maps(map_keys(lhs.raw_mapping, nest_inside_left_child), - map_keys(rhs.raw_mapping, nest_inside_right_child)), + binary_merge_disjoint_maps( + map_keys(lhs.raw_mapping, nest_inside_left_child), + map_keys(rhs.raw_mapping, nest_inside_right_child)), }; } +ParallelLayerGuidObliviousMachineMapping + restrict_to_left_child(ParallelLayerGuidObliviousMachineMapping const &) { + NOT_IMPLEMENTED(); +} + +ParallelLayerGuidObliviousMachineMapping + restrict_to_right_child(ParallelLayerGuidObliviousMachineMapping const &) { + NOT_IMPLEMENTED(); +} + std::optional get_machine_view_for_path( ParallelLayerGuidObliviousMachineMapping const &mapping, BinaryTreePath const &path) { return try_at(mapping.raw_mapping, path); } +std::unordered_map + get_machine_stencils_for_decomposition( + ParallelComputationGraph const &pcg, + PCGBinarySPDecomposition const &decomposition, + ParallelLayerGuidObliviousMachineMapping const &mapping) { + std::unordered_set leaf_paths = require_same( + pcg_sp_tree_get_all_leaf_paths(decomposition), keys(mapping.raw_mapping)); + + std::unordered_map + path_to_op_task_space_map = + map_values(pcg_sp_tree_get_path_to_leaf_map(decomposition), + [&](parallel_layer_guid_t l) -> OperatorTaskSpace { + return get_operator_task_space(pcg, l); + }); + + return generate_map( + leaf_paths, [&](BinaryTreePath const &p) -> MachineSpaceStencil { + return MachineSpaceStencil{ + /*operator_task_space=*/path_to_op_task_space_map.at(p), + /*machine_view=*/mapping.raw_mapping.at(p), + }; + }); +} + +std::unordered_map> + get_machine_stencils_for_mm_problem_tree( + MachineMappingProblemTree const &tree, + ParallelLayerGuidObliviousMachineMapping const &mapping) { + + std::unordered_map + tree_leaf_map = mm_problem_tree_get_path_to_leaf_map(tree); + + std::unordered_set mapping_paths = keys(mapping.raw_mapping); + std::unordered_set tree_paths = keys(tree_leaf_map); + + ASSERT(is_subseteq_of(mapping_paths, tree_paths)); + + return generate_map( + tree_paths, + [&](BinaryTreePath const &p) -> std::optional { + if (!contains_key(mapping.raw_mapping, p)) { + return std::nullopt; + } + + UnmappedRuntimeOnlyOpCostEstimateKey leaf = tree_leaf_map.at(p); + + ComputationGraphOpAttrs leaf_op_attrs = + compgraph_op_attrs_from_pcg_op_attrs(leaf.op_attrs).value(); + + std::unordered_map + leaf_input_degrees = + map_values(leaf.input_shapes, [](ParallelTensorShape const &s) { + return get_parallel_degrees(s); + }); + + return MachineSpaceStencil{ + /*operator_task_space=*/get_operator_task_space(leaf_op_attrs, + leaf_input_degrees), + /*machine_view=*/mapping.raw_mapping.at(p), + }; + }); +} + +std::unordered_map + get_machine_stencils_for_partially_mapped_mm_problem_tree( + MachineMappingProblemTree const &tree, + ParallelLayerGuidObliviousMachineMapping const &mappings) { + + return filtermap_values( + get_machine_stencils_for_mm_problem_tree(tree, mappings), + [](std::optional const &s) { return s; }); +} + } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/start_invariant_machine_view.cc b/lib/compiler/src/compiler/machine_mapping/start_invariant_machine_view.cc new file mode 100644 index 0000000000..cbb64d5bcf --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/start_invariant_machine_view.cc @@ -0,0 +1,83 @@ +#include "compiler/machine_mapping/start_invariant_machine_view.h" +#include "compiler/machine_mapping/machine_view.h" +#include "op-attrs/operator_task_space.h" +#include "pcg/machine_space_offset.h" +#include "utils/containers/count.h" +#include "utils/containers/filter.h" +#include "utils/containers/scanl.h" +#include "utils/containers/transform.h" +#include "utils/containers/zip.h" +#include "utils/nonnegative_int/num_elements.h" +namespace FlexFlow { + +MachineView machine_view_from_start_invariant( + StartInvariantMachineView const &start_inv_mv, + MachineSpaceCoordinate const &start) { + return MachineView{start, start_inv_mv.dimensions}; +} + +StartInvariantMachineView + start_invariant_from_machine_view(MachineView const &mv) { + return StartInvariantMachineView{mv.dimensions, get_device_type(mv)}; +} + +nonnegative_int num_dims(StartInvariantMachineView const &start_inv_mv) { + return num_elements(start_inv_mv.dimensions); +} + +DeviceType get_device_type(StartInvariantMachineView const &start_inv_mv) { + return start_inv_mv.device_type; +} + +std::vector + get_strides(StartInvariantMachineView const &start_inv_mv) { + return transform(start_inv_mv.dimensions, + [](MachineViewDimension const &dim) { return dim.stride; }); +} + +std::vector + get_dimensions(StartInvariantMachineView const &start_inv_mv) { + return transform( + start_inv_mv.dimensions, + [](MachineViewDimension const &dim) { return dim.projection; }); +} + +StartInvariantMachineView + start_invariant_machine_view_from_strides_and_machine_spec_dimensions( + std::vector const &strides, + std::vector const &dims, + DeviceType device_type) { + std::vector dimensions = + transform(zip(strides, dims), [&](auto const &p) { + return MachineViewDimension{p.first, p.second}; + }); + return StartInvariantMachineView{dimensions, device_type}; +} + +MachineSpaceOffset get_machine_space_offset( + OperatorTaskSpace const &task, + StartInvariantMachineView const &start_inv_machine_view, + TaskSpaceCoordinate const &coord) { + + MachineSpaceCoordinate dummy_start = + MachineSpaceCoordinate{0_n, 0_n, get_device_type(start_inv_machine_view)}; + + MachineView mv = + machine_view_from_start_invariant(start_inv_machine_view, dummy_start); + + MachineSpaceCoordinate ms_coord = + get_machine_space_coordinate(task, mv, coord); + + return get_machine_space_offset_from_coordinate(dummy_start, ms_coord); +} + +std::unordered_set get_machine_space_offsets( + OperatorTaskSpace const &task, + StartInvariantMachineView const &start_inv_machine_view) { + return transform( + get_task_space_coordinates(task), [&](TaskSpaceCoordinate const &coord) { + return get_machine_space_offset(task, start_inv_machine_view, coord); + }); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc b/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc index 96c8106cad..5779edc382 100644 --- a/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc +++ b/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc @@ -4,20 +4,20 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" #include "utils/containers/flatmap.h" -#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h" -#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h" -#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h" #include "utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h" #include "utils/graph/digraph/algorithms/get_predecessors.h" #include "utils/graph/digraph/algorithms/get_successors.h" #include "utils/graph/digraph/algorithms/transitive_reduction.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/transitive_reduced_kwarg_dataflow_graph/get_transitive_reduced_boundary_nodes_for_kwarg_dataflow_graph_split.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/transitive_reduced_kwarg_dataflow_graph/get_transitive_reduced_kwarg_dataflow_edges_across_split.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/transitive_reduced_kwarg_dataflow_graph/get_transitive_reduced_kwarg_dataflow_outputs_across_split.h" namespace FlexFlow { -TransitiveReducedDataflowGraphView +TransitiveReducedKwargDataflowGraphView get_underlying_transitive_reduced_dataflow_graph( TransitiveReducedPCG const &tr_pcg) { - return TransitiveReducedDataflowGraphView{ + return TransitiveReducedKwargDataflowGraphView{ /*full_dataflow_graph=*/tr_pcg.full_pcg.raw_graph, /*transitive_reduction=*/tr_pcg.transitive_reduction, }; @@ -38,16 +38,17 @@ std::unordered_set pcg_get_transitive_reduced_edges_across_split( TransitiveReducedPCG const &tr_pcg, PCGBinarySeriesSplit const &split) { - TransitiveReducedDataflowGraphView raw_tr_g = + TransitiveReducedKwargDataflowGraphView raw_tr_g = get_underlying_transitive_reduced_dataflow_graph(tr_pcg); BinarySeriesSplit raw_split = binary_series_split_from_pcg_series_split(split); - std::unordered_set raw_edges = - get_transitive_reduced_edges_across_split(raw_tr_g, raw_split); + std::unordered_set> raw_edges = + get_transitive_reduced_kwarg_dataflow_edges_across_split(raw_tr_g, + raw_split); - return transform(raw_edges, [](DataflowEdge const &e) { + return transform(raw_edges, [](KwargDataflowEdge const &e) { return ParallelComputationGraphEdge{e}; }); } @@ -55,30 +56,33 @@ std::unordered_set std::unordered_set pcg_get_transitive_reduced_tensors_across_split( TransitiveReducedPCG const &tr_pcg, PCGBinarySeriesSplit const &split) { - TransitiveReducedDataflowGraphView raw_tr_g = + TransitiveReducedKwargDataflowGraphView raw_tr_g = get_underlying_transitive_reduced_dataflow_graph(tr_pcg); BinarySeriesSplit raw_split = binary_series_split_from_pcg_series_split(split); - std::unordered_set raw_outputs = - get_transitive_reduced_outputs_across_split(raw_tr_g, raw_split); + std::unordered_set> raw_outputs = + get_transitive_reduced_kwarg_dataflow_outputs_across_split(raw_tr_g, + raw_split); - return transform(raw_outputs, [](DataflowOutput const &o) { - return parallel_tensor_guid_t{o}; - }); + return transform(raw_outputs, + [](KwargDataflowOutput const &o) { + return parallel_tensor_guid_t{o}; + }); } PCGSplitBoundaryLayers pcg_get_transitive_reduced_boundary_layers_for_split( TransitiveReducedPCG const &tr_pcg, PCGBinarySeriesSplit const &split) { - TransitiveReducedDataflowGraphView raw_tr_g = + TransitiveReducedKwargDataflowGraphView raw_tr_g = get_underlying_transitive_reduced_dataflow_graph(tr_pcg); BinarySeriesSplit raw_split = binary_series_split_from_pcg_series_split(split); SplitBoundaryNodes raw_boundary = - get_transitive_reduced_boundary_nodes_for_split(raw_tr_g, raw_split); + get_transitive_reduced_boundary_nodes_for_kwarg_dataflow_graph_split( + raw_tr_g, raw_split); return PCGSplitBoundaryLayers{ /*pre_split_boundary=*/transform( diff --git a/lib/compiler/src/compiler/machine_mapping/unstructured_device_mapping.cc b/lib/compiler/src/compiler/machine_mapping/unstructured_device_mapping.cc index 63e359d9ac..80c09d2dba 100644 --- a/lib/compiler/src/compiler/machine_mapping/unstructured_device_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/unstructured_device_mapping.cc @@ -1,20 +1,18 @@ - #include "compiler/machine_mapping/unstructured_device_mapping.h" +#include "compiler/machine_mapping/machine_view.h" #include "compiler/machine_mapping/unstructured_device_mapping.dtg.h" -#include "pcg/machine_specification.h" -#include "pcg/machine_view.h" -#include "pcg/operator_task_space.dtg.h" -#include "pcg/operator_task_space.h" +#include "op-attrs/operator_task_space.dtg.h" +#include "op-attrs/operator_task_space.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/containers/keys.h" #include "utils/containers/map_values.h" namespace FlexFlow { -UnstructuredDeviceMapping - get_unstructured_device_mapping(MachineMapping const &machine_mapping, - MachineSpecification const &machine_spec, - ParallelComputationGraph const &pcg) { +UnstructuredDeviceMapping get_unstructured_device_mapping( + MachineMapping const &machine_mapping, + MachineComputeSpecification const &machine_spec, + ParallelComputationGraph const &pcg) { std::unordered_map> device_mapping; for (auto const &[layer, machine_view] : machine_mapping.machine_views) { diff --git a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc index 5eb993c6ef..e7cbef122c 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc @@ -2,6 +2,7 @@ #include "compiler/series_parallel/pcg/pcg_binary_series_split.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_path_to_leaf_map.h" #include "utils/overload.h" namespace FlexFlow { @@ -106,10 +107,20 @@ SPDecompositionTreeNodeType }); } +std::unordered_set + pcg_sp_tree_get_all_leaf_paths(PCGBinarySPDecomposition const &tree) { + return keys(pcg_sp_tree_get_path_to_leaf_map(tree)); +} + std::unordered_set find_paths_to_leaf(PCGBinarySPDecomposition const &tree, parallel_layer_guid_t const &leaf) { return find_paths_to_leaf(tree, generic_impl_for_pcg_sp_tree(), leaf); } +std::unordered_map + pcg_sp_tree_get_path_to_leaf_map(PCGBinarySPDecomposition const &tree) { + return get_path_to_leaf_map(tree, generic_impl_for_pcg_sp_tree()); +} + } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/task_graph_simulator/pcg_task_graph.cc b/lib/compiler/src/compiler/task_graph_simulator/pcg_task_graph.cc index c072b0e61e..d4d5a78d6a 100644 --- a/lib/compiler/src/compiler/task_graph_simulator/pcg_task_graph.cc +++ b/lib/compiler/src/compiler/task_graph_simulator/pcg_task_graph.cc @@ -2,11 +2,11 @@ #include "compiler/cost_estimator/runtime_only_op_cost_estimate_key.h" #include "compiler/cost_estimator/tensor_set_movement.h" #include "compiler/machine_mapping/machine_mapping.dtg.h" +#include "compiler/machine_mapping/machine_view.dtg.h" +#include "compiler/machine_mapping/machine_view.h" +#include "op-attrs/operator_task_space.h" #include "pcg/device_id_t.dtg.h" #include "pcg/machine_specification.dtg.h" -#include "pcg/machine_view.dtg.h" -#include "pcg/machine_view.h" -#include "pcg/operator_task_space.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" @@ -18,9 +18,10 @@ namespace FlexFlow { -PCGTaskGraph get_pcg_task_graph(ParallelComputationGraph const &pcg, - MachineMapping const &machine_mapping, - MachineSpecification const &machine_spec) { +PCGTaskGraph + get_pcg_task_graph(ParallelComputationGraph const &pcg, + MachineMapping const &machine_mapping, + MachineComputeSpecification const &machine_spec) { DiGraph digraph = DiGraph::create(); bidict node_to_task; bidict node_to_layer; diff --git a/lib/compiler/src/compiler/task_graph_simulator/simulate_task_graph_execution.cc b/lib/compiler/src/compiler/task_graph_simulator/simulate_task_graph_execution.cc index 7b317df31a..30e345243c 100644 --- a/lib/compiler/src/compiler/task_graph_simulator/simulate_task_graph_execution.cc +++ b/lib/compiler/src/compiler/task_graph_simulator/simulate_task_graph_execution.cc @@ -8,7 +8,7 @@ #include "utils/containers/set_of.h" #include "utils/containers/sorted.h" #include "utils/exception.h" -#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_initial_nodes.h" #include "utils/graph/digraph/algorithms/get_predecessors.h" #include "utils/graph/digraph/algorithms/get_successors.h" #include "utils/graph/digraph/algorithms/is_acyclic.h" diff --git a/lib/compiler/src/compiler/task_graph_simulator/task_simulator.cc b/lib/compiler/src/compiler/task_graph_simulator/task_simulator.cc index a1aa53885b..bc528493a8 100644 --- a/lib/compiler/src/compiler/task_graph_simulator/task_simulator.cc +++ b/lib/compiler/src/compiler/task_graph_simulator/task_simulator.cc @@ -25,8 +25,8 @@ milliseconds_t task_simulator_estimate_forward_pass_time( MachineMapping const &machine_mapping, MachineSpecification const &machine_spec) { - PCGTaskGraph task_graph = - get_pcg_task_graph(pcg, machine_mapping, machine_spec); + PCGTaskGraph task_graph = get_pcg_task_graph( + pcg, machine_mapping, machine_spec.compute_specification); auto cost_function = [&](Node const &node) -> float { PCGTask task = task_graph.node_to_task.at_l(node); @@ -48,8 +48,8 @@ milliseconds_t task_simulator_estimate_forward_pass_time( std::unordered_set const &finished_tasks) -> bool { PCGTask current_task = task_graph.node_to_task.at_l(task); - UnstructuredDeviceMapping device_map = - get_unstructured_device_mapping(machine_mapping, machine_spec, pcg); + UnstructuredDeviceMapping device_map = get_unstructured_device_mapping( + machine_mapping, machine_spec.compute_specification, pcg); if (current_task.is_tensor_movement()) { return true; diff --git a/lib/compiler/src/compiler/unity_algorithm.cc b/lib/compiler/src/compiler/unity_algorithm.cc new file mode 100644 index 0000000000..9ae824c62a --- /dev/null +++ b/lib/compiler/src/compiler/unity_algorithm.cc @@ -0,0 +1,77 @@ +#include "compiler/unity_algorithm.h" +#include "compiler/graph_optimize_state.h" +#include "compiler/machine_mapping/get_optimal_machine_mapping.h" +#include "pcg/machine_specification.dtg.h" +#include "substitutions/substitution.h" +#include "utils/deduplicated_priority_queue.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +GraphOptimizeResult graph_optimize( + ParallelComputationGraph &pcg, + CostEstimator const &cost_estimator, + MachineSpecification const &resources, + std::function( + ParallelLayerAttrs const &, MachineSpecification const &)> const + &allowed_machine_views, + OptimizerConfig const &opt_config) { + NOT_IMPLEMENTED(); + + // std::vector substitutions = + // get_all_applicable_substitutions(pcg); + // + // MachineMappingCache cached_subgraph_costs; + // DeduplicatedPriorityQueue candidates; + // + // MachineMappingResult original_pcg_cost = + // get_optimal_machine_mapping(pcg, + // allowed_machine_views, + // cost_estimator, + // resources, + // cached_subgraph_costs); + // + // GraphOptimizeState initial_state = { + // GraphOptimizeResult(pcg, original_pcg_cost.machine_mapping), + // original_pcg_cost.runtime}; + // + // GraphOptimizeState best_state = initial_state; + // candidates.push(initial_state); + // + // for (int iteration = 0; !candidates.empty() && iteration < + // opt_config.budget; + // ++iteration) { + // GraphOptimizeState current_state = candidates.top(); + // candidates.pop(); + // + // if (current_state.runtime < best_state.runtime) { + // best_state = current_state; + // } else if (current_state.runtime > best_state.runtime * opt_config.alpha) + // { + // continue; + // } + // + // for (Substitution const &substitution : substitutions) { + // for (ParallelComputationGraph const &new_pcg : apply_substitution( + // current_state.graph_optimize_result.pcg, substitution)) { + // MachineMappingResult new_pcg_cost = + // get_optimal_machine_mapping(new_pcg, + // allowed_machine_views, + // cost_estimator, + // resources, + // cached_subgraph_costs); + // GraphOptimizeState new_state{ + // GraphOptimizeResult(new_pcg, new_pcg_cost.machine_mapping), + // new_pcg_cost.runtime}; + // if (new_pcg_cost.runtime <= opt_config.threshold && + // get_nodes(new_pcg.raw_graph).size() <= opt_config.max_num_ops) { + // candidates.push(new_state); + // } + // } + // } + // } + + // return best_state.graph_optimize_result; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/optimizations.cc.todo b/lib/compiler/src/optimizations.cc.todo deleted file mode 100644 index 1d93fb169c..0000000000 --- a/lib/compiler/src/optimizations.cc.todo +++ /dev/null @@ -1,280 +0,0 @@ -#include "optimizations.h" -#include "op-attrs/get_op_type.h" -#include "utils/graph/rewriting.h" - -namespace FlexFlow { - -struct OptimizeUnnecessaryGradientCalculations { - OptimizeUnnecessaryGradientCalculations() = default; - - Operator operator()(LabelledOpenMultiDiGraph const &g, - Node const &n, - Operator const &op) { return op; } - ParallelTensorAttrs operator()(LabelledOpenMultiDiGraph const &g, - MultiDiEdge const &e, - ParallelTensorAttrs const &pt) { - ParallelTensorAttrs result = pt; - if (get_op_type(g.at(e.src).attrs) == OperatorType::INPUT) { - result.create_gradients = CreateGrad::NO; - } - return result; - } -}; - -ParallelComputationGraph optimize_unnecessary_gradient_calculations(ParallelComputationGraph const &pcg) { - // If an operator's input is training data - // No need to compute its gradients - return pcg.on_underlying([](LabelledOpenMultiDiGraph const &g) { - return rewrite(OptimizeUnnecessaryGradientCalculations{}, g); - }); -} - -ParallelComputationGraph enable_inplace_operators(ParallelComputationGraph const &pcg) { - -} - -void FFModel::perform_fusion_optimizations() { - fprintf(stderr, "Applying fusion optimizations during compilation...\n"); - fprintf(stderr, "%zu operators before fusion...\n", operators.size()); - std::vector new_operators; - std::vector old_operators = operators; - while (apply_fusion(operators, new_operators)) { - for (size_t i = 0; i < new_operators.size(); i++) { - for (int idx = 0; idx < new_operators[i]->numInputs; idx++) { - for (size_t j = i + 1; j < new_operators.size(); j++) { - if (new_operators[i]->inputs[idx]->owner_op == new_operators[j]) { - assert(false); - } - } - } - } - operators = new_operators; - } - // Check integrity - for (size_t l = 0; l < operators.size(); l++) { - if (operators[l]->op_type == OP_FUSED) { - FusedOp *fused = (FusedOp *)operators[l]; - int ioff = 0, woff = 0, ooff = 0; - for (int op = 0; op < fused->numOperators; op++) { - Op *old_op = fused->operators[op]; - for (int i = 0; i < fused->op_num_inputs[op]; i++) { - int my_off = fused->op_input_idx[i + ioff]; - if (fused->op_input_source[i + ioff] == FusedOp::SOURCE_INPUT) { - assert(fused->inputs[my_off]->region == - old_op->inputs[i]->region); - } else if (fused->op_input_source[i + ioff] == - FusedOp::SOURCE_OUTPUT) { - assert(fused->outputs[my_off]->region == - old_op->inputs[i]->region); - } else { - assert(false); - } - } - for (int i = 0; i < fused->op_num_weights[op]; i++) { - int my_off = fused->op_weight_idx[i + woff]; - assert(fused->op_weight_source[i + woff] == FusedOp::SOURCE_WEIGHT); - assert(fused->weights[my_off]->region == - old_op->weights[i]->region); - } - for (int i = 0; i < fused->op_num_outputs[op]; i++) { - int my_off = fused->op_output_idx[i + ooff]; - assert(fused->op_output_source[i + ooff] == FusedOp::SOURCE_OUTPUT); - assert(fused->outputs[my_off]->region == - old_op->outputs[i]->region); - } - ioff += fused->op_num_inputs[op]; - woff += fused->op_num_weights[op]; - ooff += fused->op_num_outputs[op]; - } - } else { - bool found = false; - for (size_t i = 0; i < old_operators.size(); i++) { - if (old_operators[i] == operators[l]) { - assert(!found); - found = true; - } - } - assert(found); - } - } - fprintf(stderr, "%zu operators after fusion...\n", operators.size()); - for (size_t i = 0; i < operators.size(); i++) { - Op *op = operators[i]; - printf("operator[%zu]: type(%s) guid(%lu)\n", - i, - get_operator_type_name(operators[i]->op_type).c_str(), - operators[i]->op_guid); - for (int j = 0; j < op->numInputs; j++) { - LogicalRegion handle = op->inputs[j]->region; - printf("inputs[%d] region(%d,%d,%d)\n", - j, - handle.get_index_space().get_id(), - handle.get_field_space().get_id(), - handle.get_tree_id()); - } - for (int j = 0; j < op->numOutputs; j++) { - LogicalRegion handle = op->outputs[j]->region; - printf("outputs[%d] region(%d,%d,%d)\n", - j, - handle.get_index_space().get_id(), - handle.get_field_space().get_id(), - handle.get_tree_id()); - } - for (int j = 0; j < op->numWeights; j++) { - LogicalRegion handle = op->weights[j]->region; - printf("weights[%d] region(%d,%d,%d)\n", - j, - handle.get_index_space().get_id(), - handle.get_field_space().get_id(), - handle.get_tree_id()); - } - } -} - - -void FFModel::perform_inplace_optimizations() { - for (size_t l = 1; l < operators.size(); l++) { - if (operators[l]->can_inplace_output()) { - // Assume outputs[0] is inplace with inputs[0] - assert(operators[l]->numOutputs == 1); - if (operators[l]->inputs[0]->owner_op != NULL) { - // int dim1 = operators[l]->outputs[0]->num_dims; - // int dim2 = operators[l]->inputs[0]->num_dims; - MachineView view1 = operators[l]->outputs[0]->machine_view.value(); - MachineView view2 = operators[l]->inputs[0]->machine_view.value(); - if (view1 == view2) { - // Check no others also need operators[l]->inputs[0] - bool found = false; - for (size_t i = 0; i < operators.size(); i++) { - if (i == l) { - continue; - } - for (int j = 0; j < operators[i]->numInputs; j++) { - if ((operators[i]->inputs[j]->owner_op == - operators[l]->inputs[0]->owner_op) && - (operators[i]->inputs[j]->owner_idx == - operators[l]->inputs[0]->owner_idx)) { - found = true; - } - } - } - if (!found) { - // Perform inplace - operators[l]->do_inplace_output(); - } - } - } - } - } -} -bool FFModel::apply_fusion(std::vector const &operators, - std::vector &new_operators) { - // Context ctx = config.lg_ctx; - // Runtime* runtime = config.lg_hlr; - for (size_t l = 1; l < operators.size() - 1; l++) { - // don't fuse input and weight operator since they don't involve any - // forward/backward task launches - if (operators[l]->op_type == OP_INPUT || - operators[l]->op_type == OP_WEIGHT) { - continue; - } - // don't fuse parallel op since they have different parallel_is in - // forward/backward - if (operators[l]->is_parallel_op()) { - continue; - } - size_t start = 0; - { - Op *opl = operators[l]; - for (int idx = 0; idx < opl->numInputs; idx++) { - bool found = false; - for (size_t i = 0; i < l; i++) { - if (opl->inputs[idx]->owner_op == operators[i]) { - assert(!found); - found = true; - if (i > start) { - start = i; - } - } - } - assert(found || (opl->inputs[idx]->owner_op == NULL)); - } - } - for (size_t i = start; i < l; i++) { - // Domain d1 = - // runtime->get_index_space_domain(operators[l]->outputs[0]->parallel_is); - // Domain d2 = - // runtime->get_index_space_domain(operators[i]->outputs[0]->parallel_is); - MachineView view1 = operators[l]->outputs[0]->machine_view.value(); - MachineView view2 = operators[i]->outputs[0]->machine_view.value(); - if (view1 == view2) { - FusedOp *fused_op = nullptr; - bool allocate_new_fused_op = false; - if (operators[i]->op_type == OP_FUSED) { - fused_op = (FusedOp *)operators[i]; - } else { - // cannot be an in-place operator - if (operators[i]->has_inplace_output()) { - continue; - } - // don't fuse input and weight operator since they don't involve any - // forward/backward kernels - if (operators[i]->op_type == OP_INPUT || - operators[i]->op_type == OP_WEIGHT) { - continue; - } - // don't fuse parallel op since they have different parallel_is in - // forward/backward - if (operators[i]->is_parallel_op()) { - continue; - } - fused_op = new FusedOp(*this, operators[i]); - allocate_new_fused_op = true; - } - if (fused_op->add_operator(*this, operators[l])) { - // Construct new operators - new_operators.clear(); - for (size_t j = 0; j < i; j++) { - new_operators.push_back(operators[j]); - } - new_operators.push_back(fused_op); - for (size_t j = i + 1; j < operators.size(); j++) { - if (j == l) { - continue; // l and i are fused - } - Op *op = operators[j]; - // Update input tensors that belong to operator[l] or operator[i] - for (int idx = 0; idx < op->numInputs; idx++) { - if ((op->inputs[idx]->owner_op == operators[l]) || - (op->inputs[idx]->owner_op == operators[i])) { - int found = -1; - for (int k = 0; k < fused_op->numOutputs; k++) { - if (fused_op->outputs[k]->region == op->inputs[idx]->region) { - assert(found == -1); - found = k; - } - } - assert(found >= 0); - op->inputs[idx] = fused_op->outputs[found]; - } - } - // Insert op - new_operators.push_back(op); - } - // We are exact one operator fewer than the original - assert(new_operators.size() + 1 == operators.size()); - return true; - } else { - // TODO: delete fused_op to avoid memory leakage - if (allocate_new_fused_op) { - delete fused_op; - } - continue; - } - } - } - } - return false; -} - -} diff --git a/lib/compiler/src/optimizations.h.todo b/lib/compiler/src/optimizations.h.todo deleted file mode 100644 index c2846f26ca..0000000000 --- a/lib/compiler/src/optimizations.h.todo +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef _FLEXFLOW_RUNTIME_SRC_OPTIMIZATIONS_H -#define _FLEXFLOW_RUNTIME_SRC_OPTIMIZATIONS_H - -#include "parallel_computation_graph.h" - -namespace FlexFlow { - -ParallelComputationGraph fuse_operators(ParallelComputationGraph const &); -ParallelComputationGraph - remove_unnecessary_gradient_calculations(ParallelComputationGraph const &); -ParallelComputationGraph - enable_inplace_operators(ParallelComputationGraph const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/src/substitutions_implementation.h b/lib/compiler/src/substitutions_implementation.h deleted file mode 100644 index 786d9425ea..0000000000 --- a/lib/compiler/src/substitutions_implementation.h +++ /dev/null @@ -1,39 +0,0 @@ -#ifndef _FLEXFLOW_FFC_SUBSTITUTIONS_IMPLEMENTATION_H -#define _FLEXFLOW_FFC_SUBSTITUTIONS_IMPLEMENTATION_H - -#include "substitutions/substitutions_v2.h" - -namespace FlexFlow { - -substitutions::SubstitutionPattern - create_combine_inception(int num_convs, int num_dims, int num_parts); -substitutions::SubstitutionPattern - create_combine_concat(int num_inputs, int num_dims, int num_parts); -substitutions::SubstitutionPattern create_replicate_linear_combine( - int num_dims, int num_parts, ActiMode activation, bool use_bias); -substitutions::SubstitutionPattern create_partition_linear_combine( - int num_dims, int num_parts, ActiMode activation, bool use_bias); -substitutions::SubstitutionPattern - create_partition_conv2d_combine(int num_dims, int num_parts); -substitutions::SubstitutionPattern - create_partition_attention_combine(int num_heads, int num_parts); -substitutions::SubstitutionPattern - create_replicate_attention_reduce(int num_heads, int num_parts); -substitutions::SubstitutionPattern - create_partition_add_combine(int parallel_dim, int num_parts); -substitutions::SubstitutionPattern - create_partition_relu_combine(int parallel_dim, int num_parts); -substitutions::SubstitutionPattern create_partition_concat_combine( - int num_inputs, int concat_dim, int parallel_dim, int num_parts); -substitutions::SubstitutionPattern create_partition_softmax_combine( - int softmax_dim, int part_dim, int num_parts); -substitutions::SubstitutionPattern leading_relu_branch_combine( - int parallel_dim, int num_parts, int num_combines); -substitutions::SubstitutionPattern leading_relu_branch_partition( - int parallel_dim, int num_parts, int num_partitions); -substitutions::SubstitutionPattern create_linear_relu_merge(int num_dims, - bool use_bias); - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc deleted file mode 100644 index 86a211c535..0000000000 --- a/lib/compiler/src/unity_algorithm.cc +++ /dev/null @@ -1,93 +0,0 @@ -#include "compiler/unity_algorithm.h" -#include "compiler/graph_optimize_state.h" -#include "compiler/machine_mapping/get_optimal_machine_mapping.h" -#include "pcg/machine_specification.dtg.h" -#include "substitutions/substitution.h" -#include "utils/deduplicated_priority_queue.h" -#include "utils/graph/node/algorithms.h" -namespace FlexFlow { - -/* - * Gets all substitutions applicable to a PCG - */ -std::vector - get_all_applicable_substitutions(ParallelComputationGraph const &pcg) { - NOT_IMPLEMENTED(); -} - -/* - * Applies a substitution to all possible positions in PCG - */ -std::vector - apply_substitution(ParallelComputationGraph const &pcg, - Substitution const &) { - NOT_IMPLEMENTED(); -} - -GraphOptimizeResult graph_optimize( - ParallelComputationGraph &pcg, - CostEstimator const &cost_estimator, - MachineSpecification const &resources, - std::function( - ParallelLayerAttrs const &, MachineSpecification const &)> const - &allowed_machine_views, - OptimizerConfig const &opt_config) { - NOT_IMPLEMENTED(); - - // std::vector substitutions = - // get_all_applicable_substitutions(pcg); - // - // MachineMappingCache cached_subgraph_costs; - // DeduplicatedPriorityQueue candidates; - // - // MachineMappingResult original_pcg_cost = - // get_optimal_machine_mapping(pcg, - // allowed_machine_views, - // cost_estimator, - // resources, - // cached_subgraph_costs); - // - // GraphOptimizeState initial_state = { - // GraphOptimizeResult(pcg, original_pcg_cost.machine_mapping), - // original_pcg_cost.runtime}; - // - // GraphOptimizeState best_state = initial_state; - // candidates.push(initial_state); - // - // for (int iteration = 0; !candidates.empty() && iteration < - // opt_config.budget; - // ++iteration) { - // GraphOptimizeState current_state = candidates.top(); - // candidates.pop(); - // - // if (current_state.runtime < best_state.runtime) { - // best_state = current_state; - // } else if (current_state.runtime > best_state.runtime * opt_config.alpha) - // { - // continue; - // } - // - // for (Substitution const &substitution : substitutions) { - // for (ParallelComputationGraph const &new_pcg : apply_substitution( - // current_state.graph_optimize_result.pcg, substitution)) { - // MachineMappingResult new_pcg_cost = - // get_optimal_machine_mapping(new_pcg, - // allowed_machine_views, - // cost_estimator, - // resources, - // cached_subgraph_costs); - // GraphOptimizeState new_state{ - // GraphOptimizeResult(new_pcg, new_pcg_cost.machine_mapping), - // new_pcg_cost.runtime}; - // if (new_pcg_cost.runtime <= opt_config.threshold && - // get_nodes(new_pcg.raw_graph).size() <= opt_config.max_num_ops) { - // candidates.push(new_state); - // } - // } - // } - // } - - // return best_state.graph_optimize_result; -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/utils/recursive_logger.cc.todo b/lib/compiler/src/utils/recursive_logger.cc.todo deleted file mode 100644 index 892b8a63c1..0000000000 --- a/lib/compiler/src/utils/recursive_logger.cc.todo +++ /dev/null @@ -1,40 +0,0 @@ -#include "utils/recursive_logger.h" -#include "utils/exception.h" - -namespace FlexFlow { -namespace utils { - -RecursiveLogger::RecursiveLogger(std::shared_ptr const &logger) - : logger(logger) {} - -RecursiveLogger::RecursiveLogger(std::string const &logger_name) { - this->logger = spdlog::get(logger_name); -} - -std::string RecursiveLogger::get_prefix() const { - return std::string(this->depth * 2, ' '); -} - -void RecursiveLogger::enter() { - this->depth++; -} - -void RecursiveLogger::leave() { - this->depth--; - assert(this->depth >= 0); -} - -std::unique_ptr RecursiveLogger::enter_tag() { - return std::unique_ptr(new DepthTag(*this)); -} - -DepthTag::DepthTag(RecursiveLogger &_logger) : logger(_logger) { - this->logger.enter(); -} - -DepthTag::~DepthTag() { - this->logger.leave(); -} - -} // namespace utils -} // namespace FlexFlow diff --git a/lib/compiler/src/utils/recursive_logger.h.todo b/lib/compiler/src/utils/recursive_logger.h.todo deleted file mode 100644 index e78726d9b6..0000000000 --- a/lib/compiler/src/utils/recursive_logger.h.todo +++ /dev/null @@ -1,66 +0,0 @@ -#ifndef _FLEXFLOW_RECURSIVE_LOGGER_H -#define _FLEXFLOW_RECURSIVE_LOGGER_H - -#include "spdlog/spdlog.h" -#include - -#define CONCAT(a, b) CONCAT_INNER(a, b) -#define CONCAT_INNER(a, b) a##b -#define UNIQUE_TAG() CONCAT(tag, __COUNTER__) -#define TAG_ENTER(mlogger) auto UNIQUE_TAG() = mlogger->enter_tag() - -namespace FlexFlow { -namespace utils { - -class RecursiveLogger; - -class DepthTag { -public: - DepthTag() = delete; - DepthTag(RecursiveLogger &); - DepthTag(DepthTag const &) = delete; - ~DepthTag(); - -private: - RecursiveLogger &logger; -}; - -class RecursiveLogger { -public: - RecursiveLogger(std::shared_ptr const &logger); - RecursiveLogger(std::string const &logger_name); - - RecursiveLogger(RecursiveLogger const &) = delete; - - template - void info(std::string const &fmt, Args &&...args) { - this->logger->info(this->get_prefix() + fmt, std::forward(args)...); - } - - template - void debug(std::string const &fmt, Args &&...args) { - this->logger->debug(this->get_prefix() + fmt, std::forward(args)...); - } - - template - void spew(std::string const &fmt, Args &&...args) { - this->logger->trace(this->get_prefix() + fmt, std::forward(args)...); - } - - void enter(); - void leave(); - - std::unique_ptr enter_tag(); - -private: - std::string get_prefix() const; - -private: - int depth = 0; - std::shared_ptr logger; -}; - -} // namespace utils -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/test/src/allowed_machine_views.cc b/lib/compiler/test/src/compiler/allowed_machine_views.cc similarity index 89% rename from lib/compiler/test/src/allowed_machine_views.cc rename to lib/compiler/test/src/compiler/allowed_machine_views.cc index 15f7d60060..e768d8540c 100644 --- a/lib/compiler/test/src/allowed_machine_views.cc +++ b/lib/compiler/test/src/compiler/allowed_machine_views.cc @@ -1,11 +1,11 @@ #include "compiler/allowed_machine_views.h" -#include "doctest/doctest.h" #include "utils/containers/extend.h" #include "utils/containers/range.h" #include "utils/containers/transform.h" #include "utils/containers/unordered_set_of.h" #include "utils/containers/zip.h" #include "utils/fmt/unordered_set.h" +#include using namespace FlexFlow; @@ -14,15 +14,13 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_allowed_machine_views") { SUBCASE("1 degree of parallelism") { - MachineSpecification ms = MachineSpecification{ + MachineComputeSpecification ms = MachineComputeSpecification{ /*num_nodes=*/1_p, /*num_cpus_per_node=*/5_p, /*num_gpus_per_node=*/5_p, - /*inter_node_bandwidth=*/0, - /*intra_node_bandwidth=*/0, }; - OperatorTaskSpace task = OperatorTaskSpace{{3_p}}; + OperatorTaskSpace task = OperatorTaskSpace{MinimalOrthotope{{3_ge2}}}; std::unordered_set correct = { MachineView{ @@ -60,14 +58,13 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("2 degrees of parallelism") { - MachineSpecification ms = MachineSpecification{ + MachineComputeSpecification ms = MachineComputeSpecification{ /*num_nodes=*/3_p, /*num_cpus_per_node=*/3_p, /*num_gpus_per_node=*/3_p, - /*inter_node_bandwidth=*/0, - /*intra_node_bandwidth=*/0, }; - OperatorTaskSpace task = OperatorTaskSpace{{2_p, 3_p}}; + OperatorTaskSpace task = + OperatorTaskSpace{MinimalOrthotope{{2_ge2, 3_ge2}}}; auto make_2d_view = [&](nonnegative_int start_node_idx, nonnegative_int start_device_idx, diff --git a/lib/compiler/test/src/compiler/graph_optimize_state.cc b/lib/compiler/test/src/compiler/graph_optimize_state.cc new file mode 100644 index 0000000000..d99f754609 --- /dev/null +++ b/lib/compiler/test/src/compiler/graph_optimize_state.cc @@ -0,0 +1,132 @@ +#include "compiler/graph_optimize_state.h" +#include "compiler/machine_mapping/machine_mapping.dtg.h" +#include "compiler/machine_mapping/machine_mapping.h" +#include "compiler/machine_mapping/machine_view.dtg.h" +#include "compiler/machine_mapping/machine_view.h" +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "test/utils/doctest/check_without_stringify.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("GraphOptimizeState operator==") { + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 32_p, + 16_p, + }, + }, + DataType::FLOAT, + }; + + InitializerAttrs zero_init = InitializerAttrs{ZeroInitializerAttrs{}}; + + auto create_pcg = [&]() -> ParallelComputationGraph { + ParallelComputationGraphBuilder builder; + + parallel_tensor_guid_t input0 = + builder.create_input_tensor(input_shape, "input0"); + parallel_tensor_guid_t dense0 = + builder.dense(/*input=*/input0, + /*outDim=*/8_p, + /*activation=*/Activation::RELU, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/zero_init, + /*bias_initializer=*/zero_init, + /*name=*/"dense0"); + + parallel_tensor_guid_t dense1 = + builder.dense(/*input=*/dense0, + /*outDim=*/4_p, + /*activation=*/Activation::RELU, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/zero_init, + /*bias_initializer=*/zero_init, + /*name=*/"dense1"); + + return builder.pcg; + }; + + auto create_machine_mapping_for_pcg = + [](ParallelComputationGraph const &pcg) -> MachineMapping { + MachineSpaceCoordinate device = MachineSpaceCoordinate{ + /*node_idx=*/0_n, + /*device_idx=*/0_n, + /*device_type=*/DeviceType::GPU, + }; + + MachineView machine_view = make_single_device_machine_view(device); + + return MachineMapping{ + generate_map(get_parallel_layers(pcg), + [&](parallel_layer_guid_t) { return machine_view; }), + }; + }; + + ParallelComputationGraph pcg1 = create_pcg(); + MachineMapping machine_mapping_1 = create_machine_mapping_for_pcg(pcg1); + + SUBCASE("returns true if the PCGs are isomorphic") { + ParallelComputationGraph pcg2 = create_pcg(); + MachineMapping machine_mapping_2 = create_machine_mapping_for_pcg(pcg2); + + GraphOptimizeState state1 = GraphOptimizeState{ + GraphOptimizeResult{ + mapped_pcg_from_pcg_and_mapping(pcg1, machine_mapping_1), + }, + 0, + }; + + GraphOptimizeState state2 = GraphOptimizeState{ + GraphOptimizeResult{ + mapped_pcg_from_pcg_and_mapping(pcg2, machine_mapping_2), + }, + 0, + }; + + CHECK_WITHOUT_STRINGIFY(state1 == state2); + } + + SUBCASE("returns false it the PCGs are not isomorphic") { + ParallelComputationGraphBuilder builder_; + + parallel_tensor_guid_t input0_ = + builder_.create_input_tensor(input_shape, "input0"); + parallel_tensor_guid_t dense0_ = + builder_.dense(/*input=*/input0_, + /*outDim=*/8_p, + /*activation=*/Activation::RELU, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/zero_init, + /*bias_initializer=*/zero_init, + /*name=*/"dense0"); + + ParallelComputationGraph other_pcg = builder_.pcg; + + MachineMapping other_machine_mapping = + create_machine_mapping_for_pcg(other_pcg); + + GraphOptimizeState state1 = GraphOptimizeState{ + GraphOptimizeResult{ + mapped_pcg_from_pcg_and_mapping(pcg1, machine_mapping_1), + }, + 0, + }; + + GraphOptimizeState state_ = GraphOptimizeState{ + GraphOptimizeResult{ + mapped_pcg_from_pcg_and_mapping(other_pcg, other_machine_mapping), + }, + 0, + }; + + CHECK_FALSE_WITHOUT_STRINGIFY(state1 == state_); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc index 0416a73660..7adb94b5f4 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc @@ -1,14 +1,111 @@ #include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_communication.dtg.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.h" #include "compiler/machine_mapping/transitive_reduced_pcg.h" #include "op-attrs/parallel_tensor_shape.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/containers/get_only.h" +#include "utils/containers/require_only_key.h" #include "utils/full_binary_tree/binary_tree_path.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_abstracted_single_tensor_movement_along_edge") { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 10_p, + 12_p, + }, + }, + DataType::FLOAT, + }; + + ParallelTensorShape par_input_shape = lift_to_parallel(input_shape); + + ParallelLayerAttrs partition_attrs = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{ + RepartitionAttrs{ + /*repartition_dim=*/ff_dim_t{0_n}, + /*repartition_degree=*/2_p, + }, + }, + /*name=*/std::nullopt, + }; + + ParallelLayerAttrs relu_attrs = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{ + ElementUnaryAttrs{ + /*op_type=*/OperatorType::RELU, + /*scalar=*/std::nullopt, + }, + }, + /*name=*/std::nullopt, + }; + + ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); + parallel_tensor_guid_t t_input = + require_only_key(input.outputs, TensorSlotName::OUTPUT); + ParallelLayerAddedResult partition_input = add_parallel_layer( + pcg, partition_attrs, {{TensorSlotName::INPUT, t_input}}, {}); + parallel_tensor_guid_t t_partition_input = + require_only_key(partition_input.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult layer_1 = add_parallel_layer( + pcg, relu_attrs, {{TensorSlotName::INPUT, t_partition_input}}, {}); + parallel_tensor_guid_t t_layer_1 = + require_only_key(layer_1.outputs, TensorSlotName::OUTPUT); + ParallelLayerAddedResult layer_2 = add_parallel_layer( + pcg, relu_attrs, {{TensorSlotName::INPUT, t_layer_1}}, {}); + + ParallelComputationGraphEdge edge = + get_only(get_pcg_edges_from_layer_to_layer( + /*pcg=*/pcg, + /*src=*/layer_1.parallel_layer, + /*dst=*/layer_2.parallel_layer)); + + BinaryTreePath src_path = BinaryTreePath{{}}; + BinaryTreePath dst_path = BinaryTreePath{{}}; + + AbstractedSingleTensorMovement result = + get_abstracted_single_tensor_movement_along_edge( + pcg, edge, src_path, dst_path); + + num_bytes_t shard_size = + get_piece_size_in_bytes(get_parallel_tensor_shape(pcg, t_layer_1)); + + auto mk_single_tensor_communication = + [&](nonnegative_int src_coord, + nonnegative_int dst_coord) -> AbstractedSingleTensorCommunication { + return AbstractedSingleTensorCommunication{ + /*edge=*/AbstractedSingleTensorCommunicationEdge{ + /*src_coord=*/TaskSpaceCoordinate{OrthotopeCoord{{src_coord}}}, + /*dst=*/ + AbstractedDevice{ + /*operator_tree_path=*/dst_path, + /*task_space_coordinate=*/ + TaskSpaceCoordinate{OrthotopeCoord{{dst_coord}}}, + }, + }, + /*size=*/shard_size, + }; + }; + + AbstractedSingleTensorMovement correct = + abstracted_single_tensor_movement_from_communications( + /*src_op_tree_path=*/src_path, + /*communications=*/{ + mk_single_tensor_communication(0_n, 0_n), + mk_single_tensor_communication(1_n, 1_n), + }); + + CHECK(result == correct); + } + TEST_CASE("get_abstracted_tensor_set_movement_across_split") { auto make_series_split = [](PCGBinarySPDecomposition const &lhs, PCGBinarySPDecomposition const &rhs) { @@ -70,16 +167,34 @@ TEST_SUITE(FF_TEST_SUITE) { /*name=*/std::nullopt, }; + auto mk_task_space_coord = [&](nonnegative_int coord) { + return TaskSpaceCoordinate{ + OrthotopeCoord{{ + coord, + }}, + }; + }; + + auto mk_abstracted_device = [&](BinaryTreePath const &path, + nonnegative_int coord) { + return AbstractedDevice{ + /*operator_tree_path=*/path, + /*task_space_coordinate=*/mk_task_space_coord(coord), + }; + }; + SUBCASE("no edges across split") { ParallelLayerAddedResult input1 = pcg_add_input_layer(pcg, input_shape); - parallel_tensor_guid_t t_input1 = get_only(input1.outputs); - ParallelLayerAddedResult partition_input1 = - add_parallel_layer(pcg, partition_attrs, {t_input1}, {}); + parallel_tensor_guid_t t_input1 = + require_only_key(input1.outputs, TensorSlotName::OUTPUT); + ParallelLayerAddedResult partition_input1 = add_parallel_layer( + pcg, partition_attrs, {{TensorSlotName::INPUT, t_input1}}, {}); ParallelLayerAddedResult input2 = pcg_add_input_layer(pcg, input_shape); - parallel_tensor_guid_t t_input2 = get_only(input2.outputs); - ParallelLayerAddedResult partition_input2 = - add_parallel_layer(pcg, partition_attrs, {t_input2}, {}); + parallel_tensor_guid_t t_input2 = + require_only_key(input2.outputs, TensorSlotName::OUTPUT); + ParallelLayerAddedResult partition_input2 = add_parallel_layer( + pcg, partition_attrs, {{TensorSlotName::INPUT, t_input2}}, {}); PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ make_series_split(make_leaf(input1.parallel_layer), @@ -101,16 +216,19 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("single edge across split") { ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); - parallel_tensor_guid_t t_input = get_only(input.outputs); - ParallelLayerAddedResult partition_input = - add_parallel_layer(pcg, partition_attrs, {t_input}, {}); - parallel_tensor_guid_t t_partition_input = get_only(input.outputs); - - ParallelLayerAddedResult layer_1 = - add_parallel_layer(pcg, relu_attrs, {t_partition_input}, {}); - parallel_tensor_guid_t t_layer_1 = get_only(layer_1.outputs); - ParallelLayerAddedResult layer_2 = - add_parallel_layer(pcg, relu_attrs, {t_layer_1}, {}); + parallel_tensor_guid_t t_input = + require_only_key(input.outputs, TensorSlotName::OUTPUT); + ParallelLayerAddedResult partition_input = add_parallel_layer( + pcg, partition_attrs, {{TensorSlotName::INPUT, t_input}}, {}); + parallel_tensor_guid_t t_partition_input = + require_only_key(partition_input.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult layer_1 = add_parallel_layer( + pcg, relu_attrs, {{TensorSlotName::INPUT, t_partition_input}}, {}); + parallel_tensor_guid_t t_layer_1 = + require_only_key(layer_1.outputs, TensorSlotName::OUTPUT); + ParallelLayerAddedResult layer_2 = add_parallel_layer( + pcg, relu_attrs, {{TensorSlotName::INPUT, t_layer_1}}, {}); PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ make_series_split( @@ -124,19 +242,31 @@ TEST_SUITE(FF_TEST_SUITE) { get_abstracted_tensor_set_movement_across_split( pcg_get_transitive_reduction(pcg), split); + BinaryTreePath src_path = BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}; + + BinaryTreePath dst_path = BinaryTreePath{{}}; + + auto mk_abstracted_edge = [&](nonnegative_int src_coord, + nonnegative_int dst_coord) { + return AbstractedSingleTensorCommunicationEdge{ + /*src=*/mk_task_space_coord(src_coord), + /*dst=*/mk_abstracted_device(dst_path, dst_coord), + }; + }; + + num_bytes_t shard_size = get_size_in_bytes( + get_reduced_shape(get_parallel_tensor_shape(pcg, t_layer_1))); + AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ /*single_tensor_movements=*/{ AbstractedSingleTensorMovement{ - /*parallel_tensor_shape=*/par_input_shape, - /*src_machine_views=*/ - { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - }}, - }, - /*dst_machine_views=*/ + /*src_op_tree_path=*/src_path, + /*edge_to_size=*/ { - BinaryTreePath{{}}, + {mk_abstracted_edge(0_n, 0_n), shard_size}, + {mk_abstracted_edge(1_n, 1_n), shard_size}, }, }, }, @@ -147,21 +277,38 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("does not include edges removed by transitive reduction") { ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); - parallel_tensor_guid_t t_input = get_only(input.outputs); - ParallelLayerAddedResult partition_input = - add_parallel_layer(pcg, partition_attrs, {t_input}, {}); - parallel_tensor_guid_t t_partition_input = get_only(input.outputs); - - ParallelLayerAddedResult layer_1 = - add_parallel_layer(pcg, relu_attrs, {t_partition_input}, {}); - parallel_tensor_guid_t t_layer_1 = get_only(layer_1.outputs); - - ParallelLayerAddedResult layer_2 = - add_parallel_layer(pcg, relu_attrs, {t_layer_1}, {}); - parallel_tensor_guid_t t_layer_2 = get_only(layer_2.outputs); + parallel_tensor_guid_t t_input = + require_only_key(input.outputs, TensorSlotName::OUTPUT); + ParallelLayerAddedResult partition_input = add_parallel_layer( + pcg, partition_attrs, {{TensorSlotName::INPUT, t_input}}, {}); + parallel_tensor_guid_t t_partition_input = + require_only_key(partition_input.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult layer_1 = add_parallel_layer( + pcg, relu_attrs, {{TensorSlotName::INPUT, t_partition_input}}, {}); + parallel_tensor_guid_t t_layer_1 = + require_only_key(layer_1.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult layer_2 = add_parallel_layer( + pcg, relu_attrs, {{TensorSlotName::INPUT, t_layer_1}}, {}); + parallel_tensor_guid_t t_layer_2 = + require_only_key(layer_2.outputs, TensorSlotName::OUTPUT); ParallelLayerAddedResult layer_3 = - add_parallel_layer(pcg, ew_add_attrs, {t_layer_1, t_layer_2}, {}); + add_parallel_layer(pcg, + ew_add_attrs, + /*inputs=*/ + { + { + TensorSlotName::LHS_INPUT, + t_layer_1, + }, + { + TensorSlotName::RHS_INPUT, + t_layer_2, + }, + }, + /*weights=*/{}); PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ make_series_split( @@ -176,20 +323,32 @@ TEST_SUITE(FF_TEST_SUITE) { get_abstracted_tensor_set_movement_across_split( pcg_get_transitive_reduction(pcg), split); + BinaryTreePath src_path = BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}; + + BinaryTreePath dst_path = BinaryTreePath{{}}; + + auto mk_abstracted_edge = [&](nonnegative_int src_coord, + nonnegative_int dst_coord) { + return AbstractedSingleTensorCommunicationEdge{ + /*src=*/mk_task_space_coord(src_coord), + /*dst=*/mk_abstracted_device(dst_path, dst_coord), + }; + }; + + num_bytes_t shard_size = get_size_in_bytes( + get_reduced_shape(get_parallel_tensor_shape(pcg, t_layer_2))); + AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ /*single_tensor_movements=*/{ AbstractedSingleTensorMovement{ - /*parallel_tensor_shape=*/par_input_shape, - /*src_machine_views=*/ + /*src_op_tree_path=*/src_path, + /*edge_to_size=*/ { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - BinaryTreePathEntry::RIGHT_CHILD, - }}, - }, - /*dst_machine_views=*/ - { - BinaryTreePath{{}}, + {mk_abstracted_edge(0_n, 0_n), shard_size}, + {mk_abstracted_edge(1_n, 1_n), shard_size}, }, }, }, @@ -200,20 +359,23 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("single tensor, multiple consumers across split") { ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); - parallel_tensor_guid_t t_input = get_only(input.outputs); - ParallelLayerAddedResult partition_input = - add_parallel_layer(pcg, partition_attrs, {t_input}, {}); - parallel_tensor_guid_t t_partition_input = get_only(input.outputs); + parallel_tensor_guid_t t_input = + require_only_key(input.outputs, TensorSlotName::OUTPUT); + ParallelLayerAddedResult partition_input = add_parallel_layer( + pcg, partition_attrs, {{TensorSlotName::INPUT, t_input}}, {}); + parallel_tensor_guid_t t_partition_input = + require_only_key(partition_input.outputs, TensorSlotName::OUTPUT); - ParallelLayerAddedResult layer_1 = - add_parallel_layer(pcg, relu_attrs, {t_partition_input}, {}); - parallel_tensor_guid_t t_layer_1 = get_only(layer_1.outputs); + ParallelLayerAddedResult layer_1 = add_parallel_layer( + pcg, relu_attrs, {{TensorSlotName::INPUT, t_partition_input}}, {}); + parallel_tensor_guid_t t_layer_1 = + require_only_key(layer_1.outputs, TensorSlotName::OUTPUT); - ParallelLayerAddedResult layer_2 = - add_parallel_layer(pcg, relu_attrs, {t_layer_1}, {}); + ParallelLayerAddedResult layer_2 = add_parallel_layer( + pcg, relu_attrs, {{TensorSlotName::INPUT, t_layer_1}}, {}); - ParallelLayerAddedResult layer_3 = - add_parallel_layer(pcg, relu_attrs, {t_layer_1}, {}); + ParallelLayerAddedResult layer_3 = add_parallel_layer( + pcg, relu_attrs, {{TensorSlotName::INPUT, t_layer_1}}, {}); PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ make_series_split( @@ -228,24 +390,40 @@ TEST_SUITE(FF_TEST_SUITE) { get_abstracted_tensor_set_movement_across_split( pcg_get_transitive_reduction(pcg), split); + BinaryTreePath src_path = BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}; + + BinaryTreePath dst1_path = BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}; + + BinaryTreePath dst2_path = BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}; + + auto mk_abstracted_edge = [&](nonnegative_int src_coord, + BinaryTreePath dst_path, + nonnegative_int dst_coord) { + return AbstractedSingleTensorCommunicationEdge{ + /*src=*/mk_task_space_coord(src_coord), + /*dst=*/mk_abstracted_device(dst_path, dst_coord), + }; + }; + + num_bytes_t shard_size = get_size_in_bytes( + get_reduced_shape(get_parallel_tensor_shape(pcg, t_layer_1))); + AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ /*single_tensor_movements=*/{ AbstractedSingleTensorMovement{ - /*parallel_tensor_shape=*/par_input_shape, - /*src_machine_views=*/ - { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - }}, - }, - /*dst_machine_views=*/ + /*src_op_tree_path=*/src_path, + /*edge_to_size=*/ { - BinaryTreePath{{ - BinaryTreePathEntry::LEFT_CHILD, - }}, - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - }}, + {mk_abstracted_edge(0_n, dst1_path, 0_n), shard_size}, + {mk_abstracted_edge(1_n, dst1_path, 1_n), shard_size}, + {mk_abstracted_edge(0_n, dst2_path, 0_n), shard_size}, + {mk_abstracted_edge(1_n, dst2_path, 1_n), shard_size}, }, }, }, @@ -256,25 +434,41 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("multiple tensors, multiple consumers across split") { ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); - parallel_tensor_guid_t t_input = get_only(input.outputs); - ParallelLayerAddedResult partition_input = - add_parallel_layer(pcg, partition_attrs, {t_input}, {}); - parallel_tensor_guid_t t_partition_input = get_only(input.outputs); - - ParallelLayerAddedResult layer_1 = - add_parallel_layer(pcg, relu_attrs, {t_partition_input}, {}); - - ParallelLayerAddedResult layer_2 = - add_parallel_layer(pcg, relu_attrs, {t_partition_input}, {}); - - ParallelLayerAddedResult layer_3 = - add_parallel_layer(pcg, relu_attrs, {get_only(layer_1.outputs)}, {}); - - ParallelLayerAddedResult layer_4 = add_parallel_layer( - pcg, - ew_add_attrs, - {get_only(layer_1.outputs), get_only(layer_2.outputs)}, - {}); + parallel_tensor_guid_t t_input = + require_only_key(input.outputs, TensorSlotName::OUTPUT); + ParallelLayerAddedResult partition_input = add_parallel_layer( + pcg, partition_attrs, {{TensorSlotName::INPUT, t_input}}, {}); + parallel_tensor_guid_t t_partition_input = + require_only_key(partition_input.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult layer_1 = add_parallel_layer( + pcg, relu_attrs, {{TensorSlotName::INPUT, t_partition_input}}, {}); + parallel_tensor_guid_t t_layer_1 = + require_only_key(layer_1.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult layer_2 = add_parallel_layer( + pcg, relu_attrs, {{TensorSlotName::INPUT, t_partition_input}}, {}); + parallel_tensor_guid_t t_layer_2 = + require_only_key(layer_2.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult layer_3 = add_parallel_layer( + pcg, relu_attrs, {{TensorSlotName::INPUT, t_layer_1}}, {}); + + ParallelLayerAddedResult layer_4 = + add_parallel_layer(pcg, + ew_add_attrs, + /*inputs=*/ + { + { + TensorSlotName::LHS_INPUT, + t_layer_1, + }, + { + TensorSlotName::RHS_INPUT, + t_layer_2, + }, + }, + /*weights=*/{}); PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ make_series_split( @@ -289,41 +483,56 @@ TEST_SUITE(FF_TEST_SUITE) { get_abstracted_tensor_set_movement_across_split( pcg_get_transitive_reduction(pcg), split); + BinaryTreePath src1_path = BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}; + + BinaryTreePath src2_path = BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}; + + BinaryTreePath dst1_path = BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}; + + BinaryTreePath dst2_path = BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}; + + auto mk_abstracted_edge = [&](nonnegative_int src_coord, + BinaryTreePath dst_path, + nonnegative_int dst_coord) { + return AbstractedSingleTensorCommunicationEdge{ + /*src=*/mk_task_space_coord(src_coord), + /*dst=*/mk_abstracted_device(dst_path, dst_coord), + }; + }; + + num_bytes_t t1_shard_size = get_size_in_bytes( + get_reduced_shape(get_parallel_tensor_shape(pcg, t_layer_1))); + num_bytes_t t2_shard_size = get_size_in_bytes( + get_reduced_shape(get_parallel_tensor_shape(pcg, t_layer_2))); + AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ /*single_tensor_movements=*/{ AbstractedSingleTensorMovement{ - /*parallel_tensor_shape=*/par_input_shape, - /*src_machine_views=*/ - { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - BinaryTreePathEntry::LEFT_CHILD, - }}, - }, - /*dst_machine_views=*/ + /*src_op_tree_path=*/src1_path, + /*edge_to_size=*/ { - BinaryTreePath{{ - BinaryTreePathEntry::LEFT_CHILD, - }}, - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - }}, + {mk_abstracted_edge(0_n, dst1_path, 0_n), t1_shard_size}, + {mk_abstracted_edge(1_n, dst1_path, 1_n), t1_shard_size}, + {mk_abstracted_edge(0_n, dst2_path, 0_n), t1_shard_size}, + {mk_abstracted_edge(1_n, dst2_path, 1_n), t1_shard_size}, }, }, AbstractedSingleTensorMovement{ - /*parallel_tensor_shape=*/par_input_shape, - /*src_machine_views=*/ - { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - BinaryTreePathEntry::RIGHT_CHILD, - }}, - }, - /*dst_machine_views=*/ + /*src_op_tree_path=*/src2_path, + /*edge_to_size=*/ { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - }}, + {mk_abstracted_edge(0_n, dst2_path, 0_n), t2_shard_size}, + {mk_abstracted_edge(1_n, dst2_path, 1_n), t2_shard_size}, }, }, }, diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_machine_resource_splits.cc b/lib/compiler/test/src/compiler/machine_mapping/get_machine_resource_splits.cc deleted file mode 100644 index 5ae89a8123..0000000000 --- a/lib/compiler/test/src/compiler/machine_mapping/get_machine_resource_splits.cc +++ /dev/null @@ -1,240 +0,0 @@ -#include "compiler/machine_mapping/get_machine_resource_splits.h" -#include "test/utils/doctest/fmt/pair.h" -#include "test/utils/doctest/fmt/unordered_set.h" -#include "utils/hash/pair.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_machine_resource_splits") { - auto make_machine_spec = [](positive_int num_nodes, - positive_int num_gpus_per_node) { - return MachineSpecification{ - /*num_nodes=*/num_nodes, - /*num_cpus_per_node=*/1_p, - /*num_gpus_per_node=*/num_gpus_per_node, - /*inter_node_bandwidth=*/1.0, - /*intra_node_bandwidth=*/1.0, - }; - }; - - SUBCASE("returns no splits if no splits are possible") { - MachineSpecification input = make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/1_p); - - std::unordered_set> - result = get_machine_resource_splits(input); - std::unordered_set> - correct = {}; - - CHECK(result == correct); - } - - SUBCASE( - "returns splits in gpu and node dimensions, but not at the same time") { - MachineSpecification input = make_machine_spec(/*num_nodes=*/2_p, - /*num_gpus_per_node=*/2_p); - - std::unordered_set> - result = get_machine_resource_splits(input); - - std::unordered_set> - correct = { - { - make_machine_spec(/*num_nodes=*/2_p, - /*num_gpus_per_node=*/1_p), - make_machine_spec(/*num_nodes=*/2_p, - /*num_gpus_per_node=*/1_p), - }, - { - make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/2_p), - make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/2_p), - }, - - }; - - CHECK(result == correct); - } - - SUBCASE("returns splits in node dimension in powers of two") { - SUBCASE("num_nodes is a power of 2") { - MachineSpecification input = - make_machine_spec(/*num_nodes=*/8_p, - /*num_gpus_per_node=*/1_p); - - std::unordered_set< - std::pair> - result = get_machine_resource_splits(input); - - std::unordered_set< - std::pair> - correct = { - { - make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/1_p), - make_machine_spec(/*num_nodes=*/7_p, - /*num_gpus_per_node=*/1_p), - }, - { - make_machine_spec(/*num_nodes=*/2_p, - /*num_gpus_per_node=*/1_p), - make_machine_spec(/*num_nodes=*/6_p, - /*num_gpus_per_node=*/1_p), - }, - { - make_machine_spec(/*num_nodes=*/4_p, - /*num_gpus_per_node=*/1_p), - make_machine_spec(/*num_nodes=*/4_p, - /*num_gpus_per_node=*/1_p), - }, - { - make_machine_spec(/*num_nodes=*/6_p, - /*num_gpus_per_node=*/1_p), - make_machine_spec(/*num_nodes=*/2_p, - /*num_gpus_per_node=*/1_p), - }, - { - make_machine_spec(/*num_nodes=*/7_p, - /*num_gpus_per_node=*/1_p), - make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/1_p), - }, - }; - - CHECK(result == correct); - } - - SUBCASE("num_nodes is not a power of 2") { - MachineSpecification input = - make_machine_spec(/*num_nodes=*/6_p, - /*num_gpus_per_node=*/1_p); - - std::unordered_set< - std::pair> - result = get_machine_resource_splits(input); - - std::unordered_set< - std::pair> - correct = { - { - make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/1_p), - make_machine_spec(/*num_nodes=*/5_p, - /*num_gpus_per_node=*/1_p), - }, - { - make_machine_spec(/*num_nodes=*/2_p, - /*num_gpus_per_node=*/1_p), - make_machine_spec(/*num_nodes=*/4_p, - /*num_gpus_per_node=*/1_p), - }, - { - make_machine_spec(/*num_nodes=*/4_p, - /*num_gpus_per_node=*/1_p), - make_machine_spec(/*num_nodes=*/2_p, - /*num_gpus_per_node=*/1_p), - }, - { - make_machine_spec(/*num_nodes=*/5_p, - /*num_gpus_per_node=*/1_p), - make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/1_p), - }, - }; - - CHECK(result == correct); - } - } - - SUBCASE("returns splits in gpu dimension in powers of two") { - SUBCASE("num_gpus_per_node is a power of 2") { - MachineSpecification input = - make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/8_p); - - std::unordered_set< - std::pair> - result = get_machine_resource_splits(input); - - std::unordered_set< - std::pair> - correct = { - { - make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/1_p), - make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/7_p), - }, - { - make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/2_p), - make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/6_p), - }, - { - make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/4_p), - make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/4_p), - }, - { - make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/6_p), - make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/2_p), - }, - { - make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/7_p), - make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/1_p), - }, - }; - - CHECK(result == correct); - } - - SUBCASE("num_gpus_per_node is not a power of 2") { - MachineSpecification input = - make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/6_p); - - std::unordered_set< - std::pair> - result = get_machine_resource_splits(input); - - std::unordered_set< - std::pair> - correct = { - { - make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/1_p), - make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/5_p), - }, - { - make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/2_p), - make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/4_p), - }, - { - make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/4_p), - make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/2_p), - }, - { - make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/5_p), - make_machine_spec(/*num_nodes=*/1_p, - /*num_gpus_per_node=*/1_p), - }, - }; - } - } - } -} diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index 2cbc87cffe..e70c0b75d2 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -2,14 +2,16 @@ #include "compiler/cost_estimator/runtime_only_op_cost_estimate_key.dtg.h" #include "compiler/cost_estimator/runtime_only_op_cost_metrics.dtg.h" #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/machine_compute_resource_slice.h" #include "compiler/machine_mapping/machine_mapping_cache.h" #include "compiler/machine_mapping/machine_mapping_constraints.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.h" +#include "compiler/machine_mapping/machine_view.h" #include "internal/runtime_only_cost_estimator_for_test.h" #include "op-attrs/parallel_tensor_shape.h" -#include "pcg/machine_view.h" +#include "op-attrs/task_space_coordinate.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" #include "utils/containers/get_only.h" #include "utils/full_binary_tree/binary_tree_path.h" @@ -47,7 +49,7 @@ TEST_SUITE(FF_TEST_SUITE) { }; }; - MachineView mv1 = MachineView{ + MachineView mv_stride_1 = MachineView{ /*start=*/MachineSpaceCoordinate{ /*node_idx=*/0_n, /*device_idx=*/0_n, @@ -57,12 +59,12 @@ TEST_SUITE(FF_TEST_SUITE) { { MachineViewDimension{ stride_t{1_p}, - MachineSpecificationDimension::INTRA_NODE, + MachineSpecificationDimension::INTER_NODE, }, }, }; - MachineView mv2 = MachineView{ + MachineView mv_stride_2 = MachineView{ /*start=*/MachineSpaceCoordinate{ /*node_idx=*/0_n, /*device_idx=*/0_n, @@ -72,35 +74,33 @@ TEST_SUITE(FF_TEST_SUITE) { { MachineViewDimension{ stride_t{2_p}, - MachineSpecificationDimension::INTRA_NODE, + MachineSpecificationDimension::INTER_NODE, }, }, }; - MachineSpecification full_machine_spec = MachineSpecification{ - /*num_nodes=*/2_p, - /*num_cpus_per_node=*/1_p, - /*num_gpus_per_node=*/1_p, - /*inter_node_bandwidth=*/1, - /*intra_node_bandwidth=*/1, - }; + MachineComputeResourceSlice four_nodes_resources = + MachineComputeResourceSlice{ + /*num_nodes=*/4_p, + /*num_gpus_per_node=*/1_p, + }; - MachineSpecification split_machine_spec = MachineSpecification{ - /*num_nodes=*/1_p, - /*num_cpus_per_node=*/1_p, - /*num_gpus_per_node=*/1_p, - /*inter_node_bandwidth=*/1, - /*intra_node_bandwidth=*/1, - }; + MachineComputeResourceSlice three_nodes_resources = + MachineComputeResourceSlice{ + /*num_nodes=*/3_p, + /*num_gpus_per_node=*/1_p, + }; - auto allowed_machine_views1 = - [&](UnmappedRuntimeOnlyOpCostEstimateKey const &, - MachineSpecification const &resources) { - if (resources == full_machine_spec) { - return std::unordered_set{mv1, mv2}; - } else { - return std::unordered_set{mv2}; - } + MachineComputeResourceSlice two_nodes_resources = + MachineComputeResourceSlice{ + /*num_nodes=*/2_p, + /*num_gpus_per_node=*/1_p, + }; + + MachineComputeResourceSlice one_node_resources = + MachineComputeResourceSlice{ + /*num_nodes=*/1_p, + /*num_gpus_per_node=*/1_p, }; TensorShape tensor_shape = TensorShape{ @@ -113,83 +113,98 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT, }; + BinaryTreePath src_path = binary_tree_root_path(); + + ParallelLayerGuidObliviousMachineMapping mm1 = + ParallelLayerGuidObliviousMachineMapping{{ + {binary_tree_root_path(), mv_stride_1}, + }}; + ParallelLayerGuidObliviousMachineMapping mm2 = + ParallelLayerGuidObliviousMachineMapping{{ + {binary_tree_root_path(), mv_stride_2}, + }}; + + OperatorTaskSpace task_space = OperatorTaskSpace{ + MinimalOrthotope{{ + 2_ge2, + }}, + }; + + MachineMappingCache cache = empty_machine_mapping_cache(); + + ParallelTensorShape par_tensor_shape = lift_to_parallel_with_degrees( + tensor_shape, + ParallelTensorDimDegrees{ + /*sum_degree=*/SumDegree{1_p}, + /*discard_copy_degree=*/DiscardCopyDegree{1_p}, + /*shard_degrees=*/ + FFOrdered{ + 2_p, + 1_p, + }, + }); + UnmappedRuntimeOnlyOpCostEstimateKey k1 = UnmappedRuntimeOnlyOpCostEstimateKey{ - /*op_attrs=*/PCGOperatorAttrs{InputAttrs{tensor_shape}}, - /*input_shapes=*/{}, + /*op_attrs=*/PCGOperatorAttrs{ElementUnaryAttrs{ + /*type=*/OperatorType::GELU, + /*scalar=*/std::nullopt, + }}, + /*input_shapes=*/ + { + { + TensorSlotName::INPUT, + par_tensor_shape, + }, + }, /*weight_shapes=*/{}, - /*output_shapes=*/{}, + /*output_shapes=*/ + { + { + TensorSlotName::OUTPUT, + par_tensor_shape, + }, + }, }; UnmappedRuntimeOnlyOpCostEstimateKey k2 = UnmappedRuntimeOnlyOpCostEstimateKey{ - /*op_attrs=*/PCGOperatorAttrs{ElementBinaryAttrs{ - /*type=*/OperatorType::EW_ADD, - /*compute_type=*/DataType::FLOAT, - /*should_broadcast_lhs=*/false, - /*should_broadcast_rhs=*/false, + /*op_attrs=*/PCGOperatorAttrs{ElementUnaryAttrs{ + /*type=*/OperatorType::RELU, + /*scalar=*/std::nullopt, }}, - /*input_shapes=*/{}, + /*input_shapes=*/ + { + { + TensorSlotName::INPUT, + par_tensor_shape, + }, + }, /*weight_shapes=*/{}, - /*output_shapes=*/{}, + /*output_shapes=*/ + { + { + TensorSlotName::OUTPUT, + par_tensor_shape, + }, + }, }; - ParallelTensorShape par_tensor_shape = lift_to_parallel(tensor_shape); - - AbstractedTensorSetMovement movement1 = AbstractedTensorSetMovement{{ - AbstractedSingleTensorMovement{ - /*parallel_tensor_shape=*/par_tensor_shape, - /*src_machine_views=*/{}, - /*dst_machine_views=*/{}, - }, - }}; - - ParallelLayerGuidObliviousMachineMapping mm1 = - ParallelLayerGuidObliviousMachineMapping{{ - {binary_tree_root_path(), mv1}, - }}; - ParallelLayerGuidObliviousMachineMapping mm2 = - ParallelLayerGuidObliviousMachineMapping{{ - {binary_tree_root_path(), mv2}, - }}; - - auto map1 = std::unordered_map{{ - {map_unmapped_runtime_only_op_cost_estimate_key(k1, mv1), - RuntimeOnlyOpCostMetrics{/*forward_runtime=*/0.5_ms, - /*backward_runtime=*/0.5_ms}}, - {map_unmapped_runtime_only_op_cost_estimate_key(k2, mv1), - RuntimeOnlyOpCostMetrics{/*forward_runtime=*/1.0_ms, - /*backward_runtime=*/1.0_ms}}, - {map_unmapped_runtime_only_op_cost_estimate_key(k1, mv2), - RuntimeOnlyOpCostMetrics{/*forward_runtime=*/0.75_ms, - /*backward_runtime=*/0.75_ms}}, - {map_unmapped_runtime_only_op_cost_estimate_key(k2, mv2), - RuntimeOnlyOpCostMetrics{/*forward_runtime=*/1.25_ms, - /*backward_runtime=*/1.25_ms}}, - }}; - - RuntimeOnlyCostEstimator runtime_only_cost_estimator = - make_fake_runtime_only_cost_estimator( - map1, - std::unordered_map{{ - {TensorSetMovement{{}}, 0.0_ms}, - {concretize_abstracted_tensor_set_movement(movement1, mm1, mm1), - 0.1_ms}, - {concretize_abstracted_tensor_set_movement(movement1, mm2, mm2), - 0.2_ms}, - {concretize_abstracted_tensor_set_movement(movement1, mm1, mm2), - 0.3_ms}, - {concretize_abstracted_tensor_set_movement(movement1, mm2, mm1), - 0.4_ms}, - }}); - - MachineMappingContext context = MachineMappingContext{ - runtime_only_cost_estimator, - allowed_machine_views1, + auto mk_cost_metrics = [&](float total_cost) { + return RuntimeOnlyOpCostMetrics{ + /*forward_runtime=*/milliseconds_t{total_cost / 2}, + /*backward_runtime=*/milliseconds_t{total_cost / 2}, + }; }; - MachineMappingCache cache = empty_machine_mapping_cache(); + auto mk_cost_entry = [&](UnmappedRuntimeOnlyOpCostEstimateKey const &key, + MachineView const &mv, + float total_cost) { + return std::pair{ + map_unmapped_runtime_only_op_cost_estimate_key(key, mv), + mk_cost_metrics(total_cost), + }; + }; SUBCASE("single layer") { MachineMappingProblemTree problem_tree = make_leaf(k1); @@ -198,14 +213,40 @@ TEST_SUITE(FF_TEST_SUITE) { get_unconstrained_solution_for_layers( get_all_leaf_paths(problem_tree)); + auto allowed_machine_views = + [&](UnmappedRuntimeOnlyOpCostEstimateKey const &k, + MachineComputeResourceSlice const &resources) { + ASSERT(k == k1); + ASSERT(resources == four_nodes_resources); + + return std::unordered_set{ + mv_stride_1, + mv_stride_2, + }; + }; + + RuntimeOnlyCostEstimator runtime_only_cost_estimator = + make_fake_runtime_only_cost_estimator( + { + mk_cost_entry(k1, mv_stride_1, 1), + mk_cost_entry(k1, mv_stride_2, 2), + }, + std::unordered_map{{}}); + + MachineMappingContext context = MachineMappingContext{ + /*cost_estimator=*/runtime_only_cost_estimator, + /*allowed_machine_views=*/allowed_machine_views, + }; + MachineMappingResult result = get_optimal_machine_mapping( - cache, context, problem_tree, full_machine_spec, constraints); + cache, context, problem_tree, four_nodes_resources, constraints); + MachineMappingResult correct = MachineMappingResult{ FeasibleMachineMappingResult{ - /*runtime=*/1.0_ms, + /*runtime=*/1_ms, /*machine_mapping=*/ ParallelLayerGuidObliviousMachineMapping{{ - {binary_tree_root_path(), mv1}, + {binary_tree_root_path(), mv_stride_1}, }}, }, }; @@ -214,37 +255,195 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("pair of layers in sequence") { + AbstractedTensorSetMovement k1_to_k2 = AbstractedTensorSetMovement{ + /*single_tensor_movements=*/{ + AbstractedSingleTensorMovement{ + /*src_op_tree_path=*/binary_tree_root_path(), + /*edge_to_size=*/ + { + { + AbstractedSingleTensorCommunicationEdge{ + /*src_coord=*/make_task_space_coordinate({0_n}), + /*dst=*/ + AbstractedDevice{ + /*operator_tree_path=*/ + binary_tree_root_path(), + /*task_space_coordinate=*/ + make_task_space_coordinate({0_n}), + }, + }, + get_size_in_bytes(tensor_shape), + }, + { + AbstractedSingleTensorCommunicationEdge{ + /*src_coord=*/make_task_space_coordinate({1_n}), + /*dst=*/ + AbstractedDevice{ + /*operator_tree_path=*/ + binary_tree_root_path(), + /*task_space_coordinate=*/ + make_task_space_coordinate({1_n}), + }, + }, + get_size_in_bytes(tensor_shape), + }, + }, + }, + }, + }; + MachineMappingProblemTree problem_tree = - make_series_split(movement1, make_leaf(k1), make_leaf(k2)); + make_series_split(k1_to_k2, make_leaf(k1), make_leaf(k2)); + + auto mk_tensor_set_movement = [&](MachineView const &src_mv, + MachineView const &dst_mv) { + MachineSpaceStencil src_stencil = MachineSpaceStencil{ + /*operator_task_space=*/task_space, + /*machine_view=*/src_mv, + }; + + MachineSpaceStencil dst_stencil = MachineSpaceStencil{ + /*operator_task_space=*/task_space, + /*machine_view=*/dst_mv, + }; + + return concretize_abstracted_tensor_set_movement( + k1_to_k2, + /*pre_machine_stencils=*/{{binary_tree_root_path(), src_stencil}}, + /*post_machine_stencils=*/{{binary_tree_root_path(), dst_stencil}}); + }; + + auto allowed_machine_views = + [&](UnmappedRuntimeOnlyOpCostEstimateKey const &k, + MachineComputeResourceSlice const &resources) { + if (resources == four_nodes_resources) { + return std::unordered_set{mv_stride_1, mv_stride_2}; + } else if (resources == three_nodes_resources) { + return std::unordered_set{mv_stride_1, mv_stride_2}; + } else if (resources == two_nodes_resources) { + return std::unordered_set{mv_stride_1}; + } else { + return std::unordered_set{}; + } + }; MachineMappingConstraints constraints = get_unconstrained_solution_for_layers( get_all_leaf_paths(problem_tree)); - MachineMappingResult result = get_optimal_machine_mapping( - cache, context, problem_tree, full_machine_spec, constraints); - MachineMappingResult correct = MachineMappingResult{ - FeasibleMachineMappingResult{ - /*runtime=*/1.0_ms + 2.0_ms + 0.1_ms, - /*machine_mapping=*/ - ParallelLayerGuidObliviousMachineMapping{{ - { - BinaryTreePath{{ - BinaryTreePathEntry::LEFT_CHILD, - }}, - mv1, - }, - { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - }}, - mv1, - }, - }}, - }, - }; + SUBCASE("solution requires taking comm cost into account") { + RuntimeOnlyCostEstimator runtime_only_cost_estimator = + make_fake_runtime_only_cost_estimator( + std::unordered_map{{ + mk_cost_entry(k1, mv_stride_1, 1), + mk_cost_entry(k1, mv_stride_2, 3), + mk_cost_entry(k2, mv_stride_1, 4), + mk_cost_entry(k2, mv_stride_2, 1), + }}, + std::unordered_map{{ + { + TensorSetMovement{{}}, + 0.0_ms, + }, + { + mk_tensor_set_movement(mv_stride_1, mv_stride_2), + 5_ms, + }, + { + mk_tensor_set_movement(mv_stride_2, mv_stride_1), + 5_ms, + }, + }}); - CHECK(result == correct); + MachineMappingContext context = MachineMappingContext{ + /*cost_estimator=*/runtime_only_cost_estimator, + /*allowed_machine_views=*/allowed_machine_views, + }; + + MachineMappingResult result = get_optimal_machine_mapping( + cache, context, problem_tree, four_nodes_resources, constraints); + + MachineMappingResult correct = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/1.0_ms + 3.0_ms, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + mv_stride_2, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + mv_stride_2, + }, + }}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("solution places operators on different machine views") { + RuntimeOnlyCostEstimator runtime_only_cost_estimator = + make_fake_runtime_only_cost_estimator( + std::unordered_map{{ + mk_cost_entry(k1, mv_stride_1, 1), + mk_cost_entry(k1, mv_stride_2, 3), + mk_cost_entry(k2, mv_stride_1, 4), + mk_cost_entry(k2, mv_stride_2, 1), + }}, + std::unordered_map{{ + { + TensorSetMovement{{}}, + 0.0_ms, + }, + { + mk_tensor_set_movement(mv_stride_1, mv_stride_2), + 1_ms, + }, + { + mk_tensor_set_movement(mv_stride_2, mv_stride_1), + 1_ms, + }, + }}); + + MachineMappingContext context = MachineMappingContext{ + /*cost_estimator=*/runtime_only_cost_estimator, + /*allowed_machine_views=*/allowed_machine_views, + }; + + MachineMappingResult result = get_optimal_machine_mapping( + cache, context, problem_tree, four_nodes_resources, constraints); + + MachineMappingResult correct = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/1.0_ms + 1.0_ms + 1.0_ms, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + mv_stride_1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + mv_stride_2, + }, + }}, + }, + }; + + CHECK(result == correct); + } } SUBCASE("pair of layers in parallel") { @@ -255,30 +454,181 @@ TEST_SUITE(FF_TEST_SUITE) { get_unconstrained_solution_for_layers( get_all_leaf_paths(problem_tree)); - MachineMappingResult result = get_optimal_machine_mapping( - cache, context, problem_tree, full_machine_spec, constraints); - MachineMappingResult correct = MachineMappingResult{ - FeasibleMachineMappingResult{ - /*runtime=*/2.5_ms, - /*machine_mapping=*/ - ParallelLayerGuidObliviousMachineMapping{{ - { - BinaryTreePath{{ - BinaryTreePathEntry::LEFT_CHILD, - }}, - mv2, - }, - { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - }}, - mv2, - }, - }}, - }, - }; + auto allowed_machine_views = + [&](UnmappedRuntimeOnlyOpCostEstimateKey const &k, + MachineComputeResourceSlice const &resources) { + if (resources == four_nodes_resources) { + return std::unordered_set{mv_stride_1, mv_stride_2}; + } else if (resources == three_nodes_resources) { + return std::unordered_set{mv_stride_1, mv_stride_2}; + } else if (resources == two_nodes_resources) { + return std::unordered_set{mv_stride_1}; + } else { + return std::unordered_set{}; + } + }; - CHECK(result == correct); + SUBCASE("cannot use overlapping machine views in parallel") { + RuntimeOnlyCostEstimator runtime_only_cost_estimator = + make_fake_runtime_only_cost_estimator( + std::unordered_map{{ + mk_cost_entry(k1, mv_stride_1, 1), + mk_cost_entry(k1, mv_stride_2, 3), + mk_cost_entry(k2, mv_stride_1, 4), + mk_cost_entry(k2, mv_stride_2, 1), + }}, + std::unordered_map{{ + { + TensorSetMovement{{}}, + 0.0_ms, + }, + }}); + + MachineMappingContext context = MachineMappingContext{ + /*cost_estimator=*/runtime_only_cost_estimator, + /*allowed_machine_views=*/allowed_machine_views, + }; + + MachineMappingResult result = get_optimal_machine_mapping( + cache, context, problem_tree, four_nodes_resources, constraints); + + MachineMappingResult correct = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/2_ms, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + mv_stride_1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + mv_stride_2, + }, + }}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("solution is running operators in parallel") { + RuntimeOnlyCostEstimator runtime_only_cost_estimator = + make_fake_runtime_only_cost_estimator( + std::unordered_map{{ + mk_cost_entry(k1, mv_stride_1, 1), + mk_cost_entry(k1, mv_stride_2, 3), + mk_cost_entry(k2, mv_stride_1, 3), + mk_cost_entry(k2, mv_stride_2, 4), + }}, + std::unordered_map{{ + { + TensorSetMovement{{}}, + 0.0_ms, + }, + }}); + + MachineMappingContext context = MachineMappingContext{ + /*cost_estimator=*/runtime_only_cost_estimator, + /*allowed_machine_views=*/allowed_machine_views, + }; + + MachineMappingResult result = get_optimal_machine_mapping( + cache, context, problem_tree, four_nodes_resources, constraints); + + MachineView translated_mv_stride_1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/2_n, + /*device_idx=*/0_n, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + /*stride=*/stride_t{1_p}, + /*projection=*/MachineSpecificationDimension::INTER_NODE, + }, + }, + }; + + MachineMappingResult correct = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/3_ms, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + mv_stride_1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + translated_mv_stride_1, + }, + }}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("solution is running operators in series") { + RuntimeOnlyCostEstimator runtime_only_cost_estimator = + make_fake_runtime_only_cost_estimator( + std::unordered_map{{ + mk_cost_entry(k1, mv_stride_1, 3), + mk_cost_entry(k1, mv_stride_2, 1), + mk_cost_entry(k2, mv_stride_1, 4), + mk_cost_entry(k2, mv_stride_2, 1), + }}, + std::unordered_map{{ + { + TensorSetMovement{{}}, + 0.0_ms, + }, + }}); + + MachineMappingContext context = MachineMappingContext{ + /*cost_estimator=*/runtime_only_cost_estimator, + /*allowed_machine_views=*/allowed_machine_views, + }; + + MachineMappingResult result = get_optimal_machine_mapping( + cache, context, problem_tree, four_nodes_resources, constraints); + + MachineMappingResult correct = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/2_ms, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + mv_stride_2, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + mv_stride_2, + }, + }}, + }, + }; + + CHECK(result == correct); + } } } } diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc b/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc index 586a2b7764..e6af704cf1 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc @@ -1,26 +1,16 @@ #include "compiler/machine_mapping/get_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/machine_view.h" #include "compiler/machine_mapping/transitive_reduced_pcg.h" #include "internal/cost_estimator_for_test.h" -#include "pcg/machine_view.h" +#include "op-attrs/parallel_tensor_shape.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" -#include "utils/containers/get_only.h" +#include "utils/containers/require_only_key.h" #include #include using namespace ::FlexFlow; -bool isDebuggerActive() { - std::ifstream in("/proc/self/status"); - for (std::string line; std::getline(in, line);) { - static int const PREFIX_LEN = 11; - if (line.compare(0, PREFIX_LEN, "TracerPid:\t") == 0) { - return line.length() > PREFIX_LEN && line[PREFIX_LEN] != '0'; - } - } - return false; -} - TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_tensor_set_movement_across_split") { auto make_pcg_series_split = [](PCGBinarySPDecomposition const &lhs, @@ -50,7 +40,8 @@ TEST_SUITE(FF_TEST_SUITE) { }; ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); - parallel_tensor_guid_t t_input = get_only(input.outputs); + parallel_tensor_guid_t t_input = + require_only_key(input.outputs, TensorSlotName::OUTPUT); ParallelLayerAttrs partition_attrs = ParallelLayerAttrs{ /*op_attrs=*/PCGOperatorAttrs{ @@ -61,10 +52,10 @@ TEST_SUITE(FF_TEST_SUITE) { }, /*name=*/std::nullopt, }; - ParallelLayerAddedResult partition_input = - add_parallel_layer(pcg, partition_attrs, {t_input}, {}); + ParallelLayerAddedResult partition_input = add_parallel_layer( + pcg, partition_attrs, {{TensorSlotName::INPUT, t_input}}, {}); parallel_tensor_guid_t t_partition_input = - get_only(partition_input.outputs); + require_only_key(partition_input.outputs, TensorSlotName::OUTPUT); ParallelTensorShape partitioned_input_shape = get_parallel_tensor_shape(pcg, t_partition_input); @@ -91,11 +82,12 @@ TEST_SUITE(FF_TEST_SUITE) { /*name=*/std::nullopt, }; - ParallelLayerAddedResult relu_1 = - add_parallel_layer(pcg, relu_attrs, {t_partition_input}, {}); - parallel_tensor_guid_t t_relu_1 = get_only(relu_1.outputs); - ParallelLayerAddedResult relu_2 = - add_parallel_layer(pcg, relu_attrs, {t_relu_1}, {}); + ParallelLayerAddedResult relu_1 = add_parallel_layer( + pcg, relu_attrs, {{TensorSlotName::INPUT, t_partition_input}}, {}); + parallel_tensor_guid_t t_relu_1 = + require_only_key(relu_1.outputs, TensorSlotName::OUTPUT); + ParallelLayerAddedResult relu_2 = add_parallel_layer( + pcg, relu_attrs, {{TensorSlotName::INPUT, t_relu_1}}, {}); MachineView pre_mv1 = MachineView{ /*start=*/MachineSpaceCoordinate{ @@ -114,14 +106,14 @@ TEST_SUITE(FF_TEST_SUITE) { MachineView pre_mv2 = MachineView{ /*start=*/MachineSpaceCoordinate{ - /*node_idx=*/0_n, + /*node_idx=*/1_n, /*device_idx=*/0_n, /*device_type=*/DeviceType::GPU, }, /*dimensions=*/ { MachineViewDimension{ - stride_t{2_p}, + stride_t{1_p}, MachineSpecificationDimension::INTRA_NODE, }, }, @@ -129,14 +121,14 @@ TEST_SUITE(FF_TEST_SUITE) { MachineView post_mv1 = MachineView{ /*start=*/MachineSpaceCoordinate{ - /*node_idx=*/0_n, + /*node_idx=*/2_n, /*device_idx=*/0_n, /*device_type=*/DeviceType::GPU, }, /*dimensions=*/ { MachineViewDimension{ - stride_t{3_p}, + stride_t{1_p}, MachineSpecificationDimension::INTRA_NODE, }, }, @@ -144,34 +136,54 @@ TEST_SUITE(FF_TEST_SUITE) { MachineView post_mv2 = MachineView{ /*start=*/MachineSpaceCoordinate{ - /*node_idx=*/0_n, + /*node_idx=*/3_n, /*device_idx=*/0_n, /*device_type=*/DeviceType::GPU, }, /*dimensions=*/ { MachineViewDimension{ - stride_t{4_p}, + stride_t{1_p}, MachineSpecificationDimension::INTRA_NODE, }, }, }; + auto mk_communication_edge = [](MachineView const &src_mv, + nonnegative_int src_task_idx, + MachineView const &dst_mv, + nonnegative_int dst_task_idx) { + ASSERT(src_task_idx < 2); + ASSERT(dst_task_idx < 2); + + return CommunicationEdge{ + /*src=*/MachineSpaceCoordinate{ + /*node_idx=*/src_mv.start.node_idx, + /*device_idx=*/src_task_idx, + /*device_type=*/DeviceType::GPU, + }, + /*dst=*/ + MachineSpaceCoordinate{ + /*node_idx=*/dst_mv.start.node_idx, + /*device_idx=*/dst_task_idx, + /*device_type=*/DeviceType::GPU, + }, + }; + }; + + num_bytes_t piece_size = get_piece_size_in_bytes(partitioned_input_shape); + SUBCASE("single edge across split") { PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ - make_pcg_series_split( - make_pcg_series_split( - make_pcg_leaf_node(input.parallel_layer), - make_pcg_leaf_node(partition_input.parallel_layer)), - make_pcg_leaf_node(relu_1.parallel_layer)), + make_pcg_leaf_node(relu_1.parallel_layer), make_pcg_leaf_node(relu_2.parallel_layer), }; auto pre_mapping = ParallelLayerGuidObliviousMachineMapping{{ - {BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - }}, - pre_mv1}, + { + BinaryTreePath{{}}, + pre_mv1, + }, }}; auto post_mapping = ParallelLayerGuidObliviousMachineMapping{{ @@ -183,12 +195,16 @@ TEST_SUITE(FF_TEST_SUITE) { TensorSetMovement result = get_tensor_set_movement_across_split( pcg_get_transitive_reduction(pcg), split, pre_mapping, post_mapping); + TensorSetMovement correct = TensorSetMovement{ - /*single_tensor_movements=*/{ - SingleTensorMovement{ - /*parallel_tensor_shape=*/partitioned_input_shape, - /*src_machine_views=*/{pre_mv1}, - /*dst_machine_views=*/{post_mv1}, + /*edge_to_size=*/{ + { + mk_communication_edge(pre_mv1, 0_n, post_mv1, 0_n), + piece_size, + }, + { + mk_communication_edge(pre_mv1, 1_n, post_mv1, 1_n), + piece_size, }, }, }; @@ -196,18 +212,79 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } - SUBCASE("does not include edges removed by transitive reduction") {} + SUBCASE("does not include edges removed by transitive reduction") { + ParallelLayerAddedResult ew_add = add_parallel_layer( + pcg, + ew_add_attrs, + /*inputs=*/ + { + { + TensorSlotName::LHS_INPUT, + require_only_key(relu_1.outputs, TensorSlotName::OUTPUT), + }, + {TensorSlotName::RHS_INPUT, + require_only_key(relu_2.outputs, TensorSlotName::OUTPUT)}, + }, + /*weights=*/{}); + + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_pcg_series_split(make_pcg_leaf_node(relu_1.parallel_layer), + make_pcg_leaf_node(relu_2.parallel_layer)), + make_pcg_leaf_node(ew_add.parallel_layer), + }; + + auto pre_mapping = ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{BinaryTreePathEntry::LEFT_CHILD}}, + pre_mv2, + }, + { + BinaryTreePath{{BinaryTreePathEntry::RIGHT_CHILD}}, + pre_mv1, + }, + }}; + + auto post_mapping = ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{}}, + post_mv1, + }, + }}; + + TensorSetMovement result = get_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), split, pre_mapping, post_mapping); + + TensorSetMovement correct = TensorSetMovement{ + /*edge_to_size=*/{ + { + mk_communication_edge(pre_mv1, 0_n, post_mv1, 0_n), + piece_size, + }, + { + mk_communication_edge(pre_mv1, 1_n, post_mv1, 1_n), + piece_size, + }, + }, + }; + + CHECK(result == correct); + } SUBCASE("single tensor, multiple consumers across split") { - ParallelLayerAddedResult relu_3 = - add_parallel_layer(pcg, relu_attrs, {get_only(relu_1.outputs)}, {}); + ParallelLayerAddedResult relu_3 = add_parallel_layer( + pcg, + relu_attrs, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + require_only_key(relu_1.outputs, TensorSlotName::OUTPUT), + }, + }, + /*weights=*/{}); PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ - make_pcg_series_split( - make_pcg_series_split( - make_pcg_leaf_node(input.parallel_layer), - make_pcg_leaf_node(partition_input.parallel_layer)), - make_pcg_leaf_node(relu_1.parallel_layer)), + make_pcg_leaf_node(relu_1.parallel_layer), make_pcg_parallel_split(make_pcg_leaf_node(relu_2.parallel_layer), make_pcg_leaf_node(relu_3.parallel_layer)), }; @@ -215,9 +292,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("consumers have same view") { auto pre_mapping = ParallelLayerGuidObliviousMachineMapping{{ { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - }}, + BinaryTreePath{{}}, pre_mv1, }, }}; @@ -244,11 +319,14 @@ TEST_SUITE(FF_TEST_SUITE) { post_mapping); TensorSetMovement correct = TensorSetMovement{ - /*single_tensor_movements=*/{ - SingleTensorMovement{ - /*parallel_tensor_shape=*/partitioned_input_shape, - /*src_machine_views=*/{pre_mv1}, - /*dst_machine_views=*/{post_mv1}, + /*edge_to_size=*/{ + { + mk_communication_edge(pre_mv1, 0_n, post_mv1, 0_n), + piece_size, + }, + { + mk_communication_edge(pre_mv1, 1_n, post_mv1, 1_n), + piece_size, }, }, }; @@ -259,9 +337,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("consumers have different views") { auto pre_mapping = ParallelLayerGuidObliviousMachineMapping{{ { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - }}, + BinaryTreePath{{}}, pre_mv1, }, }}; @@ -288,11 +364,22 @@ TEST_SUITE(FF_TEST_SUITE) { post_mapping); TensorSetMovement correct = TensorSetMovement{ - /*single_tensor_movements=*/{ - SingleTensorMovement{ - /*parallel_tensor_shape=*/partitioned_input_shape, - /*src_machine_views=*/{pre_mv1}, - /*dst_machine_views=*/{post_mv1, post_mv2}, + /*edge_to_size=*/{ + { + mk_communication_edge(pre_mv1, 0_n, post_mv1, 0_n), + piece_size, + }, + { + mk_communication_edge(pre_mv1, 1_n, post_mv1, 1_n), + piece_size, + }, + { + mk_communication_edge(pre_mv1, 0_n, post_mv2, 0_n), + piece_size, + }, + { + mk_communication_edge(pre_mv1, 1_n, post_mv2, 1_n), + piece_size, }, }, }; @@ -302,78 +389,269 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("multiple tensors, multiple consumers across split") { - ParallelLayerAddedResult relu_3 = add_parallel_layer( - pcg, relu_attrs, {get_only(partition_input.outputs)}, {}); + ParallelLayerAddedResult relu_3 = + add_parallel_layer(pcg, + relu_attrs, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + require_only_key(partition_input.outputs, + TensorSlotName::OUTPUT), + }, + }, + /*outputs=*/{}); ParallelLayerAddedResult relu_4 = add_parallel_layer( pcg, ew_add_attrs, - {get_only(relu_1.outputs), get_only(relu_3.outputs)}, - {}); + /*inputs=*/ + { + { + TensorSlotName::LHS_INPUT, + require_only_key(relu_1.outputs, TensorSlotName::OUTPUT), + }, + {TensorSlotName::RHS_INPUT, + require_only_key(relu_3.outputs, TensorSlotName::OUTPUT)}, + }, + /*weights=*/{}); PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ - make_pcg_series_split( - make_pcg_series_split( - make_pcg_leaf_node(input.parallel_layer), - make_pcg_leaf_node(partition_input.parallel_layer)), - make_pcg_parallel_split( - make_pcg_leaf_node(relu_1.parallel_layer), - make_pcg_leaf_node(relu_3.parallel_layer))), + make_pcg_parallel_split(make_pcg_leaf_node(relu_1.parallel_layer), + make_pcg_leaf_node(relu_3.parallel_layer)), make_pcg_parallel_split(make_pcg_leaf_node(relu_2.parallel_layer), make_pcg_leaf_node(relu_4.parallel_layer)), }; - auto pre_mapping = ParallelLayerGuidObliviousMachineMapping{{ - { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - BinaryTreePathEntry::LEFT_CHILD, - }}, - pre_mv1, - }, - { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - BinaryTreePathEntry::RIGHT_CHILD, - }}, - pre_mv2, - }, - }}; + auto mk_pre_mapping = [](MachineView const &src1_mv, + MachineView const &src2_mv) { + return ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + src1_mv, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + src2_mv, + }, + }}; + }; - auto post_mapping = ParallelLayerGuidObliviousMachineMapping{{ - { - BinaryTreePath{{ - BinaryTreePathEntry::LEFT_CHILD, - }}, - post_mv1, - }, - { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - }}, - post_mv2, - }, - }}; + auto mk_post_mapping = [](MachineView const &dst1_mv, + MachineView const &dst2_mv) { + return ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + dst1_mv, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + dst2_mv, + }, + }}; + }; - TensorSetMovement result = get_tensor_set_movement_across_split( - pcg_get_transitive_reduction(pcg), split, pre_mapping, post_mapping); + SUBCASE( + "producers have different views and consumers have different views") { + ParallelLayerGuidObliviousMachineMapping pre_mapping = + mk_pre_mapping(pre_mv1, pre_mv2); + ParallelLayerGuidObliviousMachineMapping post_mapping = + mk_post_mapping(post_mv1, post_mv2); - TensorSetMovement correct = TensorSetMovement{ - /*single_tensor_movements=*/{ - SingleTensorMovement{ - /*parallel_tensor_shape=*/partitioned_input_shape, - /*src_machine_views=*/{pre_mv1}, - /*dst_machine_views=*/{post_mv1, post_mv2}, - }, - SingleTensorMovement{ - /*parallel_tensor_shape=*/partitioned_input_shape, - /*src_machine_views=*/{pre_mv2}, - /*dst_machine_views=*/{post_mv2}, - }, - }, - }; + TensorSetMovement result = get_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), + split, + pre_mapping, + post_mapping); - CHECK(result == correct); + TensorSetMovement correct = TensorSetMovement{ + /*edge_to_size=*/{ + { + mk_communication_edge(pre_mv1, 0_n, post_mv1, 0_n), + piece_size, + }, + { + mk_communication_edge(pre_mv1, 1_n, post_mv1, 1_n), + piece_size, + }, + { + mk_communication_edge(pre_mv1, 0_n, post_mv2, 0_n), + piece_size, + }, + { + mk_communication_edge(pre_mv1, 1_n, post_mv2, 1_n), + piece_size, + }, + { + mk_communication_edge(pre_mv2, 0_n, post_mv2, 0_n), + piece_size, + }, + { + mk_communication_edge(pre_mv2, 1_n, post_mv2, 1_n), + piece_size, + }, + }, + }; + + CHECK(result == correct); + } + + SUBCASE( + "producers have different views and consumers have the same view") { + ParallelLayerGuidObliviousMachineMapping pre_mapping = + mk_pre_mapping(pre_mv1, pre_mv2); + ParallelLayerGuidObliviousMachineMapping post_mapping = + mk_post_mapping(post_mv1, post_mv1); + + TensorSetMovement result = get_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), + split, + pre_mapping, + post_mapping); + + TensorSetMovement correct = TensorSetMovement{ + /*edge_to_size=*/{ + { + mk_communication_edge(pre_mv1, 0_n, post_mv1, 0_n), + piece_size, + }, + { + mk_communication_edge(pre_mv1, 1_n, post_mv1, 1_n), + piece_size, + }, + { + mk_communication_edge(pre_mv2, 0_n, post_mv1, 0_n), + piece_size, + }, + { + mk_communication_edge(pre_mv2, 1_n, post_mv1, 1_n), + piece_size, + }, + }, + }; + + CHECK(result == correct); + } + + SUBCASE( + "producers have the same view and consumers have different views") { + ParallelLayerGuidObliviousMachineMapping pre_mapping = + mk_pre_mapping(pre_mv1, pre_mv1); + ParallelLayerGuidObliviousMachineMapping post_mapping = + mk_post_mapping(post_mv1, post_mv2); + + TensorSetMovement result = get_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), + split, + pre_mapping, + post_mapping); + + TensorSetMovement correct = TensorSetMovement{ + /*edge_to_size=*/{ + { + mk_communication_edge(pre_mv1, 0_n, post_mv1, 0_n), + piece_size, + }, + { + mk_communication_edge(pre_mv1, 1_n, post_mv1, 1_n), + piece_size, + }, + { + mk_communication_edge(pre_mv1, 0_n, post_mv2, 0_n), + piece_size + piece_size, + }, + { + mk_communication_edge(pre_mv1, 1_n, post_mv2, 1_n), + piece_size + piece_size, + }, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("producers have the same view and consumers have the same view") { + ParallelLayerGuidObliviousMachineMapping pre_mapping = + mk_pre_mapping(pre_mv1, pre_mv1); + ParallelLayerGuidObliviousMachineMapping post_mapping = + mk_post_mapping(post_mv1, post_mv1); + + TensorSetMovement result = get_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), + split, + pre_mapping, + post_mapping); + + TensorSetMovement correct = TensorSetMovement{ + /*edge_to_size=*/{ + { + mk_communication_edge(pre_mv1, 0_n, post_mv1, 0_n), + piece_size + piece_size, + }, + { + mk_communication_edge(pre_mv1, 1_n, post_mv1, 1_n), + piece_size + piece_size, + }, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("all producers and consumers have the same view") { + ParallelLayerGuidObliviousMachineMapping pre_mapping = + mk_pre_mapping(pre_mv1, pre_mv1); + ParallelLayerGuidObliviousMachineMapping post_mapping = + mk_post_mapping(pre_mv1, pre_mv1); + + TensorSetMovement result = get_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), + split, + pre_mapping, + post_mapping); + + TensorSetMovement correct = TensorSetMovement{ + /*edge_to_size=*/{{}}, + }; + + CHECK(result == correct); + } + + SUBCASE("producers and one consumer share the same view") { + ParallelLayerGuidObliviousMachineMapping pre_mapping = + mk_pre_mapping(pre_mv1, pre_mv1); + ParallelLayerGuidObliviousMachineMapping post_mapping = + mk_post_mapping(post_mv1, pre_mv1); + + TensorSetMovement result = get_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), + split, + pre_mapping, + post_mapping); + + TensorSetMovement correct = TensorSetMovement{ + /*edge_to_size=*/{ + { + mk_communication_edge(pre_mv1, 0_n, post_mv1, 0_n), + piece_size, + }, + { + mk_communication_edge(pre_mv1, 1_n, post_mv1, 1_n), + piece_size, + }, + }, + }; + + CHECK(result == correct); + } } } } diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc index 928d30ecaa..8af07a032c 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc @@ -1,6 +1,6 @@ #include "compiler/machine_mapping/machine_mapping.h" -#include "doctest/doctest.h" -#include "pcg/machine_view.h" +#include "compiler/machine_mapping/machine_view.h" +#include using namespace FlexFlow; diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc index 2fcffac29a..26a643f327 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc @@ -4,6 +4,8 @@ #include "op-attrs/parallel_tensor_shape.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/containers/get_only.h" +#include "utils/containers/require_only_key.h" +#include "utils/full_binary_tree/binary_tree_path.h" #include using namespace ::FlexFlow; @@ -98,10 +100,19 @@ TEST_SUITE(FF_TEST_SUITE) { /*op_attrs=*/input_attrs, /*input_shapes=*/{}, /*weight_shapes=*/{}, - /*output_shapes=*/{parallel_tensor_shape}, + /*output_shapes=*/ + { + { + TensorSlotName::OUTPUT, + parallel_tensor_shape, + }, + }, }; }; + TaskSpaceCoordinate empty_task_space_coord = + TaskSpaceCoordinate{OrthotopeCoord{{}}}; + SUBCASE("single layer") { ParallelLayerAddedResult input_added = add_parallel_layer(pcg, @@ -130,7 +141,8 @@ TEST_SUITE(FF_TEST_SUITE) { /*inputs=*/{}, /*output_labels=*/{}); parallel_layer_guid_t input_layer = input_added.parallel_layer; - parallel_tensor_guid_t input = get_only(input_added.outputs); + parallel_tensor_guid_t input = + require_only_key(input_added.outputs, TensorSlotName::OUTPUT); UnmappedRuntimeOnlyOpCostEstimateKey input_key = make_input_key(par_input_shape); @@ -142,17 +154,39 @@ TEST_SUITE(FF_TEST_SUITE) { }, }; ParallelTensorShape relu_output_shape = par_input_shape; - ParallelLayerAddedResult relu_added = - add_parallel_layer(pcg, make_layer_attrs(relu_attrs), {input}, {}); + ParallelLayerAddedResult relu_added = add_parallel_layer( + /*pcg=*/pcg, + /*layer_attrs=*/make_layer_attrs(relu_attrs), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + input, + }, + }, + /*weights=*/{}); parallel_layer_guid_t relu_layer = relu_added.parallel_layer; - parallel_tensor_guid_t relu_output = get_only(relu_added.outputs); + parallel_tensor_guid_t relu_output = + require_only_key(relu_added.outputs, TensorSlotName::OUTPUT); UnmappedRuntimeOnlyOpCostEstimateKey relu_key = UnmappedRuntimeOnlyOpCostEstimateKey{ /*op_attrs=*/relu_attrs, - /*input_shapes=*/{par_input_shape}, + /*input_shapes=*/ + { + { + TensorSlotName::INPUT, + par_input_shape, + }, + }, /*weight_shapes=*/{}, - /*output_shapes=*/{relu_output_shape}, + /*output_shapes=*/ + { + { + TensorSlotName::OUTPUT, + relu_output_shape, + }, + }, }; PCGBinarySPDecomposition sp_decomposition = pcg_make_series( @@ -162,19 +196,26 @@ TEST_SUITE(FF_TEST_SUITE) { get_machine_mapping_problem_tree(pcg, sp_decomposition); MachineMappingProblemTree correct = mm_problem_tree_make_series( - AbstractedTensorSetMovement{{ - AbstractedSingleTensorMovement{ - /*parallel_tensor_shape=*/par_input_shape, - /*src_machine_views=*/ + AbstractedTensorSetMovement{ + /*single_tensor_movements=*/{AbstractedSingleTensorMovement{ + /*src_op_tree_path=*/binary_tree_root_path(), + /*edge_to_size=*/ { - BinaryTreePath{{}}, + { + AbstractedSingleTensorCommunicationEdge{ + /*src_coord=*/empty_task_space_coord, + /*dst=*/ + AbstractedDevice{ + /*operator_tree_path=*/ + binary_tree_root_path(), + /*task_space_coordinate=*/ + empty_task_space_coord, + }, + }, + get_piece_size_in_bytes(par_input_shape), + }, }, - /*dst_machine_views=*/ - { - BinaryTreePath{{}}, - }, - }, - }}, + }}}, mm_problem_tree_make_leaf(input_key), mm_problem_tree_make_leaf(relu_key)); @@ -211,14 +252,16 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelLayerAddedResult input1_added = pcg_add_input_layer(pcg, input_shape); parallel_layer_guid_t input1_layer = input1_added.parallel_layer; - parallel_tensor_guid_t input1_tensor = get_only(input1_added.outputs); + parallel_tensor_guid_t input1_tensor = + require_only_key(input1_added.outputs, TensorSlotName::OUTPUT); UnmappedRuntimeOnlyOpCostEstimateKey input1_key = make_input_key(par_input_shape); ParallelLayerAddedResult input2_added = pcg_add_input_layer(pcg, input_shape); parallel_layer_guid_t input2_layer = input2_added.parallel_layer; - parallel_tensor_guid_t input2_tensor = get_only(input2_added.outputs); + parallel_tensor_guid_t input2_tensor = + require_only_key(input2_added.outputs, TensorSlotName::OUTPUT); UnmappedRuntimeOnlyOpCostEstimateKey input2_key = make_input_key(par_input_shape); @@ -231,18 +274,45 @@ TEST_SUITE(FF_TEST_SUITE) { }, }; ParallelTensorShape ew_op_output_shape = par_input_shape; - ParallelLayerAddedResult ew_op_added = - add_parallel_layer(pcg, - make_layer_attrs(ew_op_attrs), - {input1_tensor, input2_tensor}, - {}); + ParallelLayerAddedResult ew_op_added = add_parallel_layer( + /*pcg=*/pcg, + /*layer_attrs=*/make_layer_attrs(ew_op_attrs), + /*inputs=*/ + { + { + TensorSlotName::LHS_INPUT, + input1_tensor, + }, + { + TensorSlotName::RHS_INPUT, + input2_tensor, + }, + }, + /*outputs=*/{}); parallel_layer_guid_t ew_op_layer = ew_op_added.parallel_layer; + UnmappedRuntimeOnlyOpCostEstimateKey ew_op_key = UnmappedRuntimeOnlyOpCostEstimateKey{ /*op_attrs=*/ew_op_attrs, - /*input_shapes=*/{par_input_shape, par_input_shape}, + /*input_shapes=*/ + { + { + TensorSlotName::LHS_INPUT, + par_input_shape, + }, + { + TensorSlotName::RHS_INPUT, + par_input_shape, + }, + }, /*weight_shapes=*/{}, - /*output_shapes=*/{ew_op_output_shape}, + /*output_shapes=*/ + { + { + TensorSlotName::OUTPUT, + ew_op_output_shape, + }, + }, }; PCGBinarySPDecomposition sp_decomposition = @@ -253,35 +323,43 @@ TEST_SUITE(FF_TEST_SUITE) { MachineMappingProblemTree result = get_machine_mapping_problem_tree(pcg, sp_decomposition); - MachineMappingProblemTree correct = mm_problem_tree_make_series( - AbstractedTensorSetMovement{{ - AbstractedSingleTensorMovement{ - /*parallel_tensor_shape=*/par_input_shape, - /*src_machine_views=*/ - { - BinaryTreePath{{ - BinaryTreePathEntry::LEFT_CHILD, - }}, - }, - /*dst_machine_views=*/ - { - BinaryTreePath{{}}, - }, + BinaryTreePath src1_path = BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}; + + BinaryTreePath src2_path = BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}; + + AbstractedSingleTensorCommunicationEdge edge = + AbstractedSingleTensorCommunicationEdge{ + /*src_coord=*/empty_task_space_coord, + /*dst=*/ + AbstractedDevice{ + /*operator_tree_path=*/binary_tree_root_path(), + /*task_space_coordinate=*/empty_task_space_coord, }, - AbstractedSingleTensorMovement{ - /*parallel_tensor_shape=*/par_input_shape, - /*src_machine_views=*/ - { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - }}, + }; + + MachineMappingProblemTree correct = mm_problem_tree_make_series( + AbstractedTensorSetMovement{ + /*single_tensor_movements=*/{ + AbstractedSingleTensorMovement{ + /*src_op_tree_path=*/src1_path, + /*edge_to_size=*/ + { + {edge, get_piece_size_in_bytes(par_input_shape)}, + }, }, - /*dst_machine_views=*/ - { - BinaryTreePath{{}}, + AbstractedSingleTensorMovement{ + /*src_op_tree_path=*/src2_path, + /*edge_to_size=*/ + { + {edge, get_piece_size_in_bytes(par_input_shape)}, + }, }, }, - }}, + }, /*pre=*/ mm_problem_tree_make_parallel(mm_problem_tree_make_leaf(input1_key), mm_problem_tree_make_leaf(input2_key)), diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc index 26f61253c3..75dd63cccb 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc @@ -1,5 +1,5 @@ #include "compiler/machine_mapping/machine_mapping_result.h" -#include "pcg/machine_view.h" +#include "compiler/machine_mapping/machine_view.h" #include using namespace FlexFlow; @@ -253,29 +253,50 @@ TEST_SUITE(FF_TEST_SUITE) { MachineMappingResult infeasible = infeasible_machine_mapping_result(); + MachineResourceSplit split = MachineResourceSplit{ + /*offset=*/3_p, + /*dimension=*/MachineSpecificationDimension::INTER_NODE, + }; + SUBCASE("lhs is infeasible") { - MachineMappingResult result = parallel_combine(infeasible, rhs); + MachineMappingResult result = parallel_combine(split, infeasible, rhs); MachineMappingResult correct = infeasible; CHECK(result == correct); } SUBCASE("rhs is infeasible") { - MachineMappingResult result = parallel_combine(lhs, infeasible); + MachineMappingResult result = parallel_combine(split, lhs, infeasible); MachineMappingResult correct = infeasible; CHECK(result == correct); } SUBCASE("both are infeasible") { - MachineMappingResult result = parallel_combine(infeasible, infeasible); + MachineMappingResult result = + parallel_combine(split, infeasible, infeasible); MachineMappingResult correct = infeasible; CHECK(result == correct); } SUBCASE("both are feasible") { - MachineMappingResult result = parallel_combine(lhs, rhs); + MachineView translated_machine_view_1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/3_n, + /*device_idx=*/0_n, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{2_p}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineMappingResult result = parallel_combine(split, lhs, rhs); MachineMappingResult correct = MachineMappingResult{ FeasibleMachineMappingResult{ /*runtime=*/4_ms, @@ -299,7 +320,7 @@ TEST_SUITE(FF_TEST_SUITE) { BinaryTreePath{{ BinaryTreePathEntry::RIGHT_CHILD, }}, - machine_view_1, + translated_machine_view_1, }, }}, }, diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_resource_split.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_resource_split.cc new file mode 100644 index 0000000000..3b47e63143 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_resource_split.cc @@ -0,0 +1,117 @@ +#include "compiler/machine_mapping/machine_resource_split.h" +#include "pcg/machine_compute_specification.dtg.h" +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "utils/hash/pair.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_machine_resource_splits") { + SUBCASE("returns no splits if no splits are possible") { + MachineComputeResourceSlice input = MachineComputeResourceSlice{ + /*num_nodes=*/1_p, + /*num_gpus_per_node=*/1_p, + }; + + std::unordered_set result = + get_machine_resource_splits(input); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + + SUBCASE("returns splits in gpu and node dimensions") { + MachineComputeResourceSlice input = MachineComputeResourceSlice{ + /*num_nodes=*/2_p, + /*num_gpus_per_node=*/2_p, + }; + + std::unordered_set result = + get_machine_resource_splits(input); + + std::unordered_set correct = { + MachineResourceSplit{ + /*offset=*/1_p, + /*dimension=*/MachineSpecificationDimension::INTRA_NODE, + }, + MachineResourceSplit{ + /*offset=*/1_p, + /*dimension=*/MachineSpecificationDimension::INTER_NODE, + }}; + + CHECK(result == correct); + } + + SUBCASE("returns splits in node dimension in powers of two") { + MachineComputeResourceSlice input = MachineComputeResourceSlice{ + /*num_nodes=*/8_p, + /*num_gpus_per_node=*/1_p, + }; + + std::unordered_set result = + get_machine_resource_splits(input); + + std::unordered_set correct = { + MachineResourceSplit{ + /*offset=*/1_p, + /*dimension=*/MachineSpecificationDimension::INTER_NODE, + }, + MachineResourceSplit{ + /*offset=*/2_p, + /*dimension=*/MachineSpecificationDimension::INTER_NODE, + }, + MachineResourceSplit{ + /*offset=*/4_p, + /*dimension=*/MachineSpecificationDimension::INTER_NODE, + }, + MachineResourceSplit{ + /*offset=*/6_p, + /*dimension=*/MachineSpecificationDimension::INTER_NODE, + }, + MachineResourceSplit{ + /*offset=*/7_p, + /*dimension=*/MachineSpecificationDimension::INTER_NODE, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("returns splits in gpu dimension in powers of two") { + MachineComputeResourceSlice input = MachineComputeResourceSlice{ + /*num_nodes=*/1_p, + /*num_gpus_per_node=*/8_p, + }; + + std::unordered_set result = + get_machine_resource_splits(input); + + std::unordered_set correct = { + MachineResourceSplit{ + /*offset=*/1_p, + /*dimension=*/MachineSpecificationDimension::INTRA_NODE, + }, + MachineResourceSplit{ + /*offset=*/2_p, + /*dimension=*/MachineSpecificationDimension::INTRA_NODE, + }, + MachineResourceSplit{ + /*offset=*/4_p, + /*dimension=*/MachineSpecificationDimension::INTRA_NODE, + }, + MachineResourceSplit{ + /*offset=*/6_p, + /*dimension=*/MachineSpecificationDimension::INTRA_NODE, + }, + MachineResourceSplit{ + /*offset=*/7_p, + /*dimension=*/MachineSpecificationDimension::INTRA_NODE, + }, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_view.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_view.cc new file mode 100644 index 0000000000..2ea8312991 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_view.cc @@ -0,0 +1,466 @@ +#include "compiler/machine_mapping/machine_view.h" +#include "op-attrs/ff_ordered/ff_ordered.h" +#include "op-attrs/task_space_coordinate.h" +#include "pcg/gpu_id_t.dtg.h" +#include "test/utils/doctest/fmt/optional.h" +#include "utils/containers/transform.h" +#include "utils/fmt/unordered_set.h" +#include "utils/fmt/vector.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("mv_get_expected_task_space_num_dims") { + MachineView mv = MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/0_n, + /*device_idx=*/0_n, + DeviceType::GPU, + }, + { + MachineViewDimension{ + stride_t{2_p}, + MachineSpecificationDimension::INTER_NODE, + }, + MachineViewDimension{ + stride_t{2_p}, + MachineSpecificationDimension::INTER_NODE, + }, + }, + }; + + CHECK(mv_get_expected_task_space_num_dims(mv) == 2_n); + } + + TEST_CASE("get_device_type") { + MachineView mv = MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/0_n, + /*device_idx=*/0_n, + DeviceType::GPU, + }, + { + MachineViewDimension{ + stride_t{2_p}, + MachineSpecificationDimension::INTER_NODE, + }, + MachineViewDimension{ + stride_t{2_p}, + MachineSpecificationDimension::INTER_NODE, + }, + }, + }; + + CHECK(get_device_type(mv) == DeviceType::GPU); + } + + TEST_CASE("get_machine_space_coordinate") { + SUBCASE("1D case") { + /** + * This operator has shape (3,), and thus 3 tasks. + * The (only) dimension is projected on the INTER (device) dimension with + * a stride of 2. The start of the projection defined by MachineView + * starts at MachineSpaceCoordinate (0,1), and the machine space has 1 + * node and 6 devices per node. + * + * The tasks will thus be distributed like this: + * +-------+-------+-------+-------+-------+-------+ + * | | (0,) | | (1,) | | (2,) | + * +-------+-------+-------+-------+-------+-------+ + * Where the (x,) are the `TaskSpaceCoordinate`s, and the underlying grid + * is the machine space. + */ + OperatorTaskSpace task = OperatorTaskSpace{ + MinimalOrthotope{{ + 3_ge2, + }}, + }; + + MachineView mv = MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/0_n, + /*device_idx=*/1_n, + DeviceType::GPU, + }, + { + MachineViewDimension{ + stride_t{2_p}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + SUBCASE("Task with TaskSpaceCoordinate = (0,)") { + TaskSpaceCoordinate coord = make_task_space_coordinate({0_n}); + + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord); + + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/0_n, + /*device_idx=*/1_n, + DeviceType::GPU, + }; + + CHECK(result == correct); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,)") { + TaskSpaceCoordinate coord = make_task_space_coordinate({1_n}); + + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord); + + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/0_n, + /*device_idx=*/3_n, + DeviceType::GPU, + }; + + CHECK(result == correct); + } + + SUBCASE("Task with TaskSpaceCoordinate = (2,)") { + TaskSpaceCoordinate coord = make_task_space_coordinate({2_n}); + + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord); + + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/0_n, + /*device_idx=*/5_n, + DeviceType::GPU, + }; + + CHECK(result == correct); + } + + SUBCASE("TaskSpaceCoordinate is out of bounds") { + TaskSpaceCoordinate coord = make_task_space_coordinate({4_n}); + + CHECK_THROWS(get_machine_space_coordinate(task, mv, coord)); + } + } + + SUBCASE("2D case - projection on different dimensions") { + /** + * This operator has shape (2, 2), and thus 2 * 2 = 4 tasks. + * The first dimension is projected onto the INTER (node) dimension with + * stride 1, while the second dimension is projected onto the INTRA + * (device) dimension with stride 2. The start of the projection defined + * by MachineView is at MachineSpaceCoordinates (1, 2), and the machine + * space has 3 nodes and 5 devices per node. + * + * The tasks will thus be distributed like this: + * +-------+-------+-------+-------+-------+ + * | | | | | | + * +-------+-------+-------+-------+-------+ + * | | | (0,0) | | (0,1) | + * +-------+-------+-------+-------+-------+ + * | | | (1,0) | | (1,1) | + * +-------+-------+-------+-------+-------+ + * Where the (x,y) are the `TaskSpaceCoordinate`s, and the underlying + * grid is the machine space. + */ + + OperatorTaskSpace task = OperatorTaskSpace{ + MinimalOrthotope{{ + 2_ge2, + 2_ge2, + }}, + }; + MachineView mv = MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/1_n, + /*device_idx=*/2_n, + DeviceType::GPU, + }, + { + MachineViewDimension{ + stride_t{1_p}, + MachineSpecificationDimension::INTER_NODE, + }, + MachineViewDimension{ + stride_t{2_p}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + SUBCASE("Task with TaskSpaceCoordinate = (0,0)") { + TaskSpaceCoordinate coord = make_task_space_coordinate({0_n, 0_n}); + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/1_n, /*device_idx=*/2_n, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (0,1)") { + TaskSpaceCoordinate coord = make_task_space_coordinate({0_n, 1_n}); + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/1_n, /*device_idx=*/4_n, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,0)") { + TaskSpaceCoordinate coord = make_task_space_coordinate({1_n, 0_n}); + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/2_n, /*device_idx=*/2_n, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,1)") { + TaskSpaceCoordinate coord = make_task_space_coordinate({1_n, 1_n}); + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/2_n, /*device_idx=*/4_n, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord); + CHECK(correct == result); + } + } + + SUBCASE("2D case - projection on same dimension") { + /** + * This operator has shape (2, 2), and thus 2 * 2 = 4 tasks. + * Both dimensions are projected on the INTRA (device) dimension, with + * strides 1 and 2 respectively. The start of the projection defined by + * MachineView is at MachineSpaceCoordinates (1, 0), and the machine + * space has 2 nodes and 6 devices per node. + * + * +-------+-------+-------+-------+-------+-------+ + * | (0,0) | (1,0) | | | (0,1) | (1,1) | + * +-------+-------+-------+-------+-------+-------+ + * Where the (x,y) are the `TaskSpaceCoordinate`s, and the underlying + * grid is the machine space. + */ + + OperatorTaskSpace task = OperatorTaskSpace{ + MinimalOrthotope{{ + 2_ge2, + 2_ge2, + }}, + }; + MachineView mv = MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/1_n, + /*device_idx=*/0_n, + DeviceType::GPU, + }, + { + MachineViewDimension{ + stride_t{1_p}, + MachineSpecificationDimension::INTRA_NODE, + }, + MachineViewDimension{ + stride_t{2_p}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + SUBCASE("Task with TaskSpaceCoordinate = (0,0)") { + TaskSpaceCoordinate coord = make_task_space_coordinate({0_n, 0_n}); + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/1_n, /*device_idx=*/0_n, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (0,1)") { + TaskSpaceCoordinate coord = make_task_space_coordinate({0_n, 1_n}); + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/1_n, /*device_idx=*/4_n, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,0)") { + TaskSpaceCoordinate coord = make_task_space_coordinate({1_n, 0_n}); + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/1_n, /*device_idx=*/1_n, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,1)") { + TaskSpaceCoordinate coord = make_task_space_coordinate({1_n, 1_n}); + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/1_n, /*device_idx=*/5_n, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord); + CHECK(correct == result); + } + } + + SUBCASE("3D case") { + /** + * This operator has shape (2, 2, 2), and thus 2 * 2 * 2 = 8 tasks. + * - The first dimension is projected onto the INTER (node) dimension + * with stride 1, + * - The second dimension is projected onto the INTRA (device) dimension + * with stride 2, + * - The third dimension is projected onto the INTRA (device) dimension + * with stride 1. The start of the projection defined by MachineView is + * at MachineSpaceCoordinates (0, 1), and the machine space has 2 nodes + * and 8 devices per node. + * + * The tasks will thus be distributed like this: + * +-------+-------+-------+-------+-------+-------+-------+-------+ + * | |(0,0,0)| |(0,0,1)| |(0,1,0)| |(0,1,1)| + * +-------+-------+-------+-------+-------+-------+-------+-------+ + * | |(1,0,0)| |(1,0,1)| |(1,1,0)| |(1,1,1)| + * +-------+-------+-------+-------+-------+-------+-------+-------+ + * Where the (x,y,z) are the `TaskSpaceCoordinate`s, and the underlying + * grid is the machine space. + */ + + OperatorTaskSpace task = OperatorTaskSpace{ + MinimalOrthotope{{ + 2_ge2, + 2_ge2, + 2_ge2, + }}, + }; + MachineView mv = MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/0_n, /*device_idx=*/1_n, DeviceType::GPU}, + {MachineViewDimension{stride_t{1_p}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{2_p}, + MachineSpecificationDimension::INTRA_NODE}, + MachineViewDimension{stride_t{1_p}, + MachineSpecificationDimension::INTRA_NODE}}}; + + SUBCASE("Task with TaskSpaceCoordinate = (0,0,1)") { + TaskSpaceCoordinate coord = make_task_space_coordinate({0_n, 1_n, 0_n}); + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/0_n, /*device_idx=*/3_n, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,1,0)") { + TaskSpaceCoordinate coord = make_task_space_coordinate({1_n, 0_n, 1_n}); + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/1_n, /*device_idx=*/5_n, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,1,1)") { + TaskSpaceCoordinate coord = make_task_space_coordinate({1_n, 1_n, 1_n}); + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/1_n, /*device_idx=*/7_n, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord); + CHECK(correct == result); + } + } + } + + TEST_CASE("get_device_ids") { + SUBCASE("1D machine view") { + /** + * This operator has shape (3,), and thus 3 tasks. + * The (only) dimension is projected onto the INTRA (device) dimension + * with a stride of 2. The start of the projection defined by MachineView + * is at MachineSpaceCoordinate (0, 1), and the machine space has 1 node + * and 6 devices per node. + * + * The tasks will thus be distributed like this: + * +-------+-------+-------+-------+-------+-------+ + * | 0 | ((1)) | 2 | ((3)) | 4 | ((5)) | + * +-------+-------+-------+-------+-------+-------+ + * Where the integers are the device ids and ((x)) are the devices we + * select + */ + MachineComputeSpecification ms = MachineComputeSpecification{ + /*num_nodes=*/1_p, + /*num_cpus_per_node=*/6_p, + /*num_gpus_per_node=*/6_p, + }; + + OperatorTaskSpace task = OperatorTaskSpace{ + MinimalOrthotope{{ + 3_ge2, + }}, + }; + MachineView mv = MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/0_n, /*device_idx=*/1_n, DeviceType::GPU}, + {MachineViewDimension{stride_t{2_p}, + MachineSpecificationDimension::INTRA_NODE}}}; + + std::unordered_set correct = { + device_id_t{gpu_id_t{1_n}}, + device_id_t{gpu_id_t{3_n}}, + device_id_t{gpu_id_t{5_n}}, + }; + std::unordered_set result = get_device_ids(task, mv, ms); + CHECK(result == correct); + } + + SUBCASE("2D machine view") { + /** + * This operator has shape (2, 2), and thus 2 * 2 = 4 tasks. + * - The first dimension is projected onto the INTER (node) dimension with + * stride 1, + * - The second dimension is projected onto the INTRA (device) dimension + * with stride 2. The start of the projection defined by MachineView is at + * MachineSpaceCoordinate (1, 2), and the machine space has 3 nodes and 5 + * devices per node. + * + * The tasks will thus be distributed like this: + * +-------+-------+-------+-------+-------+ + * | 0 | 1 | 2 | 3 | 4 | + * +-------+-------+-------+-------+-------+ + * | 5 | 6 | ((7)) | 8 | ((9)) | + * +-------+-------+-------+-------+-------+ + * | 10 | 11 | ((12))| 13 | ((14))| + * +-------+-------+-------+-------+-------+ + * Where the integers are the device ids and ((x)) are the devices we + * select + */ + + MachineComputeSpecification ms = MachineComputeSpecification{ + /*num_nodes=*/3_p, + /*num_cpus_per_node=*/5_p, + /*num_gpus_per_node=*/5_p, + }; + + OperatorTaskSpace task = OperatorTaskSpace{ + MinimalOrthotope{{ + 2_ge2, + 2_ge2, + }}, + }; + MachineView mv = MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/1_n, /*device_idx=*/2_n, DeviceType::GPU}, + {MachineViewDimension{stride_t{1_p}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{2_p}, + MachineSpecificationDimension::INTRA_NODE}}}; + + std::unordered_set correct = { + device_id_t{gpu_id_t{7_n}}, + device_id_t{gpu_id_t{9_n}}, + device_id_t{gpu_id_t{12_n}}, + device_id_t{gpu_id_t{14_n}}, + }; + std::unordered_set result = get_device_ids(task, mv, ms); + CHECK(result == correct); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc b/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc index 96b11e6d33..54717d6699 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc @@ -1,14 +1,17 @@ #include "compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.h" +#include "compiler/cost_estimator/tensor_set_movement.h" #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" #include "compiler/machine_mapping/machine_mapping_constraints.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" +#include "compiler/machine_mapping/machine_view.h" #include "compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.h" #include "internal/cost_estimator_for_test.h" #include "op-attrs/parallel_tensor_shape.h" -#include "pcg/machine_view.h" +#include "op-attrs/task_space_coordinate.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" #include "utils/containers/get_only.h" +#include "utils/containers/map_from_pairs.h" #include "utils/full_binary_tree/binary_tree_path.h" #include "utils/nonnegative_int/nonnegative_int.h" #include @@ -46,6 +49,15 @@ TEST_SUITE(FF_TEST_SUITE) { }; MachineView mv1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0_n, + /*device_idx=*/0_n, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/{}, + }; + + MachineView mv2 = MachineView{ /*start=*/MachineSpaceCoordinate{ /*node_idx=*/0_n, /*device_idx=*/0_n, @@ -54,13 +66,13 @@ TEST_SUITE(FF_TEST_SUITE) { /*dimensions=*/ { MachineViewDimension{ - stride_t{1_p}, - MachineSpecificationDimension::INTRA_NODE, + /*stride=*/stride_t{1_p}, + /*projection=*/MachineSpecificationDimension::INTER_NODE, }, }, }; - MachineView mv2 = MachineView{ + MachineView mv3 = MachineView{ /*start=*/MachineSpaceCoordinate{ /*node_idx=*/0_n, /*device_idx=*/0_n, @@ -69,36 +81,49 @@ TEST_SUITE(FF_TEST_SUITE) { /*dimensions=*/ { MachineViewDimension{ - stride_t{2_p}, - MachineSpecificationDimension::INTRA_NODE, + /*stride=*/stride_t{2_p}, + /*projection=*/MachineSpecificationDimension::INTER_NODE, }, }, }; - MachineSpecification full_machine_spec = MachineSpecification{ - /*num_nodes=*/2_p, - /*num_cpus_per_node=*/1_p, - /*num_gpus_per_node=*/1_p, - /*inter_node_bandwidth=*/1, - /*intra_node_bandwidth=*/1, + MachineView mv4 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/1_n, + /*device_idx=*/0_n, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + /*stride=*/stride_t{1_p}, + /*projection=*/MachineSpecificationDimension::INTER_NODE, + }, + }, }; - MachineSpecification split_machine_spec = MachineSpecification{ - /*num_nodes=*/1_p, - /*num_cpus_per_node=*/1_p, - /*num_gpus_per_node=*/1_p, - /*inter_node_bandwidth=*/1, - /*intra_node_bandwidth=*/1, - }; + MachineComputeResourceSlice four_nodes_resources = + MachineComputeResourceSlice{ + /*num_nodes=*/4_p, + /*num_gpus_per_node=*/1_p, + }; - auto allowed_machine_views1 = - [&](UnmappedRuntimeOnlyOpCostEstimateKey const &, - MachineSpecification const &resources) { - if (resources == full_machine_spec) { - return std::unordered_set{mv1, mv2}; - } else { - return std::unordered_set{mv2}; - } + MachineComputeResourceSlice three_nodes_resources = + MachineComputeResourceSlice{ + /*num_nodes=*/3_p, + /*num_gpus_per_node=*/1_p, + }; + + MachineComputeResourceSlice two_nodes_resources = + MachineComputeResourceSlice{ + /*num_nodes=*/2_p, + /*num_gpus_per_node=*/1_p, + }; + + MachineComputeResourceSlice one_node_resources = + MachineComputeResourceSlice{ + /*num_nodes=*/1_p, + /*num_gpus_per_node=*/1_p, }; TensorShape tensor_shape = TensorShape{ @@ -111,7 +136,20 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT, }; - ParallelTensorShape par_tensor_shape = lift_to_parallel(tensor_shape); + ParallelTensorShape pre_partition_par_tensor_shape = + lift_to_parallel(tensor_shape); + ParallelTensorShape post_partition_par_tensor_shape = + lift_to_parallel_with_degrees( + tensor_shape, + ParallelTensorDimDegrees{ + /*sum_degree=*/SumDegree{1_p}, + /*discard_copy_degree=*/DiscardCopyDegree{1_p}, + /*shard_degrees=*/ + FFOrdered{ + 2_p, + 1_p, + }, + }); OptimizerAttrs optimizer_attrs = OptimizerAttrs{ SGDOptimizerAttrs{ @@ -126,30 +164,72 @@ TEST_SUITE(FF_TEST_SUITE) { /*op_attrs=*/PCGOperatorAttrs{InputAttrs{tensor_shape}}, /*input_shapes=*/{}, /*weight_shapes=*/{}, - /*output_shapes=*/{}, + /*output_shapes=*/ + { + { + TensorSlotName::OUTPUT, + pre_partition_par_tensor_shape, + }, + }, /*optimizer_attrs=*/optimizer_attrs, }; UnmappedOpCostEstimateKey k2 = UnmappedOpCostEstimateKey{ - /*op_attrs=*/PCGOperatorAttrs{ElementBinaryAttrs{ - /*type=*/OperatorType::EW_ADD, - /*compute_type=*/DataType::FLOAT, - /*should_broadcast_lhs=*/false, - /*should_broadcast_rhs=*/false, + /*op_attrs=*/PCGOperatorAttrs{ElementUnaryAttrs{ + /*type=*/OperatorType::GELU, + /*scalar=*/std::nullopt, }}, - /*input_shapes=*/{}, + /*input_shapes=*/ + { + { + TensorSlotName::INPUT, + post_partition_par_tensor_shape, + }, + }, /*weight_shapes=*/{}, - /*output_shapes=*/{}, + /*output_shapes=*/ + { + { + TensorSlotName::OUTPUT, + post_partition_par_tensor_shape, + }, + }, /*optimizer_attrs=*/optimizer_attrs, }; - AbstractedTensorSetMovement movement1 = AbstractedTensorSetMovement{{ - AbstractedSingleTensorMovement{ - /*parallel_tensor_shape=*/par_tensor_shape, - /*src_machine_views=*/{}, - /*dst_machine_views=*/{}, + UnmappedOpCostEstimateKey k3 = UnmappedOpCostEstimateKey{ + /*op_attrs=*/PCGOperatorAttrs{ElementUnaryAttrs{ + /*type=*/OperatorType::RELU, + /*scalar=*/std::nullopt, + }}, + /*input_shapes=*/ + { + { + TensorSlotName::INPUT, + post_partition_par_tensor_shape, + }, }, - }}; + /*weight_shapes=*/{}, + /*output_shapes=*/ + { + { + TensorSlotName::OUTPUT, + post_partition_par_tensor_shape, + }, + }, + /*optimizer_attrs=*/optimizer_attrs, + }; + + TaskSpaceCoordinate empty_task_space_coord = + TaskSpaceCoordinate{OrthotopeCoord{{}}}; + + BinaryTreePath src_path = binary_tree_root_path(); + TaskSpaceCoordinate src_coord = empty_task_space_coord; + + AbstractedDevice dst_device = AbstractedDevice{ + /*operator_tree_path=*/binary_tree_root_path(), + /*task_space_coordinate=*/empty_task_space_coord, + }; ParallelLayerGuidObliviousMachineMapping mm1 = ParallelLayerGuidObliviousMachineMapping{{ @@ -160,128 +240,164 @@ TEST_SUITE(FF_TEST_SUITE) { {binary_tree_root_path(), mv2}, }}; - CostEstimator cost_estimator = make_fake_cost_estimator( - std::unordered_map{{ - {map_unmapped_op_cost_estimate_key(k1, mv1), - OpCostMetrics{/*forward_runtime=*/1_ms, - /*backward_runtime=*/1_ms, - /*memory_usage=*/2_bytes}}, - {map_unmapped_op_cost_estimate_key(k2, mv1), - OpCostMetrics{/*forward_runtime=*/2_ms, - /*backward_runtime=*/2_ms, - /*memory_usage=*/3_bytes}}, - {map_unmapped_op_cost_estimate_key(k1, mv2), - OpCostMetrics{/*forward_runtime=*/1.5_ms, - /*backward_runtime=*/1.5_ms, - /*memory_usage=*/1_bytes}}, - {map_unmapped_op_cost_estimate_key(k2, mv2), - OpCostMetrics{/*forward_runtime=*/2.5_ms, - /*backward_runtime=*/2.5_ms, - /*memory_usage=*/2_bytes}}, + OperatorTaskSpace unparallel_task_space = + OperatorTaskSpace{MinimalOrthotope{{}}}; + OperatorTaskSpace parallel_task_space = OperatorTaskSpace{ + MinimalOrthotope{{ + 2_ge2, }}, - std::unordered_map{{ - {TensorSetMovement{/*movements=*/{}}, /*cost=*/0.0_ms}, - {concretize_abstracted_tensor_set_movement(movement1, mm1, mm1), - /*cost=*/0.1_ms}, - {concretize_abstracted_tensor_set_movement(movement1, mm2, mm2), - /*cost=*/0.2_ms}, - {concretize_abstracted_tensor_set_movement(movement1, mm1, mm2), - /*cost=*/0.3_ms}, - {concretize_abstracted_tensor_set_movement(movement1, mm2, mm1), - /*cost=*/0.4_ms}, - }}); - - MachineMappingWithMemoryContext context = MachineMappingWithMemoryContext{ - cost_estimator, - optimizer_attrs, - allowed_machine_views1, }; - MachineMappingWithMemoryCache cache = - empty_machine_mapping_with_memory_cache(); + auto get_corresponding_task_space = [&](MachineView const &mv) { + if (mv == mv1) { + return unparallel_task_space; + } else { + ASSERT(mv == mv2 || mv == mv3); + + return parallel_task_space; + } + }; + + SUBCASE("single layer with single option") { + OpCostMetrics k1_on_mv1_cost = OpCostMetrics{ + /*forward_runtime=*/1_ms, + /*backward_runtime=*/1_ms, + /*memory_usage=*/2_bytes, + }; + + CostEstimator cost_estimator = make_fake_cost_estimator( + std::unordered_map{{ + { + map_unmapped_op_cost_estimate_key(k1, mv1), + k1_on_mv1_cost, + }, + }}, + std::unordered_map{ + { + empty_tensor_set_movement(), + 0_ms, + }, + }); - SUBCASE("single layer") { MachineMappingProblemTree problem_tree = make_leaf(k1); MachineMappingConstraints constraints = get_unconstrained_solution_for_layers( get_all_leaf_paths(problem_tree)); + auto allowed_machine_views = + [&](UnmappedRuntimeOnlyOpCostEstimateKey const &k, + MachineComputeResourceSlice const &resources) { + ASSERT(k == runtime_only_from_unmapped_op_cost_estimate_key(k1)); + ASSERT(resources == four_nodes_resources); + return std::unordered_set{mv1}; + }; + + MachineMappingWithMemoryContext context = MachineMappingWithMemoryContext{ + cost_estimator, + optimizer_attrs, + allowed_machine_views, + }; + + MachineMappingWithMemoryCache cache = + empty_machine_mapping_with_memory_cache(); + MachineMappingWithMemoryResult result = get_optimal_machine_mapping_with_memory( - cache, context, problem_tree, full_machine_spec, constraints); + cache, context, problem_tree, four_nodes_resources, constraints); + MachineMappingWithMemoryResult correct = MachineMappingWithMemoryResult{{ - MachineMappingForSingleLayer{ - OpCostMetrics{/*forward_runtime=*/1_ms, - /*backward_runtime=*/1_ms, - /*memory_usage=*/2_bytes}, + ParetoOptimalMachineMapping{ + k1_on_mv1_cost, ParallelLayerGuidObliviousMachineMapping{{ {binary_tree_root_path(), mv1}, }}, }, - MachineMappingForSingleLayer{ - OpCostMetrics{/*forward_runtime=*/1.5_ms, - /*backward_runtime=*/1.5_ms, - /*memory_usage=*/1_bytes}, - ParallelLayerGuidObliviousMachineMapping{{ - {binary_tree_root_path(), mv2}, - }}, - }, }}; CHECK(result == correct); } - SUBCASE("pair of layers in sequence") { - MachineMappingProblemTree problem_tree = - make_series_split(movement1, make_leaf(k1), make_leaf(k2)); + SUBCASE("single layer with multiple options") { + + auto allowed_machine_views = + [&](UnmappedRuntimeOnlyOpCostEstimateKey const &k, + MachineComputeResourceSlice const &resources) { + ASSERT(k == runtime_only_from_unmapped_op_cost_estimate_key(k3)); + ASSERT(resources == four_nodes_resources); + return std::unordered_set{mv2, mv3, mv4}; + }; + + OpCostMetrics k3_on_mv2_cost = OpCostMetrics{ + /*forward_runtime=*/2.5_ms, + /*backward_runtime=*/2.5_ms, + /*memory_usage=*/2_bytes, + }; + + OpCostMetrics k3_on_mv3_cost = OpCostMetrics{ + /*forward_runtime=*/2_ms, + /*backward_runtime=*/2_ms, + /*memory_usage=*/2_bytes, + }; + + OpCostMetrics k3_on_mv4_cost = OpCostMetrics{ + /*forward_runtime=*/3_ms, + /*backward_runtime=*/3_ms, + /*memory_usage=*/3_bytes, + }; + + CostEstimator cost_estimator = make_fake_cost_estimator( + std::unordered_map{{ + { + map_unmapped_op_cost_estimate_key(k3, mv2), + k3_on_mv2_cost, + }, + { + map_unmapped_op_cost_estimate_key(k3, mv3), + k3_on_mv3_cost, + }, + { + map_unmapped_op_cost_estimate_key(k3, mv4), + k3_on_mv4_cost, + }, + }}, + std::unordered_map{ + { + empty_tensor_set_movement(), + 0_ms, + }, + }); + + MachineMappingProblemTree problem_tree = make_leaf(k3); MachineMappingConstraints constraints = get_unconstrained_solution_for_layers( get_all_leaf_paths(problem_tree)); + MachineMappingWithMemoryCache cache = + empty_machine_mapping_with_memory_cache(); + + MachineMappingWithMemoryContext context = MachineMappingWithMemoryContext{ + cost_estimator, + optimizer_attrs, + allowed_machine_views, + }; + MachineMappingWithMemoryResult result = get_optimal_machine_mapping_with_memory( - cache, context, problem_tree, full_machine_spec, constraints); + cache, context, problem_tree, four_nodes_resources, constraints); + MachineMappingWithMemoryResult correct = MachineMappingWithMemoryResult{{ - MachineMappingForSingleLayer{ - OpCostMetrics{ - /*forward_runtime=*/1.0_ms + 2.0_ms + 0.1_ms, - /*backward_runtime=*/1.0_ms + 2.0_ms + 0.1_ms, - /*memory_usage=*/2_bytes + 3_bytes, - }, + ParetoOptimalMachineMapping{ + k3_on_mv2_cost, ParallelLayerGuidObliviousMachineMapping{{ - { - BinaryTreePath{{ - BinaryTreePathEntry::LEFT_CHILD, - }}, - mv1, - }, - { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - }}, - mv1, - }, + {binary_tree_root_path(), mv2}, }}, }, - MachineMappingForSingleLayer{ - OpCostMetrics{/*forward_runtime=*/1.5_ms + 2.5_ms + 0.1_ms, - /*backward_runtime=*/1.5_ms + 2.5_ms + 0.1_ms, - /*memory_usage=*/1_bytes + 2_bytes}, + ParetoOptimalMachineMapping{ + k3_on_mv3_cost, ParallelLayerGuidObliviousMachineMapping{{ - { - BinaryTreePath{{ - BinaryTreePathEntry::LEFT_CHILD, - }}, - mv2, - }, - { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - }}, - mv2, - }, + {binary_tree_root_path(), mv3}, }}, }, }}; @@ -289,40 +405,384 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } + SUBCASE("pair of layers in sequence") { + AbstractedTensorSetMovement k2_to_k3 = AbstractedTensorSetMovement{ + /*single_tensor_movements=*/{ + AbstractedSingleTensorMovement{ + /*src_op_tree_path=*/binary_tree_root_path(), + /*edge_to_size=*/ + { + { + AbstractedSingleTensorCommunicationEdge{ + /*src_coord=*/make_task_space_coordinate({0_n}), + /*dst=*/ + AbstractedDevice{ + /*operator_tree_path=*/ + binary_tree_root_path(), + /*task_space_coordinate=*/ + make_task_space_coordinate({0_n}), + }, + }, + get_size_in_bytes(tensor_shape), + }, + { + AbstractedSingleTensorCommunicationEdge{ + /*src_coord=*/make_task_space_coordinate({1_n}), + /*dst=*/ + AbstractedDevice{ + /*operator_tree_path=*/ + binary_tree_root_path(), + /*task_space_coordinate=*/ + make_task_space_coordinate({1_n}), + }, + }, + get_size_in_bytes(tensor_shape), + }, + }, + }, + }, + }; + + auto mk_tensor_set_movement = [&](MachineView const &src_mv, + MachineView const &dst_mv) { + MachineSpaceStencil src_stencil = MachineSpaceStencil{ + /*operator_task_space=*/get_corresponding_task_space(src_mv), + /*machine_view=*/src_mv, + }; + + MachineSpaceStencil dst_stencil = MachineSpaceStencil{ + /*operator_task_space=*/get_corresponding_task_space(dst_mv), + /*machine_view=*/dst_mv, + }; + + return concretize_abstracted_tensor_set_movement( + k2_to_k3, + /*pre_machine_stencils=*/{{binary_tree_root_path(), src_stencil}}, + /*post_machine_stencils=*/{{binary_tree_root_path(), dst_stencil}}); + }; + + auto mk_cost_estimator = [&](milliseconds_t k2_on_mv2_cost, + num_bytes_t k2_on_mv2_mem_usage, + milliseconds_t k2_on_mv3_cost, + num_bytes_t k2_on_mv3_mem_usage, + milliseconds_t k3_on_mv2_cost, + num_bytes_t k3_on_mv2_mem_usage, + milliseconds_t k3_on_mv3_cost, + num_bytes_t k3_on_mv3_mem_usage, + milliseconds_t mv2_to_mv2_cost, + milliseconds_t mv2_to_mv3_cost, + milliseconds_t mv3_to_mv2_cost, + milliseconds_t mv3_to_mv3_cost) { + return make_fake_cost_estimator( + std::unordered_map{{ + { + map_unmapped_op_cost_estimate_key(k2, mv2), + OpCostMetrics{ + /*forward_runtime=*/k2_on_mv2_cost, + /*backward_runtime=*/k2_on_mv2_cost, + /*memory_usage=*/k2_on_mv2_mem_usage, + }, + }, + { + map_unmapped_op_cost_estimate_key(k2, mv3), + OpCostMetrics{ + /*forward_runtime=*/k2_on_mv3_cost, + /*backward_runtime=*/k2_on_mv3_cost, + /*memory_usage=*/k2_on_mv3_mem_usage, + }, + }, + { + map_unmapped_op_cost_estimate_key(k3, mv2), + OpCostMetrics{ + /*forward_runtime=*/k3_on_mv2_cost, + /*backward_runtime=*/k3_on_mv2_cost, + /*memory_usage=*/k3_on_mv2_mem_usage, + }, + }, + { + map_unmapped_op_cost_estimate_key(k3, mv3), + OpCostMetrics{ + /*forward_runtime=*/k3_on_mv3_cost, + /*backward_runtime=*/k3_on_mv3_cost, + /*memory_usage=*/k3_on_mv3_mem_usage, + }, + }, + }}, + std::unordered_map{{ + { + empty_tensor_set_movement(), + 0_ms, + }, + { + mk_tensor_set_movement(mv2, mv2), + mv2_to_mv2_cost, + }, + { + mk_tensor_set_movement(mv2, mv3), + mv2_to_mv3_cost, + }, + { + mk_tensor_set_movement(mv3, mv2), + mv3_to_mv2_cost, + }, + { + mk_tensor_set_movement(mv3, mv3), + mv3_to_mv3_cost, + }, + }}); + }; + + MachineMappingProblemTree problem_tree = + make_series_split(k2_to_k3, make_leaf(k2), make_leaf(k3)); + + MachineMappingConstraints constraints = + get_unconstrained_solution_for_layers( + get_all_leaf_paths(problem_tree)); + + MachineMappingWithMemoryCache cache = + empty_machine_mapping_with_memory_cache(); + + SUBCASE("solution is mv2, mv3 due to runtime") { + CostEstimator cost_estimator = mk_cost_estimator( + /*k2_on_mv2_cost=*/2_ms, + /*k2_on_mv2_mem_usage=*/2_bytes, + /*k2_on_mv3_cost=*/2.4_ms, + /*k2_on_mv3_mem_usage=*/2_bytes, + /*k3_on_mv2_cost=*/3.6_ms, + /*k3_on_mv2_mem_usage=*/2_bytes, + /*k3_on_mv3_cost=*/3_ms, + /*k3_on_mv3_mem_usage=*/2_bytes, + /*mv2_to_mv2_cost=*/0.1_ms, + /*mv2_to_mv3_cost=*/1.0_ms, + /*mv3_to_mv2_cost=*/0.3_ms, + /*mv3_to_mv3_cost=*/0.1_ms); + + auto allowed_machine_views = + [&](UnmappedRuntimeOnlyOpCostEstimateKey const &k, + MachineComputeResourceSlice const &resources) { + if (k == runtime_only_from_unmapped_op_cost_estimate_key(k1)) { + return std::unordered_set{ + mv1, + }; + } else { + if (resources == four_nodes_resources) { + return std::unordered_set{mv2, mv3}; + } else if (resources == three_nodes_resources) { + return std::unordered_set{mv2, mv3}; + } else if (resources == two_nodes_resources) { + return std::unordered_set{mv2}; + } else { + return std::unordered_set{}; + } + }; + }; + + MachineMappingWithMemoryContext context = + MachineMappingWithMemoryContext{ + cost_estimator, + optimizer_attrs, + allowed_machine_views, + }; + + MachineMappingWithMemoryResult result = + get_optimal_machine_mapping_with_memory(cache, + context, + problem_tree, + four_nodes_resources, + constraints); + + MachineMappingWithMemoryResult correct = + MachineMappingWithMemoryResult{{ + ParetoOptimalMachineMapping{ + OpCostMetrics{ + /*forward_runtime=*/2_ms + 0.3_ms + 3_ms, + /*backward_runtime=*/2_ms + 0.3_ms + 3_ms, + /*memory_usage=*/4_bytes, + }, + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + mv1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + mv1, + }, + }}, + }, + }}; + } + } + SUBCASE("pair of layers in parallel") { + auto mk_cost_estimator = [&](milliseconds_t k2_on_mv2_cost, + num_bytes_t k2_on_mv2_mem_usage, + milliseconds_t k2_on_mv3_cost, + num_bytes_t k2_on_mv3_mem_usage, + milliseconds_t k3_on_mv2_cost, + num_bytes_t k3_on_mv2_mem_usage, + milliseconds_t k3_on_mv3_cost, + num_bytes_t k3_on_mv3_mem_usage) { + return make_fake_cost_estimator( + std::unordered_map{{ + { + map_unmapped_op_cost_estimate_key(k2, mv2), + OpCostMetrics{ + /*forward_runtime=*/k2_on_mv2_cost, + /*backward_runtime=*/k2_on_mv2_cost, + /*memory_usage=*/k2_on_mv2_mem_usage, + }, + }, + { + map_unmapped_op_cost_estimate_key(k2, mv3), + OpCostMetrics{ + /*forward_runtime=*/k2_on_mv3_cost, + /*backward_runtime=*/k2_on_mv3_cost, + /*memory_usage=*/k2_on_mv3_mem_usage, + }, + }, + { + map_unmapped_op_cost_estimate_key(k3, mv2), + OpCostMetrics{ + /*forward_runtime=*/k3_on_mv2_cost, + /*backward_runtime=*/k3_on_mv2_cost, + /*memory_usage=*/k3_on_mv2_mem_usage, + }, + }, + { + map_unmapped_op_cost_estimate_key(k3, mv3), + OpCostMetrics{ + /*forward_runtime=*/k3_on_mv3_cost, + /*backward_runtime=*/k3_on_mv3_cost, + /*memory_usage=*/k3_on_mv3_mem_usage, + }, + }, + }}, + std::unordered_map{ + { + empty_tensor_set_movement(), + 0_ms, + }, + }); + }; + + CostEstimator cost_estimator = mk_cost_estimator( + /*k2_on_mv2_cost=*/2_ms, + /*k2_on_mv2_mem_usage=*/3_bytes, + /*k2_on_mv3_cost=*/2.5_ms, + /*k2_on_mv3_mem_usage=*/2_bytes, + /*k3_on_mv2_cost=*/2.5_ms, + /*k3_on_mv2_mem_usage=*/2_bytes, + /*k3_on_mv3_cost=*/2_ms, + /*k3_on_mv3_mem_usage=*/1_bytes); + + auto allowed_machine_views = + [&](UnmappedRuntimeOnlyOpCostEstimateKey const &k, + MachineComputeResourceSlice const &resources) { + if (k == runtime_only_from_unmapped_op_cost_estimate_key(k1)) { + return std::unordered_set{ + mv1, + }; + } else { + if (resources == four_nodes_resources) { + return std::unordered_set{mv2, mv3}; + } else if (resources == three_nodes_resources) { + return std::unordered_set{mv2, mv3}; + } else if (resources == two_nodes_resources) { + return std::unordered_set{mv2}; + } else { + return std::unordered_set{}; + } + }; + }; + + MachineMappingWithMemoryContext context = MachineMappingWithMemoryContext{ + cost_estimator, + optimizer_attrs, + allowed_machine_views, + }; + MachineMappingProblemTree problem_tree = - make_parallel_split(make_leaf(k1), make_leaf(k2)); + make_parallel_split(make_leaf(k2), make_leaf(k3)); MachineMappingConstraints constraints = get_unconstrained_solution_for_layers( get_all_leaf_paths(problem_tree)); + MachineMappingWithMemoryCache cache = + empty_machine_mapping_with_memory_cache(); + MachineMappingWithMemoryResult result = get_optimal_machine_mapping_with_memory( - cache, context, problem_tree, full_machine_spec, constraints); - MachineMappingWithMemoryResult correct = - MachineMappingWithMemoryResult{{MachineMappingForSingleLayer{ - OpCostMetrics{/*forward_runtime=*/2.5_ms, - /*backward_runtime=*/2.5_ms, - /*memory_usage=*/2_bytes}, - ParallelLayerGuidObliviousMachineMapping{{ - { - BinaryTreePath{{ - BinaryTreePathEntry::LEFT_CHILD, - }}, - mv2, + cache, context, problem_tree, four_nodes_resources, constraints); + + MachineView translated_mv2 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/2_n, + /*device_idx=*/0_n, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + /*stride=*/stride_t{1_p}, + /*projection=*/MachineSpecificationDimension::INTER_NODE, + }, + }, + }; + + MachineMappingWithMemoryResult correct = MachineMappingWithMemoryResult{ + /*pareto_frontier=*/{ + ParetoOptimalMachineMapping{ + OpCostMetrics{ + /*forward_runtime=*/2.5_ms, + /*backward_runtime=*/2.5_ms, + /*memory_usage=*/3_bytes, }, - { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - }}, - mv2, + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + mv2, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + translated_mv2, + }, + }}, + }, + ParetoOptimalMachineMapping{ + OpCostMetrics{ + /*forward_runtime=*/4.5_ms, + /*backward_runtime=*/4.5_ms, + /*memory_usage=*/3_bytes, }, - }}, - - }}}; + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + mv3, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + mv3, + }, + }}, + }, + }, + }; - CHECK(result == correct); + ASSERT(result == correct); } } } diff --git a/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/machine_mapping_result_with_memory.cc b/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/machine_mapping_result_with_memory.cc deleted file mode 100644 index 2192b442cd..0000000000 --- a/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/machine_mapping_result_with_memory.cc +++ /dev/null @@ -1,617 +0,0 @@ -#include "compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.h" -#include "pcg/machine_view.h" -#include "utils/nonnegative_int/nonnegative_int.h" -#include - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("remove_non_pareto_optimal_machine_mapping_result") { - MachineView machine_view_0 = MachineView{ - /*start=*/MachineSpaceCoordinate{ - /*node_idx=*/0_n, - /*device_idx=*/0_n, - /*device_type=*/DeviceType::GPU, - }, - /*dimensions=*/ - { - MachineViewDimension{ - stride_t{1_p}, - MachineSpecificationDimension::INTRA_NODE, - }, - }, - }; - - MachineView machine_view_1 = MachineView{ - /*start=*/MachineSpaceCoordinate{ - /*node_idx=*/0_n, - /*device_idx=*/0_n, - /*device_type=*/DeviceType::GPU, - }, - /*dimensions=*/ - { - MachineViewDimension{ - stride_t{2_p}, - MachineSpecificationDimension::INTRA_NODE, - }, - }, - }; - - MachineView machine_view_2 = MachineView{ - /*start=*/MachineSpaceCoordinate{ - /*node_idx=*/0_n, - /*device_idx=*/0_n, - /*device_type=*/DeviceType::GPU, - }, - /*dimensions=*/ - { - MachineViewDimension{ - stride_t{4_p}, - MachineSpecificationDimension::INTRA_NODE, - }, - }, - }; - - OpCostMetrics cost1 = OpCostMetrics{ - /*forward_runtime=*/2_ms, - /*backward_runtime=*/2_ms, - /*memory_usage=*/2_bytes, - }; - - OpCostMetrics cost2 = OpCostMetrics{ - /*forward_runtime=*/4_ms, - /*backward_runtime=*/4_ms, - /*memory_usage=*/1_bytes, - }; - - OpCostMetrics cost3 = OpCostMetrics{ - /*forward_runtime=*/2_ms, - /*backward_runtime=*/2_ms, - /*memory_usage=*/3_bytes, - }; - - MachineMappingForSingleLayer mm1 = MachineMappingForSingleLayer{ - cost1, - ParallelLayerGuidObliviousMachineMapping{ - { - { - BinaryTreePath{{}}, - machine_view_0, - }, - }, - }, - }; - - MachineMappingForSingleLayer mm2 = MachineMappingForSingleLayer{ - cost2, - ParallelLayerGuidObliviousMachineMapping{ - { - { - BinaryTreePath{{}}, - machine_view_1, - }, - }, - }, - }; - - MachineMappingForSingleLayer mm3 = MachineMappingForSingleLayer{ - cost3, - ParallelLayerGuidObliviousMachineMapping{ - { - { - BinaryTreePath{{}}, - machine_view_2, - }, - }, - }, - }; - - SUBCASE("empty") { - MachineMappingWithMemoryResult before_remove = - empty_machine_mapping_with_memory_result(); - MachineMappingWithMemoryResult result = - remove_non_pareto_optimal_machine_mapping_result(before_remove); - MachineMappingWithMemoryResult correct = - empty_machine_mapping_with_memory_result(); - - CHECK(result == correct); - } - - SUBCASE("all solutions are pareto-optimal") { - MachineMappingWithMemoryResult before_remove = - MachineMappingWithMemoryResult{ - { - mm1, - mm2, - }, - }; - MachineMappingWithMemoryResult result = - remove_non_pareto_optimal_machine_mapping_result(before_remove); - MachineMappingWithMemoryResult correct = before_remove; - - CHECK(result == correct); - } - - SUBCASE("there exists a non-pareto-optimal solution") { - MachineMappingWithMemoryResult before_remove = - MachineMappingWithMemoryResult{ - { - mm1, - mm2, - mm3, - }, - }; - MachineMappingWithMemoryResult result = - remove_non_pareto_optimal_machine_mapping_result(before_remove); - MachineMappingWithMemoryResult correct = MachineMappingWithMemoryResult{ - { - mm1, - mm2, - }, - }; - - CHECK(result == correct); - } - } - - TEST_CASE("series_combine(float, MachineMappingWithMemoryResult const &, " - "MachineMappingWithMemoryResult const &, " - "std::optional const&)") { - MachineView machine_view_0 = MachineView{ - /*start=*/MachineSpaceCoordinate{ - /*node_idx=*/0_n, - /*device_idx=*/0_n, - /*device_type=*/DeviceType::GPU, - }, - /*dimensions=*/ - { - MachineViewDimension{ - stride_t{1_p}, - MachineSpecificationDimension::INTRA_NODE, - }, - }, - }; - - MachineView machine_view_1 = MachineView{ - /*start=*/MachineSpaceCoordinate{ - /*node_idx=*/0_n, - /*device_idx=*/0_n, - /*device_type=*/DeviceType::GPU, - }, - /*dimensions=*/ - { - MachineViewDimension{ - stride_t{2_p}, - MachineSpecificationDimension::INTRA_NODE, - }, - }, - }; - - OpCostMetrics pre_cost = OpCostMetrics{ - /*forward_runtime=*/2_ms, - /*backward_runtime=*/2_ms, - /*memory_usage=*/2_bytes, - }; - MachineMappingWithMemoryResult pre = MachineMappingWithMemoryResult{{ - MachineMappingForSingleLayer{ - pre_cost, - ParallelLayerGuidObliviousMachineMapping{ - { - { - BinaryTreePath{ - {BinaryTreePathEntry::LEFT_CHILD}, - }, - machine_view_0, - }, - { - BinaryTreePath{ - {BinaryTreePathEntry::RIGHT_CHILD}, - }, - machine_view_1, - }, - }, - }, - }, - }}; - - OpCostMetrics post_cost = OpCostMetrics{ - /*forward_runtime=*/4_ms, - /*backward_runtime=*/4_ms, - /*memory_usage=*/1_bytes, - }; - - MachineMappingWithMemoryResult post = MachineMappingWithMemoryResult{{ - MachineMappingForSingleLayer{ - post_cost, - ParallelLayerGuidObliviousMachineMapping{ - { - { - BinaryTreePath{{}}, - machine_view_1, - }, - }, - }, - }, - }}; - - MachineMappingWithMemoryResult empty = - empty_machine_mapping_with_memory_result(); - - milliseconds_t comm_cost = 3_ms; - - SUBCASE("pre is empty") { - MachineMappingWithMemoryResult result = series_combine( - comm_cost, empty, post, ParallelSplitTransformation::LthenR); - MachineMappingWithMemoryResult correct = empty; - - CHECK(result == correct); - } - - SUBCASE("post is empty") { - MachineMappingWithMemoryResult result = series_combine( - comm_cost, pre, empty, ParallelSplitTransformation::LthenR); - MachineMappingWithMemoryResult correct = empty; - - CHECK(result == correct); - } - - SUBCASE("both are nonempty") { - MachineMappingWithMemoryResult no_parallel_split_transform = - MachineMappingWithMemoryResult{ - { - MachineMappingForSingleLayer{ - /*cost=*/OpCostMetrics{ - /*forward_runtime=*/pre_cost.forward_runtime + - comm_cost + post_cost.forward_runtime, - /*backward_runtime=*/pre_cost.backward_runtime + - comm_cost + post_cost.backward_runtime, - /*memory_usage=*/pre_cost.memory_usage + - post_cost.memory_usage, - }, - /*machine_mapping=*/ - ParallelLayerGuidObliviousMachineMapping{{ - { - BinaryTreePath{{ - BinaryTreePathEntry::LEFT_CHILD, - BinaryTreePathEntry::LEFT_CHILD, - }}, - machine_view_0, - }, - { - BinaryTreePath{{ - BinaryTreePathEntry::LEFT_CHILD, - BinaryTreePathEntry::RIGHT_CHILD, - }}, - machine_view_1, - }, - { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - }}, - machine_view_1, - }, - }}, - }, - }, - }; - - SUBCASE("parallel_split_transformation = std::nullopt") { - MachineMappingWithMemoryResult result = - series_combine(comm_cost, pre, post, std::nullopt); - MachineMappingWithMemoryResult correct = no_parallel_split_transform; - - CHECK(result == correct); - } - - SUBCASE("parallel_split_transformation = LthenR") { - MachineMappingWithMemoryResult result = series_combine( - comm_cost, pre, post, ParallelSplitTransformation::LthenR); - MachineMappingWithMemoryResult correct = no_parallel_split_transform; - - CHECK(result == correct); - } - - SUBCASE("parallel_split_transformation = RthenL") { - MachineMappingWithMemoryResult result = series_combine( - comm_cost, pre, post, ParallelSplitTransformation::RthenL); - MachineMappingWithMemoryResult correct = MachineMappingWithMemoryResult{ - { - MachineMappingForSingleLayer{ - /*cost=*/OpCostMetrics{ - /*forward_runtime=*/pre_cost.forward_runtime + - comm_cost + post_cost.forward_runtime, - /*backward_runtime=*/pre_cost.backward_runtime + - comm_cost + post_cost.backward_runtime, - /*memory_usage=*/pre_cost.memory_usage + - post_cost.memory_usage, - }, - /*machine_mapping=*/ - ParallelLayerGuidObliviousMachineMapping{{ - { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - BinaryTreePathEntry::LEFT_CHILD, - }}, - machine_view_0, - }, - { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - BinaryTreePathEntry::RIGHT_CHILD, - }}, - machine_view_1, - }, - { - BinaryTreePath{{ - BinaryTreePathEntry::LEFT_CHILD, - }}, - machine_view_1, - }, - }}, - }, - }, - }; - - CHECK(result == correct); - } - } - } - - TEST_CASE("parallel_combine(float, MachineMappingWithMemoryResult const &, " - "MachineMappingWithMemoryResult const &, " - "std::optional const&)") { - MachineView machine_view_0 = MachineView{ - /*start=*/MachineSpaceCoordinate{ - /*node_idx=*/0_n, - /*device_idx=*/0_n, - /*device_type=*/DeviceType::GPU, - }, - /*dimensions=*/ - { - MachineViewDimension{ - stride_t{1_p}, - MachineSpecificationDimension::INTRA_NODE, - }, - }, - }; - - MachineView machine_view_1 = MachineView{ - /*start=*/MachineSpaceCoordinate{ - /*node_idx=*/0_n, - /*device_idx=*/0_n, - /*device_type=*/DeviceType::GPU, - }, - /*dimensions=*/ - { - MachineViewDimension{ - stride_t{2_p}, - MachineSpecificationDimension::INTRA_NODE, - }, - }, - }; - - OpCostMetrics lhs_cost = OpCostMetrics{ - /*forward_runtime=*/2_ms, - /*backward_runtime=*/2_ms, - /*memory_usage=*/2_bytes, - }; - MachineMappingWithMemoryResult lhs = MachineMappingWithMemoryResult{{ - MachineMappingForSingleLayer{ - lhs_cost, - ParallelLayerGuidObliviousMachineMapping{ - { - { - BinaryTreePath{ - {BinaryTreePathEntry::LEFT_CHILD}, - }, - machine_view_0, - }, - { - BinaryTreePath{ - {BinaryTreePathEntry::RIGHT_CHILD}, - }, - machine_view_1, - }, - }, - }, - }, - }}; - - OpCostMetrics rhs_cost = OpCostMetrics{ - /*forward_runtime=*/4_ms, - /*backward_runtime=*/4_ms, - /*memory_usage=*/1_bytes, - }; - MachineMappingWithMemoryResult rhs = MachineMappingWithMemoryResult{{ - MachineMappingForSingleLayer{ - rhs_cost, - ParallelLayerGuidObliviousMachineMapping{ - { - { - BinaryTreePath{{}}, - machine_view_1, - }, - }, - }, - }, - }}; - - MachineMappingWithMemoryResult empty = - empty_machine_mapping_with_memory_result(); - - SUBCASE("lhs is empty") { - MachineMappingWithMemoryResult result = parallel_combine(empty, rhs); - MachineMappingWithMemoryResult correct = empty; - - CHECK(result == correct); - } - - SUBCASE("rhs is empty") { - MachineMappingWithMemoryResult result = parallel_combine(lhs, empty); - MachineMappingWithMemoryResult correct = empty; - - CHECK(result == correct); - } - - SUBCASE("both are nonempty") { - MachineMappingWithMemoryResult result = parallel_combine(lhs, rhs); - MachineMappingWithMemoryResult correct = MachineMappingWithMemoryResult{{ - MachineMappingForSingleLayer{ - /*cost=*/OpCostMetrics{ - /*forward_runtime=*/std::max(lhs_cost.forward_runtime, - rhs_cost.forward_runtime), - /*backward_runtime=*/ - std::max(lhs_cost.backward_runtime, - rhs_cost.backward_runtime), - /*memory_usage=*/ - std::max(lhs_cost.memory_usage, rhs_cost.memory_usage), - }, - /*machine_mapping=*/ - ParallelLayerGuidObliviousMachineMapping{ - { - { - BinaryTreePath{{BinaryTreePathEntry::LEFT_CHILD, - BinaryTreePathEntry::LEFT_CHILD}}, - machine_view_0, - }, - { - BinaryTreePath{{BinaryTreePathEntry::LEFT_CHILD, - BinaryTreePathEntry::RIGHT_CHILD}}, - machine_view_1, - }, - { - BinaryTreePath{{BinaryTreePathEntry::RIGHT_CHILD}}, - machine_view_1, - }, - }, - }, - }, - }}; - - CHECK(result == correct); - } - } - - TEST_CASE("minimize_runtime(memory)") { - MachineView machine_view_0 = MachineView{ - /*start=*/MachineSpaceCoordinate{ - /*node_idx=*/0_n, - /*device_idx=*/0_n, - /*device_type=*/DeviceType::GPU, - }, - /*dimensions=*/ - { - MachineViewDimension{ - stride_t{1_p}, - MachineSpecificationDimension::INTRA_NODE, - }, - }, - }; - - MachineView machine_view_1 = MachineView{ - /*start=*/MachineSpaceCoordinate{ - /*node_idx=*/0_n, - /*device_idx=*/0_n, - /*device_type=*/DeviceType::GPU, - }, - /*dimensions=*/ - { - MachineViewDimension{ - stride_t{2_p}, - MachineSpecificationDimension::INTRA_NODE, - }, - }, - }; - - MachineView machine_view_2 = MachineView{ - /*start=*/MachineSpaceCoordinate{ - /*node_idx=*/0_n, - /*device_idx=*/0_n, - /*device_type=*/DeviceType::GPU, - }, - /*dimensions=*/ - { - MachineViewDimension{ - stride_t{4_p}, - MachineSpecificationDimension::INTRA_NODE, - }, - }, - }; - - OpCostMetrics cost1 = OpCostMetrics{ - /*forward_runtime=*/2_ms, - /*backward_runtime=*/2_ms, - /*memory_usage=*/2_bytes, - }; - OpCostMetrics cost2 = OpCostMetrics{ - /*forward_runtime=*/4_ms, - /*backward_runtime=*/4_ms, - /*memory_usage=*/1_bytes, - }; - OpCostMetrics cost3 = OpCostMetrics{ - /*forward_runtime=*/2_ms, - /*backward_runtime=*/2_ms, - /*memory_usage=*/3_bytes, - }; - - MachineMappingForSingleLayer mm1 = MachineMappingForSingleLayer{ - cost1, - ParallelLayerGuidObliviousMachineMapping{ - { - { - BinaryTreePath{{}}, - machine_view_0, - }, - }, - }, - }; - - MachineMappingForSingleLayer mm2 = MachineMappingForSingleLayer{ - cost2, - ParallelLayerGuidObliviousMachineMapping{ - { - { - BinaryTreePath{{}}, - machine_view_1, - }, - }, - }, - }; - - MachineMappingForSingleLayer mm3 = MachineMappingForSingleLayer{ - cost3, - ParallelLayerGuidObliviousMachineMapping{ - { - { - BinaryTreePath{{}}, - machine_view_2, - }, - }, - }, - }; - - MachineMappingWithMemoryResult result1 = MachineMappingWithMemoryResult{ - { - mm1, - mm2, - }, - }; - - MachineMappingWithMemoryResult result2 = MachineMappingWithMemoryResult{ - { - mm2, - mm3, - }, - }; - - MachineMappingWithMemoryResult result = minimize_runtime(result1, result2); - MachineMappingWithMemoryResult correct = MachineMappingWithMemoryResult{ - { - mm1, - mm2, - }, - }; - - CHECK(result == correct); - } -} diff --git a/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.cc b/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.cc new file mode 100644 index 0000000000..402dbe66d7 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.cc @@ -0,0 +1,587 @@ +#include "compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.h" +#include "compiler/machine_mapping/machine_view.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/rapidcheck/some.h" +#include "utils/nonnegative_int/nonnegative_int.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("MachineMappingWithMemoryResult") { + SUBCASE("initialization") { + SUBCASE("throws if initialized with non-pareto-optimal elements") { + CHECK_THROWS(MachineMappingWithMemoryResult{{ + ParetoOptimalMachineMapping{ + /*cost=*/OpCostMetrics{ + /*forward_runtime=*/5_ms, + /*backward_runtime=*/5_ms, + /*memory_usage=*/6_bytes, + }, + /*machine_mapping=*/ + some(), + }, + ParetoOptimalMachineMapping{ + /*cost=*/OpCostMetrics{ + /*forward_runtime=*/2_ms, + /*backward_runtime=*/4_ms, + /*memory_usage=*/5_bytes, + }, + /*machine_mapping=*/ + some(), + }, + }}); + } + + SUBCASE("allows elements with identical performance") { + ParetoOptimalMachineMapping mapping1 = ParetoOptimalMachineMapping{ + /*cost=*/OpCostMetrics{ + /*forward_runtime=*/5_ms, + /*backward_runtime=*/5_ms, + /*memory_usage=*/6_bytes, + }, + /*machine_mapping=*/ + some(), + }; + + ParetoOptimalMachineMapping mapping2 = ParetoOptimalMachineMapping{ + /*cost=*/OpCostMetrics{ + /*forward_runtime=*/5_ms, + /*backward_runtime=*/5_ms, + /*memory_usage=*/5_bytes, + }, + /*machine_mapping=*/ + some(), + }; + + ParetoOptimalMachineMapping mapping3 = ParetoOptimalMachineMapping{ + /*cost=*/OpCostMetrics{ + /*forward_runtime=*/5_ms, + /*backward_runtime=*/5_ms, + /*memory_usage=*/6_bytes, + }, + /*machine_mapping=*/ + some(), + }; + + MachineMappingWithMemoryResult mapping_result = + MachineMappingWithMemoryResult{{ + mapping1, + mapping2, + mapping3, + }}; + + std::unordered_set result = + mapping_result.get_pareto_frontier(); + + std::unordered_set correct = { + mapping1, + mapping2, + mapping3, + }; + + CHECK(result == correct); + } + + SUBCASE("allows empty set") { + MachineMappingWithMemoryResult mapping_result = + MachineMappingWithMemoryResult{{}}; + + std::unordered_set result = + mapping_result.get_pareto_frontier(); + + std::unordered_set correct = {}; + + CHECK(result == correct); + } + } + } + + TEST_CASE("series_combine(float, MachineMappingWithMemoryResult const &, " + "MachineMappingWithMemoryResult const &, " + "std::optional const&)") { + MachineView machine_view_0 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0_n, + /*device_idx=*/0_n, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{1_p}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView machine_view_1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0_n, + /*device_idx=*/0_n, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{2_p}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + OpCostMetrics pre_cost = OpCostMetrics{ + /*forward_runtime=*/2_ms, + /*backward_runtime=*/2_ms, + /*memory_usage=*/2_bytes, + }; + MachineMappingWithMemoryResult pre = MachineMappingWithMemoryResult{{ + ParetoOptimalMachineMapping{ + pre_cost, + ParallelLayerGuidObliviousMachineMapping{ + { + { + BinaryTreePath{ + {BinaryTreePathEntry::LEFT_CHILD}, + }, + machine_view_0, + }, + { + BinaryTreePath{ + {BinaryTreePathEntry::RIGHT_CHILD}, + }, + machine_view_1, + }, + }, + }, + }, + }}; + + OpCostMetrics post_cost = OpCostMetrics{ + /*forward_runtime=*/4_ms, + /*backward_runtime=*/4_ms, + /*memory_usage=*/1_bytes, + }; + + MachineMappingWithMemoryResult post = MachineMappingWithMemoryResult{{ + ParetoOptimalMachineMapping{ + post_cost, + ParallelLayerGuidObliviousMachineMapping{ + { + { + BinaryTreePath{{}}, + machine_view_1, + }, + }, + }, + }, + }}; + + MachineMappingWithMemoryResult empty = + empty_machine_mapping_with_memory_result(); + + milliseconds_t comm_cost = 3_ms; + + SUBCASE("pre is empty") { + MachineMappingWithMemoryResult result = series_combine( + comm_cost, empty, post, ParallelSplitTransformation::LthenR); + MachineMappingWithMemoryResult correct = empty; + + CHECK(result == correct); + } + + SUBCASE("post is empty") { + MachineMappingWithMemoryResult result = series_combine( + comm_cost, pre, empty, ParallelSplitTransformation::LthenR); + MachineMappingWithMemoryResult correct = empty; + + CHECK(result == correct); + } + + SUBCASE("both are nonempty") { + MachineMappingWithMemoryResult no_parallel_split_transform = + MachineMappingWithMemoryResult{ + { + ParetoOptimalMachineMapping{ + /*cost=*/OpCostMetrics{ + /*forward_runtime=*/pre_cost.forward_runtime + + comm_cost + post_cost.forward_runtime, + /*backward_runtime=*/pre_cost.backward_runtime + + comm_cost + post_cost.backward_runtime, + /*memory_usage=*/pre_cost.memory_usage + + post_cost.memory_usage, + }, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }, + }; + + SUBCASE("parallel_split_transformation = std::nullopt") { + MachineMappingWithMemoryResult result = + series_combine(comm_cost, pre, post, std::nullopt); + MachineMappingWithMemoryResult correct = no_parallel_split_transform; + + CHECK(result == correct); + } + + SUBCASE("parallel_split_transformation = LthenR") { + MachineMappingWithMemoryResult result = series_combine( + comm_cost, pre, post, ParallelSplitTransformation::LthenR); + MachineMappingWithMemoryResult correct = no_parallel_split_transform; + + CHECK(result == correct); + } + + SUBCASE("parallel_split_transformation = RthenL") { + MachineMappingWithMemoryResult result = series_combine( + comm_cost, pre, post, ParallelSplitTransformation::RthenL); + MachineMappingWithMemoryResult correct = MachineMappingWithMemoryResult{ + { + ParetoOptimalMachineMapping{ + /*cost=*/OpCostMetrics{ + /*forward_runtime=*/pre_cost.forward_runtime + + comm_cost + post_cost.forward_runtime, + /*backward_runtime=*/pre_cost.backward_runtime + + comm_cost + post_cost.backward_runtime, + /*memory_usage=*/pre_cost.memory_usage + + post_cost.memory_usage, + }, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }, + }; + + CHECK(result == correct); + } + } + } + + TEST_CASE("parallel_combine(float, MachineMappingWithMemoryResult const &, " + "MachineMappingWithMemoryResult const &, " + "std::optional const&)") { + MachineView machine_view_0 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0_n, + /*device_idx=*/0_n, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{1_p}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView machine_view_1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0_n, + /*device_idx=*/0_n, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{2_p}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + OpCostMetrics lhs_cost = OpCostMetrics{ + /*forward_runtime=*/2_ms, + /*backward_runtime=*/2_ms, + /*memory_usage=*/2_bytes, + }; + MachineMappingWithMemoryResult lhs = MachineMappingWithMemoryResult{{ + ParetoOptimalMachineMapping{ + lhs_cost, + ParallelLayerGuidObliviousMachineMapping{ + { + { + BinaryTreePath{ + {BinaryTreePathEntry::LEFT_CHILD}, + }, + machine_view_0, + }, + { + BinaryTreePath{ + {BinaryTreePathEntry::RIGHT_CHILD}, + }, + machine_view_1, + }, + }, + }, + }, + }}; + + OpCostMetrics rhs_cost = OpCostMetrics{ + /*forward_runtime=*/4_ms, + /*backward_runtime=*/4_ms, + /*memory_usage=*/1_bytes, + }; + MachineMappingWithMemoryResult rhs = MachineMappingWithMemoryResult{{ + ParetoOptimalMachineMapping{ + rhs_cost, + ParallelLayerGuidObliviousMachineMapping{ + { + { + BinaryTreePath{{}}, + machine_view_1, + }, + }, + }, + }, + }}; + + MachineMappingWithMemoryResult empty = + empty_machine_mapping_with_memory_result(); + + MachineResourceSplit split = MachineResourceSplit{ + /*offset=*/3_p, + /*dimension=*/MachineSpecificationDimension::INTER_NODE, + }; + + SUBCASE("lhs is empty") { + MachineMappingWithMemoryResult result = + parallel_combine(split, empty, rhs); + MachineMappingWithMemoryResult correct = empty; + + CHECK(result == correct); + } + + SUBCASE("rhs is empty") { + MachineMappingWithMemoryResult result = + parallel_combine(split, lhs, empty); + MachineMappingWithMemoryResult correct = empty; + + CHECK(result == correct); + } + + SUBCASE("both are nonempty") { + MachineMappingWithMemoryResult result = parallel_combine(split, lhs, rhs); + + MachineView translated_machine_view_1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/3_n, + /*device_idx=*/0_n, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{2_p}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineMappingWithMemoryResult correct = MachineMappingWithMemoryResult{{ + ParetoOptimalMachineMapping{ + /*cost=*/OpCostMetrics{ + /*forward_runtime=*/std::max(lhs_cost.forward_runtime, + rhs_cost.forward_runtime), + /*backward_runtime=*/ + std::max(lhs_cost.backward_runtime, + rhs_cost.backward_runtime), + /*memory_usage=*/ + std::max(lhs_cost.memory_usage, rhs_cost.memory_usage), + }, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{ + { + { + BinaryTreePath{{BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::LEFT_CHILD}}, + machine_view_0, + }, + { + BinaryTreePath{{BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD}}, + machine_view_1, + }, + { + BinaryTreePath{{BinaryTreePathEntry::RIGHT_CHILD}}, + translated_machine_view_1, + }, + }, + }, + }, + }}; + + CHECK(result == correct); + } + } + + TEST_CASE("minimize_runtime(MachineMappingWithMemoryResult, " + "MachineMappingWithMemoryResult)") { + MachineView machine_view_0 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0_n, + /*device_idx=*/0_n, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{1_p}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView machine_view_1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0_n, + /*device_idx=*/0_n, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{2_p}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView machine_view_2 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0_n, + /*device_idx=*/0_n, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{4_p}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + OpCostMetrics cost1 = OpCostMetrics{ + /*forward_runtime=*/2_ms, + /*backward_runtime=*/2_ms, + /*memory_usage=*/2_bytes, + }; + OpCostMetrics cost2 = OpCostMetrics{ + /*forward_runtime=*/4_ms, + /*backward_runtime=*/4_ms, + /*memory_usage=*/1_bytes, + }; + OpCostMetrics cost3 = OpCostMetrics{ + /*forward_runtime=*/2.5_ms, + /*backward_runtime=*/2.5_ms, + /*memory_usage=*/3_bytes, + }; + + ParetoOptimalMachineMapping mm1 = ParetoOptimalMachineMapping{ + cost1, + ParallelLayerGuidObliviousMachineMapping{ + { + { + BinaryTreePath{{}}, + machine_view_0, + }, + }, + }, + }; + + ParetoOptimalMachineMapping mm2 = ParetoOptimalMachineMapping{ + cost2, + ParallelLayerGuidObliviousMachineMapping{ + { + { + BinaryTreePath{{}}, + machine_view_1, + }, + }, + }, + }; + + ParetoOptimalMachineMapping mm3 = ParetoOptimalMachineMapping{ + cost3, + ParallelLayerGuidObliviousMachineMapping{ + { + { + BinaryTreePath{{}}, + machine_view_2, + }, + }, + }, + }; + + MachineMappingWithMemoryResult mapping_result1 = + MachineMappingWithMemoryResult{ + { + mm1, + mm2, + }, + }; + + MachineMappingWithMemoryResult mapping_result2 = + MachineMappingWithMemoryResult{ + { + mm2, + mm3, + }, + }; + + MachineMappingWithMemoryResult result = + minimize_runtime(mapping_result1, mapping_result2); + MachineMappingWithMemoryResult correct = MachineMappingWithMemoryResult{ + { + mm1, + mm2, + }, + }; + + CHECK(result == correct); + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/start_invariant_machine_view.cc b/lib/compiler/test/src/compiler/machine_mapping/start_invariant_machine_view.cc new file mode 100644 index 0000000000..3159f49118 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/start_invariant_machine_view.cc @@ -0,0 +1,237 @@ +#include "compiler/machine_mapping/start_invariant_machine_view.h" +#include "op-attrs/task_space_coordinate.h" +#include "utils/fmt/unordered_set.h" +#include "utils/fmt/vector.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("StartInvariantMachineView - utility functions") { + StartInvariantMachineView simv = StartInvariantMachineView{ + {MachineViewDimension{stride_t{2_p}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{2_p}, + MachineSpecificationDimension::INTER_NODE}}, + DeviceType::GPU}; + + SUBCASE("num_dims") { + nonnegative_int result = num_dims(simv); + nonnegative_int correct = 2_n; + CHECK(result == correct); + } + + SUBCASE("get_device_type") { + DeviceType result = get_device_type(simv); + DeviceType correct = DeviceType::GPU; + CHECK(result == correct); + } + + SUBCASE("get_strides") { + std::vector result = get_strides(simv); + std::vector correct = {stride_t{2_p}, stride_t{2_p}}; + CHECK(result == correct); + } + + SUBCASE("get_dimensions") { + std::vector result = get_dimensions(simv); + std::vector correct = { + MachineSpecificationDimension::INTER_NODE, + MachineSpecificationDimension::INTER_NODE}; + CHECK(result == correct); + } + } + + TEST_CASE("StartInvariantMachineView - conversions") { + MachineSpaceCoordinate start = + MachineSpaceCoordinate{1_n, 2_n, DeviceType::GPU}; + std::vector dimensions = { + MachineViewDimension{stride_t{2_p}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{3_p}, + MachineSpecificationDimension::INTRA_NODE}}; + + MachineView mv = MachineView{start, dimensions}; + StartInvariantMachineView simv = + StartInvariantMachineView{dimensions, DeviceType::GPU}; + + SUBCASE("start_invariant_from_machine_view") { + StartInvariantMachineView result = start_invariant_from_machine_view(mv); + StartInvariantMachineView correct = simv; + CHECK(result == correct); + } + + SUBCASE("machine_view_from_start_invariant") { + MachineView result = machine_view_from_start_invariant(simv, start); + MachineView correct = mv; + CHECK(result == correct); + } + + SUBCASE("conversion is invertible") { + SUBCASE("MachineView -> StartInvariant -> MachineView") { + MachineView result = machine_view_from_start_invariant( + start_invariant_from_machine_view(mv), start); + MachineView correct = mv; + CHECK(result == correct); + } + + SUBCASE("StartInvariant -> MachineView -> StartInvariant") { + StartInvariantMachineView result = start_invariant_from_machine_view( + machine_view_from_start_invariant(simv, start)); + StartInvariantMachineView correct = simv; + CHECK(result == correct); + } + } + } + + TEST_CASE("StartInvariantMachineView - get_machine_space_offset") { + SUBCASE("1D case") { + // This operator has shape (3,), and thus 3 tasks. + // The (only) dimension is projected on the INTRA (device) dimension with + // a stride of 2. The machine space has 1 node and 6 devices per node. + /** + * The tasks will thus be distributed like this: + * +-------+-------+-------+-------+-------+-------+ + * | (0,) | | (1,) | | (2,) | | + * +-------+-------+-------+-------+-------+-------+ + */ + OperatorTaskSpace task = OperatorTaskSpace{ + MinimalOrthotope{{ + 3_ge2, + }}, + }; + StartInvariantMachineView simv = StartInvariantMachineView{ + {MachineViewDimension{stride_t{2_p}, + MachineSpecificationDimension::INTRA_NODE}}, + DeviceType::GPU}; + MachineComputeSpecification ms = MachineComputeSpecification{ + /*num_nodes=*/1_p, + /*num_cpus_per_node=*/6_p, + /*num_gpus_per_node=*/6_p, + }; + + SUBCASE("get_machine_space_offset") { + SUBCASE("Task with TaskSpaceCoordinate = (0,)") { + TaskSpaceCoordinate coord = make_task_space_coordinate({0_n}); + MachineSpaceOffset correct = + MachineSpaceOffset{0, 0, DeviceType::GPU}; + MachineSpaceOffset result = + get_machine_space_offset(task, simv, coord); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,)") { + TaskSpaceCoordinate coord = make_task_space_coordinate({1_n}); + MachineSpaceOffset correct = + MachineSpaceOffset{0, 2, DeviceType::GPU}; + MachineSpaceOffset result = + get_machine_space_offset(task, simv, coord); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (2,)") { + TaskSpaceCoordinate coord = make_task_space_coordinate({2_n}); + MachineSpaceOffset correct = + MachineSpaceOffset{0, 4, DeviceType::GPU}; + MachineSpaceOffset result = + get_machine_space_offset(task, simv, coord); + CHECK(correct == result); + } + } + + SUBCASE("get_machine_space_offsets") { + std::unordered_set correct = { + MachineSpaceOffset{0, 0, DeviceType::GPU}, + MachineSpaceOffset{0, 2, DeviceType::GPU}, + MachineSpaceOffset{0, 4, DeviceType::GPU}}; + std::unordered_set result = + get_machine_space_offsets(task, simv); + CHECK(correct == result); + } + } + + SUBCASE("2D case") { + // This operator has shape (2, 2), and thus 2 * 2 = 4 tasks. + // The first dimension is projected onto the INTER (node) dimension with + // stride 1, while the second dimension is projected onto the INTRA + // (device) dimension with stride 2. The machine space has 2 nodes and 4 + // devices per node. + + /** + * The tasks will thus be distributed like this: + * +-------+-------+-------+-------+ + * | (0,0) | | (0,1) | | + * +-------+-------+-------+-------+ + * | (1,0) | | (1,1) | | + * +-------+-------+-------+-------+ + */ + + OperatorTaskSpace task = OperatorTaskSpace{ + MinimalOrthotope{{ + 2_ge2, + 2_ge2, + }}, + }; + StartInvariantMachineView simv = StartInvariantMachineView{ + {MachineViewDimension{stride_t{1_p}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{2_p}, + MachineSpecificationDimension::INTRA_NODE}}, + DeviceType::GPU}; + MachineComputeSpecification ms = MachineComputeSpecification{ + /*num_nodes=*/2_p, + /*num_cpus_per_node=*/4_p, + /*num_gpus_per_node=*/4_p, + }; + + SUBCASE("get_machine_space_offset") { + SUBCASE("Task with TaskSpaceCoordinate = (0,0)") { + TaskSpaceCoordinate coord = make_task_space_coordinate({0_n, 0_n}); + MachineSpaceOffset correct = + MachineSpaceOffset{0, 0, DeviceType::GPU}; + MachineSpaceOffset result = + get_machine_space_offset(task, simv, coord); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (0,1)") { + TaskSpaceCoordinate coord = make_task_space_coordinate({0_n, 1_n}); + MachineSpaceOffset correct = + MachineSpaceOffset{0, 2, DeviceType::GPU}; + MachineSpaceOffset result = + get_machine_space_offset(task, simv, coord); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,0)") { + TaskSpaceCoordinate coord = make_task_space_coordinate({1_n, 0_n}); + MachineSpaceOffset correct = + MachineSpaceOffset{1, 0, DeviceType::GPU}; + MachineSpaceOffset result = + get_machine_space_offset(task, simv, coord); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,1)") { + TaskSpaceCoordinate coord = make_task_space_coordinate({1_n, 1_n}); + MachineSpaceOffset correct = + MachineSpaceOffset{1, 2, DeviceType::GPU}; + MachineSpaceOffset result = + get_machine_space_offset(task, simv, coord); + CHECK(correct == result); + } + } + + SUBCASE("get_machine_space_offsets") { + std::unordered_set correct = { + MachineSpaceOffset{0, 0, DeviceType::GPU}, + MachineSpaceOffset{0, 2, DeviceType::GPU}, + MachineSpaceOffset{1, 0, DeviceType::GPU}, + MachineSpaceOffset{1, 2, DeviceType::GPU}}; + std::unordered_set result = + get_machine_space_offsets(task, simv); + CHECK(correct == result); + } + } + } +} diff --git a/lib/compiler/test/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc b/lib/compiler/test/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc index 1625d79f80..4fce096818 100644 --- a/lib/compiler/test/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc +++ b/lib/compiler/test/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc @@ -9,6 +9,7 @@ #include "pcg/computation_graph.h" #include "pcg/computation_graph_builder.h" #include "utils/containers/get_only.h" +#include "utils/containers/require_only_key.h" #include using namespace ::FlexFlow; @@ -87,20 +88,40 @@ TEST_SUITE(FF_TEST_SUITE) { LayerAddedResult input_added = add_layer(cg, make_layer_attrs(input_attrs), {}, {}); - tensor_guid_t t_input = get_only(input_added.outputs); + tensor_guid_t t_input = + require_only_key(input_added.outputs, TensorSlotName::OUTPUT); LayerAddedResult projection_weight_added = add_layer(cg, make_layer_attrs(projection_weight_attrs), {}, {}); - tensor_guid_t t_projection = get_only(projection_weight_added.outputs); + tensor_guid_t t_projection = require_only_key( + projection_weight_added.outputs, TensorSlotName::OUTPUT); LayerAddedResult bias_weight_added = add_layer(cg, make_layer_attrs(bias_weight_attrs), {}, {}); - tensor_guid_t t_bias = get_only(bias_weight_added.outputs); - - LayerAddedResult linear_added = add_layer(cg, - make_layer_attrs(linear_attrs), - {t_input}, - {t_projection, t_bias}); + tensor_guid_t t_bias = + require_only_key(bias_weight_added.outputs, TensorSlotName::OUTPUT); + + LayerAddedResult linear_added = add_layer( + /*computation_graph=*/cg, + /*attrs=*/make_layer_attrs(linear_attrs), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_input, + }, + }, + /*weights=*/ + { + { + TensorSlotName::WEIGHT, + t_projection, + }, + { + TensorSlotName::BIAS, + t_bias, + }, + }); std::optional result = get_computation_graph_series_parallel_decomposition(cg); @@ -143,21 +164,54 @@ TEST_SUITE(FF_TEST_SUITE) { LayerAddedResult input_added = add_layer(cg, make_layer_attrs(input_attrs), {}, {}); - tensor_guid_t t_input = get_only(input_added.outputs); + tensor_guid_t t_input = + require_only_key(input_added.outputs, TensorSlotName::OUTPUT); LayerAddedResult w1_added = add_layer(cg, make_layer_attrs(projection_weight_attrs), {}, {}); - tensor_guid_t t_w1 = get_only(w1_added.outputs); + tensor_guid_t t_w1 = + require_only_key(w1_added.outputs, TensorSlotName::OUTPUT); LayerAddedResult w2_added = add_layer(cg, make_layer_attrs(projection_weight_attrs), {}, {}); - tensor_guid_t t_w2 = get_only(w2_added.outputs); - - LayerAddedResult op1_added = - add_layer(cg, make_layer_attrs(linear_attrs), {t_input}, {t_w1}); - - LayerAddedResult op2_added = - add_layer(cg, make_layer_attrs(linear_attrs), {t_input}, {t_w2}); + tensor_guid_t t_w2 = + require_only_key(w2_added.outputs, TensorSlotName::OUTPUT); + + LayerAddedResult op1_added = add_layer( + /*computation_graph=*/cg, + /*layer_attrs=*/make_layer_attrs(linear_attrs), + /*input=*/ + { + { + TensorSlotName::INPUT, + t_input, + }, + }, + /*weights=*/ + { + { + TensorSlotName::WEIGHT, + t_w1, + }, + }); + + LayerAddedResult op2_added = add_layer( + /*computation_graph=*/cg, + /*layer_attrs=*/make_layer_attrs(linear_attrs), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_input, + }, + }, + /*weights=*/ + { + { + TensorSlotName::WEIGHT, + t_w2, + }, + }); std::optional result = get_computation_graph_series_parallel_decomposition(cg); @@ -193,17 +247,37 @@ TEST_SUITE(FF_TEST_SUITE) { LayerAddedResult input1_added = add_layer(cg, make_layer_attrs(input_attrs), {}, {}); - tensor_guid_t t_input1 = get_only(input1_added.outputs); + tensor_guid_t t_input1 = + require_only_key(input1_added.outputs, TensorSlotName::OUTPUT); LayerAddedResult input2_added = add_layer(cg, make_layer_attrs(input_attrs), {}, {}); - tensor_guid_t t_input2 = get_only(input2_added.outputs); - - LayerAddedResult op1_added = - add_layer(cg, make_layer_attrs(relu_attrs), {t_input1}, {}); - - LayerAddedResult op2_added = - add_layer(cg, make_layer_attrs(relu_attrs), {t_input2}, {}); + tensor_guid_t t_input2 = + require_only_key(input2_added.outputs, TensorSlotName::OUTPUT); + + LayerAddedResult op1_added = add_layer( + /*computation_graph=*/cg, + /*layer_attrs=*/make_layer_attrs(relu_attrs), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_input1, + }, + }, + /*weights=*/{}); + + LayerAddedResult op2_added = add_layer( + /*computation_graph=*/cg, + /*layer_attrs=*/make_layer_attrs(relu_attrs), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_input2, + }, + }, + /*weights=*/{}); std::optional result = get_computation_graph_series_parallel_decomposition(cg); @@ -234,7 +308,8 @@ TEST_SUITE(FF_TEST_SUITE) { LayerAddedResult input1_added = add_layer(cg, make_layer_attrs(input_attrs), {}, {}); - tensor_guid_t t_input1 = get_only(input1_added.outputs); + tensor_guid_t t_input1 = + require_only_key(input1_added.outputs, TensorSlotName::OUTPUT); ElementBinaryAttrs ew_add_attrs = ElementBinaryAttrs{ /*type=*/OperatorType::EW_ADD, @@ -243,19 +318,61 @@ TEST_SUITE(FF_TEST_SUITE) { /*should_broadcast_rhs=*/false, }; - LayerAddedResult op1_added = - add_layer(cg, make_layer_attrs(relu_attrs), {t_input1}, {}); - tensor_guid_t t_op1 = get_only(op1_added.outputs); - - LayerAddedResult op2_added = - add_layer(cg, make_layer_attrs(relu_attrs), {t_input1}, {}); - tensor_guid_t t_op2 = get_only(op2_added.outputs); - - LayerAddedResult op3_added = - add_layer(cg, make_layer_attrs(relu_attrs), {t_op1}, {}); - - LayerAddedResult op4_added = - add_layer(cg, make_layer_attrs(ew_add_attrs), {t_op1, t_op2}, {}); + LayerAddedResult op1_added = add_layer( + /*computation_graph=*/cg, + /*layer_attrs=*/make_layer_attrs(relu_attrs), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_input1, + }, + }, + /*weights=*/{}); + tensor_guid_t t_op1 = + require_only_key(op1_added.outputs, TensorSlotName::OUTPUT); + + LayerAddedResult op2_added = add_layer( + /*computation_graph=*/cg, + /*layer_attrs=*/make_layer_attrs(relu_attrs), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_input1, + }, + }, + /*weights=*/{}); + tensor_guid_t t_op2 = + require_only_key(op2_added.outputs, TensorSlotName::OUTPUT); + + LayerAddedResult op3_added = add_layer( + /*computation_graph=*/cg, + /*layer_attrs=*/make_layer_attrs(relu_attrs), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_op1, + }, + }, + /*weights=*/{}); + + LayerAddedResult op4_added = add_layer( + /*computation_graph=*/cg, + /*layer_attrs=*/make_layer_attrs(ew_add_attrs), + /*inputs=*/ + { + { + TensorSlotName::LHS_INPUT, + t_op1, + }, + { + TensorSlotName::RHS_INPUT, + t_op2, + }, + }, + /*weights=*/{}); std::optional result = get_computation_graph_series_parallel_decomposition(cg); diff --git a/lib/compiler/test/src/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.cc b/lib/compiler/test/src/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.cc index 81531d7073..a2230f9a86 100644 --- a/lib/compiler/test/src/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.cc +++ b/lib/compiler/test/src/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.cc @@ -4,6 +4,7 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" #include "utils/containers/get_only.h" +#include "utils/containers/require_only_key.h" #include using namespace ::FlexFlow; @@ -55,7 +56,8 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelLayerAddedResult input_added = pcg_add_input_layer(pcg, input_shape); - parallel_tensor_guid_t t_input = get_only(input_added.outputs); + parallel_tensor_guid_t t_input = + require_only_key(input_added.outputs, TensorSlotName::OUTPUT); LinearAttrs linear_attrs = LinearAttrs{ /*out_channels=*/14_p, @@ -79,8 +81,8 @@ TEST_SUITE(FF_TEST_SUITE) { /*layer_attrs=*/make_layer_attrs(projection_weight_attrs), /*inputs=*/{}, /*weights=*/{}); - parallel_tensor_guid_t t_projection_weights = - get_only(projection_weights_added.outputs); + parallel_tensor_guid_t t_projection_weights = require_only_key( + projection_weights_added.outputs, TensorSlotName::OUTPUT); WeightAttrs bias_weight_attrs = WeightAttrs{ /*shape=*/bias_weights_shape, @@ -92,13 +94,29 @@ TEST_SUITE(FF_TEST_SUITE) { /*inputs=*/{}, /*weights=*/{}); parallel_tensor_guid_t t_bias_weights = - get_only(bias_weights_added.outputs); - - ParallelLayerAddedResult linear_added = add_parallel_layer( - pcg, - /*layer_attrs=*/make_layer_attrs(linear_attrs), - /*inputs=*/{t_input}, - /*weights=*/{t_projection_weights, t_bias_weights}); + require_only_key(bias_weights_added.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult linear_added = + add_parallel_layer(pcg, + /*layer_attrs=*/make_layer_attrs(linear_attrs), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_input, + }, + }, + /*weights=*/ + { + { + TensorSlotName::WEIGHT, + t_projection_weights, + }, + { + TensorSlotName::BIAS, + t_bias_weights, + }, + }); std::optional result = get_pcg_series_parallel_decomposition(pcg); @@ -117,14 +135,16 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("SP without weight nodes but non-SP with weight nodes (parallel op " "chain following is not necessary)") { - // A minimal computation graph where without weights (w1 and w2) the - // computation graph is series-parallel, but with weight nodes it is not, - // but parallel op chain following is not necessary - // (in this case because there are no parallel ops involved) - // - // w1 input w2 - // \ / \ / - // op1 op2 + /** + * A minimal computation graph where without weights (w1 and w2) the + * computation graph is series-parallel, but with weight nodes it is not, + * but parallel op chain following is not necessary + * (in this case because there are no parallel ops involved) + * + * w1 input w2 + * \ / \ / + * op1 op2 + */ ParallelComputationGraph pcg = empty_parallel_computation_graph(); @@ -150,21 +170,54 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelLayerAddedResult input_added = add_parallel_layer(pcg, make_layer_attrs(input_attrs), {}, {}); - parallel_tensor_guid_t t_input = get_only(input_added.outputs); + parallel_tensor_guid_t t_input = + require_only_key(input_added.outputs, TensorSlotName::OUTPUT); ParallelLayerAddedResult w1_added = add_parallel_layer( pcg, make_layer_attrs(projection_weight_attrs), {}, {}); - parallel_tensor_guid_t t_w1 = get_only(w1_added.outputs); + parallel_tensor_guid_t t_w1 = + require_only_key(w1_added.outputs, TensorSlotName::OUTPUT); ParallelLayerAddedResult w2_added = add_parallel_layer( pcg, make_layer_attrs(projection_weight_attrs), {}, {}); - parallel_tensor_guid_t t_w2 = get_only(w2_added.outputs); + parallel_tensor_guid_t t_w2 = + require_only_key(w2_added.outputs, TensorSlotName::OUTPUT); - ParallelLayerAddedResult op1_added = add_parallel_layer( - pcg, make_layer_attrs(linear_attrs), {t_input}, {t_w1}); + ParallelLayerAddedResult op1_added = + add_parallel_layer(pcg, + make_layer_attrs(linear_attrs), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_input, + }, + }, + /*weights=*/ + { + { + TensorSlotName::WEIGHT, + t_w1, + }, + }); - ParallelLayerAddedResult op2_added = add_parallel_layer( - pcg, make_layer_attrs(linear_attrs), {t_input}, {t_w2}); + ParallelLayerAddedResult op2_added = + add_parallel_layer(pcg, + make_layer_attrs(linear_attrs), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_input, + }, + }, + /*weights=*/ + { + { + TensorSlotName::WEIGHT, + t_w2, + }, + }); std::optional result = get_pcg_series_parallel_decomposition(pcg); @@ -186,20 +239,23 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("SP without weight nodes but non-SP with weight node (parallel op " "chain following necessary)") { - // A minimal computation graph where without weights (w1 and w2) the - // computation graph is series-parallel, but with weight nodes it is not - // and parallel op chain following is necessary - // - // w1 input w2 - // | | | - // | p2 p4 - // | | | - // p1 p3 p5 - // | | | - // | |\ / - // | op0 \ | - // \ / | / - // op1 op2 + + /** + * A minimal computation graph where without weights (w1 and w2) the + * computation graph is series-parallel, but with weight nodes it is not + * and parallel op chain following is necessary + * + * w1 input w2 + * | | | + * | p2 p4 + * | | | + * p1 p3 p5 + * | | | + * | |\ / + * | op0 \ | + * \ / | / + * op1 op2 + */ ParallelComputationGraph pcg = empty_parallel_computation_graph(); @@ -214,15 +270,20 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelLayerAddedResult input_added = pcg_add_input_layer(pcg, input_shape); parallel_layer_guid_t layer_input = input_added.parallel_layer; - parallel_tensor_guid_t t_input = get_only(input_added.outputs); + parallel_tensor_guid_t t_input = + require_only_key(input_added.outputs, TensorSlotName::OUTPUT); RepartitionAttrs p2_attrs = RepartitionAttrs{ /*repartition_dim=*/ff_dim_t{0_n}, /*repartition_degree=*/3_p, }; ParallelLayerAddedResult p2_added = - add_parallel_layer(pcg, make_layer_attrs(p2_attrs), {t_input}, {}); - parallel_tensor_guid_t t_p2 = get_only(p2_added.outputs); + add_parallel_layer(pcg, + make_layer_attrs(p2_attrs), + {{TensorSlotName::INPUT, t_input}}, + {}); + parallel_tensor_guid_t t_p2 = + require_only_key(p2_added.outputs, TensorSlotName::OUTPUT); ParallelLayerAttrs p3_attrs = ParallelLayerAttrs{ PCGOperatorAttrs{RepartitionAttrs{ @@ -231,16 +292,21 @@ TEST_SUITE(FF_TEST_SUITE) { }}, /*name=*/std::nullopt, }; - ParallelLayerAddedResult p3_added = - add_parallel_layer(pcg, p3_attrs, {t_p2}, {}); - parallel_tensor_guid_t t_p3 = get_only(p3_added.outputs); + ParallelLayerAddedResult p3_added = add_parallel_layer( + pcg, p3_attrs, {{TensorSlotName::INPUT, t_p2}}, {}); + parallel_tensor_guid_t t_p3 = + require_only_key(p3_added.outputs, TensorSlotName::OUTPUT); CastAttrs op0_attrs = CastAttrs{ /*dtype=*/DataType::INT32, }; ParallelLayerAddedResult op0_added = - add_parallel_layer(pcg, make_layer_attrs(op0_attrs), {t_p3}, {}); - parallel_tensor_guid_t t_op0 = get_only(op0_added.outputs); + add_parallel_layer(pcg, + make_layer_attrs(op0_attrs), + {{TensorSlotName::INPUT, t_p3}}, + {}); + parallel_tensor_guid_t t_op0 = + require_only_key(op0_added.outputs, TensorSlotName::OUTPUT); EmbeddingAttrs op1_attrs = EmbeddingAttrs{ /*num_entires=*/100_p, @@ -259,17 +325,34 @@ TEST_SUITE(FF_TEST_SUITE) { }; ParallelLayerAddedResult w1_added = add_parallel_layer(pcg, make_layer_attrs(w1_attrs), {}, {}); - parallel_tensor_guid_t t_w1 = get_only(w1_added.outputs); + parallel_tensor_guid_t t_w1 = + require_only_key(w1_added.outputs, TensorSlotName::OUTPUT); ReplicateAttrs p1_attrs = ReplicateAttrs{ /*replicate_degree=*/6_p, }; - ParallelLayerAddedResult p1_added = - add_parallel_layer(pcg, make_layer_attrs(p1_attrs), {t_w1}, {}); - parallel_tensor_guid_t t_p1 = get_only(p1_added.outputs); + ParallelLayerAddedResult p1_added = add_parallel_layer( + pcg, make_layer_attrs(p1_attrs), {{TensorSlotName::INPUT, t_w1}}, {}); + parallel_tensor_guid_t t_p1 = + require_only_key(p1_added.outputs, TensorSlotName::OUTPUT); - ParallelLayerAddedResult op1_added = - add_parallel_layer(pcg, make_layer_attrs(op1_attrs), {t_op0}, {t_p1}); + ParallelLayerAddedResult op1_added = add_parallel_layer( + /*pcg=*/pcg, + /*layer_attrs=*/make_layer_attrs(op1_attrs), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_op0, + }, + }, + /*weights=*/ + { + { + TensorSlotName::WEIGHT, + t_p1, + }, + }); LinearAttrs op2_attrs = LinearAttrs{ /*out_channels=*/14_p, @@ -286,25 +369,43 @@ TEST_SUITE(FF_TEST_SUITE) { }; ParallelLayerAddedResult w2_added = add_parallel_layer(pcg, make_layer_attrs(w2_attrs), {}, {}); - parallel_tensor_guid_t t_w2 = get_only(w2_added.outputs); + parallel_tensor_guid_t t_w2 = + require_only_key(w2_added.outputs, TensorSlotName::OUTPUT); ReplicateAttrs p4_attrs = ReplicateAttrs{ /*replicate_degree=*/3_p, }; - ParallelLayerAddedResult p4_added = - add_parallel_layer(pcg, make_layer_attrs(p4_attrs), {t_w2}, {}); - parallel_tensor_guid_t t_p4 = get_only(p4_added.outputs); + ParallelLayerAddedResult p4_added = add_parallel_layer( + pcg, make_layer_attrs(p4_attrs), {{TensorSlotName::INPUT, t_w2}}, {}); + parallel_tensor_guid_t t_p4 = + require_only_key(p4_added.outputs, TensorSlotName::OUTPUT); RepartitionAttrs p5_attrs = RepartitionAttrs{ /*repartition_dim=*/ff_dim_t{1_n}, /*repartition_degree=*/2_p, }; - ParallelLayerAddedResult p5_added = - add_parallel_layer(pcg, make_layer_attrs(p5_attrs), {t_p4}, {}); - parallel_tensor_guid_t t_p5 = get_only(p5_added.outputs); + ParallelLayerAddedResult p5_added = add_parallel_layer( + pcg, make_layer_attrs(p5_attrs), {{TensorSlotName::INPUT, t_p4}}, {}); + parallel_tensor_guid_t t_p5 = + require_only_key(p5_added.outputs, TensorSlotName::OUTPUT); - ParallelLayerAddedResult op2_added = - add_parallel_layer(pcg, make_layer_attrs(op2_attrs), {t_p3}, {t_p5}); + ParallelLayerAddedResult op2_added = add_parallel_layer( + /*pcg=*/pcg, + /*layer_attrs=*/make_layer_attrs(op2_attrs), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_p3, + }, + }, + /*weights=*/ + { + { + TensorSlotName::WEIGHT, + t_p5, + }, + }); std::optional result = get_pcg_series_parallel_decomposition(pcg); @@ -350,25 +451,36 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("SP with or without preprocessing, but preprocessing would change " "resulting SP " "decomposition") { - // parallel computation graph: - // - // input1 input2 - // | | - // op1 op2 + + /** + * parallel computation graph: + * + * input1 input2 + * | | + * op1 op2 + */ ParallelLayerAddedResult input1_added = add_parallel_layer(pcg, make_layer_attrs(input_attrs), {}, {}); - parallel_tensor_guid_t t_input1 = get_only(input1_added.outputs); + parallel_tensor_guid_t t_input1 = + require_only_key(input1_added.outputs, TensorSlotName::OUTPUT); ParallelLayerAddedResult input2_added = add_parallel_layer(pcg, make_layer_attrs(input_attrs), {}, {}); - parallel_tensor_guid_t t_input2 = get_only(input2_added.outputs); + parallel_tensor_guid_t t_input2 = + require_only_key(input2_added.outputs, TensorSlotName::OUTPUT); ParallelLayerAddedResult op1_added = - add_parallel_layer(pcg, make_layer_attrs(relu_attrs), {t_input1}, {}); + add_parallel_layer(pcg, + make_layer_attrs(relu_attrs), + {{TensorSlotName::INPUT, t_input1}}, + {}); ParallelLayerAddedResult op2_added = - add_parallel_layer(pcg, make_layer_attrs(relu_attrs), {t_input2}, {}); + add_parallel_layer(pcg, + make_layer_attrs(relu_attrs), + {{TensorSlotName::INPUT, t_input2}}, + {}); std::optional result = get_pcg_series_parallel_decomposition(pcg); @@ -388,18 +500,22 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("not SP with or without weight nodes") { - // parallel computation graph: - // - // input1 - // / \ - // op1 op2 - // | \ | - // | \ | - // op3 op4 + + /** + * parallel computation graph: + * + * input1 + * / \ + * op1 op2 + * | \ | + * | \ | + * op3 op4 + */ ParallelLayerAddedResult input1_added = add_parallel_layer(pcg, make_layer_attrs(input_attrs), {}, {}); - parallel_tensor_guid_t t_input1 = get_only(input1_added.outputs); + parallel_tensor_guid_t t_input1 = + require_only_key(input1_added.outputs, TensorSlotName::OUTPUT); ElementBinaryAttrs ew_add_attrs = ElementBinaryAttrs{ /*type=*/OperatorType::EW_ADD, @@ -409,18 +525,42 @@ TEST_SUITE(FF_TEST_SUITE) { }; ParallelLayerAddedResult op1_added = - add_parallel_layer(pcg, make_layer_attrs(relu_attrs), {t_input1}, {}); - parallel_tensor_guid_t t_op1 = get_only(op1_added.outputs); + add_parallel_layer(pcg, + make_layer_attrs(relu_attrs), + {{TensorSlotName::INPUT, t_input1}}, + {}); + parallel_tensor_guid_t t_op1 = + require_only_key(op1_added.outputs, TensorSlotName::OUTPUT); ParallelLayerAddedResult op2_added = - add_parallel_layer(pcg, make_layer_attrs(relu_attrs), {t_input1}, {}); - parallel_tensor_guid_t t_op2 = get_only(op2_added.outputs); + add_parallel_layer(pcg, + make_layer_attrs(relu_attrs), + {{TensorSlotName::INPUT, t_input1}}, + {}); + parallel_tensor_guid_t t_op2 = + require_only_key(op2_added.outputs, TensorSlotName::OUTPUT); ParallelLayerAddedResult op3_added = - add_parallel_layer(pcg, make_layer_attrs(relu_attrs), {t_op1}, {}); + add_parallel_layer(pcg, + make_layer_attrs(relu_attrs), + {{TensorSlotName::INPUT, t_op1}}, + {}); ParallelLayerAddedResult op4_added = add_parallel_layer( - pcg, make_layer_attrs(ew_add_attrs), {t_op1, t_op2}, {}); + /*pcg=*/pcg, + /*layer_attrs=*/make_layer_attrs(ew_add_attrs), + /*inputs=*/ + { + { + TensorSlotName::LHS_INPUT, + t_op1, + }, + { + TensorSlotName::RHS_INPUT, + t_op2, + }, + }, + /*=*/{}); std::optional result = get_pcg_series_parallel_decomposition(pcg); diff --git a/lib/compiler/test/src/compiler/task_graph_simulator/task_simulator.cc b/lib/compiler/test/src/compiler/task_graph_simulator/task_simulator.cc index 6571b78540..2846de6559 100644 --- a/lib/compiler/test/src/compiler/task_graph_simulator/task_simulator.cc +++ b/lib/compiler/test/src/compiler/task_graph_simulator/task_simulator.cc @@ -4,6 +4,10 @@ #include "compiler/machine_mapping/machine_mapping.dtg.h" #include "compiler/machine_mapping/machine_mapping.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" +#include "compiler/machine_mapping/machine_view.dtg.h" +#include "compiler/machine_mapping/machine_view.h" +#include "compiler/machine_mapping/machine_view_dimension.dtg.h" +#include "compiler/machine_mapping/stride_t.dtg.h" #include "internal/runtime_only_cost_estimator_for_test.h" #include "op-attrs/ops/input_attrs.dtg.h" #include "op-attrs/parallel_tensor_dims.dtg.h" @@ -12,16 +16,11 @@ #include "pcg/device_id.h" #include "pcg/device_type.dtg.h" #include "pcg/machine_space_coordinate.dtg.h" -#include "pcg/machine_specification.h" #include "pcg/machine_specification_dimension.dtg.h" -#include "pcg/machine_view.dtg.h" -#include "pcg/machine_view.h" -#include "pcg/machine_view_dimension.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" -#include "pcg/stride_t.dtg.h" #include "substitutions/sub_parallel_computation_graph.dtg.h" #include "substitutions/sub_parallel_computation_graph.h" #include "utils/containers/get_only.h" @@ -37,12 +36,17 @@ namespace FlexFlow { TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("task_simulator_estimate_forward_pass_time") { - MachineSpecification machine_spec = - MachineSpecification{/*num_nodes=*/3_p, - /*num_cpus_per_node=*/3_p, - /*num_gpus_per_node=*/3_p, - /*inter_node_bandwidth=*/1.0f, - /*intra_node_bandwidth=*/1.0f}; + MachineSpecification machine_spec = MachineSpecification{ + MachineComputeSpecification{ + /*num_nodes=*/3_p, + /*num_cpus_per_node=*/3_p, + /*num_gpus_per_node=*/3_p, + }, + MachineInterconnectSpecification{ + /*inter_node_bandwidth=*/bytes_per_second_t{1.0f}, + /*intra_node_bandwidth=*/bytes_per_second_t{1.0f}, + }, + }; SUBCASE("linear graph") { ParallelComputationGraphBuilder b; @@ -61,16 +65,7 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t layer0 = get_source_layer(tensor0); parallel_layer_guid_t layer1 = get_source_layer(tensor1); - std::vector dims = { - MachineViewDimension{stride_t{1_p}, - MachineSpecificationDimension::INTER_NODE}, - MachineViewDimension{stride_t{1_p}, - MachineSpecificationDimension::INTER_NODE}, - MachineViewDimension{stride_t{1_p}, - MachineSpecificationDimension::INTER_NODE}, - MachineViewDimension{stride_t{1_p}, - MachineSpecificationDimension::INTER_NODE}, - }; + std::vector dims = {}; ParallelComputationGraph pcg = b.pcg; MachineView mv1 = MachineView{MachineSpaceCoordinate{0_n, 0_n, DeviceType::GPU}, dims}; @@ -150,16 +145,7 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t layer3 = get_source_layer(tensor3); ParallelComputationGraph pcg = b.pcg; - std::vector dims = { - MachineViewDimension{stride_t{1_p}, - MachineSpecificationDimension::INTER_NODE}, - MachineViewDimension{stride_t{1_p}, - MachineSpecificationDimension::INTER_NODE}, - MachineViewDimension{stride_t{1_p}, - MachineSpecificationDimension::INTER_NODE}, - MachineViewDimension{stride_t{1_p}, - MachineSpecificationDimension::INTER_NODE}, - }; + std::vector dims = {}; SUBCASE("all different devices") { MachineView mv0 = MachineView{ diff --git a/lib/compiler/test/src/graph_optimize_state.cc b/lib/compiler/test/src/graph_optimize_state.cc deleted file mode 100644 index e7060ef421..0000000000 --- a/lib/compiler/test/src/graph_optimize_state.cc +++ /dev/null @@ -1,103 +0,0 @@ -#include "compiler/graph_optimize_state.h" -#include "doctest/doctest.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("GraphOptimizeState::operator==") { - TensorShape input_shape = TensorShape{ - TensorDims{ - FFOrdered{ - 32_p, - 16_p, - }, - }, - DataType::FLOAT, - }; - - // `machine_mapping` is determined by the PCG and the device mapping - // algorithm, and `runtime` is determined by the PCG and the device mapping, - // so their values here do not matter. - std::unordered_map empty_machine_views; - MachineMapping empty_machine_mapping(empty_machine_views); - - InitializerAttrs zero_init = InitializerAttrs{ZeroInitializerAttrs{}}; - - auto create_pcg = [&]() -> ParallelComputationGraph { - ParallelComputationGraphBuilder builder; - - parallel_tensor_guid_t input0 = - builder.create_input_tensor(input_shape, "input0"); - parallel_tensor_guid_t dense0 = - builder.dense(/*input=*/input0, - /*outDim=*/8_p, - /*activation=*/Activation::RELU, - /*use_bias=*/true, - /*data_type=*/DataType::FLOAT, - /*projection_initializer=*/zero_init, - /*bias_initializer=*/zero_init, - /*name=*/"dense0"); - - parallel_tensor_guid_t dense1 = - builder.dense(/*input=*/dense0, - /*outDim=*/4_p, - /*activation=*/Activation::RELU, - /*use_bias=*/true, - /*data_type=*/DataType::FLOAT, - /*projection_initializer=*/zero_init, - /*bias_initializer=*/zero_init, - /*name=*/"dense1"); - - return builder.pcg; - }; - - ParallelComputationGraph pcg1 = create_pcg(); - - SUBCASE("returns true if the PCGs are isomorphic") { - ParallelComputationGraph pcg2 = create_pcg(); - - GraphOptimizeState state1 = GraphOptimizeState{ - GraphOptimizeResult{pcg1, empty_machine_mapping}, - 0, - }; - - GraphOptimizeState state2 = GraphOptimizeState{ - GraphOptimizeResult{pcg2, empty_machine_mapping}, - 0, - }; - - CHECK(state1 == state2); - } - - SUBCASE("returns false it the PCGs are not isomorphic") { - ParallelComputationGraphBuilder builder_; - - parallel_tensor_guid_t input0_ = - builder_.create_input_tensor(input_shape, "input0"); - parallel_tensor_guid_t dense0_ = - builder_.dense(/*input=*/input0_, - /*outDim=*/8_p, - /*activation=*/Activation::RELU, - /*use_bias=*/true, - /*data_type=*/DataType::FLOAT, - /*projection_initializer=*/zero_init, - /*bias_initializer=*/zero_init, - /*name=*/"dense0"); - - ParallelComputationGraph pcg_ = builder_.pcg; - - GraphOptimizeState state1 = GraphOptimizeState{ - GraphOptimizeResult{pcg1, empty_machine_mapping}, - 0, - }; - - GraphOptimizeState state_ = GraphOptimizeState{ - GraphOptimizeResult{pcg_, empty_machine_mapping}, - 0, - }; - - CHECK_FALSE(state1 == state_); - } - } -} diff --git a/lib/compiler/test/src/internal/cost_estimator_for_test.cc b/lib/compiler/test/src/internal/cost_estimator_for_test.cc index 60bf6ba7a4..7092a3848b 100644 --- a/lib/compiler/test/src/internal/cost_estimator_for_test.cc +++ b/lib/compiler/test/src/internal/cost_estimator_for_test.cc @@ -2,6 +2,7 @@ #include "compiler/cost_estimator/op_cost_metrics.dtg.h" #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" +#include "utils/containers/contains_key.h" #include "utils/nonnegative_int/nonnegative_int.h" namespace FlexFlow { @@ -38,8 +39,12 @@ CostEstimator make_fake_cost_estimator( std::unordered_map const &comm_cost_map) { return make_fake_cost_estimator( - [op_cost_map](OpCostEstimateKey const &k) { return op_cost_map.at(k); }, + [op_cost_map](OpCostEstimateKey const &k) { + ASSERT(contains_key(op_cost_map, k), k); + return op_cost_map.at(k); + }, [comm_cost_map](TensorSetMovement const &m) { + ASSERT(contains_key(comm_cost_map, m), m); return comm_cost_map.at(m); }); } diff --git a/lib/compiler/test/src/internal/runtime_only_cost_estimator_for_test.cc b/lib/compiler/test/src/internal/runtime_only_cost_estimator_for_test.cc index c52344c6b3..59bf08a399 100644 --- a/lib/compiler/test/src/internal/runtime_only_cost_estimator_for_test.cc +++ b/lib/compiler/test/src/internal/runtime_only_cost_estimator_for_test.cc @@ -5,6 +5,7 @@ #include "compiler/cost_estimator/op_cost_metrics.h" #include "compiler/cost_estimator/runtime_only_cost_estimator_from_cost_estimator.h" #include "internal/cost_estimator_for_test.h" +#include "utils/containers/contains_key.h" namespace FlexFlow { @@ -31,9 +32,13 @@ RuntimeOnlyCostEstimator make_fake_runtime_only_cost_estimator( &comm_cost_map) { return make_fake_runtime_only_cost_estimator( [op_cost_map](RuntimeOnlyOpCostEstimateKey const &k) { + ASSERT(contains_key(op_cost_map, k), k); + return op_cost_map.at(k); }, [comm_cost_map](TensorSetMovement const &m) { + ASSERT(contains_key(comm_cost_map, m), m); + return comm_cost_map.at(m); }); } diff --git a/lib/compiler/test/src/unity_algorithm.cc b/lib/compiler/test/src/unity_algorithm.cc deleted file mode 100644 index 8ff0978ea5..0000000000 --- a/lib/compiler/test/src/unity_algorithm.cc +++ /dev/null @@ -1,26 +0,0 @@ -#include "compiler/unity_algorithm.h" -#include "doctest/doctest.h" - -TEST_SUITE(FF_TEST_SUITE) { - // Rapidcheck does not work for now - // TEST_CASE("graph_optimize") { - // RC_SUBCASE([](ComputationGraph const &g, - // float alpha, - // int budget, - // float threshold, - // int max_num_ops) { - // Strategy s = graph_optimize( - // g, - // TestCostEstimator{}, - // MachineSpecification{1, 1, 4, 0.1, 0.2}, - // [](Operator const &, MachineSpecification const &) { - // return std::unordered_set{make_1d_machine_view(0, 1, - // 1)}; - // }, - // OptimizerConfig{alpha, budget, threshold, max_num_ops}); - // RC_ASSERT(get_nodes(s.pcg).size() > 0); - // RC_ASSERT(s.machine_mapping.runtime > 0); - // RC_ASSERT(keys(s.machine_mapping.machine_views) == get_nodes(s.pcg)); - // }); - // } -} diff --git a/lib/kernels/include/kernels/accessor.h b/lib/kernels/include/kernels/accessor.h index ec0d6fde0d..27d3693e7e 100644 --- a/lib/kernels/include/kernels/accessor.h +++ b/lib/kernels/include/kernels/accessor.h @@ -66,6 +66,8 @@ class GenericTensorAccessorR { decltype(ptr) const &, decltype(device_type) const &> tie() const; + + friend ::std::hash; }; std::string format_as(GenericTensorAccessorR const &); @@ -133,6 +135,8 @@ class GenericTensorAccessorW { decltype(ptr) const &, decltype(device_type) const &> tie() const; + + friend ::std::hash; }; std::string format_as(GenericTensorAccessorW const &); @@ -222,12 +226,18 @@ real_type_t
accessor_get_only_value(GenericTensorAccessorR const &acc) { } // namespace FlexFlow -namespace FlexFlow { -static_assert(is_well_behaved_value_type_no_hash::value, - ""); -static_assert(is_well_behaved_value_type_no_hash::value, - ""); +namespace std { -} // namespace FlexFlow +template <> +struct hash<::FlexFlow::GenericTensorAccessorR> { + size_t operator()(::FlexFlow::GenericTensorAccessorR const &) const; +}; + +template <> +struct hash<::FlexFlow::GenericTensorAccessorW> { + size_t operator()(::FlexFlow::GenericTensorAccessorW const &) const; +}; + +} // namespace std #endif diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index db377162b6..d54663f110 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -1,37 +1,31 @@ #ifndef _FLEXFLOW_OPS_KERNELS_BATCH_MATMUL_KERNELS_H #define _FLEXFLOW_OPS_KERNELS_BATCH_MATMUL_KERNELS_H +#include "kernels/accessor.h" #include "kernels/device_handle_t.dtg.h" #include "kernels/device_stream_t.dtg.h" #include "kernels/ff_handle.h" +#include "utils/nonnegative_int/nonnegative_int.h" namespace FlexFlow::Kernels::BatchMatmul { void forward_kernel(device_stream_t const &stream, device_handle_t const &handle, - float *output_ptr, - float const *a_input_ptr, - float const *b_input_ptr, - int m, - int n, - int k, - int batch, - int seq_length, - int a_seq_length_dim, - int b_seq_length_dim); + GenericTensorAccessorW const &output, + GenericTensorAccessorR const &input_a, + GenericTensorAccessorR const &input_b, + positive_int seq_length, + std::optional a_seq_length_dim, + std::optional b_seq_length_dim); void backward_kernel(device_stream_t const &stream, device_handle_t const &handle, - float const *o_ptr, - float const *o_grad_ptr, - float const *a_ptr, - float *a_grad_ptr, - float const *b_ptr, - float *b_grad_ptr, - int m, - int n, - int k, - int batch); + GenericTensorAccessorR const &output, + GenericTensorAccessorR const &output_grad, + GenericTensorAccessorR const &input_a, + GenericTensorAccessorW const &input_a_grad, + GenericTensorAccessorR const &input_b, + GenericTensorAccessorW const &input_b_grad); } // namespace FlexFlow::Kernels::BatchMatmul diff --git a/lib/kernels/include/kernels/batch_matmul_kernels_cpu.h b/lib/kernels/include/kernels/batch_matmul_kernels_cpu.h index fdef3d7fa1..6d9c804be2 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels_cpu.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels_cpu.h @@ -5,27 +5,19 @@ namespace FlexFlow::Kernels::BatchMatmul { -void cpu_forward_kernel(float *output_ptr, - float const *a_input_ptr, - float const *b_input_ptr, - int m, - int n, - int k, - int batch, - int seq_length, - int a_seq_length_dim, - int b_seq_length_dim); +void cpu_forward_kernel(GenericTensorAccessorW const &output, + GenericTensorAccessorR const &input_a, + GenericTensorAccessorR const &input_b, + positive_int seq_length, + std::optional a_seq_length_dim, + std::optional b_seq_length_dim); -void cpu_backward_kernel(float const *o_ptr, - float const *o_grad_ptr, - float const *a_ptr, - float *a_grad_ptr, - float const *b_ptr, - float *b_grad_ptr, - int m, - int n, - int k, - int batch); +void cpu_backward_kernel(GenericTensorAccessorR const &output, + GenericTensorAccessorR const &output_grad, + GenericTensorAccessorR const &input_a, + GenericTensorAccessorW const &input_a_grad, + GenericTensorAccessorR const &input_b, + GenericTensorAccessorW const &input_b_grad); } // namespace FlexFlow::Kernels::BatchMatmul diff --git a/lib/kernels/include/kernels/batch_matmul_kernels_gpu.h b/lib/kernels/include/kernels/batch_matmul_kernels_gpu.h index 4a35c000c3..1e13755b81 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels_gpu.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels_gpu.h @@ -10,8 +10,8 @@ namespace FlexFlow::Kernels::BatchMatmul { void gpu_forward_kernel(ffStream_t stream, PerDeviceFFHandle const &handle, float *output_ptr, - float const *a_input_ptr, - float const *b_input_ptr, + float const *input_a_ptr, + float const *input_b_ptr, int m, int n, int k, @@ -22,12 +22,12 @@ void gpu_forward_kernel(ffStream_t stream, void gpu_backward_kernel(ffStream_t stream, PerDeviceFFHandle const &handle, - float const *o_ptr, - float const *o_grad_ptr, - float const *a_ptr, - float *a_grad_ptr, - float const *b_ptr, - float *b_grad_ptr, + float const *output_ptr, + float const *output_grad_ptr, + float const *input_a_ptr, + float *input_a_grad_ptr, + float const *input_b_ptr, + float *input_b_grad_ptr, int m, int n, int k, diff --git a/lib/kernels/include/kernels/batch_norm_per_device_state.dtg.toml b/lib/kernels/include/kernels/batch_norm_per_device_state.dtg.toml new file mode 100644 index 0000000000..bdf9e1ed51 --- /dev/null +++ b/lib/kernels/include/kernels/batch_norm_per_device_state.dtg.toml @@ -0,0 +1,69 @@ +namespace = "FlexFlow" +name = "BatchNormPerDeviceState" +type = "struct" +features = [] + +includes = [ + "kernels/device.h", + "kernels/ff_handle.h", +] + +[[fields]] +name = "handle" +type = "::FlexFlow::PerDeviceFFHandle" + +[[fields]] +name = "inputTensor" +type = "ffTensorDescriptor_t" + +[[fields]] +name = "outputTensor" +type = "ffTensorDescriptor_t" + +[[fields]] +name = "biasTensor" +type = "ffTensorDescriptor_t" + +[[fields]] +name = "actiDesc" +type = "ffActivationDescriptor_t" + +[[fields]] +name = "mode" +type = "ffBatchNormMode_t" + +[[fields]] +name = "runningMean" +type = "float *" + +[[fields]] +name = "runningVar" +type = "float *" + +[[fields]] +name = "saveMean" +type = "float *" + +[[fields]] +name = "saveVar" +type = "float *" + +[[fields]] +name = "output_n" +type = "int" + +[[fields]] +name = "output_c" +type = "int" + +[[fields]] +name = "output_h" +type = "int" + +[[fields]] +name = "output_w" +type = "int" + +[[fields]] +name = "relu" +type = "bool" diff --git a/lib/kernels/include/kernels/batch_norm_per_device_state.struct.toml b/lib/kernels/include/kernels/batch_norm_per_device_state.struct.toml deleted file mode 100644 index 6d2f04f60c..0000000000 --- a/lib/kernels/include/kernels/batch_norm_per_device_state.struct.toml +++ /dev/null @@ -1,68 +0,0 @@ -namespace = "FlexFlow" -name = "BatchNormPerDeviceState" -features = [] - -includes = [ - "kernels/device.h", - "kernels/ff_handle.h", -] - -[[fields]] -name = "handle" -type = "::FlexFlow::PerDeviceFFHandle" - -[[fields]] -name = "inputTensor" -type = "ffTensorDescriptor_t" - -[[fields]] -name = "outputTensor" -type = "ffTensorDescriptor_t" - -[[fields]] -name = "biasTensor" -type = "ffTensorDescriptor_t" - -[[fields]] -name = "actiDesc" -type = "ffActivationDescriptor_t" - -[[fields]] -name = "mode" -type = "ffBatchNormMode_t" - -[[fields]] -name = "runningMean" -type = "float *" - -[[fields]] -name = "runningVar" -type = "float *" - -[[fields]] -name = "saveMean" -type = "float *" - -[[fields]] -name = "saveVar" -type = "float *" - -[[fields]] -name = "output_n" -type = "int" - -[[fields]] -name = "output_c" -type = "int" - -[[fields]] -name = "output_h" -type = "int" - -[[fields]] -name = "output_w" -type = "int" - -[[fields]] -name = "relu" -type = "bool" diff --git a/lib/kernels/include/kernels/conv_2d_per_device_state.dtg.toml b/lib/kernels/include/kernels/conv_2d_per_device_state.dtg.toml new file mode 100644 index 0000000000..cdbb2fea38 --- /dev/null +++ b/lib/kernels/include/kernels/conv_2d_per_device_state.dtg.toml @@ -0,0 +1,49 @@ +namespace = "FlexFlow" +name = "Conv2DPerDeviceState" +type = "struct" +features = [] + +includes = [ + "kernels/device.h", + "kernels/ff_handle.h", +] + +[[fields]] +name = "handle" +type = "::FlexFlow::PerDeviceFFHandle" + +[[fields]] +name = "inputTensor" +type = "ffTensorDescriptor_t" + +[[fields]] +name = "biasTensor" +type = "ffTensorDescriptor_t" + +[[fields]] +name = "outputTensor" +type = "ffTensorDescriptor_t" + +[[fields]] +name = "filterDesc" +type = "ffFilterDescriptor_t" + +[[fields]] +name = "actiDesc" +type = "ffActivationDescriptor_t" + +[[fields]] +name = "convDesc" +type = "ffConvolutionDescriptor_t" + +[[fields]] +name = "fwdAlgo" +type = "ffConvolutionFwdAlgo_t" + +[[fields]] +name = "bwdFilterAlgo" +type = "ffConvolutionBwdFilterAlgo_t" + +[[fields]] +name = "bwdDataAlgo" +type = "ffConvolutionBwdDataAlgo_t" diff --git a/lib/kernels/include/kernels/conv_2d_per_device_state.struct.toml b/lib/kernels/include/kernels/conv_2d_per_device_state.struct.toml deleted file mode 100644 index d76dbc89d0..0000000000 --- a/lib/kernels/include/kernels/conv_2d_per_device_state.struct.toml +++ /dev/null @@ -1,48 +0,0 @@ -namespace = "FlexFlow" -name = "Conv2DPerDeviceState" -features = [] - -includes = [ - "kernels/device.h", - "kernels/ff_handle.h", -] - -[[fields]] -name = "handle" -type = "::FlexFlow::PerDeviceFFHandle" - -[[fields]] -name = "inputTensor" -type = "ffTensorDescriptor_t" - -[[fields]] -name = "biasTensor" -type = "ffTensorDescriptor_t" - -[[fields]] -name = "outputTensor" -type = "ffTensorDescriptor_t" - -[[fields]] -name = "filterDesc" -type = "ffFilterDescriptor_t" - -[[fields]] -name = "actiDesc" -type = "ffActivationDescriptor_t" - -[[fields]] -name = "convDesc" -type = "ffConvolutionDescriptor_t" - -[[fields]] -name = "fwdAlgo" -type = "ffConvolutionFwdAlgo_t" - -[[fields]] -name = "bwdFilterAlgo" -type = "ffConvolutionBwdFilterAlgo_t" - -[[fields]] -name = "bwdDataAlgo" -type = "ffConvolutionBwdDataAlgo_t" diff --git a/lib/kernels/include/kernels/create_accessor_with_contents.h b/lib/kernels/include/kernels/create_accessor_with_contents.h index 3574ad0c88..dfe4428bdc 100644 --- a/lib/kernels/include/kernels/create_accessor_with_contents.h +++ b/lib/kernels/include/kernels/create_accessor_with_contents.h @@ -41,10 +41,10 @@ GenericTensorAccessorW create_2d_accessor_w_with_contents( std::vector> const &contents, Allocator &allocator) { positive_int nrows = positive_int{num_elements(contents)}; - positive_int ncols = throw_if_unexpected( + positive_int ncols = require_all_same1(transform(contents, [](std::vector const &row) { return positive_int{num_elements(row)}; - }))); + })); TensorShape shape = TensorShape{ TensorDims{FFOrdered{nrows, ncols}}, @@ -78,18 +78,17 @@ GenericTensorAccessorW create_3d_accessor_w_with_contents( Allocator &allocator) { positive_int dim0_size = positive_int{num_elements(contents)}; - positive_int dim1_size = throw_if_unexpected(require_all_same1( + positive_int dim1_size = require_all_same1( transform(contents, [](std::vector> const &m) { return positive_int{num_elements(m)}; - }))); + })); - positive_int dim2_size = throw_if_unexpected(require_all_same1( + positive_int dim2_size = require_all_same1( transform(contents, [](std::vector> const &m) { - return throw_if_unexpected( - require_all_same1(transform(m, [](std::vector const &vec) { - return positive_int{num_elements(vec)}; - }))); - }))); + return require_all_same1(transform(m, [](std::vector const &vec) { + return positive_int{num_elements(vec)}; + })); + })); TensorShape shape = TensorShape{ TensorDims{FFOrdered{dim0_size, dim1_size, dim2_size}}, @@ -127,29 +126,29 @@ GenericTensorAccessorW create_4d_accessor_w_with_contents( Allocator &allocator) { positive_int dim0_size = positive_int{num_elements(contents)}; - positive_int dim1_size = throw_if_unexpected(require_all_same1(transform( + positive_int dim1_size = require_all_same1(transform( contents, [](std::vector>> const &t) { return positive_int{num_elements(t)}; - }))); + })); - positive_int dim2_size = throw_if_unexpected(require_all_same1(transform( + positive_int dim2_size = require_all_same1(transform( contents, [](std::vector>> const &m) { - return throw_if_unexpected(require_all_same1( + return require_all_same1( transform(m, [](std::vector> const &vec) { return positive_int{num_elements(vec)}; - }))); - }))); + })); + })); - positive_int dim3_size = throw_if_unexpected(require_all_same1(transform( + positive_int dim3_size = require_all_same1(transform( contents, [](std::vector>> const &t) { - return throw_if_unexpected(require_all_same1( + return require_all_same1( transform(t, [](std::vector> const &mat) { - return throw_if_unexpected(require_all_same1( + return require_all_same1( transform(mat, [](std::vector const &vec) { return positive_int{num_elements(vec)}; - }))); - }))); - }))); + })); + })); + })); TensorShape shape = TensorShape{ TensorDims{FFOrdered{dim0_size, dim1_size, dim2_size, dim3_size}}, diff --git a/lib/kernels/include/kernels/device_handle_t.dtg.toml b/lib/kernels/include/kernels/device_handle_t.dtg.toml new file mode 100644 index 0000000000..f26c0f0321 --- /dev/null +++ b/lib/kernels/include/kernels/device_handle_t.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "device_handle_t" +type = "variant" +features = [] + +includes = [ + "", + "kernels/ff_handle.h", +] + +[[values]] +type = "::FlexFlow::PerDeviceFFHandle" +key = "for_gpu" + +[[values]] +type = "std::monostate" +key = "for_cpu" diff --git a/lib/kernels/include/kernels/device_handle_t.variant.toml b/lib/kernels/include/kernels/device_handle_t.variant.toml deleted file mode 100644 index ef574e0745..0000000000 --- a/lib/kernels/include/kernels/device_handle_t.variant.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "device_handle_t" -features = [] - -includes = [ - "", - "kernels/ff_handle.h", -] - -[[values]] -type = "::FlexFlow::PerDeviceFFHandle" -key = "for_gpu" - -[[values]] -type = "std::monostate" -key = "for_cpu" diff --git a/lib/kernels/include/kernels/device_stream_t.dtg.toml b/lib/kernels/include/kernels/device_stream_t.dtg.toml new file mode 100644 index 0000000000..9460a30ea1 --- /dev/null +++ b/lib/kernels/include/kernels/device_stream_t.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "device_stream_t" +type = "variant" +features = [] + +includes = [ + "", + "kernels/device.h", +] + +[[values]] +type = "ffStream_t" +key = "gpu" + +[[values]] +type = "std::monostate" +key = "cpu" diff --git a/lib/kernels/include/kernels/device_stream_t.variant.toml b/lib/kernels/include/kernels/device_stream_t.variant.toml deleted file mode 100644 index b3f8e77171..0000000000 --- a/lib/kernels/include/kernels/device_stream_t.variant.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "device_stream_t" -features = [] - -includes = [ - "", - "kernels/device.h", -] - -[[values]] -type = "ffStream_t" -key = "gpu" - -[[values]] -type = "std::monostate" -key = "cpu" diff --git a/lib/kernels/include/kernels/dropout_per_device_state.dtg.toml b/lib/kernels/include/kernels/dropout_per_device_state.dtg.toml new file mode 100644 index 0000000000..e8ada7212b --- /dev/null +++ b/lib/kernels/include/kernels/dropout_per_device_state.dtg.toml @@ -0,0 +1,41 @@ +namespace = "FlexFlow" +name = "DropoutPerDeviceState" +type = "struct" +features = [] + +includes = [ + "kernels/device.h", + "kernels/ff_handle.h", +] + +[[fields]] +name = "handle" +type = "PerDeviceFFHandle" + +[[fields]] +name = "inputTensor" +type = "ffTensorDescriptor_t" + +[[fields]] +name = "outputTensor" +type = "ffTensorDescriptor_t" + +[[fields]] +name = "dropoutDesc" +type = "ffDropoutDescriptor_t" + +[[fields]] +name = "reserveSpace" +type = "void *" + +[[fields]] +name = "dropoutStates" +type = "void *" + +[[fields]] +name = "reserveSpaceSize" +type = "size_t" + +[[fields]] +name = "dropoutStateSize" +type = "size_t" diff --git a/lib/kernels/include/kernels/dropout_per_device_state.struct.toml b/lib/kernels/include/kernels/dropout_per_device_state.struct.toml deleted file mode 100644 index ffd8bf37e9..0000000000 --- a/lib/kernels/include/kernels/dropout_per_device_state.struct.toml +++ /dev/null @@ -1,40 +0,0 @@ -namespace = "FlexFlow" -name = "DropoutPerDeviceState" -features = [] - -includes = [ - "kernels/device.h", - "kernels/ff_handle.h", -] - -[[fields]] -name = "handle" -type = "PerDeviceFFHandle" - -[[fields]] -name = "inputTensor" -type = "ffTensorDescriptor_t" - -[[fields]] -name = "outputTensor" -type = "ffTensorDescriptor_t" - -[[fields]] -name = "dropoutDesc" -type = "ffDropoutDescriptor_t" - -[[fields]] -name = "reserveSpace" -type = "void *" - -[[fields]] -name = "dropoutStates" -type = "void *" - -[[fields]] -name = "reserveSpaceSize" -type = "size_t" - -[[fields]] -name = "dropoutStateSize" -type = "size_t" diff --git a/lib/kernels/include/kernels/element_binary_per_device_state.dtg.toml b/lib/kernels/include/kernels/element_binary_per_device_state.dtg.toml new file mode 100644 index 0000000000..d0fb28dab4 --- /dev/null +++ b/lib/kernels/include/kernels/element_binary_per_device_state.dtg.toml @@ -0,0 +1,33 @@ +namespace = "FlexFlow" +name = "ElementBinaryPerDeviceState" +type = "struct" +features = [] + +includes = [ + "kernels/ff_handle.h", + "kernels/device.h", +] + +[[fields]] +name = "handle" +type = "::FlexFlow::PerDeviceFFHandle" + +[[fields]] +name = "inputLHSTensor" +type = "ffTensorDescriptor_t" + +[[fields]] +name = "inputRHSTensor" +type = "ffTensorDescriptor_t" + +[[fields]] +name = "outputTensor" +type = "ffTensorDescriptor_t" + +[[fields]] +name = "opDesc" +type = "ffOpTensorDescriptor_t" + +[[fields]] +name = "reduceAddDesc" +type = "ffReduceTensorDescriptor_t" diff --git a/lib/kernels/include/kernels/element_binary_per_device_state.struct.toml b/lib/kernels/include/kernels/element_binary_per_device_state.struct.toml deleted file mode 100644 index 2cae58f847..0000000000 --- a/lib/kernels/include/kernels/element_binary_per_device_state.struct.toml +++ /dev/null @@ -1,32 +0,0 @@ -namespace = "FlexFlow" -name = "ElementBinaryPerDeviceState" -features = [] - -includes = [ - "kernels/ff_handle.h", - "kernels/device.h", -] - -[[fields]] -name = "handle" -type = "::FlexFlow::PerDeviceFFHandle" - -[[fields]] -name = "inputLHSTensor" -type = "ffTensorDescriptor_t" - -[[fields]] -name = "inputRHSTensor" -type = "ffTensorDescriptor_t" - -[[fields]] -name = "outputTensor" -type = "ffTensorDescriptor_t" - -[[fields]] -name = "opDesc" -type = "ffOpTensorDescriptor_t" - -[[fields]] -name = "reduceAddDesc" -type = "ffReduceTensorDescriptor_t" diff --git a/lib/kernels/include/kernels/element_unary_per_device_state.dtg.toml b/lib/kernels/include/kernels/element_unary_per_device_state.dtg.toml new file mode 100644 index 0000000000..a2c1718fa5 --- /dev/null +++ b/lib/kernels/include/kernels/element_unary_per_device_state.dtg.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "ElementUnaryPerDeviceState" +type = "struct" +features = [] + +includes = [ + "kernels/device.h", +] + +[[fields]] +name = "inputTensor" +type = "ffTensorDescriptor_t" + +[[fields]] +name = "outputTensor" +type = "ffTensorDescriptor_t" + +[[fields]] +name = "actiDesc" +type = "ffActivationDescriptor_t" diff --git a/lib/kernels/include/kernels/element_unary_per_device_state.struct.toml b/lib/kernels/include/kernels/element_unary_per_device_state.struct.toml deleted file mode 100644 index 019df40315..0000000000 --- a/lib/kernels/include/kernels/element_unary_per_device_state.struct.toml +++ /dev/null @@ -1,19 +0,0 @@ -namespace = "FlexFlow" -name = "ElementUnaryPerDeviceState" -features = [] - -includes = [ - "kernels/device.h", -] - -[[fields]] -name = "inputTensor" -type = "ffTensorDescriptor_t" - -[[fields]] -name = "outputTensor" -type = "ffTensorDescriptor_t" - -[[fields]] -name = "actiDesc" -type = "ffActivationDescriptor_t" diff --git a/lib/kernels/include/kernels/embedding_kernels.h b/lib/kernels/include/kernels/embedding_kernels.h index e9c158598a..9a23386efe 100644 --- a/lib/kernels/include/kernels/embedding_kernels.h +++ b/lib/kernels/include/kernels/embedding_kernels.h @@ -14,8 +14,8 @@ void forward_kernel(device_stream_t const &stream, DataType input_data_type, DataType output_data_type, std::optional aggr, - int in_dim, - int out_dim, + num_tensor_dims_t in_dim, + num_tensor_dims_t out_dim, int batch_size); void backward_kernel(device_stream_t const &stream, @@ -25,8 +25,8 @@ void backward_kernel(device_stream_t const &stream, DataType output_data_type, DataType input_data_type, std::optional aggr, - int in_dim, - int out_dim, + num_tensor_dims_t in_dim, + num_tensor_dims_t out_dim, int batch_size); } // namespace FlexFlow::Kernels::Embedding diff --git a/lib/kernels/include/kernels/embedding_kernels_cpu.h b/lib/kernels/include/kernels/embedding_kernels_cpu.h index 23e32589ae..c2430ba987 100644 --- a/lib/kernels/include/kernels/embedding_kernels_cpu.h +++ b/lib/kernels/include/kernels/embedding_kernels_cpu.h @@ -12,8 +12,8 @@ void cpu_forward_kernel(GenericTensorAccessorR const &input, DataType input_data_type, DataType output_data_type, std::optional aggr, - int in_dim, - int out_dim, + num_tensor_dims_t in_dim, + num_tensor_dims_t out_dim, int batch_size); void cpu_backward_kernel(GenericTensorAccessorR const &output, @@ -22,8 +22,8 @@ void cpu_backward_kernel(GenericTensorAccessorR const &output, DataType output_data_type, DataType input_data_type, std::optional aggr, - int in_dim, - int out_dim, + num_tensor_dims_t in_dim, + num_tensor_dims_t out_dim, int batch_size); } // namespace FlexFlow::Kernels::Embedding diff --git a/lib/kernels/include/kernels/gather_per_device_state.dtg.toml b/lib/kernels/include/kernels/gather_per_device_state.dtg.toml new file mode 100644 index 0000000000..dae8f23137 --- /dev/null +++ b/lib/kernels/include/kernels/gather_per_device_state.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "GatherPerDeviceState" +type = "struct" +features = [] + +includes = [ + "kernels/ff_handle.h", + "op-attrs/ff_dim_t.dtg.h", +] + +[[fields]] +name = "handle" +type = "::FlexFlow::PerDeviceFFHandle" + +[[fields]] +name = "dim" +type = "::FlexFlow::ff_dim_t" diff --git a/lib/kernels/include/kernels/gather_per_device_state.struct.toml b/lib/kernels/include/kernels/gather_per_device_state.struct.toml deleted file mode 100644 index c5163f0ddc..0000000000 --- a/lib/kernels/include/kernels/gather_per_device_state.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "GatherPerDeviceState" -features = [] - -includes = [ - "kernels/ff_handle.h", - "op-attrs/ff_dim_t.dtg.h", -] - -[[fields]] -name = "handle" -type = "::FlexFlow::PerDeviceFFHandle" - -[[fields]] -name = "dim" -type = "::FlexFlow::ff_dim_t" diff --git a/lib/kernels/include/kernels/layer_norm_per_device_state.dtg.toml b/lib/kernels/include/kernels/layer_norm_per_device_state.dtg.toml new file mode 100644 index 0000000000..b31bba81d8 --- /dev/null +++ b/lib/kernels/include/kernels/layer_norm_per_device_state.dtg.toml @@ -0,0 +1,58 @@ +namespace = "FlexFlow" +name = "LayerNormPerDeviceState" +type = "struct" +features = [] + +includes = [ + "kernels/ff_handle.h", + "op-attrs/datatype.dtg.h", +] + +[[fields]] +name = "handle" +type = "::FlexFlow::PerDeviceFFHandle" + +[[fields]] +name = "elementwise_affine" +type = "bool" + +[[fields]] +name = "effective_num_elements" +type = "int64_t" + +[[fields]] +name = "effective_batch_size" +type = "int64_t" + +[[fields]] +name = "eps" +type = "float" + +[[fields]] +name = "mean" +type = "float *" + +[[fields]] +name = "rstd" +type = "float *" + +[[fields]] +name = "ds" +type = "float *" + +[[fields]] +name = "db" +type = "float *" + +[[fields]] +name = "scale" +type = "float *" + +[[fields]] +name = "bias" +type = "float *" + +[[fields]] +name = "data_type" +type = "::FlexFlow::DataType" + diff --git a/lib/kernels/include/kernels/layer_norm_per_device_state.struct.toml b/lib/kernels/include/kernels/layer_norm_per_device_state.struct.toml deleted file mode 100644 index 0a482d5395..0000000000 --- a/lib/kernels/include/kernels/layer_norm_per_device_state.struct.toml +++ /dev/null @@ -1,57 +0,0 @@ -namespace = "FlexFlow" -name = "LayerNormPerDeviceState" -features = [] - -includes = [ - "kernels/ff_handle.h", - "op-attrs/datatype.dtg.h", -] - -[[fields]] -name = "handle" -type = "::FlexFlow::PerDeviceFFHandle" - -[[fields]] -name = "elementwise_affine" -type = "bool" - -[[fields]] -name = "effective_num_elements" -type = "int64_t" - -[[fields]] -name = "effective_batch_size" -type = "int64_t" - -[[fields]] -name = "eps" -type = "float" - -[[fields]] -name = "mean" -type = "float *" - -[[fields]] -name = "rstd" -type = "float *" - -[[fields]] -name = "ds" -type = "float *" - -[[fields]] -name = "db" -type = "float *" - -[[fields]] -name = "scale" -type = "float *" - -[[fields]] -name = "bias" -type = "float *" - -[[fields]] -name = "data_type" -type = "::FlexFlow::DataType" - diff --git a/lib/kernels/include/kernels/legion_dim.h b/lib/kernels/include/kernels/legion_dim.h index 24eff46e22..8649d2cbeb 100644 --- a/lib/kernels/include/kernels/legion_dim.h +++ b/lib/kernels/include/kernels/legion_dim.h @@ -5,6 +5,7 @@ #include "kernels/legion_ordered/legion_ordered.h" #include "op-attrs/ff_dim_t.dtg.h" #include "op-attrs/ff_ordered/ff_ordered.h" +#include "op-attrs/num_tensor_dims_t.h" #include "op-attrs/tensor_dims.dtg.h" #include "utils/containers/set_of.h" #include "utils/containers/transform.h" @@ -19,9 +20,9 @@ positive_int &dim_at_idx(TensorDims &, legion_dim_t); legion_dim_t add_to_legion_dim(legion_dim_t legion_dim, int value); -legion_dim_t legion_dim_from_ff_dim(ff_dim_t, nonnegative_int num_dimensions); +legion_dim_t legion_dim_from_ff_dim(ff_dim_t, num_tensor_dims_t num_dimensions); -ff_dim_t ff_dim_from_legion_dim(legion_dim_t, nonnegative_int num_dimensions); +ff_dim_t ff_dim_from_legion_dim(legion_dim_t, num_tensor_dims_t num_dimensions); template std::set key_range(LegionOrdered const &d) { diff --git a/lib/kernels/include/kernels/legion_dim_t.dtg.toml b/lib/kernels/include/kernels/legion_dim_t.dtg.toml new file mode 100644 index 0000000000..1578ce4e20 --- /dev/null +++ b/lib/kernels/include/kernels/legion_dim_t.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "legion_dim_t" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "utils/nonnegative_int/nonnegative_int.h", +] + +[[fields]] +name = "value" +type = "::FlexFlow::nonnegative_int" diff --git a/lib/kernels/include/kernels/legion_dim_t.struct.toml b/lib/kernels/include/kernels/legion_dim_t.struct.toml deleted file mode 100644 index 6c047f096b..0000000000 --- a/lib/kernels/include/kernels/legion_dim_t.struct.toml +++ /dev/null @@ -1,17 +0,0 @@ -namespace = "FlexFlow" -name = "legion_dim_t" -features = [ - "eq", - "ord", - "hash", - "json", - "fmt", -] - -includes = [ - "utils/nonnegative_int/nonnegative_int.h", -] - -[[fields]] -name = "value" -type = "::FlexFlow::nonnegative_int" diff --git a/lib/kernels/include/kernels/linear_per_device_state.dtg.toml b/lib/kernels/include/kernels/linear_per_device_state.dtg.toml new file mode 100644 index 0000000000..ec4f93b3dc --- /dev/null +++ b/lib/kernels/include/kernels/linear_per_device_state.dtg.toml @@ -0,0 +1,57 @@ +namespace = "FlexFlow" +name = "LinearPerDeviceState" +type = "struct" +features = [] + +includes = [ + "kernels/ff_handle.h", + "kernels/device.h", + "", + "op-attrs/activation.dtg.h", + "op-attrs/regularizer_attrs.dtg.h", + "op-attrs/datatype.dtg.h", +] + +[[fields]] +name = "handle" +type = "::FlexFlow::PerDeviceFFHandle" + +[[fields]] +name = "outputTensor" +type = "ffTensorDescriptor_t" + +[[fields]] +name = "actiDesc" +type = "ffActivationDescriptor_t" + +[[fields]] +name = "one_ptr" +type = "float const *" + +[[fields]] +name = "activation_mode" +type = "cudnnActivationMode_t" + +[[fields]] +name = "activation" +type = "std::optional<::FlexFlow::Activation>" + +[[fields]] +name = "regularizer" +type = "std::optional<::FlexFlow::RegularizerAttrs>" + +[[fields]] +name = "use_bias" +type = "bool" + +[[fields]] +name = "input_type" +type = "::FlexFlow::DataType" + +[[fields]] +name = "weight_type" +type = "::FlexFlow::DataType" + +[[fields]] +name = "output_type" +type = "::FlexFlow::DataType" diff --git a/lib/kernels/include/kernels/linear_per_device_state.struct.toml b/lib/kernels/include/kernels/linear_per_device_state.struct.toml deleted file mode 100644 index 3ed534a23f..0000000000 --- a/lib/kernels/include/kernels/linear_per_device_state.struct.toml +++ /dev/null @@ -1,56 +0,0 @@ -namespace = "FlexFlow" -name = "LinearPerDeviceState" -features = [] - -includes = [ - "kernels/ff_handle.h", - "kernels/device.h", - "", - "op-attrs/activation.dtg.h", - "op-attrs/regularizer_attrs.dtg.h", - "op-attrs/datatype.dtg.h", -] - -[[fields]] -name = "handle" -type = "::FlexFlow::PerDeviceFFHandle" - -[[fields]] -name = "outputTensor" -type = "ffTensorDescriptor_t" - -[[fields]] -name = "actiDesc" -type = "ffActivationDescriptor_t" - -[[fields]] -name = "one_ptr" -type = "float const *" - -[[fields]] -name = "activation_mode" -type = "cudnnActivationMode_t" - -[[fields]] -name = "activation" -type = "std::optional<::FlexFlow::Activation>" - -[[fields]] -name = "regularizer" -type = "std::optional<::FlexFlow::RegularizerAttrs>" - -[[fields]] -name = "use_bias" -type = "bool" - -[[fields]] -name = "input_type" -type = "::FlexFlow::DataType" - -[[fields]] -name = "weight_type" -type = "::FlexFlow::DataType" - -[[fields]] -name = "output_type" -type = "::FlexFlow::DataType" diff --git a/lib/kernels/include/kernels/map_tensor_accessors.h b/lib/kernels/include/kernels/map_tensor_accessors.h index f7aa6a1001..7473ee26e1 100644 --- a/lib/kernels/include/kernels/map_tensor_accessors.h +++ b/lib/kernels/include/kernels/map_tensor_accessors.h @@ -99,11 +99,11 @@ struct CPUMapTensorAccessors2 { GenericTensorAccessorW &output, F &&f) { - TensorDims dims = throw_if_unexpected(require_all_same1(std::vector{ + TensorDims dims = require_all_same1(std::vector{ lhs.shape.dims, rhs.shape.dims, output.shape.dims, - })); + }); ASSERT(lhs.device_type == DeviceType::CPU); ASSERT(rhs.device_type == DeviceType::CPU); diff --git a/lib/kernels/include/kernels/mha_per_device_state.dtg.toml b/lib/kernels/include/kernels/mha_per_device_state.dtg.toml new file mode 100644 index 0000000000..616e43e9cf --- /dev/null +++ b/lib/kernels/include/kernels/mha_per_device_state.dtg.toml @@ -0,0 +1,66 @@ +namespace = "FlexFlow" +name = "MHAPerDeviceState" +type = "struct" +features = [] + +includes = [ + "kernels/device.h", + "kernels/ff_handle.h", + "kernels/allocation.h", +] + +[[fields]] +name = "handle" +type = "::FlexFlow::PerDeviceFFHandle" + +[[fields]] +name = "weightSize" +type = "size_t" + +[[fields]] +name = "reserveSpaceSize" +type = "size_t" + +[[fields]] +name = "attnDesc" +type = "ffAttnDescriptor_t" + +[[fields]] +name = "qDesc" +type = "ffSeqDataDescriptor_t" + +[[fields]] +name = "kDesc" +type = "ffSeqDataDescriptor_t" + +[[fields]] +name = "vDesc" +type = "ffSeqDataDescriptor_t" + +[[fields]] +name = "oDesc" +type = "ffSeqDataDescriptor_t" + +[[fields]] +name = "devQoSeqArray" +type = "int *" + +[[fields]] +name = "devKvSeqArray" +type = "int *" + +[[fields]] +name = "loWinIdx" +type = "int *" + +[[fields]] +name = "hiWinIdx" +type = "int *" + +[[fields]] +name = "reserveSpace" +type = "void *" + +[[fields]] +name = "allocator" +type = "::FlexFlow::Allocator" diff --git a/lib/kernels/include/kernels/mha_per_device_state.struct.toml b/lib/kernels/include/kernels/mha_per_device_state.struct.toml deleted file mode 100644 index 324e8d1184..0000000000 --- a/lib/kernels/include/kernels/mha_per_device_state.struct.toml +++ /dev/null @@ -1,65 +0,0 @@ -namespace = "FlexFlow" -name = "MHAPerDeviceState" -features = [] - -includes = [ - "kernels/device.h", - "kernels/ff_handle.h", - "kernels/allocation.h", -] - -[[fields]] -name = "handle" -type = "::FlexFlow::PerDeviceFFHandle" - -[[fields]] -name = "weightSize" -type = "size_t" - -[[fields]] -name = "reserveSpaceSize" -type = "size_t" - -[[fields]] -name = "attnDesc" -type = "ffAttnDescriptor_t" - -[[fields]] -name = "qDesc" -type = "ffSeqDataDescriptor_t" - -[[fields]] -name = "kDesc" -type = "ffSeqDataDescriptor_t" - -[[fields]] -name = "vDesc" -type = "ffSeqDataDescriptor_t" - -[[fields]] -name = "oDesc" -type = "ffSeqDataDescriptor_t" - -[[fields]] -name = "devQoSeqArray" -type = "int *" - -[[fields]] -name = "devKvSeqArray" -type = "int *" - -[[fields]] -name = "loWinIdx" -type = "int *" - -[[fields]] -name = "hiWinIdx" -type = "int *" - -[[fields]] -name = "reserveSpace" -type = "void *" - -[[fields]] -name = "allocator" -type = "::FlexFlow::Allocator" diff --git a/lib/kernels/include/kernels/partition_per_device_state.dtg.toml b/lib/kernels/include/kernels/partition_per_device_state.dtg.toml new file mode 100644 index 0000000000..5a51d06736 --- /dev/null +++ b/lib/kernels/include/kernels/partition_per_device_state.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "RepartitionPerDeviceState" +type = "struct" +features = [] + +includes = [ + "kernels/ff_handle.h", + "op-attrs/datatype.dtg.h", +] + +[[fields]] +name = "handle" +type = "::FlexFlow::PerDeviceFFHandle" + +[[fields]] +name = "data_type" +type = "::FlexFlow::DataType" diff --git a/lib/kernels/include/kernels/partition_per_device_state.struct.toml b/lib/kernels/include/kernels/partition_per_device_state.struct.toml deleted file mode 100644 index a008e422cd..0000000000 --- a/lib/kernels/include/kernels/partition_per_device_state.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "RepartitionPerDeviceState" -features = [] - -includes = [ - "kernels/ff_handle.h", - "op-attrs/datatype.dtg.h", -] - -[[fields]] -name = "handle" -type = "::FlexFlow::PerDeviceFFHandle" - -[[fields]] -name = "data_type" -type = "::FlexFlow::DataType" diff --git a/lib/kernels/include/kernels/perf_metrics.dtg.toml b/lib/kernels/include/kernels/perf_metrics.dtg.toml new file mode 100644 index 0000000000..87188ba516 --- /dev/null +++ b/lib/kernels/include/kernels/perf_metrics.dtg.toml @@ -0,0 +1,60 @@ +namespace = "FlexFlow" +name = "PerfMetrics" +type = "struct" +features = [ + "eq", + "hash", + "json", +] + +includes = [ + "", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", +] + +[[fields]] +name = "train_all" +type = "int" +docstring = "measure_accuracy denominator" + +[[fields]] +name = "train_correct" +type = "std::optional" +docstring = "measure_accuracy numerator" + +[[fields]] +name = "cce_loss" +type = "std::optional" +docstring = "measure_categorical_crossentropy" + +[[fields]] +name = "sparse_cce_loss" +type = "std::optional" +docstring = "measure_sparse_categorical_crossentropy" + +[[fields]] +name = "mse_loss" +type = "std::optional" +docstring = "measure_mean_squared_error" + +[[fields]] +name = "rmse_loss" +type = "std::optional" +docstring = "measure_root_mean_squared_error" + +[[fields]] +name = "mae_loss" +type = "std::optional" +docstring = "measure_mean_absolute_error" + +[[fields]] +name = "start_time" +type = "double" + +[[fields]] +name = "current_time" +type = "double" diff --git a/lib/kernels/include/kernels/perf_metrics.h b/lib/kernels/include/kernels/perf_metrics.h index c4a34e4f79..69f96491e0 100644 --- a/lib/kernels/include/kernels/perf_metrics.h +++ b/lib/kernels/include/kernels/perf_metrics.h @@ -1,37 +1,11 @@ #ifndef _FLEXFLOW_KERNELS_INCLUDE_KERNELS_PERF_METRICS_H #define _FLEXFLOW_KERNELS_INCLUDE_KERNELS_PERF_METRICS_H -#include "utils/fmt.h" -#include "utils/visitable.h" +#include "kernels/perf_metrics.dtg.h" +#include namespace FlexFlow { -struct PerfMetrics : public use_visitable_cmp { - PerfMetrics() = delete; - PerfMetrics(double start_time); - PerfMetrics(int train_all, - std::optional train_correct, - std::optional cce_loss, - std::optional sparse_cce_loss, - std::optional mse_loss, - std::optional rmse_loss, - std::optional mae_loss, - double start_time_micro, - double current_time_micro); - - int train_all = 0; // measure_accuracy_denominator - std::optional train_correct = 0; // measure_accuracy numerator - std::optional cce_loss = - std::nullopt; // measure_categorical_crossentropy - std::optional sparse_cce_loss = - 0.0f; // measure_sparse_categorical_crossentropy - std::optional mse_loss = 0.0f; // measure_mean_squared_error - std::optional rmse_loss = 0.0f; // measure_root_mean_squared_error - std::optional mae_loss = 0.0f; // measure_mean_absolute_error - double start_time; - double current_time; -}; - float get_throughput(PerfMetrics const &); float get_accuracy(PerfMetrics const &); @@ -40,16 +14,6 @@ PerfMetrics apply_scale(PerfMetrics const &, float scale); } // namespace FlexFlow -VISITABLE_STRUCT(::FlexFlow::PerfMetrics, - train_all, - train_correct, - cce_loss, - sparse_cce_loss, - mse_loss, - rmse_loss, - mae_loss, - start_time); - namespace fmt { template <> diff --git a/lib/kernels/include/kernels/pool_2d_per_device_state.dtg.toml b/lib/kernels/include/kernels/pool_2d_per_device_state.dtg.toml new file mode 100644 index 0000000000..afc8bad21a --- /dev/null +++ b/lib/kernels/include/kernels/pool_2d_per_device_state.dtg.toml @@ -0,0 +1,33 @@ +namespace = "FlexFlow" +name = "Pool2DPerDeviceState" +type = "struct" +features = [] + +includes = [ + "kernels/ff_handle.h", + "kernels/device.h", +] + +[[fields]] +name = "handle" +type = "::FlexFlow::PerDeviceFFHandle" + +[[fields]] +name = "inputTensor" +type = "ffTensorDescriptor_t" + +[[fields]] +name = "outputTensor" +type = "ffTensorDescriptor_t" + +[[fields]] +name = "actiDesc" +type = "ffActivationDescriptor_t" + +[[fields]] +name = "poolDesc" +type = "ffPoolingDescriptor_t" + +[[fields]] +name = "relu" +type = "bool" diff --git a/lib/kernels/include/kernels/pool_2d_per_device_state.struct.toml b/lib/kernels/include/kernels/pool_2d_per_device_state.struct.toml deleted file mode 100644 index 63e98cca85..0000000000 --- a/lib/kernels/include/kernels/pool_2d_per_device_state.struct.toml +++ /dev/null @@ -1,32 +0,0 @@ -namespace = "FlexFlow" -name = "Pool2DPerDeviceState" -features = [] - -includes = [ - "kernels/ff_handle.h", - "kernels/device.h", -] - -[[fields]] -name = "handle" -type = "::FlexFlow::PerDeviceFFHandle" - -[[fields]] -name = "inputTensor" -type = "ffTensorDescriptor_t" - -[[fields]] -name = "outputTensor" -type = "ffTensorDescriptor_t" - -[[fields]] -name = "actiDesc" -type = "ffActivationDescriptor_t" - -[[fields]] -name = "poolDesc" -type = "ffPoolingDescriptor_t" - -[[fields]] -name = "relu" -type = "bool" diff --git a/lib/kernels/include/kernels/profiling.h b/lib/kernels/include/kernels/profiling.h index c0a0e794e3..6b79f40359 100644 --- a/lib/kernels/include/kernels/profiling.h +++ b/lib/kernels/include/kernels/profiling.h @@ -5,15 +5,16 @@ #include "kernels/device_stream_t.h" #include "kernels/profiling_settings.dtg.h" #include "pcg/device_type.dtg.h" +#include "utils/units/milliseconds_t.h" #include namespace FlexFlow { template -std::optional profiling_wrapper(F const &f, - bool enable_profiling, - DeviceType device_type, - Ts &&...ts) { +std::optional profiling_wrapper(F const &f, + bool enable_profiling, + DeviceType device_type, + Ts &&...ts) { if (enable_profiling) { ProfilingSettings settings = ProfilingSettings{ /*warmup_iters=*/0, @@ -27,10 +28,11 @@ std::optional profiling_wrapper(F const &f, } template -std::optional profiling_wrapper(F const &f, - ProfilingSettings const &settings, - DeviceType device_type, - Ts &&...ts) { +std::optional + profiling_wrapper(F const &f, + ProfilingSettings const &settings, + DeviceType device_type, + Ts &&...ts) { if (settings.measure_iters <= 0) { return std::nullopt; } @@ -44,9 +46,9 @@ std::optional profiling_wrapper(F const &f, } template -float cpu_profiling_wrapper(F const &f, - ProfilingSettings const &settings, - Ts &&...ts) { +milliseconds_t cpu_profiling_wrapper(F const &f, + ProfilingSettings const &settings, + Ts &&...ts) { ASSERT(settings.measure_iters > 0); device_stream_t stream = get_cpu_device_stream(); @@ -67,13 +69,15 @@ float cpu_profiling_wrapper(F const &f, std::chrono::duration avg_duration = (end.value() - start.value()) / settings.measure_iters; - return avg_duration.count(); + return milliseconds_t{ + static_cast(avg_duration.count()), + }; } template -float gpu_profiling_wrapper(F const &f, - ProfilingSettings const &settings, - Ts &&...ts) { +milliseconds_t gpu_profiling_wrapper(F const &f, + ProfilingSettings const &settings, + Ts &&...ts) { ASSERT(settings.measure_iters > 0); device_stream_t stream = get_gpu_device_stream(); @@ -95,7 +99,9 @@ float gpu_profiling_wrapper(F const &f, checkCUDA(ffEventElapsedTime(&elapsed, t_start, t_end)); checkCUDA(ffEventDestroy(t_start)); checkCUDA(ffEventDestroy(t_end)); - return elapsed / settings.measure_iters; + return milliseconds_t{ + elapsed / settings.measure_iters, + }; } } // namespace FlexFlow diff --git a/lib/kernels/include/kernels/profiling_settings.dtg.toml b/lib/kernels/include/kernels/profiling_settings.dtg.toml new file mode 100644 index 0000000000..c9f19c3a50 --- /dev/null +++ b/lib/kernels/include/kernels/profiling_settings.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "ProfilingSettings" +type = "struct" + +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +[[fields]] +name = "warmup_iters" +type = "int" + +[[fields]] +name = "measure_iters" +type = "int" diff --git a/lib/kernels/include/kernels/profiling_settings.struct.toml b/lib/kernels/include/kernels/profiling_settings.struct.toml deleted file mode 100644 index 694dfac76a..0000000000 --- a/lib/kernels/include/kernels/profiling_settings.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "ProfilingSettings" - -features = [ - "eq", - "ord", - "hash", - "json", - "fmt", -] - -[[fields]] -name = "warmup_iters" -type = "int" - -[[fields]] -name = "measure_iters" -type = "int" diff --git a/lib/kernels/include/kernels/reduce_per_device_state.dtg.toml b/lib/kernels/include/kernels/reduce_per_device_state.dtg.toml new file mode 100644 index 0000000000..56cf70fd03 --- /dev/null +++ b/lib/kernels/include/kernels/reduce_per_device_state.dtg.toml @@ -0,0 +1,34 @@ +namespace = "FlexFlow" +name = "ReducePerDeviceState" +type = "struct" +features = [] + +includes = [ + "kernels/device.h", + "kernels/ff_handle.h", + "op-attrs/operator_type.dtg.h", +] + +[[fields]] +name = "handle" +type = "::FlexFlow::PerDeviceFFHandle" + +[[fields]] +name = "inputTensor" +type = "ffTensorDescriptor_t" + +[[fields]] +name = "outputTensor" +type = "ffTensorDescriptor_t" + +[[fields]] +name = "reduceDesc" +type = "ffReduceTensorDescriptor_t" + +[[fields]] +name = "op_type" +type = "::FlexFlow::OperatorType" + +[[fields]] +name = "reduction_size" +type = "size_t" diff --git a/lib/kernels/include/kernels/reduce_per_device_state.struct.toml b/lib/kernels/include/kernels/reduce_per_device_state.struct.toml deleted file mode 100644 index e82099ad25..0000000000 --- a/lib/kernels/include/kernels/reduce_per_device_state.struct.toml +++ /dev/null @@ -1,33 +0,0 @@ -namespace = "FlexFlow" -name = "ReducePerDeviceState" -features = [] - -includes = [ - "kernels/device.h", - "kernels/ff_handle.h", - "op-attrs/operator_type.dtg.h", -] - -[[fields]] -name = "handle" -type = "::FlexFlow::PerDeviceFFHandle" - -[[fields]] -name = "inputTensor" -type = "ffTensorDescriptor_t" - -[[fields]] -name = "outputTensor" -type = "ffTensorDescriptor_t" - -[[fields]] -name = "reduceDesc" -type = "ffReduceTensorDescriptor_t" - -[[fields]] -name = "op_type" -type = "::FlexFlow::OperatorType" - -[[fields]] -name = "reduction_size" -type = "size_t" diff --git a/lib/kernels/include/kernels/reduce_tensor_accessor.h b/lib/kernels/include/kernels/reduce_tensor_accessor.h index a06afbf5f6..02ff63544f 100644 --- a/lib/kernels/include/kernels/reduce_tensor_accessor.h +++ b/lib/kernels/include/kernels/reduce_tensor_accessor.h @@ -33,15 +33,15 @@ struct CPUReduceTensorAccessorInDims { return contains(dims_to_reduce, dim); }; - std::unordered_map> - output_coord_from_input_coord = group_by( - get_tensor_dims_coord_set(input.shape.dims), - [&](TensorDimsCoord const &input_coord) { - return tensor_dims_coord_drop_dims(input_coord, should_drop_dim); - }); + OneToMany output_coord_from_input_coord = + group_by(get_tensor_dims_coord_set(input.shape.dims), + [&](TensorDimsCoord const &input_coord) { + return tensor_dims_coord_drop_dims(input_coord, + should_drop_dim); + }); for (auto const &[output_coord, input_coords] : - output_coord_from_input_coord) { + output_coord_from_input_coord.l_to_r()) { std::vector input_values = transform( sorted(input_coords), [&](TensorDimsCoord const &input_coord) -> T { return input.at
(input_coord); diff --git a/lib/kernels/include/kernels/reverse_kernels_params.dtg.toml b/lib/kernels/include/kernels/reverse_kernels_params.dtg.toml new file mode 100644 index 0000000000..0c69722863 --- /dev/null +++ b/lib/kernels/include/kernels/reverse_kernels_params.dtg.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "ReverseKernelsParams" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/positive_int/positive_int.h", +] + +[[fields]] +name = "num_out_blks" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "reverse_dim_size" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "in_blk_size" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "out_size" +type = "::FlexFlow::positive_int" diff --git a/lib/kernels/include/kernels/reverse_kernels_params.struct.toml b/lib/kernels/include/kernels/reverse_kernels_params.struct.toml deleted file mode 100644 index 1689594491..0000000000 --- a/lib/kernels/include/kernels/reverse_kernels_params.struct.toml +++ /dev/null @@ -1,28 +0,0 @@ -namespace = "FlexFlow" -name = "ReverseKernelsParams" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/positive_int/positive_int.h", -] - -[[fields]] -name = "num_out_blks" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "reverse_dim_size" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "in_blk_size" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "out_size" -type = "::FlexFlow::positive_int" diff --git a/lib/kernels/include/kernels/softmax_per_device_state.dtg.toml b/lib/kernels/include/kernels/softmax_per_device_state.dtg.toml new file mode 100644 index 0000000000..abf144631e --- /dev/null +++ b/lib/kernels/include/kernels/softmax_per_device_state.dtg.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "SoftmaxPerDeviceState" +type = "struct" +features = [] + +includes = [ + "kernels/ff_handle.h", + "kernels/device.h", + "op-attrs/ff_dim_t.dtg.h", +] + +[[fields]] +name = "handle" +type = "::FlexFlow::PerDeviceFFHandle" + +[[fields]] +name = "inputTensor" +type = "ffTensorDescriptor_t" + +[[fields]] +name = "dim" +type = "::FlexFlow::ff_dim_t" diff --git a/lib/kernels/include/kernels/softmax_per_device_state.struct.toml b/lib/kernels/include/kernels/softmax_per_device_state.struct.toml deleted file mode 100644 index 374dd28c63..0000000000 --- a/lib/kernels/include/kernels/softmax_per_device_state.struct.toml +++ /dev/null @@ -1,21 +0,0 @@ -namespace = "FlexFlow" -name = "SoftmaxPerDeviceState" -features = [] - -includes = [ - "kernels/ff_handle.h", - "kernels/device.h", - "op-attrs/ff_dim_t.dtg.h", -] - -[[fields]] -name = "handle" -type = "::FlexFlow::PerDeviceFFHandle" - -[[fields]] -name = "inputTensor" -type = "ffTensorDescriptor_t" - -[[fields]] -name = "dim" -type = "::FlexFlow::ff_dim_t" diff --git a/lib/kernels/src/cuda/ops/element_binary_kernels.cu b/lib/kernels/src/cuda/ops/element_binary_kernels.cu index 7e13486429..e4c698ad02 100644 --- a/lib/kernels/src/cuda/ops/element_binary_kernels.cu +++ b/lib/kernels/src/cuda/ops/element_binary_kernels.cu @@ -18,6 +18,7 @@ #include "kernels/ff_handle.h" #include "op-attrs/datatype.h" #include "op-attrs/operator_type.h" +#include "utils/exception.h" namespace FlexFlow { namespace Kernels { diff --git a/lib/kernels/src/cuda/ops/pool_2d_kernels.cu b/lib/kernels/src/cuda/ops/pool_2d_kernels.cu index ec185a360e..4e06f2da02 100644 --- a/lib/kernels/src/cuda/ops/pool_2d_kernels.cu +++ b/lib/kernels/src/cuda/ops/pool_2d_kernels.cu @@ -15,6 +15,7 @@ #include "internal/device.h" #include "kernels/pool_2d_kernels_gpu.h" +#include "utils/exception.h" namespace FlexFlow { diff --git a/lib/kernels/src/cuda/ops/softmax_kernels.cu b/lib/kernels/src/cuda/ops/softmax_kernels.cu index 85575d7bf6..1ecf42dfd8 100644 --- a/lib/kernels/src/cuda/ops/softmax_kernels.cu +++ b/lib/kernels/src/cuda/ops/softmax_kernels.cu @@ -15,6 +15,7 @@ #include "internal/device.h" #include "kernels/softmax_kernels_gpu.h" +#include "utils/exception.h" namespace FlexFlow { diff --git a/lib/kernels/src/cuda/ops/transpose_kernels.cu b/lib/kernels/src/cuda/ops/transpose_kernels.cu index 85a259769c..f243c13703 100644 --- a/lib/kernels/src/cuda/ops/transpose_kernels.cu +++ b/lib/kernels/src/cuda/ops/transpose_kernels.cu @@ -48,47 +48,44 @@ __global__ void transpose_simple_kernel(std::size_t volume, } } -static LegionOrdered - legion_ordered_perm_from_ff_ordered(FFOrdered const &perm) { - nonnegative_int perm_size = num_elements(perm); - LegionOrdered legion_ordered_perm = - transform(legion_ordered_from_ff_ordered(perm), [&](ff_dim_t d) { - return legion_dim_from_ff_dim(d, perm_size); - }); - - return legion_ordered_perm; -} - -void gpu_forward_kernel(cudaStream_t stream, - TransposeAttrs const &m, - GenericTensorAccessorR const &input, - GenericTensorAccessorW const &output) { +static TransposeStrides make_strides(TransposeAttrs const &m, + TensorDims const &input_dims, + TensorDims const &output_dims) { + ASSERT(get_num_dims(input_dims) == m.permutation.num_tensor_dims()); TransposeStrides info; - info.num_dim = get_num_dims(input.shape.dims).unwrap_nonnegative(); - assert(info.num_dim == m.perm.size()); - - LegionOrdered legion_ordered_perm = - legion_ordered_perm_from_ff_ordered(m.perm); + num_tensor_dims_t num_dims = m.permutation.num_tensor_dims(); + info.num_dim = num_dims.int_from_num_tensor_dims(); for (int i = 0; i < info.num_dim; i++) { + legion_dim_t legion_dim = legion_dim_t{nonnegative_int{i}}; + ff_dim_t ff_dim = ff_dim_from_legion_dim(legion_dim, num_dims); + if (i == 0) { info.in_strides[i] = 1; info.out_strides[i] = 1; } else { int in_dim_size = - dim_at_idx(input.shape.dims, legion_dim_t{nonnegative_int{i}}) - .int_from_positive_int(); + dim_at_idx(input_dims, legion_dim).int_from_positive_int(); int out_dim_size = - dim_at_idx(output.shape.dims, legion_dim_t{nonnegative_int{i}}) - .int_from_positive_int(); + dim_at_idx(output_dims, legion_dim).int_from_positive_int(); info.in_strides[i] = info.in_strides[i - 1] * in_dim_size; info.out_strides[i] = info.out_strides[i - 1] * out_dim_size; } - info.perm[i] = legion_ordered_perm.at(legion_dim_t{nonnegative_int{i}}) - .value.unwrap_nonnegative(); + ff_dim_t ff_permuted_dim = m.permutation.at_l(ff_dim); + legion_dim_t legion_permuted_dim = + legion_dim_from_ff_dim(ff_permuted_dim, num_dims); + info.perm[i] = legion_permuted_dim.value.unwrap_nonnegative(); } + + return info; +} + +void gpu_forward_kernel(cudaStream_t stream, + TransposeAttrs const &m, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &output) { transpose_simple_kernel<<< GET_BLOCKS(get_num_elements(output.shape.dims).int_from_positive_int()), CUDA_NUM_THREADS, @@ -96,7 +93,7 @@ void gpu_forward_kernel(cudaStream_t stream, stream>>>(get_num_elements(output.shape.dims).int_from_positive_int(), input.get_float_ptr(), output.get_float_ptr(), - info, + make_strides(m, input.shape.dims, output.shape.dims), /*beta=*/0.0f); } @@ -104,31 +101,8 @@ void gpu_backward_kernel(cudaStream_t stream, TransposeAttrs const &m, GenericTensorAccessorR const &out_grad, GenericTensorAccessorW const &in_grad) { + ASSERT(get_num_dims(in_grad.shape.dims) == m.permutation.num_tensor_dims()); - TransposeStrides info; - info.num_dim = get_num_dims(in_grad.shape.dims).unwrap_nonnegative(); - assert(info.num_dim == m.perm.size()); - - LegionOrdered legion_ordered_perm = - legion_ordered_perm_from_ff_ordered(m.perm); - - for (int i = 0; i < info.num_dim; i++) { - if (i == 0) { - info.in_strides[i] = 1; - info.out_strides[i] = 1; - } else { - int in_dim_size = - dim_at_idx(out_grad.shape.dims, legion_dim_t{nonnegative_int{i}}) - .int_from_positive_int(); - int out_dim_size = - dim_at_idx(in_grad.shape.dims, legion_dim_t{nonnegative_int{i}}) - .int_from_positive_int(); - info.in_strides[i] = info.in_strides[i - 1] * in_dim_size; - info.out_strides[i] = info.out_strides[i - 1] * out_dim_size; - } - info.perm[legion_ordered_perm.at(legion_dim_t{nonnegative_int{i}}) - .value.unwrap_nonnegative()] = i; - } transpose_simple_kernel<<< GET_BLOCKS(get_num_elements(in_grad.shape.dims).int_from_positive_int()), CUDA_NUM_THREADS, @@ -136,7 +110,7 @@ void gpu_backward_kernel(cudaStream_t stream, stream>>>(get_num_elements(in_grad.shape.dims).int_from_positive_int(), out_grad.get_float_ptr(), in_grad.get_float_ptr(), - info, + make_strides(m, out_grad.shape.dims, in_grad.shape.dims), /*beta=*/1.0f); } diff --git a/lib/kernels/src/kernels/accessor.cc b/lib/kernels/src/kernels/accessor.cc index 868940bf6c..bfa2169b0d 100644 --- a/lib/kernels/src/kernels/accessor.cc +++ b/lib/kernels/src/kernels/accessor.cc @@ -6,6 +6,7 @@ #include "op-attrs/tensor_shape.h" #include "utils/containers/reversed.h" #include "utils/containers/vector_of.h" +#include "utils/hash/tuple.h" #include "utils/nonnegative_int/nonnegative_range.h" #include @@ -19,7 +20,7 @@ nonnegative_int calculate_accessor_offset(TensorDimsCoord const &coord, nonnegative_int offset = 0_n; positive_int multiplier = 1_p; - for (ff_dim_t dim : reversed(get_idxs(tensor_dims.ff_ordered))) { + for (ff_dim_t dim : reversed(vector_of(get_idxs(tensor_dims.ff_ordered)))) { ASSERT(coord.ff_ordered.at(dim) < dim_at_idx(tensor_dims, dim), "Out of bounds access", dim); @@ -293,3 +294,19 @@ template int32_t accessor_get_only_value(GenericTensorAccessorR const &); } // namespace FlexFlow + +namespace std { + +using namespace ::FlexFlow; + +size_t hash::operator()( + GenericTensorAccessorR const &a) const { + return get_std_hash(a.tie()); +} + +size_t hash::operator()( + GenericTensorAccessorW const &a) const { + return get_std_hash(a.tie()); +} + +} // namespace std diff --git a/lib/kernels/src/kernels/batch_matmul_kernels.cc b/lib/kernels/src/kernels/batch_matmul_kernels.cc index 652d4fb137..a6ac364900 100644 --- a/lib/kernels/src/kernels/batch_matmul_kernels.cc +++ b/lib/kernels/src/kernels/batch_matmul_kernels.cc @@ -1,46 +1,74 @@ #include "kernels/batch_matmul_kernels.h" #include "kernels/batch_matmul_kernels_cpu.h" #include "kernels/batch_matmul_kernels_gpu.h" +#include "utils/containers/require_same.h" namespace FlexFlow::Kernels::BatchMatmul { +static std::tuple + get_params(TensorDims const &input_a_dims, + TensorDims const &input_b_dims, + TensorDims const &output_dims) { + positive_int m = require_same(dim_at_idx(input_b_dims, relative_ff_dim_t{-1}), + dim_at_idx(output_dims, relative_ff_dim_t{-1})); + + positive_int n = require_same(dim_at_idx(input_a_dims, relative_ff_dim_t{-2}), + dim_at_idx(output_dims, relative_ff_dim_t{-2})); + + positive_int k = + require_same(dim_at_idx(input_a_dims, relative_ff_dim_t{-1}), + dim_at_idx(input_b_dims, relative_ff_dim_t{-2})); + + TensorDims leading_dims = require_same( + slice_tensor_dims( + input_a_dims, relative_ff_dim_t{0}, relative_ff_dim_t{-2}), + slice_tensor_dims( + input_b_dims, relative_ff_dim_t{0}, relative_ff_dim_t{-2})); + + positive_int batch = get_num_elements(leading_dims); + + return {m, n, k, batch}; +} + void forward_kernel(device_stream_t const &stream, device_handle_t const &handle, - float *output_ptr, - float const *a_input_ptr, - float const *b_input_ptr, - int m, - int n, - int k, - int batch, - int seq_length, - int a_seq_length_dim, - int b_seq_length_dim) { + GenericTensorAccessorW const &output, + GenericTensorAccessorR const &input_a, + GenericTensorAccessorR const &input_b, + positive_int seq_length, + std::optional a_seq_length_dim, + std::optional b_seq_length_dim) { + + auto [m, n, k, batch] = + get_params(input_a.shape.dims, input_b.shape.dims, output.shape.dims); + + auto get_raw_seq_len = [](std::optional seq_len) -> int { + return transform(seq_len, + [](positive_int x) { return x.int_from_positive_int(); }) + .value_or(-1); + }; + if (stream.is_gpu()) { gpu_forward_kernel( /*stream=*/stream.require_gpu(), /*handle=*/handle.require_for_gpu(), - /*output_ptr=*/output_ptr, - /*a_input_ptr=*/a_input_ptr, - /*b_input_ptr=*/b_input_ptr, - /*m=*/m, - /*n=*/n, - /*k=*/k, - /*batch=*/batch, - /*seq_length=*/seq_length, - /*a_seq_length_dim=*/a_seq_length_dim, - /*b_seq_length_dim=*/b_seq_length_dim); + /*output_ptr=*/output.get_float_ptr(), + /*a_input_ptr=*/input_a.get_float_ptr(), + /*b_input_ptr=*/input_b.get_float_ptr(), + /*m=*/m.int_from_positive_int(), + /*n=*/n.int_from_positive_int(), + /*k=*/k.int_from_positive_int(), + /*batch=*/batch.int_from_positive_int(), + /*seq_length=*/seq_length.int_from_positive_int(), + /*a_seq_length_dim=*/get_raw_seq_len(a_seq_length_dim), + /*b_seq_length_dim=*/get_raw_seq_len(b_seq_length_dim)); } else { ASSERT(stream.is_cpu()); ASSERT(handle.is_for_cpu()); cpu_forward_kernel( - /*output_ptr=*/output_ptr, - /*a_input_ptr=*/a_input_ptr, - /*b_input_ptr=*/b_input_ptr, - /*m=*/m, - /*n=*/n, - /*k=*/k, - /*batch=*/batch, + /*output=*/output, + /*input_a=*/input_a, + /*input_b=*/input_b, /*seq_length=*/seq_length, /*a_seq_length_dim=*/a_seq_length_dim, /*b_seq_length_dim=*/b_seq_length_dim); @@ -49,44 +77,43 @@ void forward_kernel(device_stream_t const &stream, void backward_kernel(device_stream_t const &stream, device_handle_t const &handle, - float const *o_ptr, - float const *o_grad_ptr, - float const *a_ptr, - float *a_grad_ptr, - float const *b_ptr, - float *b_grad_ptr, - int m, - int n, - int k, - int batch) { + GenericTensorAccessorR const &output, + GenericTensorAccessorR const &output_grad, + GenericTensorAccessorR const &input_a, + GenericTensorAccessorW const &input_a_grad, + GenericTensorAccessorR const &input_b, + GenericTensorAccessorW const &input_b_grad) { + TensorShape input_a_shape = require_same(input_a.shape, input_a_grad.shape); + TensorShape input_b_shape = require_same(input_b.shape, input_b_grad.shape); + TensorShape output_shape = require_same(output.shape, output_grad.shape); + + auto [m, n, k, batch] = + get_params(input_a_shape.dims, input_b_shape.dims, output_shape.dims); + if (stream.is_gpu()) { gpu_backward_kernel( /*stream=*/stream.require_gpu(), /*handle=*/handle.require_for_gpu(), - /*o_ptr=*/o_ptr, - /*o_grad_ptr=*/o_grad_ptr, - /*a_ptr=*/a_ptr, - /*a_grad_ptr=*/a_grad_ptr, - /*b_ptr=*/b_ptr, - /*b_grad_ptr=*/b_grad_ptr, - /*m=*/m, - /*n=*/n, - /*k=*/k, - /*batch=*/batch); + /*output_ptr=*/output.get_float_ptr(), + /*output_grad_ptr=*/output_grad.get_float_ptr(), + /*input_a_ptr=*/input_a.get_float_ptr(), + /*input_a_grad_ptr=*/input_a_grad.get_float_ptr(), + /*input_b_ptr=*/input_b.get_float_ptr(), + /*input_b_grad_ptr=*/input_b_grad.get_float_ptr(), + /*m=*/m.int_from_positive_int(), + /*n=*/n.int_from_positive_int(), + /*k=*/k.int_from_positive_int(), + /*batch=*/batch.int_from_positive_int()); } else { ASSERT(stream.is_cpu()); ASSERT(handle.is_for_cpu()); cpu_backward_kernel( - /*o_ptr=*/o_ptr, - /*o_grad_ptr=*/o_grad_ptr, - /*a_ptr=*/a_ptr, - /*a_grad_ptr=*/a_grad_ptr, - /*b_ptr=*/b_ptr, - /*b_grad_ptr=*/b_grad_ptr, - /*m=*/m, - /*n=*/n, - /*k=*/k, - /*batch=*/batch); + /*output=*/output, + /*output_grad=*/output_grad, + /*input_a=*/input_a, + /*input_a_grad=*/input_a_grad, + /*input_b=*/input_b, + /*input_b_grad=*/input_b_grad); } } diff --git a/lib/kernels/src/kernels/batch_matmul_kernels_cpu.cc b/lib/kernels/src/kernels/batch_matmul_kernels_cpu.cc index f139d42992..292841d19f 100644 --- a/lib/kernels/src/kernels/batch_matmul_kernels_cpu.cc +++ b/lib/kernels/src/kernels/batch_matmul_kernels_cpu.cc @@ -2,29 +2,21 @@ namespace FlexFlow::Kernels::BatchMatmul { -void cpu_forward_kernel(float *output_ptr, - float const *a_input_ptr, - float const *b_input_ptr, - int m, - int n, - int k, - int batch, - int seq_length, - int a_seq_length_dim, - int b_seq_length_dim) { +void cpu_forward_kernel(GenericTensorAccessorW const &output, + GenericTensorAccessorR const &input_a, + GenericTensorAccessorR const &input_b, + positive_int seq_length, + std::optional a_seq_length_dim, + std::optional b_seq_length_dim) { NOT_IMPLEMENTED(); } -void cpu_backward_kernel(float const *o_ptr, - float const *o_grad_ptr, - float const *a_ptr, - float *a_grad_ptr, - float const *b_ptr, - float *b_grad_ptr, - int m, - int n, - int k, - int batch) { +void cpu_backward_kernel(GenericTensorAccessorR const &output, + GenericTensorAccessorR const &output_grad, + GenericTensorAccessorR const &input_a, + GenericTensorAccessorW const &input_a_grad, + GenericTensorAccessorR const &input_b, + GenericTensorAccessorW const &input_b_grad) { NOT_IMPLEMENTED(); } diff --git a/lib/kernels/src/kernels/element_binary_kernels.cc b/lib/kernels/src/kernels/element_binary_kernels.cc index bea317dfec..1d8fbaaf77 100644 --- a/lib/kernels/src/kernels/element_binary_kernels.cc +++ b/lib/kernels/src/kernels/element_binary_kernels.cc @@ -1,6 +1,7 @@ #include "kernels/element_binary_kernels.h" #include "kernels/element_binary_kernels_cpu.h" #include "kernels/element_binary_kernels_gpu.h" +#include namespace FlexFlow::Kernels::ElementBinary { diff --git a/lib/kernels/src/kernels/embedding_kernels.cc b/lib/kernels/src/kernels/embedding_kernels.cc index 957d297b9e..cd6ad051bc 100644 --- a/lib/kernels/src/kernels/embedding_kernels.cc +++ b/lib/kernels/src/kernels/embedding_kernels.cc @@ -11,8 +11,8 @@ void forward_kernel(device_stream_t const &stream, DataType input_data_type, DataType output_data_type, std::optional aggr, - int in_dim, - int out_dim, + num_tensor_dims_t in_dim, + num_tensor_dims_t out_dim, int batch_size) { if (stream.is_gpu()) { gpu_forward_kernel( @@ -23,8 +23,8 @@ void forward_kernel(device_stream_t const &stream, /*input_data_type=*/input_data_type, /*output_data_type=*/output_data_type, /*aggr=*/aggr, - /*in_dim=*/in_dim, - /*out_dim=*/out_dim, + /*in_dim=*/in_dim.int_from_num_tensor_dims(), + /*out_dim=*/out_dim.int_from_num_tensor_dims(), /*batch_size=*/batch_size); } else { ASSERT(stream.is_cpu()); @@ -48,8 +48,8 @@ void backward_kernel(device_stream_t const &stream, DataType output_data_type, DataType input_data_type, std::optional aggr, - int in_dim, - int out_dim, + num_tensor_dims_t in_dim, + num_tensor_dims_t out_dim, int batch_size) { if (stream.is_gpu()) { gpu_backward_kernel( @@ -60,8 +60,8 @@ void backward_kernel(device_stream_t const &stream, /*output_data_type=*/output_data_type, /*input_data_type=*/input_data_type, /*aggr=*/aggr, - /*in_dim=*/in_dim, - /*out_dim=*/out_dim, + /*in_dim=*/in_dim.int_from_num_tensor_dims(), + /*out_dim=*/out_dim.int_from_num_tensor_dims(), /*batch_size=*/batch_size); } else { ASSERT(stream.is_cpu()); diff --git a/lib/kernels/src/kernels/embedding_kernels_cpu.cc b/lib/kernels/src/kernels/embedding_kernels_cpu.cc index f5df53e322..db1a696ebb 100644 --- a/lib/kernels/src/kernels/embedding_kernels_cpu.cc +++ b/lib/kernels/src/kernels/embedding_kernels_cpu.cc @@ -8,8 +8,8 @@ void cpu_forward_kernel(GenericTensorAccessorR const &input, DataType input_data_type, DataType output_data_type, std::optional aggr, - int in_dim, - int out_dim, + num_tensor_dims_t in_dim, + num_tensor_dims_t out_dim, int batch_size) { NOT_IMPLEMENTED(); } @@ -20,8 +20,8 @@ void cpu_backward_kernel(GenericTensorAccessorR const &output, DataType output_data_type, DataType input_data_type, std::optional aggr, - int in_dim, - int out_dim, + num_tensor_dims_t in_dim, + num_tensor_dims_t out_dim, int batch_size) { NOT_IMPLEMENTED(); } diff --git a/lib/kernels/src/kernels/format_accessor_contents.cc b/lib/kernels/src/kernels/format_accessor_contents.cc index cbdf2870dd..40d56254b8 100644 --- a/lib/kernels/src/kernels/format_accessor_contents.cc +++ b/lib/kernels/src/kernels/format_accessor_contents.cc @@ -14,7 +14,8 @@ struct Print1DCPUAccessorR { void operator()(GenericTensorAccessorR const &accessor, std::ostream &stream) { ASSERT(accessor.device_type == DeviceType::CPU); - nonnegative_int dims = get_num_dims(accessor.shape.dims); + nonnegative_int dims = get_num_dims(accessor.shape.dims) + .nonnegative_int_from_num_tensor_dims(); ASSERT(dims == 1_n); positive_int ncols = dim_at_idx(accessor.shape.dims, ff_dim_t{0_n}); @@ -47,7 +48,8 @@ struct Print2DCPUAccessorR { void operator()(GenericTensorAccessorR const &accessor, std::ostream &stream) { ASSERT(accessor.device_type == DeviceType::CPU); - nonnegative_int dims = get_num_dims(accessor.shape.dims); + nonnegative_int dims = get_num_dims(accessor.shape.dims) + .nonnegative_int_from_num_tensor_dims(); ASSERT(dims == 2_n); positive_int dim0_size = dim_at_idx(accessor.shape.dims, ff_dim_t{0_n}); positive_int dim1_size = dim_at_idx(accessor.shape.dims, ff_dim_t{1_n}); @@ -91,7 +93,8 @@ struct Print3DCPUAccessorR { void operator()(GenericTensorAccessorR const &accessor, std::ostream &stream) { ASSERT(accessor.device_type == DeviceType::CPU); - nonnegative_int dims = get_num_dims(accessor.shape.dims); + nonnegative_int dims = get_num_dims(accessor.shape.dims) + .nonnegative_int_from_num_tensor_dims(); ASSERT(dims == 3_n); positive_int dim0_size = dim_at_idx(accessor.shape.dims, ff_dim_t{0_n}); @@ -150,7 +153,8 @@ struct Print4DCPUAccessorR { void operator()(GenericTensorAccessorR const &accessor, std::ostream &stream) { ASSERT(accessor.device_type == DeviceType::CPU); - nonnegative_int dims = get_num_dims(accessor.shape.dims); + nonnegative_int dims = get_num_dims(accessor.shape.dims) + .nonnegative_int_from_num_tensor_dims(); ASSERT(dims == 4_n); positive_int dim0_size = dim_at_idx(accessor.shape.dims, ff_dim_t{0_n}); @@ -248,7 +252,8 @@ std::string format_accessor_r_contents(GenericTensorAccessorR const &accessor) { GenericTensorAccessorR cpu_accessor = copy_tensor_accessor_r_to_cpu_if_necessary(accessor, cpu_allocator); - int num_dims = get_num_dims(cpu_accessor.shape.dims).unwrap_nonnegative(); + int num_dims = + get_num_dims(cpu_accessor.shape.dims).int_from_num_tensor_dims(); switch (num_dims) { case 1: return format_1d_accessor_r_contents(cpu_accessor); @@ -268,7 +273,8 @@ std::string format_accessor_w_contents(GenericTensorAccessorW const &accessor) { GenericTensorAccessorW cpu_accessor = copy_tensor_accessor_w_to_cpu_if_necessary(accessor, cpu_allocator); - int num_dims = get_num_dims(cpu_accessor.shape.dims).unwrap_nonnegative(); + int num_dims = + get_num_dims(cpu_accessor.shape.dims).int_from_num_tensor_dims(); switch (num_dims) { case 1: return format_1d_accessor_w_contents(cpu_accessor); diff --git a/lib/kernels/src/kernels/legion_dim.cc b/lib/kernels/src/kernels/legion_dim.cc index f3fa67387a..c1dffbbdf5 100644 --- a/lib/kernels/src/kernels/legion_dim.cc +++ b/lib/kernels/src/kernels/legion_dim.cc @@ -26,15 +26,16 @@ legion_dim_t add_to_legion_dim(legion_dim_t legion_dim, int value) { } legion_dim_t legion_dim_from_ff_dim(ff_dim_t ff_dim, - nonnegative_int num_dimensions) { - return legion_dim_t{nonnegative_int{num_dimensions.unwrap_nonnegative() - - ff_dim.value.unwrap_nonnegative() - 1}}; + num_tensor_dims_t num_dimensions) { + return legion_dim_t{ + nonnegative_int{num_dimensions.int_from_num_tensor_dims() - + ff_dim.value.unwrap_nonnegative() - 1}}; ; } ff_dim_t ff_dim_from_legion_dim(legion_dim_t legion_dim, - nonnegative_int num_dimensions) { - return ff_dim_t{nonnegative_int{num_dimensions.unwrap_nonnegative() - + num_tensor_dims_t num_dimensions) { + return ff_dim_t{nonnegative_int{num_dimensions.int_from_num_tensor_dims() - legion_dim.value.unwrap_nonnegative() - 1}}; } diff --git a/lib/kernels/src/kernels/pool_2d_kernels.cc b/lib/kernels/src/kernels/pool_2d_kernels.cc index 6ebfc68c86..f8f5571716 100644 --- a/lib/kernels/src/kernels/pool_2d_kernels.cc +++ b/lib/kernels/src/kernels/pool_2d_kernels.cc @@ -1,6 +1,7 @@ #include "kernels/pool_2d_kernels.h" #include "kernels/pool_2d_kernels_cpu.h" #include "kernels/pool_2d_kernels_gpu.h" +#include namespace FlexFlow::Kernels::Pool2D { diff --git a/lib/kernels/src/kernels/reduce_kernels.cc b/lib/kernels/src/kernels/reduce_kernels.cc index bd3d6a8cd1..284d07dd96 100644 --- a/lib/kernels/src/kernels/reduce_kernels.cc +++ b/lib/kernels/src/kernels/reduce_kernels.cc @@ -1,6 +1,7 @@ #include "kernels/reduce_kernels.h" #include "kernels/reduce_kernels_cpu.h" #include "kernels/reduce_kernels_gpu.h" +#include namespace FlexFlow::Kernels::Reduce { diff --git a/lib/kernels/src/kernels/reverse_kernels_params.cc b/lib/kernels/src/kernels/reverse_kernels_params.cc index cf72fb3eef..162f67d782 100644 --- a/lib/kernels/src/kernels/reverse_kernels_params.cc +++ b/lib/kernels/src/kernels/reverse_kernels_params.cc @@ -7,17 +7,17 @@ namespace FlexFlow { ReverseKernelsParams compute_reverse_kernels_params(TensorDims const &output_dims, ReverseAttrs const &attrs) { - auto axis = attrs.axis; + ff_dim_t axis = attrs.axis; positive_int in_blk_size = 1_p; positive_int reverse_dim_size = 1_p; positive_int num_out_blks = 1_p; - for (nonnegative_int i : nonnegative_range(get_num_dims(output_dims))) { - if (i < axis.value) { - in_blk_size *= dim_at_idx(output_dims, ff_dim_t{i}); - } else if (i == axis.value) { - reverse_dim_size = dim_at_idx(output_dims, ff_dim_t{i}); + for (ff_dim_t i : tensor_dims_range(get_num_dims(output_dims))) { + if (i < axis) { + in_blk_size *= dim_at_idx(output_dims, i); + } else if (i == axis) { + reverse_dim_size = dim_at_idx(output_dims, i); } else { - num_out_blks *= dim_at_idx(output_dims, ff_dim_t{i}); + num_out_blks *= dim_at_idx(output_dims, i); } } diff --git a/lib/kernels/src/perf_metrics.cc b/lib/kernels/src/perf_metrics.cc index 2036ddd35a..c9161e55b1 100644 --- a/lib/kernels/src/perf_metrics.cc +++ b/lib/kernels/src/perf_metrics.cc @@ -2,22 +2,6 @@ namespace FlexFlow { -PerfMetrics::PerfMetrics(double _start_time) - : start_time(_start_time), current_time(_start_time) {} - -PerfMetrics::PerfMetrics(int _train_all, - std::optional _train_correct, - std::optional _cce_loss, - std::optional _sparse_cce_loss, - std::optional _mse_loss, - std::optional _rmse_loss, - std::optional _mae_loss, - double _start_time_micro, - double _current_time_micro) - : train_all(_train_all), train_correct(_train_correct), cce_loss(_cce_loss), - mse_loss(_mse_loss), rmse_loss(_rmse_loss), mae_loss(_mae_loss), - start_time(_start_time_micro), current_time(_current_time_micro) {} - float get_throughput(PerfMetrics const &m) { return m.train_all / (m.current_time - m.start_time); } diff --git a/lib/kernels/test/src/test_gather_kernels.cc b/lib/kernels/test/src/test_gather_kernels.cc index d08058b063..05ae61b889 100644 --- a/lib/kernels/test/src/test_gather_kernels.cc +++ b/lib/kernels/test/src/test_gather_kernels.cc @@ -19,7 +19,7 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { TensorShape output_shape) { ff_dim_t dim = ff_dim_t{ nonnegative_int{ - get_num_dims(input_shape.dims).unwrap_nonnegative() - 1}, + get_num_dims(input_shape.dims).int_from_num_tensor_dims() - 1}, }; GatherPerDeviceState state = Kernels::Gather::gpu_init_kernel(managed_handle.raw_handle(), dim); @@ -79,7 +79,7 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { TensorShape output_shape) { ff_dim_t dim = ff_dim_t{ nonnegative_int{ - get_num_dims(input_shape.dims).unwrap_nonnegative() - 1}, + get_num_dims(input_shape.dims).int_from_num_tensor_dims() - 1}, }; GatherPerDeviceState state = Kernels::Gather::gpu_init_kernel(managed_handle.raw_handle(), dim); diff --git a/lib/kernels/test/src/test_transpose_kernel.cc b/lib/kernels/test/src/test_transpose_kernel.cc index 9d4809b2cf..d3aa9262f3 100644 --- a/lib/kernels/test/src/test_transpose_kernel.cc +++ b/lib/kernels/test/src/test_transpose_kernel.cc @@ -6,9 +6,11 @@ using namespace ::FlexFlow; TEST_SUITE(FF_CUDA_TEST_SUITE) { TEST_CASE("Test Transpose Kernel Operations") { TransposeAttrs attrs = TransposeAttrs{ - FFOrdered{ - ff_dim_t{1_n}, - ff_dim_t{0_n}, + TensorDimPermutation{ + bidict{ + {ff_dim_t{1_n}, ff_dim_t{0_n}}, + {ff_dim_t{0_n}, ff_dim_t{1_n}}, + }, }, }; diff --git a/lib/local-execution/include/local-execution/atomic_training_tensor_guid_t.dtg.toml b/lib/local-execution/include/local-execution/atomic_training_tensor_guid_t.dtg.toml new file mode 100644 index 0000000000..12380d80ba --- /dev/null +++ b/lib/local-execution/include/local-execution/atomic_training_tensor_guid_t.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "atomic_training_tensor_guid_t" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/nonnegative_int/nonnegative_int.h" +] + +[[fields]] +name = "raw_index" +type = "::FlexFlow::nonnegative_int" diff --git a/lib/local-execution/include/local-execution/computation_graph_instance/README.md b/lib/local-execution/include/local-execution/computation_graph_instance/README.md new file mode 100644 index 0000000000..6b7f4b43db --- /dev/null +++ b/lib/local-execution/include/local-execution/computation_graph_instance/README.md @@ -0,0 +1 @@ +The primary external-facing interface of local-execution diff --git a/lib/local-execution/include/local-execution/computation_graph_instance/computation_graph_instance.h b/lib/local-execution/include/local-execution/computation_graph_instance/computation_graph_instance.h new file mode 100644 index 0000000000..f28552603f --- /dev/null +++ b/lib/local-execution/include/local-execution/computation_graph_instance/computation_graph_instance.h @@ -0,0 +1,46 @@ +#ifndef _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_COMPUTATION_GRAPH_INSTANCE_H +#define _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_COMPUTATION_GRAPH_INSTANCE_H + +#include "kernels/accessor.h" +#include "local-execution/computation_graph_training_tensor_ref_t.dtg.h" +#include "local-execution/local_task_registry.dtg.h" +#include "local-execution/local_tensor_backing.dtg.h" +#include "pcg/computation_graph.dtg.h" +#include "pcg/layer_guid_t.dtg.h" +#include "task-spec/symbolic/training_symbolic_computation_graph_from_cg_conversion.dtg.h" +#include "utils/units/milliseconds_t.h" +#include + +namespace FlexFlow { + +struct ComputationGraphInstance { +public: + ComputationGraphInstance() = delete; + + explicit ComputationGraphInstance( + TrainingSymbolicComputationGraphFromCgConversion const &, + LocalTensorBacking const &, + LocalTaskRegistry const &); + +public: + TrainingSymbolicComputationGraphFromCgConversion const & + get_symbolic_training_graph_for_cg() const; + LocalTensorBacking const &get_tensor_backing() const; + LocalTaskRegistry const &get_task_registry() const; + +private: + TrainingSymbolicComputationGraphFromCgConversion + symbolic_training_graph_for_cg; + LocalTensorBacking tensor_backing; + LocalTaskRegistry task_registry; +}; + +ComputationGraphInstance create_computation_graph_instance( + ComputationGraph const &, + bidict> const + &); + +} // namespace FlexFlow + +#endif diff --git a/lib/local-execution/include/local-execution/computation_graph_instance/initialized_computation_graph_instance.dtg.toml b/lib/local-execution/include/local-execution/computation_graph_instance/initialized_computation_graph_instance.dtg.toml new file mode 100644 index 0000000000..8589d5edec --- /dev/null +++ b/lib/local-execution/include/local-execution/computation_graph_instance/initialized_computation_graph_instance.dtg.toml @@ -0,0 +1,35 @@ +namespace = "FlexFlow" +name = "InitializedComputationGraphInstance" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", + "rapidcheck", +] + +includes = [ + # "local-execution/computation_graph_instance.dtg.h", + # "local-execution/local_device_states_backing.dtg.h", +] + +src_includes = [] + +fields = [] +# [[fields]] +# name = "per_device_op_states" +# type = "::FlexFlow::LocalDeviceStatesBacking" +# +# [[fields]] +# name = "allocator" +# type = "::FlexFlow::Allocator" +# +# [[fields]] +# name = "atomic_tensor_backing" +# type = "::FlexFlow::LocalAtomicTensorBacking" +# +# [[fields]] +# name = "computation_graph_instance" +# type = "::FlexFlow::ComputationGraphInstance" diff --git a/lib/local-execution/include/local-execution/computation_graph_instance/initialized_computation_graph_instance.h b/lib/local-execution/include/local-execution/computation_graph_instance/initialized_computation_graph_instance.h new file mode 100644 index 0000000000..a014ff596d --- /dev/null +++ b/lib/local-execution/include/local-execution/computation_graph_instance/initialized_computation_graph_instance.h @@ -0,0 +1,49 @@ +#ifndef _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_INITIALIZED_COMPUTATION_GRAPH_INSTANCE_H +#define _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_INITIALIZED_COMPUTATION_GRAPH_INSTANCE_H + +#include "local-execution/computation_graph_instance/computation_graph_instance.h" +#include "local-execution/local_atomic_tensor_backing.dtg.h" +#include "local-execution/local_device_states_backing.dtg.h" +#include "local-execution/local_task_registry.dtg.h" +#include "local-execution/local_tensor_backing.dtg.h" +#include "task-spec/runtime_task_invocation/runtime_arg_config.dtg.h" +#include "task-spec/symbolic/training_symbolic_computation_graph_from_cg_conversion.dtg.h" +#include "utils/units/milliseconds_t.h" + +namespace FlexFlow { + +struct InitializedComputationGraphInstance { +public: + LocalTensorBacking const &get_tensor_backing() const; + LocalTaskRegistry const &get_task_registry() const; + TrainingSymbolicComputationGraphFromCgConversion const & + get_symbolic_training_graph_for_cg() const; + LocalAtomicTensorBacking const &get_atomic_tensor_backing() const; + Allocator &get_allocator() const; + RuntimeArgConfig const &get_runtime_arg_config() const; + +private: + LocalDeviceStatesBacking per_device_op_states; + Allocator &allocator; + LocalAtomicTensorBacking atomic_tensor_backing; + ComputationGraphInstance computation_graph_instance; +}; + +InitializedComputationGraphInstance + initialize_computation_graph_instance(ComputationGraphInstance const &, + Allocator &); + +std::unordered_map> + perform_forward_pass_for_computation_graph_instance( + InitializedComputationGraphInstance const &); + +std::unordered_map> + perform_backward_pass_for_computation_graph_instance( + InitializedComputationGraphInstance const &); + +void perform_update_pass_for_computation_graph_instance( + InitializedComputationGraphInstance const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/local-execution/include/local-execution/computation_graph_training_tensor_ref_t.dtg.toml b/lib/local-execution/include/local-execution/computation_graph_training_tensor_ref_t.dtg.toml new file mode 100644 index 0000000000..d25dc407e2 --- /dev/null +++ b/lib/local-execution/include/local-execution/computation_graph_training_tensor_ref_t.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "computation_graph_training_tensor_ref_t" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", + "rapidcheck", +] + +includes = [ + "pcg/tensor_guid_t.dtg.h" , + "task-spec/op_training_tensor_type.dtg.h", +] + +[[fields]] +name = "tensor_guid" +type = "::FlexFlow::tensor_guid_t" + +[[fields]] +name = "tensor_type" +type = "::FlexFlow::OpTrainingTensorType" diff --git a/lib/local-execution/include/local-execution/cost_estimator/local_cost_estimator.h b/lib/local-execution/include/local-execution/cost_estimator/local_cost_estimator.h new file mode 100644 index 0000000000..ba5b511227 --- /dev/null +++ b/lib/local-execution/include/local-execution/cost_estimator/local_cost_estimator.h @@ -0,0 +1,35 @@ +#ifndef _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_COST_ESTIMATOR_LOCAL_COST_ESTIMATOR_H +#define _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_COST_ESTIMATOR_LOCAL_COST_ESTIMATOR_H + +#include "compiler/cost_estimator/cost_estimator.h" +#include "pcg/machine_interconnect_specification.dtg.h" +#include "pcg/optimizer_attrs.dtg.h" +#include "task-spec/runtime_task_invocation/runtime_arg_config.dtg.h" + +namespace FlexFlow { + +struct LocalCostEstimator : public ICostEstimator { + explicit LocalCostEstimator(RuntimeArgConfig const &, + MachineInterconnectSpecification const &, + DeviceType); + + LocalCostEstimator(LocalCostEstimator const &) = delete; + LocalCostEstimator(LocalCostEstimator &&) = delete; + ~LocalCostEstimator() = default; + + OpCostMetrics estimate_cost(OpCostEstimateKey const &) const override; + + milliseconds_t estimate_cost(TensorSetMovement const &) const override; + +private: + RuntimeArgConfig runtime_arg_config; + MachineInterconnectSpecification interconnect_specification; + DeviceType device_type; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(LocalCostEstimator); + +CostEstimator get_local_cost_estimator(RuntimeArgConfig const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/local-execution/include/local-execution/tracked_allocator.h b/lib/local-execution/include/local-execution/cost_estimator/tracked_allocator.h similarity index 100% rename from lib/local-execution/include/local-execution/tracked_allocator.h rename to lib/local-execution/include/local-execution/cost_estimator/tracked_allocator.h diff --git a/lib/local-execution/include/local-execution/execute_task_for_layer.h b/lib/local-execution/include/local-execution/execute_task_for_layer.h new file mode 100644 index 0000000000..587ff96687 --- /dev/null +++ b/lib/local-execution/include/local-execution/execute_task_for_layer.h @@ -0,0 +1,87 @@ +#ifndef _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_EXECUTE_TASK_FOR_LAYER_H +#define _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_EXECUTE_TASK_FOR_LAYER_H + +#include "local-execution/local_atomic_tensor_backing.dtg.h" +#include "local-execution/local_ready_to_launch_task.dtg.h" +#include "local-execution/local_task_registry.dtg.h" +#include "local-execution/local_tensor_backing.dtg.h" +#include "pcg/layer_guid_t.dtg.h" +#include "task-spec/runtime_task_invocation/runtime_arg_config.dtg.h" +#include "task-spec/runtime_task_invocation/runtime_task_invocation.dtg.h" +#include "task-spec/symbolic/symbolic_cg_op_attrs_and_training_signature_with_shapes.dtg.h" +#include "task-spec/symbolic/training_symbolic_computation_graph.dtg.h" +#include "task-spec/symbolic/training_symbolic_computation_graph_from_cg_conversion.dtg.h" +#include "utils/units/milliseconds_t.h" + +namespace FlexFlow { + +LocalReadyToLaunchTask + prepare_runtime_task_invocation(RuntimeTaskInvocation const &, + LocalTensorBacking const &, + LocalAtomicTensorBacking const &, + Allocator &, + RuntimeArgConfig const &); + +std::optional execute_init_for_layer( + symbolic_layer_guid_t, + SymbolicCgOpAttrsAndTrainingSignatureWithShapes const &, + LocalTensorBacking const &, + LocalAtomicTensorBacking const &, + Allocator &, + LocalTaskRegistry const &, + RuntimeArgConfig const &); + +std::optional execute_forward_for_layer( + symbolic_layer_guid_t, + SymbolicCgOpAttrsAndTrainingSignatureWithShapes const &, + LocalTensorBacking const &, + LocalAtomicTensorBacking const &, + Allocator &, + LocalTaskRegistry const &, + RuntimeArgConfig const &); + +std::optional execute_backward_for_layer( + symbolic_layer_guid_t, + SymbolicCgOpAttrsAndTrainingSignatureWithShapes const &, + LocalTensorBacking const &, + LocalAtomicTensorBacking const &, + Allocator &, + LocalTaskRegistry const &, + RuntimeArgConfig const &); + +void execute_compute_loss(TrainingSymbolicComputationGraph const &, + LocalTensorBacking const &, + LocalAtomicTensorBacking const &, + Allocator &, + LocalTaskRegistry const &, + RuntimeArgConfig const &); + +void execute_update_for_layer(symbolic_layer_guid_t, + TrainingSymbolicComputationGraph const &, + LocalTensorBacking const &, + LocalAtomicTensorBacking const &, + OptimizerAttrs const &, + Allocator &, + RuntimeArgConfig const &); + +std::unordered_map> + execute_forward_pass( + TrainingSymbolicComputationGraphFromCgConversion const &training_cg, + LocalTensorBacking const &local_tensor_backing, + LocalAtomicTensorBacking const &local_atomic_tensor_backing, + Allocator &, + LocalTaskRegistry const &, + RuntimeArgConfig const &); + +std::unordered_map> + execute_backward_pass( + TrainingSymbolicComputationGraphFromCgConversion const &training_cg, + LocalTensorBacking const &local_tensor_backing, + LocalAtomicTensorBacking const &local_atomic_tensor_backing, + Allocator &, + LocalTaskRegistry const &, + RuntimeArgConfig const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/local-execution/include/local-execution/local_args_backing.h b/lib/local-execution/include/local-execution/local_args_backing.h deleted file mode 100644 index 94748cf7ed..0000000000 --- a/lib/local-execution/include/local-execution/local_args_backing.h +++ /dev/null @@ -1,48 +0,0 @@ -#ifndef _FLEXFLOW_LOCAL_EXECUTION_LOCAL_ARGS_BACKING_H -#define _FLEXFLOW_LOCAL_EXECUTION_LOCAL_ARGS_BACKING_H - -#include "local-execution/local_args_backing.dtg.h" -#include "local-execution/local_task_argument_accessor.h" -#include "local-execution/local_task_registry.dtg.h" -#include "local-execution/local_tensor_backing.dtg.h" -#include "pcg/computation_graph.h" -#include "task-spec/per_device_op_state.h" -#include "task-spec/task_binding.h" -#include "task-spec/task_invocation.dtg.h" -#include "task-spec/training_computation_graph.dtg.h" -#include "task-spec/training_layer_plus_context.dtg.h" - -namespace FlexFlow { - -LocalArgsBacking make_local_computation_args_backing_with_empty_device_states( - RuntimeArgConfig const &); - -std::optional - get_per_device_op_state_if_exists(LocalArgsBacking const &, - layer_guid_t const &); - -std::unordered_map - construct_arg_slots_backing(TaskBinding const &, RuntimeArgConfig const &); - -std::optional - create_per_device_op_state(LocalTaskRegistry const &, - LocalTensorBacking const &, - RuntimeArgConfig const &, - Allocator &, - TrainingLayerPlusContext const &); - -TaskArgumentAccessor get_task_arg_accessor(LocalTensorBacking const &, - RuntimeArgConfig const &, - TaskInvocation const &, - Allocator &); - -LocalArgsBacking make_local_args_backing_for_computation_graph( - LocalTaskRegistry const &, - TrainingComputationGraph const &, - RuntimeArgConfig const &, - LocalTensorBacking const &, - Allocator &); - -} // namespace FlexFlow - -#endif diff --git a/lib/local-execution/include/local-execution/local_args_backing.struct.toml b/lib/local-execution/include/local-execution/local_args_backing.struct.toml deleted file mode 100644 index 449f883194..0000000000 --- a/lib/local-execution/include/local-execution/local_args_backing.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "LocalArgsBacking" -features = [] - -includes = [ - "task-spec/runtime_arg_config.dtg.h", - "task-spec/device_specific_device_states.dtg.h", - "pcg/layer_guid_t.dtg.h", - "", -] - -[[fields]] -name = "runtime_arg_config" -type = "::FlexFlow::RuntimeArgConfig" - -[[fields]] -name = "per_device_op_states" -type = "std::unordered_map<::FlexFlow::layer_guid_t, std::optional<::FlexFlow::DeviceSpecificDeviceStates>>" diff --git a/lib/local-execution/include/local-execution/local_atomic_tensor_backing.dtg.toml b/lib/local-execution/include/local-execution/local_atomic_tensor_backing.dtg.toml new file mode 100644 index 0000000000..5fe6b05b52 --- /dev/null +++ b/lib/local-execution/include/local-execution/local_atomic_tensor_backing.dtg.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "LocalAtomicTensorBacking" +type = "struct" +features = [ + "eq", + "fmt", +] + +includes = [ + "kernels/accessor.h", + "local-execution/atomic_training_tensor_guid_t.dtg.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", +] + + +[[fields]] +name = "accessor_from_atomic_tensor_map" +type = "std::unordered_map<::FlexFlow::atomic_training_tensor_guid_t, ::FlexFlow::GenericTensorAccessorW>" diff --git a/lib/local-execution/include/local-execution/local_atomic_tensor_backing.h b/lib/local-execution/include/local-execution/local_atomic_tensor_backing.h new file mode 100644 index 0000000000..11f9f3e56a --- /dev/null +++ b/lib/local-execution/include/local-execution/local_atomic_tensor_backing.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_LOCAL_ATOMIC_TENSOR_BACKING_H +#define _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_LOCAL_ATOMIC_TENSOR_BACKING_H + +#include "kernels/allocation.h" +#include "local-execution/atomic_task_invocation.dtg.h" +#include "local-execution/local_atomic_tensor_backing.dtg.h" +#include "local-execution/tensor_slot_backing.dtg.h" +#include "task-spec/runtime_task_invocation/runtime_arg_config.dtg.h" +#include "task-spec/task_argument_accessor/task_argument_accessor.h" + +namespace FlexFlow { + +std::unordered_map + construct_tensor_slots_backing_for_binding(LocalAtomicTensorBacking const &, + AtomicTaskBinding const &); + +TaskArgumentAccessor get_task_arg_accessor_for_atomic_task_binding( + LocalAtomicTensorBacking const &, AtomicTaskBinding const &, Allocator &); + +} // namespace FlexFlow + +#endif diff --git a/lib/local-execution/include/local-execution/local_concrete_task_graph.dtg.toml b/lib/local-execution/include/local-execution/local_concrete_task_graph.dtg.toml new file mode 100644 index 0000000000..8dde33a49a --- /dev/null +++ b/lib/local-execution/include/local-execution/local_concrete_task_graph.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "LocalConcreteTaskGraph" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", + "rapidcheck", +] + +includes = [ + "local-execution/local_concrete_task_invocation.dtg.h", +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "task_invocations" +type = "std::unordered_set<::FlexFlow::LocalConcreteTaskInvocation>" diff --git a/lib/local-execution/include/local-execution/local_concrete_task_graph.h b/lib/local-execution/include/local-execution/local_concrete_task_graph.h new file mode 100644 index 0000000000..c2f8c405b0 --- /dev/null +++ b/lib/local-execution/include/local-execution/local_concrete_task_graph.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_LOCAL_CONCRETE_TASK_GRAPH_H +#define _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_LOCAL_CONCRETE_TASK_GRAPH_H + +#include "local-execution/local_concrete_task_graph.dtg.h" + +namespace FlexFlow { + +std::vector + local_concrete_task_graph_topological_ordering( + LocalConcreteTaskGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/local-execution/include/local-execution/local_concrete_task_invocation.dtg.toml b/lib/local-execution/include/local-execution/local_concrete_task_invocation.dtg.toml new file mode 100644 index 0000000000..ce0b64dd6b --- /dev/null +++ b/lib/local-execution/include/local-execution/local_concrete_task_invocation.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "LocalConcreteTaskInvocation" +type = "struct" +features = [] + +includes = [ + "task-spec/task_id_t.dtg.h", + "task-spec/task_argument_accessor/task_argument_accessor.h", +] + +[[fields]] +name = "task_id" +type = "::FlexFlow::task_id_t" + +[[fields]] +name = "task_arg_accessor" +type = "::FlexFlow::TaskArgumentAccessor" diff --git a/lib/local-execution/include/local-execution/local_cost_estimator.h b/lib/local-execution/include/local-execution/local_cost_estimator.h deleted file mode 100644 index c42876bbd6..0000000000 --- a/lib/local-execution/include/local-execution/local_cost_estimator.h +++ /dev/null @@ -1,30 +0,0 @@ -#ifndef _FLEXFLOW_LOCAL_EXECUTION_LOCAL_COST_ESTIMATOR_H -#define _FLEXFLOW_LOCAL_EXECUTION_LOCAL_COST_ESTIMATOR_H - -#include "compiler/cost_estimator/cost_estimator.h" -#include "pcg/optimizer_attrs.dtg.h" -#include "task-spec/runtime_arg_config.dtg.h" - -namespace FlexFlow { - -struct LocalCostEstimator : public ICostEstimator { - LocalCostEstimator(RuntimeArgConfig const &); - - LocalCostEstimator(LocalCostEstimator const &) = delete; - LocalCostEstimator(LocalCostEstimator &&) = delete; - ~LocalCostEstimator() = default; - - OpCostMetrics estimate_cost(OpCostEstimateKey const &) const override; - - milliseconds_t estimate_cost(TensorSetMovement const &) const override; - -private: - RuntimeArgConfig runtime_arg_config; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(LocalCostEstimator); - -CostEstimator get_local_cost_estimator(RuntimeArgConfig const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/local-execution/include/local-execution/local_device_states_backing.dtg.toml b/lib/local-execution/include/local-execution/local_device_states_backing.dtg.toml new file mode 100644 index 0000000000..350bf7756f --- /dev/null +++ b/lib/local-execution/include/local-execution/local_device_states_backing.dtg.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "LocalDeviceStatesBacking" +type = "struct" +features = [] + +includes = [ + "task-spec/device_specific_per_device_op_state.dtg.h", + "pcg/layer_guid_t.dtg.h", + "", +] + +[[fields]] +name = "per_device_op_states" +type = "std::unordered_map<::FlexFlow::layer_guid_t, std::optional<::FlexFlow::DeviceSpecificPerDeviceOpState>>" diff --git a/lib/local-execution/include/local-execution/local_device_states_backing.h b/lib/local-execution/include/local-execution/local_device_states_backing.h new file mode 100644 index 0000000000..5650197e44 --- /dev/null +++ b/lib/local-execution/include/local-execution/local_device_states_backing.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_LOCAL_DEVICE_STATES_BACKING_H +#define _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_LOCAL_DEVICE_STATES_BACKING_H + +#include "local-execution/local_device_states_backing.dtg.h" +#include "local-execution/local_task_argument_accessor.h" +#include "local-execution/local_task_registry.dtg.h" +#include "local-execution/local_tensor_backing.dtg.h" +#include "pcg/computation_graph.h" +#include "task-spec/per_device_op_state.h" +#include "task-spec/symbolic/symbolic_layer_training_tensor_group_signature_with_shapes.dtg.h" + +namespace FlexFlow { + +LocalDeviceStatesBacking make_local_device_states_backing_for_computation_graph( + LocalTaskRegistry const &, + std::unordered_map< + layer_guid_t, + SymbolicLayerTrainingTensorGroupSignatureWithShapes> const &, + RuntimeArgConfig const &runtime_arg_config, + LocalTensorBacking const &, + Allocator &); + +std::optional + get_per_device_op_state_if_exists(LocalDeviceStatesBacking const &, + layer_guid_t const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/local-execution/include/local-execution/local_task_argument_accessor.h b/lib/local-execution/include/local-execution/local_task_argument_accessor.h index 0ab66234eb..53026f81fd 100644 --- a/lib/local-execution/include/local-execution/local_task_argument_accessor.h +++ b/lib/local-execution/include/local-execution/local_task_argument_accessor.h @@ -2,9 +2,10 @@ #define _FLEXFLOW_LOCAL_EXECUTION_LOCAL_TASK_ARGUMENT_ACCESSOR_H #include "local-execution/tensor_slot_backing.dtg.h" -#include "task-spec/runtime_arg_config.dtg.h" -#include "task-spec/task_argument_accessor.h" -#include "task-spec/tensor_sub_slot_id_t.dtg.h" +#include "pcg/device_id_t.dtg.h" +#include "task-spec/runtime_task_invocation/runtime_arg_config.dtg.h" +#include "task-spec/task_argument_accessor/task_argument_accessor.h" +#include "task-spec/task_argument_accessor/task_tensor_parameter.dtg.h" #include #include @@ -13,30 +14,54 @@ namespace FlexFlow { struct LocalTaskArgumentAccessor : public ITaskArgumentAccessor { explicit LocalTaskArgumentAccessor( Allocator const &allocator, - std::unordered_map const + std::unordered_map const &tensor_slots_backing, - std::unordered_map const &arg_slots_backing); + ProfilingSettings const &profiling_settings, + device_handle_t const &ff_handle, + DeviceType kernel_device_type, + PCGOperatorAttrs const &op_attrs, + std::optional const &loss_attrs, + std::optional const &per_device_op_state, + FFIterationConfig const &iteration_config, + std::optional const &optimizer_attrs, + size_t device_idx); LocalTaskArgumentAccessor(LocalTaskArgumentAccessor const &) = delete; LocalTaskArgumentAccessor(LocalTaskArgumentAccessor &&) = delete; - ConcreteArgSpec const &get_concrete_arg(slot_id_t) const override; + ConcreteArgSpec const &get_concrete_arg(arg_slot_id_t) const override; - GenericTensorAccessor get_tensor(slot_id_t slot, - Permissions priv, - TensorType tensor_type) const override; - VariadicGenericTensorAccessor get_variadic_tensor( - slot_id_t slot, Permissions priv, TensorType tensor_type) const override; + GenericTensorAccessor get_tensor(TaskTensorParameter slot, + Permissions priv) const override; + + ProfilingSettings get_profiling_settings() const override; + device_handle_t get_ff_handle() const override; + DeviceType get_kernel_device_type() const override; + PCGOperatorAttrs get_op_attrs() const override; + LossAttrs get_loss_attrs() const override; + PerDeviceOpState get_per_device_op_state() const override; + FFIterationConfig get_iteration_config() const override; + OptimizerAttrs get_optimizer_attrs() const override; Allocator get_allocator() const override; - size_t get_device_idx() const override; + device_id_t get_device_idx() const override; private: Allocator allocator; - std::unordered_map + std::unordered_map tensor_slots_backing; - std::unordered_map arg_slots_backing; + + ProfilingSettings profiling_settings; + device_handle_t ff_handle; + DeviceType kernel_device_type; + PCGOperatorAttrs op_attrs; + std::optional loss_attrs; + std::optional per_device_op_state; + FFIterationConfig iteration_config; + std::optional optimizer_attrs; + + device_id_t device_idx; }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(LocalTaskArgumentAccessor); diff --git a/lib/local-execution/include/local-execution/local_task_registry.dtg.toml b/lib/local-execution/include/local-execution/local_task_registry.dtg.toml new file mode 100644 index 0000000000..056fe39ca7 --- /dev/null +++ b/lib/local-execution/include/local-execution/local_task_registry.dtg.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "LocalTaskRegistry" +type = "struct" +features = [ + "eq", + "fmt", + "hash" +] + +includes = [ + "task-spec/task_impl_function.dtg.h", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/fmt/unordered_map.h", +] + +[[fields]] +name = "task_mapping" +type = "std::unordered_map<::FlexFlow::task_id_t, ::FlexFlow::TaskImplFunction>" diff --git a/lib/local-execution/include/local-execution/local_task_registry.h b/lib/local-execution/include/local-execution/local_task_registry.h index 142433ba53..6adacab0a9 100644 --- a/lib/local-execution/include/local-execution/local_task_registry.h +++ b/lib/local-execution/include/local-execution/local_task_registry.h @@ -2,22 +2,29 @@ #define _FLEXFLOW_LOCAL_EXECUTION_TASK_REGISTRY_H #include "local-execution/local_task_registry.dtg.h" -#include "local-execution/registered_task_t.dtg.h" #include "pcg/layer_attrs.dtg.h" -#include "task-spec/op_task_type.dtg.h" +#include "task-spec/device_specific_per_device_op_state.dtg.h" +#include "task-spec/ops/op_task_type.dtg.h" #include "utils/units/milliseconds_t.h" namespace FlexFlow { LocalTaskRegistry construct_local_task_registry_for_layers( - std::unordered_map const &); + std::unordered_set const &); -std::optional try_get_registered_task( - LocalTaskRegistry const &, layer_guid_t const &, OpTaskType const &); +std::optional + call_init_task_impl(LocalTaskRegistry const &local_task_registry, + task_id_t task_id, + TaskArgumentAccessor const &arg_accessor); -std::optional call_task_impl(LocalTaskRegistry const &, - task_id_t const &task_id, - TaskArgumentAccessor const &acc); +std::optional + call_fwb_task_impl(LocalTaskRegistry const &local_task_registry, + task_id_t task_id, + TaskArgumentAccessor const &arg_accessor); + +void call_generic_task_impl(LocalTaskRegistry const &local_task_registry, + task_id_t task_id, + TaskArgumentAccessor const &arg_accessor); } // namespace FlexFlow diff --git a/lib/local-execution/include/local-execution/local_task_registry.struct.toml b/lib/local-execution/include/local-execution/local_task_registry.struct.toml deleted file mode 100644 index 84abc7aa0c..0000000000 --- a/lib/local-execution/include/local-execution/local_task_registry.struct.toml +++ /dev/null @@ -1,26 +0,0 @@ -namespace = "FlexFlow" -name = "LocalTaskRegistry" -features = [ - "eq", - "fmt", - "hash" -] - -includes = [ - "task-spec/task_signature_impl.dtg.h", - "pcg/layer_guid_t.dtg.h", - "local-execution/operator_task_set.dtg.h" -] - -src_includes = [ - "utils/hash/unordered_map.h", - "utils/fmt/unordered_map.h", -] - -[[fields]] -name = "task_sets" -type = "std::unordered_map<::FlexFlow::layer_guid_t, ::FlexFlow::OperatorTaskSet>" - -[[fields]] -name = "task_mapping" -type = "std::unordered_map<::FlexFlow::task_id_t, ::FlexFlow::TaskSignatureAndImpl>" diff --git a/lib/local-execution/include/local-execution/local_tensor_backing.h b/lib/local-execution/include/local-execution/local_tensor_backing.h deleted file mode 100644 index 479ad4734a..0000000000 --- a/lib/local-execution/include/local-execution/local_tensor_backing.h +++ /dev/null @@ -1,31 +0,0 @@ -#ifndef _FLEXFLOW_LOCAL_EXECUTION_LOCAL_TENSOR_BACKING_H -#define _FLEXFLOW_LOCAL_EXECUTION_LOCAL_TENSOR_BACKING_H - -#include "kernels/accessor.h" -#include "kernels/allocation.h" -#include "local-execution/local_tensor_backing.dtg.h" -#include "local-execution/tensor_slot_backing.dtg.h" -#include "task-spec/task_binding.h" -#include "task-spec/training_computation_graph.dtg.h" -#include "task-spec/training_tensor_guid_t.dtg.h" - -namespace FlexFlow { - -LocalTensorBacking construct_local_tensor_backing( - std::unordered_map const - &training_tensor_shapes, - std::unordered_map const - &preallocated_tensors, - Allocator &); - -GenericTensorAccessorW - get_accessor_for_training_tensor(LocalTensorBacking const &, - training_tensor_guid_t); - -std::unordered_map - construct_tensor_slots_backing_for_binding(LocalTensorBacking const &, - TaskBinding const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/local-execution/include/local-execution/local_tensor_backing.struct.toml b/lib/local-execution/include/local-execution/local_tensor_backing.struct.toml deleted file mode 100644 index 48a7a7fa90..0000000000 --- a/lib/local-execution/include/local-execution/local_tensor_backing.struct.toml +++ /dev/null @@ -1,19 +0,0 @@ -namespace = "FlexFlow" -name = "LocalTensorBacking" -features = [ - "eq", - "fmt", -] - -includes = [ - "kernels/accessor.h", - "task-spec/training_tensor_guid_t.dtg.h", -] - -src_includes = [ - "utils/fmt/unordered_map.h", -] - -[[fields]] -name = "backing_for_training_tensor_map" -type = "std::unordered_map<::FlexFlow::training_tensor_guid_t, ::FlexFlow::GenericTensorAccessorW>" diff --git a/lib/local-execution/include/local-execution/local_training_backing.h b/lib/local-execution/include/local-execution/local_training_backing.h deleted file mode 100644 index f2177016fa..0000000000 --- a/lib/local-execution/include/local-execution/local_training_backing.h +++ /dev/null @@ -1,42 +0,0 @@ -#ifndef _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_LOCAL_TRAINING_BACKING_H -#define _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_LOCAL_TRAINING_BACKING_H - -#include "local-execution/local_training_backing.dtg.h" -#include "op-attrs/ops/loss_functions/loss_attrs.dtg.h" -#include "pcg/optimizer_attrs.dtg.h" -#include "task-spec/training_computation_graph.dtg.h" -#include "task-spec/training_tensor_guid_t.dtg.h" -#include "utils/units/milliseconds_t.h" - -namespace FlexFlow { - -LocalTrainingBacking make_local_training_backing_for_computation_graph( - Allocator &allocator, - std::unordered_map const - &preallocated_tensors, - TrainingComputationGraph const &training_computation_graph, - RuntimeArgConfig const &runtime_arg_config, - OptimizerAttrs const &optimizer_attrs); - -std::optional execute_forward(LocalTaskRegistry const &, - LocalTensorBacking const &, - LocalArgsBacking const &, - TrainingLayerPlusContext const &, - Allocator &); - -std::optional execute_backward(LocalTaskRegistry const &, - LocalTensorBacking const &, - LocalArgsBacking const &, - TrainingLayerPlusContext const &, - Allocator &); - -void compute_loss(LocalTrainingBacking const &, LossAttrs const &, Allocator &); - -void execute_update(LocalTrainingBacking const &, - layer_guid_t const &, - OptimizerAttrs const &, - Allocator &); - -} // namespace FlexFlow - -#endif diff --git a/lib/local-execution/include/local-execution/local_training_backing.struct.toml b/lib/local-execution/include/local-execution/local_training_backing.struct.toml deleted file mode 100644 index 7da8c3bed6..0000000000 --- a/lib/local-execution/include/local-execution/local_training_backing.struct.toml +++ /dev/null @@ -1,26 +0,0 @@ -namespace = "FlexFlow" -name = "LocalTrainingBacking" -features = [] - -includes = [ - "task-spec/training_computation_graph.dtg.h", - "local-execution/local_task_registry.h", - "local-execution/local_tensor_backing.h", - "local-execution/local_args_backing.h", -] - -[[fields]] -name = "training_computation_graph" -type = "::FlexFlow::TrainingComputationGraph" - -[[fields]] -name = "local_task_registry" -type = "::FlexFlow::LocalTaskRegistry" - -[[fields]] -name = "local_tensor_backing" -type = "::FlexFlow::LocalTensorBacking" - -[[fields]] -name = "local_args_backing" -type = "::FlexFlow::LocalArgsBacking" diff --git a/lib/local-execution/include/local-execution/model_training_instance.h b/lib/local-execution/include/local-execution/model_training_instance.h deleted file mode 100644 index bfd279fde5..0000000000 --- a/lib/local-execution/include/local-execution/model_training_instance.h +++ /dev/null @@ -1,31 +0,0 @@ -#ifndef _FLEXFLOW_LOCAL_EXECUTION_MODEL_TRAINING_INSTANCE_H -#define _FLEXFLOW_LOCAL_EXECUTION_MODEL_TRAINING_INSTANCE_H - -#include "local-execution/local_training_backing.h" -#include "op-attrs/ops/loss_functions/loss_attrs.dtg.h" -#include "pcg/tensor_guid_t.dtg.h" -#include "task-spec/loss_tensor_guid_t.dtg.h" - -namespace FlexFlow { - -struct ModelTrainingInstance { - ModelTrainingInstance(Allocator const &, - LocalTrainingBacking const &, - LossAttrs const &, - OptimizerAttrs const &); - - Allocator allocator; - LocalTrainingBacking training_backing; - LossAttrs loss_attrs; - OptimizerAttrs optimizer_attrs; - -public: - std::unordered_map> forward(); - std::unordered_map> backward(); - void update(); - GenericTensorAccessorR get_loss_tensor_accessor() const; -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/local-execution/include/local-execution/operator_task_set.dtg.toml b/lib/local-execution/include/local-execution/operator_task_set.dtg.toml new file mode 100644 index 0000000000..b074d981d1 --- /dev/null +++ b/lib/local-execution/include/local-execution/operator_task_set.dtg.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "OperatorTaskSet" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "local-execution/task_id_with_noop_default_t.dtg.h" +] + +[[fields]] +name = "init_task" +type = "::FlexFlow::task_id_with_noop_default_t" + +[[fields]] +name = "fwd_task" +type = "::FlexFlow::task_id_with_noop_default_t" + +[[fields]] +name = "bwd_task" +type = "::FlexFlow::task_id_with_noop_default_t" diff --git a/lib/local-execution/include/local-execution/operator_task_set.h b/lib/local-execution/include/local-execution/operator_task_set.h index bbe9da5d7f..b94ed9ac47 100644 --- a/lib/local-execution/include/local-execution/operator_task_set.h +++ b/lib/local-execution/include/local-execution/operator_task_set.h @@ -3,18 +3,19 @@ #include "local-execution/operator_task_set.dtg.h" #include "op-attrs/computation_graph_op_attrs.dtg.h" -#include "task-spec/op_task_type.dtg.h" +#include "task-spec/ops/op_task_type.dtg.h" #include "utils/bidict/bidict.h" namespace FlexFlow { -bidict +bidict get_map_from_task_type_to_task(OperatorTaskSet const &); -std::unordered_set +std::unordered_set get_all_tasks_in_task_set(OperatorTaskSet const &); -registered_task_t get_task_for_task_type(OperatorTaskSet const &op_task_set, - OpTaskType task_type); +task_id_with_noop_default_t + get_task_for_task_type(OperatorTaskSet const &op_task_set, + OpTaskType task_type); OperatorTaskSet get_task_set_for_operator(ComputationGraphOpAttrs const &op_attrs); diff --git a/lib/local-execution/include/local-execution/operator_task_set.struct.toml b/lib/local-execution/include/local-execution/operator_task_set.struct.toml deleted file mode 100644 index dda2a1478d..0000000000 --- a/lib/local-execution/include/local-execution/operator_task_set.struct.toml +++ /dev/null @@ -1,24 +0,0 @@ -namespace = "FlexFlow" -name = "OperatorTaskSet" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "local-execution/registered_task_t.dtg.h" -] - -[[fields]] -name = "init_task" -type = "::FlexFlow::registered_task_t" - -[[fields]] -name = "fwd_task" -type = "::FlexFlow::registered_task_t" - -[[fields]] -name = "bwd_task" -type = "::FlexFlow::registered_task_t" diff --git a/lib/local-execution/include/local-execution/per_device_op_state_initialization.h b/lib/local-execution/include/local-execution/per_device_op_state_initialization.h new file mode 100644 index 0000000000..31f8958a1c --- /dev/null +++ b/lib/local-execution/include/local-execution/per_device_op_state_initialization.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_PER_DEVICE_OP_STATE_INITIALIZATION_H +#define _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_PER_DEVICE_OP_STATE_INITIALIZATION_H + +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" +namespace FlexFlow { + +DynamicOpenDataflowGraph perform_per_device_op_state_initialization( + DynamicOpenDataflowGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/local-execution/include/local-execution/registered_task.h b/lib/local-execution/include/local-execution/registered_task.h deleted file mode 100644 index d6e8a87b18..0000000000 --- a/lib/local-execution/include/local-execution/registered_task.h +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_REGISTERED_TASK_H -#define _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_REGISTERED_TASK_H - -#include "local-execution/registered_task_t.dtg.h" - -namespace FlexFlow { - -registered_task_t make_noop_registered_task(); - -} // namespace FlexFlow - -#endif diff --git a/lib/local-execution/include/local-execution/registered_task_t.variant.toml b/lib/local-execution/include/local-execution/registered_task_t.variant.toml deleted file mode 100644 index d4bab60ec9..0000000000 --- a/lib/local-execution/include/local-execution/registered_task_t.variant.toml +++ /dev/null @@ -1,27 +0,0 @@ -namespace = "FlexFlow" -name = "registered_task_t" -features = [ - "eq", - "ord", - "hash", - "fmt", - "rapidcheck", -] - -includes = [ - "task-spec/task_id_t.dtg.h", - "", -] - -src_includes = [ - "utils/rapidcheck/monostate.h", - "utils/fmt/monostate.h", -] - -[[values]] -type = "::FlexFlow::task_id_t" -key = "real_task" - -[[values]] -type = "std::monostate" -key = "noop_task" diff --git a/lib/local-execution/include/local-execution/task_execution.h b/lib/local-execution/include/local-execution/task_execution.h new file mode 100644 index 0000000000..215f1dbc08 --- /dev/null +++ b/lib/local-execution/include/local-execution/task_execution.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_TASK_EXECUTION_H +#define _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_TASK_EXECUTION_H + +#include "kernels/profiling_settings.dtg.h" +#include "op-attrs/ops/loss_functions/loss_attrs.dtg.h" +#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" +#include "task-spec/per_device_op_state.dtg.h" +#include "task-spec/task_argument_accessor/task_argument_accessor.h" + +namespace FlexFlow { + +TaskArgumentAccessor make_task_argument_accessor_for_invocation( + DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, + DeviceType kernel_device_type, + PCGOperatorAttrs op_attrs, + std::optional const &loss_attrs, + std::optional const &per_device_op_state, + FFIterationConfig iteration_config, + std::optional const &optimizer_attrs); + +void execute_dynamic_node_invocation( + DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, + DeviceType kernel_device_type, + PCGOperatorAttrs op_attrs, + std::optional const &loss_attrs, + std::optional const &per_device_op_state, + FFIterationConfig iteration_config, + std::optional const &optimizer_attrs); + +} // namespace FlexFlow + +#endif diff --git a/lib/local-execution/include/local-execution/task_id_with_noop_default_t.h b/lib/local-execution/include/local-execution/task_id_with_noop_default_t.h new file mode 100644 index 0000000000..72e151bcc8 --- /dev/null +++ b/lib/local-execution/include/local-execution/task_id_with_noop_default_t.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_TASK_ID_WITH_NOOP_DEFAULT_T_H +#define _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_TASK_ID_WITH_NOOP_DEFAULT_T_H + +#include "local-execution/task_id_with_noop_default_t.dtg.h" + +namespace FlexFlow { + +task_id_with_noop_default_t make_default_noop_task(); + +} // namespace FlexFlow + +#endif diff --git a/lib/local-execution/include/local-execution/tasks.h b/lib/local-execution/include/local-execution/tasks.h deleted file mode 100644 index aae3b3fe44..0000000000 --- a/lib/local-execution/include/local-execution/tasks.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef _FLEXFLOW_LOCAL_EXECUTION_TASKS_H -#define _FLEXFLOW_LOCAL_EXECUTION_TASKS_H - -#include "task-spec/task_id_t.dtg.h" -#include -#include -#include - -namespace FlexFlow { -// PYTHON_TOP_LEVEL_TASK_ID = 11111, - -void register_flexflow_internal_tasks(); - -} // namespace FlexFlow - -#endif diff --git a/lib/local-execution/include/local-execution/tensor_allocation.h b/lib/local-execution/include/local-execution/tensor_allocation.h new file mode 100644 index 0000000000..67acb3de70 --- /dev/null +++ b/lib/local-execution/include/local-execution/tensor_allocation.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_TENSOR_ALLOCATION_H +#define _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_TENSOR_ALLOCATION_H + +#include "kernels/allocation.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" + +namespace FlexFlow { + +bool no_tensors_are_allocated(DynamicOpenDataflowGraph const &); +bool all_tensors_are_allocated(DynamicOpenDataflowGraph const &); + +DynamicValueAttrs perform_tensor_allocation_for_value(DynamicValueAttrs const &, + Allocator &); + +DynamicOpenDataflowGraph perform_tensor_allocation( + DynamicOpenDataflowGraph const &, + std::unordered_map const + &preallocated, + Allocator &); + +} // namespace FlexFlow + +#endif diff --git a/lib/local-execution/include/local-execution/tensor_slot_backing.dtg.toml b/lib/local-execution/include/local-execution/tensor_slot_backing.dtg.toml new file mode 100644 index 0000000000..4d8c817461 --- /dev/null +++ b/lib/local-execution/include/local-execution/tensor_slot_backing.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "TensorSlotBacking" +type = "variant" +features = [ + "eq", + "fmt", +] + +includes = [ + "kernels/accessor.h", + "", +] + +src_includes = [ + "utils/fmt/vector.h", +] + +[[values]] +type = "::FlexFlow::GenericTensorAccessorW" +key = "single" + +[[values]] +type = "std::vector<::FlexFlow::GenericTensorAccessorW>" +key = "variadic" diff --git a/lib/local-execution/include/local-execution/tensor_slot_backing.variant.toml b/lib/local-execution/include/local-execution/tensor_slot_backing.variant.toml deleted file mode 100644 index 434988fa21..0000000000 --- a/lib/local-execution/include/local-execution/tensor_slot_backing.variant.toml +++ /dev/null @@ -1,23 +0,0 @@ -namespace = "FlexFlow" -name = "TensorSlotBacking" -features = [ - "eq", - "fmt", -] - -includes = [ - "kernels/accessor.h", - "", -] - -src_includes = [ - "utils/fmt/vector.h", -] - -[[values]] -type = "::FlexFlow::GenericTensorAccessorW" -key = "single" - -[[values]] -type = "std::vector<::FlexFlow::GenericTensorAccessorW>" -key = "variadic" diff --git a/lib/local-execution/src/local-execution/computation_graph_instance/initialized_computation_graph_instance.cc b/lib/local-execution/src/local-execution/computation_graph_instance/initialized_computation_graph_instance.cc new file mode 100644 index 0000000000..a9f7018bb2 --- /dev/null +++ b/lib/local-execution/src/local-execution/computation_graph_instance/initialized_computation_graph_instance.cc @@ -0,0 +1,19 @@ +#include "local-execution/computation_graph_instance/initialized_computation_graph_instance.h" + +namespace FlexFlow { + +std::unordered_map> + perform_forward_pass_for_computation_graph_instance( + InitializedComputationGraphInstance const &instance) { + + NOT_IMPLEMENTED(); +} + +std::unordered_map> + perform_backward_pass_for_computation_graph_instance( + InitializedComputationGraphInstance const &instance) { + + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/local-execution/src/local-execution/cost_estimator/local_cost_estimator.cc b/lib/local-execution/src/local-execution/cost_estimator/local_cost_estimator.cc new file mode 100644 index 0000000000..fc181d26b0 --- /dev/null +++ b/lib/local-execution/src/local-execution/cost_estimator/local_cost_estimator.cc @@ -0,0 +1,153 @@ +#include "local-execution/cost_estimator/local_cost_estimator.h" +#include "compiler/machine_mapping/machine_view.dtg.h" +#include "kernels/create_local_allocator_for_device_type.h" +#include "kernels/device.h" +#include "kernels/local_cpu_allocator.h" +#include "kernels/local_cuda_allocator.h" +#include "local-execution/cost_estimator/tracked_allocator.h" +#include "op-attrs/computation_graph_op_attrs.h" +#include "op-attrs/pcg_operator_attrs.h" +#include "pcg/computation_graph.h" +#include "pcg/computation_graph/layer_added_result.dtg.h" +#include "pcg/parallel_tensor_attrs.h" +#include "utils/containers/concat_vectors.h" +#include "utils/containers/get_only.h" +#include "utils/containers/maximum.h" +#include "utils/containers/sum.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/values.h" + +namespace FlexFlow { + +LocalCostEstimator::LocalCostEstimator(RuntimeArgConfig const &config) + : runtime_arg_config(config) {} + +static ComputationGraph computation_graph_for_local_cost_estimation( + ComputationGraphOpAttrs const &op, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs) { + ComputationGraph computation_graph = make_empty_computation_graph(); + + std::vector input_tensors; + for (ParallelTensorShape const &input : inputs) { + LayerAddedResult inputs_layer = add_layer( + computation_graph, + LayerAttrs{ComputationGraphOpAttrs{InputAttrs{get_piece_shape(input)}}, + std::nullopt}, + {}, + {}); + input_tensors.push_back(get_only(inputs_layer.outputs)); + } + + std::vector weight_tensors; + for (ParallelTensorShape const &weight : weights) { + LayerAddedResult weights_layer = + add_layer(computation_graph, + LayerAttrs{ComputationGraphOpAttrs{WeightAttrs{ + get_piece_shape(weight), + InitializerAttrs{ZeroInitializerAttrs{}}}}, + std::nullopt}, + {}, + {}); + weight_tensors.push_back(get_only(weights_layer.outputs)); + } + + // create operator layer + LayerAddedResult operator_layer = add_layer(computation_graph, + LayerAttrs{ + /*op_attrs=*/op, + /*name=*/"operator", + }, + input_tensors, + weight_tensors); + + return computation_graph; +} + +OpCostMetrics LocalCostEstimator::estimate_cost( + OpCostEstimateKey const &op_cost_estimate_key) const { + + PCGOperatorAttrs op = op_cost_estimate_key.op_attrs; + std::vector inputs = op_cost_estimate_key.input_shapes; + std::vector weights = op_cost_estimate_key.weight_shapes; + std::vector outputs = op_cost_estimate_key.output_shapes; + MachineView mv = op_cost_estimate_key.machine_view; + + if (is_parallel_op(op) || op.has() || op.has() || + op.has()) { + return OpCostMetrics{ + /*forward_runtime=*/0_ms, + /*backward_runtime=*/0_ms, + /*memory=*/0_bytes, + }; + } + + // allocate memory + std::shared_ptr tracked_allocator_ptr = + std::make_shared(create_local_allocator_for_device_type( + runtime_arg_config.kernel_device_type)); + + layer_guid_t layer_guid = layer_guid_t{Node{0}}; + + Allocator allocator = Allocator(tracked_allocator_ptr); + + // execute layer + layer_guid_t operator_layer_guid = + get_layer_by_name(training_cg.computation_graph, "operator"); + + milliseconds_t fwd = execute_forward(local_backing.local_task_registry, + local_backing.local_tensor_backing, + local_backing.local_args_backing, + get_training_layer_plus_context( + training_cg, operator_layer_guid), + allocator) + .value(); + milliseconds_t bwd = execute_backward(local_backing.local_task_registry, + local_backing.local_tensor_backing, + local_backing.local_args_backing, + get_training_layer_plus_context( + training_cg, operator_layer_guid), + allocator) + .value(); + + return OpCostMetrics{ + /*forward_runtime=*/fwd, + /*backward_runtime=*/bwd, + /*memory=*/tracked_allocator_ptr->get_current_mem_usage(), + }; +} + +milliseconds_t LocalCostEstimator::estimate_cost( + TensorSetMovement const &tensor_set_movement) const { + + auto estimate_single_comm_cost = + [&](MachineSpaceCoordinate const &src, + MachineSpaceCoordinate const &dst, + num_bytes_t num_bytes) -> milliseconds_t { + if (src == dst) { + return 0_ms; + } else if (src.node_idx == dst.node_idx) { + return (num_bytes / + this->interconnect_specification.intra_node_bandwidth); + } else { + return (num_bytes / + this->interconnect_specification.inter_node_bandwidth); + } + }; + + return maximum( + transform(unordered_set_of(tensor_set_movement.edge_to_size), + [&](std::pair const &p) { + return estimate_single_comm_cost( + p.first.get_src(), p.first.get_dst(), p.second); + })); +} + +CostEstimator + get_local_cost_estimator(RuntimeArgConfig const &runtime_arg_config) { + return CostEstimator::create(runtime_arg_config); +} + +} // namespace FlexFlow diff --git a/lib/local-execution/src/tracked_allocator.cc b/lib/local-execution/src/local-execution/cost_estimator/tracked_allocator.cc similarity index 100% rename from lib/local-execution/src/tracked_allocator.cc rename to lib/local-execution/src/local-execution/cost_estimator/tracked_allocator.cc diff --git a/lib/local-execution/src/local-execution/execute_task_for_layer.cc b/lib/local-execution/src/local-execution/execute_task_for_layer.cc new file mode 100644 index 0000000000..5a7ea74e52 --- /dev/null +++ b/lib/local-execution/src/local-execution/execute_task_for_layer.cc @@ -0,0 +1,274 @@ +#include "local-execution/execute_task_for_layer.h" +#include "local-execution/atomic_task_binding.dtg.h" +#include "local-execution/local_atomic_tensor_backing.h" +#include "local-execution/local_ready_to_launch_task.dtg.h" +#include "local-execution/local_task_registry.h" +#include "local-execution/local_tensor_backing.h" +#include "task-spec/fwb_op_task_type.h" +#include "task-spec/runtime_task_invocation/runtime_task_invocation.dtg.h" +#include "task-spec/symbolic/training_symbolic_computation_graph.h" +#include "utils/containers/flatmap.h" + +namespace FlexFlow { + +LocalReadyToLaunchTask prepare_runtime_task_invocation( + RuntimeTaskInvocation const &runtime_task_invocation, + LocalTensorBacking const &local_tensor_backing, + LocalAtomicTensorBacking const &local_atomic_tensor_backing, + Allocator &allocator, + RuntimeArgConfig const &runtime_arg_config) { + + AtomicTaskInvocation atomic_task_invocation = + lower_local_runtime_task_invocation_to_atomic_task_invocation( + local_tensor_backing, runtime_task_invocation, runtime_arg_config); + + TaskArgumentAccessor task_arg_accessor = + get_task_arg_accessor_for_atomic_task_invocation( + local_atomic_tensor_backing, atomic_task_invocation, allocator); + + return LocalReadyToLaunchTask{ + atomic_task_invocation.task_id, + task_arg_accessor, + }; +} + +std::optional execute_init_for_layer( + symbolic_layer_guid_t symbolic_layer_guid, + TrainingSymbolicComputationGraph const &g, + LocalTensorBacking const &tensor_backing, + LocalAtomicTensorBacking const &atomic_tensor_backing, + Allocator &allocator, + LocalTaskRegistry const &task_registry, + RuntimeArgConfig const &runtime_arg_config) { + + SymbolicCgOpAttrsAndTrainingSignatureWithShapes attrs_and_signature = + get_attrs_and_signature_for_layer(g, symbolic_layer_guid); + + RuntimeTaskInvocation runtime_task_invocation = ({ + std::optional maybe_runtime_task_invocation = + get_init_runtime_task_invocation_for_layer(symbolic_layer_guid, + attrs_and_signature); + if (!maybe_runtime_task_invocation.has_value()) { + return std::nullopt; + } + maybe_runtime_task_invocation.value(); + }); + + LocalReadyToLaunchTask prepared_task = + prepare_runtime_task_invocation(runtime_task_invocation, + tensor_backing, + atomic_tensor_backing, + allocator, + runtime_arg_config); + + std::optional per_device_op_state = + call_init_task_impl(task_registry, + prepared_task.task_id, + prepared_task.task_arg_accessor); + + return per_device_op_state; +} + +static std::optional execute_fwb_for_layer( + symbolic_layer_guid_t symbolic_layer_guid, + SymbolicCgOpAttrsAndTrainingSignatureWithShapes const &attrs_and_signature, + LocalTensorBacking const &local_tensor_backing, + LocalAtomicTensorBacking const &local_atomic_tensor_backing, + Allocator &allocator, + LocalTaskRegistry const &local_task_registry, + RuntimeArgConfig const &runtime_arg_config, + FwbOpTaskType task_type) { + + OpTaskType op_task_type = + assert_unwrap(op_task_type_from_fwb_op_task_type(task_type)); + + RuntimeTaskInvocation runtime_task_invocation = ({ + std::optional maybe_runtime_task_invocation = + get_runtime_task_invocation_for_layer_and_type( + symbolic_layer_guid, attrs_and_signature, op_task_type); + if (!maybe_runtime_task_invocation.has_value()) { + return std::nullopt; + } + maybe_runtime_task_invocation.value(); + }); + + task_id_t task_id = runtime_task_invocation.task_id; + + RuntimeTaskBinding runtime_task_binding = runtime_task_invocation.binding; + + AtomicTaskBinding atomic_task_binding = + lower_local_runtime_task_binding_to_atomic_task_binding( + local_tensor_backing, runtime_task_binding, runtime_arg_config); + + TaskArgumentAccessor task_arg_accessor = + get_task_arg_accessor_for_atomic_task_binding( + local_atomic_tensor_backing, atomic_task_binding, allocator); + + std::optional execution_time = + call_fwb_task_impl(local_task_registry, task_id, task_arg_accessor); + + return execution_time; +} + +std::optional execute_forward_for_layer( + symbolic_layer_guid_t layer, + SymbolicCgOpAttrsAndTrainingSignatureWithShapes const &attrs_and_signature, + LocalTensorBacking const &tensor_backing, + LocalAtomicTensorBacking const &atomic_tensor_backing, + Allocator &allocator, + LocalTaskRegistry const &task_registry, + RuntimeArgConfig const &runtime_arg_config) { + + return execute_fwb_for_layer(layer, + attrs_and_signature, + tensor_backing, + atomic_tensor_backing, + allocator, + task_registry, + runtime_arg_config, + FwbOpTaskType::FWD); +} + +std::optional execute_backward_for_layer( + symbolic_layer_guid_t layer, + SymbolicCgOpAttrsAndTrainingSignatureWithShapes const &attrs_and_signature, + LocalTensorBacking const &tensor_backing, + LocalAtomicTensorBacking const &atomic_tensor_backing, + Allocator &allocator, + LocalTaskRegistry const &task_registry, + RuntimeArgConfig const &runtime_arg_config) { + + return execute_fwb_for_layer(layer, + attrs_and_signature, + tensor_backing, + atomic_tensor_backing, + allocator, + task_registry, + runtime_arg_config, + FwbOpTaskType::BWD); +} + +void execute_compute_loss(LossAttrs const &loss_attrs, + symbolic_forward_tensor_guid_t logit_fwd_tensor, + symbolic_gradient_tensor_guid_t logit_grad_tensor, + symbolic_loss_tensor_guid_t loss_tensor, + LocalTensorBacking const &tensor_backing, + LocalAtomicTensorBacking const &atomic_tensor_backing, + Allocator &allocator, + LocalTaskRegistry const &task_registry, + RuntimeArgConfig const &runtime_arg_config) { + + RuntimeTaskInvocation invocation = get_compute_loss_runtime_task_invocation( + loss_attrs, logit_fwd_tensor, logit_grad_tensor, loss_tensor); + + LocalReadyToLaunchTask prepared_task = + prepare_runtime_task_invocation(invocation, + tensor_backing, + atomic_tensor_backing, + allocator, + runtime_arg_config); + + call_generic_task_impl( + task_registry, prepared_task.task_id, prepared_task.task_arg_accessor); +} + +void execute_update_for_layer( + symbolic_layer_guid_t symbolic_layer_guid, + TrainingSymbolicComputationGraph const &graph, + LocalTensorBacking const &tensor_backing, + LocalAtomicTensorBacking const &atomic_tensor_backing, + OptimizerAttrs const &optimizer_attrs, + Allocator &allocator, + LocalTaskRegistry const &task_registry, + RuntimeArgConfig const &runtime_arg_config) { + + SymbolicTrainingLayerAttrsPlusContext attrs_plus_context = + get_symbolic_training_layer_attrs_plus_context(graph, + symbolic_layer_guid); + + RuntimeTaskInvocation invocation = ({ + std::optional maybe_invocation = + get_update_runtime_task_invocation_for_layer(attrs_plus_context, + optimizer_attrs); + if (!maybe_invocation.has_value()) { + return; + } + maybe_invocation.value(); + }); + + LocalReadyToLaunchTask prepared_task = + prepare_runtime_task_invocation(invocation, + tensor_backing, + atomic_tensor_backing, + allocator, + runtime_arg_config); + + call_generic_task_impl( + task_registry, prepared_task.task_id, prepared_task.task_arg_accessor); +} + +std::unordered_map> + execute_forward_pass( + TrainingSymbolicComputationGraphFromCgConversion const &training_cg, + LocalTensorBacking const &local_tensor_backing, + LocalAtomicTensorBacking const &local_atomic_tensor_backing, + Allocator &allocator, + LocalTaskRegistry const &local_task_registry, + RuntimeArgConfig const &runtime_arg_config) { + std::unordered_map> + per_layer_elapsed_time; + + for (symbolic_layer_guid_t symbolic_layer_guid : + symbolic_cg_topological_ordering( + training_cg.training_symbolic_computation_graph)) { + + std::optional elapsed_time = execute_forward_for_layer( + symbolic_layer_guid, + training_cg.training_symbolic_computation_graph, + local_tensor_backing, + local_atomic_tensor_backing, + allocator, + local_task_registry, + runtime_arg_config); + + layer_guid_t layer_guid = + training_cg.layer_mapping.at_r(symbolic_layer_guid); + per_layer_elapsed_time.insert({layer_guid, elapsed_time}); + } + + return per_layer_elapsed_time; +} + +std::unordered_map> + execute_backward_pass( + TrainingSymbolicComputationGraphFromCgConversion const &training_cg, + LocalTensorBacking const &local_tensor_backing, + LocalAtomicTensorBacking const &local_atomic_tensor_backing, + Allocator &allocator, + LocalTaskRegistry const &local_task_registry, + RuntimeArgConfig const &runtime_arg_config) { + std::unordered_map> + per_layer_elapsed_time; + + for (symbolic_layer_guid_t symbolic_layer_guid : + reversed(symbolic_cg_topological_ordering( + training_cg.training_symbolic_computation_graph))) { + + std::optional elapsed_time = execute_backward_for_layer( + symbolic_layer_guid, + training_cg.training_symbolic_computation_graph, + local_tensor_backing, + local_atomic_tensor_backing, + allocator, + local_task_registry, + runtime_arg_config); + + layer_guid_t layer_guid = + training_cg.layer_mapping.at_r(symbolic_layer_guid); + per_layer_elapsed_time.insert({layer_guid, elapsed_time}); + } + + return per_layer_elapsed_time; +} + +} // namespace FlexFlow diff --git a/lib/local-execution/src/local-execution/local_args_backing.cc b/lib/local-execution/src/local-execution/local_args_backing.cc deleted file mode 100644 index eb1c7b067e..0000000000 --- a/lib/local-execution/src/local-execution/local_args_backing.cc +++ /dev/null @@ -1,109 +0,0 @@ -#include "local-execution/local_args_backing.h" -#include "local-execution/local_task_registry.h" -#include "local-execution/local_tensor_backing.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "task-spec/op_task_to_task_invocation.h" -#include "task-spec/task_signature_impl.h" -#include "task-spec/training_computation_graph.h" -#include "task-spec/training_layer_plus_context.h" -#include "utils/containers/contains_key.h" -#include "utils/containers/generate_map.h" -#include "utils/containers/map_values.h" -#include "utils/containers/try_at.h" -#include "utils/overload.h" - -namespace FlexFlow { - -std::optional get_per_device_op_state_if_exists( - LocalArgsBacking const &local_args_backing, - layer_guid_t const &layer_guid) { - - return local_args_backing.per_device_op_states.at(layer_guid); -} - -std::unordered_map - construct_arg_slots_backing(TaskBinding const &binding, - RuntimeArgConfig const &runtime_arg_config) { - return map_values( - binding.get_arg_bindings(), [&](TaskArgSpec const &arg_binding) { - return arg_binding.template visit( - overload{[&](RuntimeArgRefSpec const &s) { - return lower_to_concrete_arg_spec(s, runtime_arg_config); - }, - [](ConcreteArgSpec const &s) { return s; }}); - }); - ; -} - -std::optional - create_per_device_op_state(LocalTaskRegistry const &local_task_registry, - LocalTensorBacking const &tensor_backing, - RuntimeArgConfig const &runtime_arg_config, - Allocator &allocator, - TrainingLayerPlusContext const &training_layer) { - std::optional maybe_registered_task = try_get_registered_task( - local_task_registry, training_layer.layer_guid, OpTaskType::INIT); - - ASSERT(maybe_registered_task.has_value()); - - registered_task_t registered_task = maybe_registered_task.value(); - if (registered_task.is_noop_task()) { - return std::nullopt; - } - - TaskInvocation invocation = lower_to_task_invocation( - /*op_task_invocation=*/get_init_op_task_invocation( - training_layer.layer_attrs.op_attrs), - /*training_layer=*/training_layer, - /*device_specific_device_states=*/std::nullopt); - - TaskArgumentAccessor accessor = get_task_arg_accessor( - tensor_backing, runtime_arg_config, invocation, allocator); - TaskSignatureAndImpl task_sig_impl = - local_task_registry.task_mapping.at(invocation.task_id); - auto fn = - task_sig_impl.impl_function.get().function_ptr; - DeviceSpecificDeviceStates device_state = fn(accessor); - return device_state; -} - -TaskArgumentAccessor - get_task_arg_accessor(LocalTensorBacking const &local_tensor_backing, - RuntimeArgConfig const &runtime_arg_config, - TaskInvocation const &invocation, - Allocator &allocator) { - std::unordered_map - tensor_slots_backing = construct_tensor_slots_backing_for_binding( - local_tensor_backing, invocation.binding); - std::unordered_map arg_slots_backing = - construct_arg_slots_backing(invocation.binding, runtime_arg_config); - return TaskArgumentAccessor::create( - allocator, tensor_slots_backing, arg_slots_backing); -} - -LocalArgsBacking make_local_args_backing_for_computation_graph( - LocalTaskRegistry const &task_registry, - TrainingComputationGraph const &training_computation_graph, - RuntimeArgConfig const &runtime_arg_config, - LocalTensorBacking const &local_tensor_backing, - Allocator &allocator) { - std::unordered_map> - per_device_op_states = generate_map( - topological_ordering(training_computation_graph.computation_graph), - [&](layer_guid_t const &layer_guid) { - return create_per_device_op_state( - task_registry, - local_tensor_backing, - runtime_arg_config, - allocator, - get_training_layer_plus_context(training_computation_graph, - layer_guid)); - }); - - return LocalArgsBacking{ - runtime_arg_config, - per_device_op_states, - }; -} - -} // namespace FlexFlow diff --git a/lib/local-execution/src/local-execution/local_atomic_tensor_backing.cc b/lib/local-execution/src/local-execution/local_atomic_tensor_backing.cc new file mode 100644 index 0000000000..c43fd6bdf3 --- /dev/null +++ b/lib/local-execution/src/local-execution/local_atomic_tensor_backing.cc @@ -0,0 +1,35 @@ +#include "local-execution/local_atomic_tensor_backing.h" +#include "local-execution/local_task_argument_accessor.h" +#include "utils/containers/map_values.h" + +namespace FlexFlow { + +std::unordered_map + construct_tensor_slots_backing_for_binding( + LocalAtomicTensorBacking const &tensor_backing, + AtomicTaskBinding const &binding) { + return map_values(binding.tensor_bindings, + [&](atomic_training_tensor_guid_t t) -> TensorSlotBacking { + return TensorSlotBacking{ + tensor_backing.accessor_from_atomic_tensor_map.at(t), + }; + }); +} + +TaskArgumentAccessor get_task_arg_accessor_for_atomic_task_invocation( + LocalAtomicTensorBacking const &local_tensor_backing, + AtomicTaskBinding const &atomic_task_binding, + Allocator &allocator) { + + std::unordered_map + tensor_slots_backing = construct_tensor_slots_backing_for_binding( + local_tensor_backing, atomic_task_binding); + + std::unordered_map arg_slots_backing = + atomic_task_binding.arg_bindings; + + return TaskArgumentAccessor::create( + allocator, tensor_slots_backing, arg_slots_backing, 0); +} + +} // namespace FlexFlow diff --git a/lib/local-execution/src/local-execution/local_concrete_task_graph.cc b/lib/local-execution/src/local-execution/local_concrete_task_graph.cc new file mode 100644 index 0000000000..9806758f06 --- /dev/null +++ b/lib/local-execution/src/local-execution/local_concrete_task_graph.cc @@ -0,0 +1,12 @@ +#include "local-execution/local_concrete_task_graph.h" + +namespace FlexFlow { + +std::vector + local_concrete_task_graph_topological_ordering( + LocalConcreteTaskGraph const &) { + + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/local-execution/src/local-execution/local_cost_estimator.cc b/lib/local-execution/src/local-execution/local_cost_estimator.cc deleted file mode 100644 index 6517dbfdbc..0000000000 --- a/lib/local-execution/src/local-execution/local_cost_estimator.cc +++ /dev/null @@ -1,165 +0,0 @@ -#include "local-execution/local_cost_estimator.h" -#include "kernels/create_local_allocator_for_device_type.h" -#include "kernels/device.h" -#include "kernels/local_cpu_allocator.h" -#include "kernels/local_cuda_allocator.h" -#include "local-execution/local_training_backing.h" -#include "local-execution/tracked_allocator.h" -#include "op-attrs/computation_graph_op_attrs.h" -#include "op-attrs/pcg_operator_attrs.h" -#include "pcg/computation_graph.h" -#include "pcg/computation_graph/layer_added_result.dtg.h" -#include "pcg/machine_view.dtg.h" -#include "pcg/parallel_tensor_attrs.h" -#include "task-spec/forward_tensor_source.h" -#include "task-spec/gradient_tensor_source.h" -#include "task-spec/optimizer_tensor_source.h" -#include "task-spec/training_computation_graph.h" -#include "utils/containers/concat_vectors.h" -#include "utils/containers/get_only.h" -#include "utils/containers/sum.h" -#include "utils/containers/transform.h" -#include "utils/containers/values.h" - -namespace FlexFlow { - -LocalCostEstimator::LocalCostEstimator(RuntimeArgConfig const &config) - : runtime_arg_config(config) {} - -static TrainingComputationGraph - create_computation_graph_for_local_cost_estimation( - PCGOperatorAttrs const &op, - OptimizerAttrs const &optimizer_attrs, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs) { - ComputationGraph computation_graph = make_empty_computation_graph(); - - std::vector input_tensors; - for (ParallelTensorShape const &input : inputs) { - LayerAddedResult inputs_layer = add_layer( - computation_graph, - LayerAttrs{ComputationGraphOpAttrs{InputAttrs{get_piece_shape(input)}}, - std::nullopt}, - {}, - {}); - input_tensors.push_back(get_only(inputs_layer.outputs)); - } - - std::vector weight_tensors; - for (ParallelTensorShape const &weight : weights) { - LayerAddedResult weights_layer = - add_layer(computation_graph, - LayerAttrs{ComputationGraphOpAttrs{WeightAttrs{ - get_piece_shape(weight), - InitializerAttrs{ZeroInitializerAttrs{}}}}, - std::nullopt}, - {}, - {}); - weight_tensors.push_back(get_only(weights_layer.outputs)); - } - - // create operator layer - LayerAddedResult operator_layer = add_layer( - computation_graph, - LayerAttrs{compgraph_op_attrs_from_pcg_op_attrs(op), "operator"}, - input_tensors, - weight_tensors); - - ForwardTensorSource forward_tensor_source; - GradientTensorSource gradient_tensor_source; - OptimizerTensorSource optimizer_tensor_source; - LossTensorSource loss_tensor_source; - - TrainingComputationGraph training_cg = generate_training_computation_graph( - /*computation_graph=*/computation_graph, - /*optimizer_attrs=*/optimizer_attrs, - /*logit_tensor=*/operator_layer.outputs.at(0), - /*forward_tensor_source=*/forward_tensor_source, - /*gradient_tensor_source=*/gradient_tensor_source, - /*optimizer_tensor_source=*/optimizer_tensor_source, - /*loss_tensor_source=*/loss_tensor_source); - - return training_cg; -} - -OpCostMetrics LocalCostEstimator::estimate_cost( - OpCostEstimateKey const &op_cost_estimate_key) const { - - PCGOperatorAttrs op = op_cost_estimate_key.op_attrs; - std::vector inputs = op_cost_estimate_key.input_shapes; - std::vector weights = op_cost_estimate_key.weight_shapes; - std::vector outputs = op_cost_estimate_key.output_shapes; - MachineView mv = op_cost_estimate_key.machine_view; - - if (is_parallel_op(op) || op.has() || op.has() || - op.has()) { - return OpCostMetrics{ - /*forward_runtime=*/0_ms, - /*backward_runtime=*/0_ms, - /*memory=*/0_bytes, - }; - } - - TrainingComputationGraph training_cg = - create_computation_graph_for_local_cost_estimation( - /*op=*/op, - /*optimizer_attrs=*/op_cost_estimate_key.optimizer_attrs, - /*inputs=*/inputs, - /*weights=*/weights, - /*outputs=*/outputs); - - // allocate memory - std::shared_ptr tracked_allocator_ptr = - std::make_shared(create_local_allocator_for_device_type( - runtime_arg_config.kernel_device_type)); - Allocator allocator = Allocator(tracked_allocator_ptr); - - LocalTrainingBacking local_backing = - make_local_training_backing_for_computation_graph( - /*allocator=*/allocator, - /*preallocated_tensors=*/{}, - /*training_computation_graph=*/training_cg, - /*runtime_arg_config=*/this->runtime_arg_config, - /*optimizer_attrs=*/op_cost_estimate_key.optimizer_attrs); - - // execute layer - layer_guid_t operator_layer_guid = - get_layer_by_name(training_cg.computation_graph, "operator"); - - milliseconds_t fwd = execute_forward(local_backing.local_task_registry, - local_backing.local_tensor_backing, - local_backing.local_args_backing, - get_training_layer_plus_context( - training_cg, operator_layer_guid), - allocator) - .value(); - milliseconds_t bwd = execute_backward(local_backing.local_task_registry, - local_backing.local_tensor_backing, - local_backing.local_args_backing, - get_training_layer_plus_context( - training_cg, operator_layer_guid), - allocator) - .value(); - - return OpCostMetrics{ - /*forward_runtime=*/fwd, - /*backward_runtime=*/bwd, - /*memory=*/tracked_allocator_ptr->get_current_mem_usage(), - }; -} - -milliseconds_t LocalCostEstimator::estimate_cost( - TensorSetMovement const &tensor_set_movement) const { - // TODO: model communication cost analytically - // https://github.com/flexflow/FlexFlow/issues/1414 - - NOT_IMPLEMENTED(); -} - -CostEstimator - get_local_cost_estimator(RuntimeArgConfig const &runtime_arg_config) { - return CostEstimator::create(runtime_arg_config); -} - -} // namespace FlexFlow diff --git a/lib/local-execution/src/local-execution/local_device_states_backing.cc b/lib/local-execution/src/local-execution/local_device_states_backing.cc new file mode 100644 index 0000000000..1dc34b120d --- /dev/null +++ b/lib/local-execution/src/local-execution/local_device_states_backing.cc @@ -0,0 +1,48 @@ +#include "local-execution/local_device_states_backing.h" +#include "local-execution/local_task_registry.h" +#include "local-execution/local_tensor_backing.h" +#include "task-spec/task_signature_impl.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/keys.h" +#include "utils/overload.h" + +namespace FlexFlow { + +// LocalDeviceStatesBacking +// make_local_device_states_backing_for_computation_graph( +// LocalTaskRegistry const &task_registry, +// std::unordered_map const &layers, +// std::unordered_map const +// &op_attrs, RuntimeArgConfig const &runtime_arg_config, LocalTensorBacking +// const &local_tensor_backing, Allocator &allocator) { +// +// std::unordered_map> +// per_device_op_states = generate_map( +// keys(layers), +// [&](layer_guid_t const &layer_guid) -> +// std::optional { +// return create_per_device_op_state( +// task_registry, +// local_tensor_backing, +// runtime_arg_config, +// allocator, +// op_attrs, +// layers.at(layer_guid)); +// }); +// +// return LocalDeviceStatesBacking{ +// per_device_op_states, +// }; +// } + +// std::optional +// get_per_device_op_state_if_exists( +// LocalArgsBacking const &local_args_backing, +// layer_guid_t const &layer_guid) { +// +// return local_args_backing.per_device_op_states.at(layer_guid); +// } + +} // namespace FlexFlow diff --git a/lib/local-execution/src/local-execution/local_task_argument_accessor.cc b/lib/local-execution/src/local-execution/local_task_argument_accessor.cc new file mode 100644 index 0000000000..c9bdb84fbf --- /dev/null +++ b/lib/local-execution/src/local-execution/local_task_argument_accessor.cc @@ -0,0 +1,88 @@ +#include "local-execution/local_task_argument_accessor.h" +#include "pcg/device_id_t.h" +#include "utils/containers/contains_key.h" +#include "utils/containers/transform.h" +#include "utils/hash/pair.h" +#include "utils/overload.h" + +namespace FlexFlow { + +LocalTaskArgumentAccessor::LocalTaskArgumentAccessor( + Allocator const &allocator, + std::unordered_map const + &tensor_slots_backing, + ProfilingSettings const &profiling_settings, + device_handle_t const &ff_handle, + DeviceType kernel_device_type, + PCGOperatorAttrs const &op_attrs, + std::optional const &loss_attrs, + std::optional const &per_device_op_state, + FFIterationConfig const &iteration_config, + std::optional const &optimizer_attrs, + size_t device_idx) + : allocator(allocator), tensor_slots_backing(tensor_slots_backing), + profiling_settings(profiling_settings), ff_handle(ff_handle), + kernel_device_type(kernel_device_type), op_attrs(op_attrs), + loss_attrs(loss_attrs), per_device_op_state(per_device_op_state), + iteration_config(iteration_config), optimizer_attrs(optimizer_attrs), + device_idx(make_device_id_t_from_idx(nonnegative_int{device_idx}, + kernel_device_type)) {} + +GenericTensorAccessor LocalTaskArgumentAccessor::get_tensor( + TensorSlotName slot, + Permissions priv, + TrainingTensorType tensor_type) const { + GenericTensorAccessorW tensor_backing = + this->tensor_slots_backing.at(slot_tensor_type).require_single(); + if (priv == Permissions::RO) { + GenericTensorAccessorR readonly_tensor_backing = + read_only_accessor_from_write_accessor(tensor_backing); + return readonly_tensor_backing; + } else if (priv == Permissions::RW || priv == Permissions::WO) { + return tensor_backing; + } else { + PANIC(fmt::format("Unhandled privilege mode {}", priv)); + } +} + +ProfilingSettings LocalTaskArgumentAccessor::get_profiling_settings() const { + return this->profiling_settings; +} + +device_handle_t LocalTaskArgumentAccessor::get_ff_handle() const { + return this->ff_handle; +} + +DeviceType LocalTaskArgumentAccessor::get_kernel_device_type() const { + return this->kernel_device_type; +} + +PCGOperatorAttrs LocalTaskArgumentAccessor::get_op_attrs() const { + return this->op_attrs; +} + +LossAttrs LocalTaskArgumentAccessor::get_loss_attrs() const { + return assert_unwrap(this->loss_attrs); +} + +PerDeviceOpState LocalTaskArgumentAccessor::get_per_device_op_state() const { + return assert_unwrap(this->per_device_op_state); +} + +FFIterationConfig LocalTaskArgumentAccessor::get_iteration_config() const { + return this->iteration_config; +} + +OptimizerAttrs LocalTaskArgumentAccessor::get_optimizer_attrs() const { + return assert_unwrap(this->optimizer_attrs); +} + +Allocator LocalTaskArgumentAccessor::get_allocator() const { + return this->allocator; +} + +size_t LocalTaskArgumentAccessor::get_device_idx() const { + return this->device_idx; +} + +} // namespace FlexFlow diff --git a/lib/local-execution/src/local-execution/local_task_registry.cc b/lib/local-execution/src/local-execution/local_task_registry.cc index d482736a5b..fb6936425d 100644 --- a/lib/local-execution/src/local-execution/local_task_registry.cc +++ b/lib/local-execution/src/local-execution/local_task_registry.cc @@ -1,6 +1,5 @@ #include "local-execution/local_task_registry.h" #include "local-execution/operator_task_set.h" -#include "local-execution/registered_task.h" #include "pcg/computation_graph.h" #include "task-spec/task_signature_impl.h" #include "utils/containers/contains_key.h" @@ -14,51 +13,62 @@ namespace FlexFlow { LocalTaskRegistry construct_local_task_registry_for_layers( - std::unordered_map const &layer_attrs_mapping) { + std::unordered_set const &op_attrs) { - std::unordered_map task_sets = - map_values(layer_attrs_mapping, [](LayerAttrs const &layer_attrs) { - return get_task_set_for_operator(layer_attrs.op_attrs); - }); - - std::unordered_set all_tasks = - flatmap(unordered_set_of(values(task_sets)), get_all_tasks_in_task_set); - - std::unordered_set all_real_tasks = - filtrans(all_tasks, [](registered_task_t const &t) { - return t.try_require_real_task(); - }); + std::unordered_set task_ids = flatmap( + op_attrs, + [](ComputationGraphOpAttrs const &op_attrs) + -> std::unordered_set { return get_task_ids(op_attrs); }); std::unordered_map task_mapping = - generate_map(all_real_tasks, get_task_signature_and_impl_for_task_id); + generate_map(task_ids, get_task_signature_and_impl_for_task_id); return LocalTaskRegistry{ - /*task_sets=*/task_sets, /*task_mapping=*/task_mapping, }; } -std::optional - try_get_registered_task(LocalTaskRegistry const &task_registry, - layer_guid_t const &layer_guid, - OpTaskType const &op_task_type) { - if (!contains_key(task_registry.task_sets, layer_guid)) { +std::optional + call_init_task_impl(LocalTaskRegistry const &local_task_registry, + task_id_with_noop_default_t registered_task, + TaskArgumentAccessor const &arg_accessor) { + + if (registered_task.is_noop_task()) { return std::nullopt; } - return get_task_for_task_type(task_registry.task_sets.at(layer_guid), - op_task_type); + task_id_t task_id = registered_task.require_real_task(); + + TaskSignatureAndImpl task_sig_impl = + local_task_registry.task_mapping.at(task_id); + + auto fn = + task_sig_impl.impl_function.get().function_ptr; + + std::optional device_state = fn(arg_accessor); + + return device_state; } std::optional - call_task_impl(LocalTaskRegistry const &task_registry, - task_id_t const &task_id, - TaskArgumentAccessor const &acc) { + call_fwb_task_impl(LocalTaskRegistry const &task_registry, + task_id_t const &task_id, + TaskArgumentAccessor const &acc) { TaskSignatureAndImpl task_sig_impl = task_registry.task_mapping.at(task_id); auto fn = task_sig_impl.impl_function.get().function_ptr; - return transform( - fn(acc), [](float running_time) { return milliseconds_t{running_time}; }); + + return fn(acc); +} + +void call_generic_task_impl(LocalTaskRegistry const &task_registry, + task_id_t const &task_id, + TaskArgumentAccessor const &acc) { + TaskSignatureAndImpl task_sig_impl = task_registry.task_mapping.at(task_id); + auto fn = + task_sig_impl.impl_function.get().function_ptr; + + fn(acc); } } // namespace FlexFlow diff --git a/lib/local-execution/src/local-execution/local_tensor_backing.cc b/lib/local-execution/src/local-execution/local_tensor_backing.cc deleted file mode 100644 index be8e44736c..0000000000 --- a/lib/local-execution/src/local-execution/local_tensor_backing.cc +++ /dev/null @@ -1,74 +0,0 @@ -#include "local-execution/local_tensor_backing.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "pcg/computation_graph.h" -#include "pcg/optimizer_attrs.h" -#include "task-spec/slot_grad_id.dtg.h" -#include "task-spec/training_computation_graph.h" -#include "utils/containers/contains_key.h" -#include "utils/containers/generate_map.h" -#include "utils/containers/is_submapeq_of.h" -#include "utils/containers/is_subseteq_of.h" -#include "utils/containers/keys.h" -#include "utils/containers/map_values.h" -#include "utils/containers/merge_maps.h" -#include "utils/containers/set_minus.h" -#include "utils/containers/set_of.h" -#include "utils/overload.h" - -namespace FlexFlow { - -LocalTensorBacking construct_local_tensor_backing( - std::unordered_map const - &training_tensor_shapes, - std::unordered_map const - &preallocated, - Allocator &allocator) { - - ASSERT(is_subseteq_of(keys(preallocated), keys(training_tensor_shapes))); - - std::unordered_set to_allocate = - set_minus(keys(training_tensor_shapes), keys(preallocated)); - - std::unordered_map allocated = - generate_map(to_allocate, [&](training_tensor_guid_t t) { - TensorShape shape = training_tensor_shapes.at(t); - return allocator.allocate_tensor(shape); - }); - - std::unordered_map - backing_for_training_tensor_map = - merge_disjoint_maps(allocated, preallocated); - - ASSERT(is_submapeq_of(preallocated, backing_for_training_tensor_map)); - - ASSERT(keys(backing_for_training_tensor_map) == keys(training_tensor_shapes), - backing_for_training_tensor_map.size(), - training_tensor_shapes.size(), - keys(preallocated)); - - return LocalTensorBacking{ - backing_for_training_tensor_map, - }; -} - -GenericTensorAccessorW get_accessor_for_training_tensor( - LocalTensorBacking const &local_tensor_backing, - training_tensor_guid_t training_tensor) { - return local_tensor_backing.backing_for_training_tensor_map.at( - training_tensor); -} - -std::unordered_map - construct_tensor_slots_backing_for_binding( - LocalTensorBacking const &local_tensor_backing, - TaskBinding const &binding) { - - return map_values( - binding.get_tensor_bindings(), [&](training_tensor_guid_t t) { - return TensorSlotBacking{ - get_accessor_for_training_tensor(local_tensor_backing, t), - }; - }); -} - -} // namespace FlexFlow diff --git a/lib/local-execution/src/local-execution/local_training_backing.cc b/lib/local-execution/src/local-execution/local_training_backing.cc deleted file mode 100644 index 1aac8506f2..0000000000 --- a/lib/local-execution/src/local-execution/local_training_backing.cc +++ /dev/null @@ -1,179 +0,0 @@ -#include "local-execution/local_training_backing.h" -#include "local-execution/local_args_backing.h" -#include "pcg/computation_graph.h" -#include "pcg/optimizer_attrs.h" -#include "task-spec/loss_functions.h" -#include "task-spec/op_task_to_task_invocation.h" -#include "task-spec/optimizer.h" -#include "task-spec/task_invocation.h" -#include "task-spec/task_signature_impl.h" -#include "task-spec/training_computation_graph.h" -#include "utils/containers/contains.h" -#include "utils/containers/contains_key.h" -#include "utils/containers/get_only.h" -#include "utils/containers/is_subseteq_of.h" -#include "utils/containers/keys.h" -#include "utils/containers/values.h" -#include "utils/exception.h" - -namespace FlexFlow { - -LocalTrainingBacking make_local_training_backing_for_computation_graph( - Allocator &allocator, - std::unordered_map const - &preallocated, - TrainingComputationGraph const &training_computation_graph, - RuntimeArgConfig const &runtime_arg_config, - OptimizerAttrs const &optimizer_attrs) { - - ASSERT(is_subseteq_of( - keys(preallocated), - keys(get_all_training_tensor_shapes(training_computation_graph)))); - - LocalTaskRegistry local_task_registry = - construct_local_task_registry_for_layers(get_layer_attrs_mapping( - training_computation_graph.computation_graph)); - - LocalTensorBacking local_tensor_backing = construct_local_tensor_backing( - get_all_training_tensor_shapes(training_computation_graph), - preallocated, - allocator); - - LocalArgsBacking local_args_backing = - make_local_args_backing_for_computation_graph(local_task_registry, - training_computation_graph, - runtime_arg_config, - local_tensor_backing, - allocator); - - return LocalTrainingBacking{ - /*computation_graph=*/training_computation_graph, - /*local_task_registry=*/local_task_registry, - /*local_tensor_backing=*/local_tensor_backing, - /*local_args_backing=*/local_args_backing, - }; -} - -std::optional - execute_forward(LocalTaskRegistry const &local_task_registry, - LocalTensorBacking const &local_tensor_backing, - LocalArgsBacking const &local_args_backing, - TrainingLayerPlusContext const &training_layer, - Allocator &allocator) { - - std::optional maybe_registered_task = try_get_registered_task( - local_task_registry, training_layer.layer_guid, OpTaskType::BWD); - - ASSERT(maybe_registered_task.has_value()); - - registered_task_t registered_task = maybe_registered_task.value(); - if (registered_task.is_noop_task()) { - return std::nullopt; - } - - std::optional device_state = - get_per_device_op_state_if_exists(local_args_backing, - training_layer.layer_guid); - - TaskInvocation invocation = lower_to_task_invocation( - /*op_task_invocation=*/get_forward_op_task_invocation( - training_layer.layer_attrs.op_attrs), - /*training_layer=*/training_layer, - /*device_specific_device_states=*/device_state); - - TaskArgumentAccessor accessor = - get_task_arg_accessor(local_tensor_backing, - local_args_backing.runtime_arg_config, - invocation, - allocator); - return call_task_impl(local_task_registry, invocation.task_id, accessor); -} - -void compute_loss(LocalTrainingBacking const &local_training_backing, - LossAttrs const &loss_attrs, - Allocator &allocator) { - - TrainingComputationGraph training_cg = - local_training_backing.training_computation_graph; - tensor_guid_t logit_tensor = training_cg.logit_tensor; - loss_tensor_guid_t label_tensor = training_cg.label_tensor; - - TaskInvocation loss_invocation = backward( - loss_attrs, - get_forward_tensor_guid_for_tensor_guid(training_cg, logit_tensor), - get_gradient_tensor_guid_for_tensor_guid(training_cg, logit_tensor), - label_tensor); - // TODO: https://github.com/flexflow/flexflow-train/issues/1442 - // assert(is_invocation_valid(get_loss_bwd_signature(), loss_invocation)); - TaskArgumentAccessor loss_accessor = get_task_arg_accessor( - local_training_backing.local_tensor_backing, - local_training_backing.local_args_backing.runtime_arg_config, - loss_invocation, - allocator); - TaskImplFunction loss_impl_fn = get_loss_bwd_task_impl(); - loss_impl_fn.get().function_ptr(loss_accessor); -} - -std::optional - execute_backward(LocalTaskRegistry const &local_task_registry, - LocalTensorBacking const &local_tensor_backing, - LocalArgsBacking const &local_args_backing, - TrainingLayerPlusContext const &training_layer, - Allocator &allocator) { - - std::optional maybe_registered_task = try_get_registered_task( - local_task_registry, training_layer.layer_guid, OpTaskType::BWD); - - ASSERT(maybe_registered_task.has_value()); - - registered_task_t registered_task = maybe_registered_task.value(); - if (registered_task.is_noop_task()) { - return std::nullopt; - } - - std::optional device_state = - get_per_device_op_state_if_exists(local_args_backing, - training_layer.layer_guid); - TaskInvocation invocation = lower_to_task_invocation( - get_backward_op_task_invocation(training_layer.layer_attrs.op_attrs), - training_layer, - device_state); - TaskArgumentAccessor accessor = - get_task_arg_accessor(local_tensor_backing, - local_args_backing.runtime_arg_config, - invocation, - allocator); - return call_task_impl(local_task_registry, invocation.task_id, accessor); -} - -void execute_update(LocalTrainingBacking const &local_training_backing, - layer_guid_t const &layer_guid, - OptimizerAttrs const &optimizer_attrs, - Allocator &allocator) { - TrainingLayerPlusContext training_layer = get_training_layer_plus_context( - local_training_backing.training_computation_graph, layer_guid); - - if (training_layer.layer_attrs.op_attrs.has()) { - TrainingTensorGroupWithAttrs weight_tensor_group = - get_only(training_layer.output_tensor_groups); - - TaskInvocation invocation = - get_update_invocation(optimizer_attrs, - weight_tensor_group.forward_tensor, - weight_tensor_group.gradient_tensor, - weight_tensor_group.optimizer_tensors); - - // TODO: https://github.com/flexflow/flexflow-train/issues/1442 - // assert(is_invocation_valid(get_update_signature(attrs), invocation)); - - TaskArgumentAccessor accessor = get_task_arg_accessor( - local_training_backing.local_tensor_backing, - local_training_backing.local_args_backing.runtime_arg_config, - invocation, - allocator); - TaskImplFunction update_impl_fn = get_update_task_impl(optimizer_attrs); - update_impl_fn.get().function_ptr(accessor); - } -} - -} // namespace FlexFlow diff --git a/lib/local-execution/src/local-execution/model_training_instance.cc b/lib/local-execution/src/local-execution/model_training_instance.cc deleted file mode 100644 index be2791a365..0000000000 --- a/lib/local-execution/src/local-execution/model_training_instance.cc +++ /dev/null @@ -1,85 +0,0 @@ -#include "local-execution/model_training_instance.h" -#include "pcg/computation_graph.h" -#include "pcg/optimizer_attrs.h" -#include "task-spec/training_computation_graph.h" -#include "utils/containers/reversed.h" - -namespace FlexFlow { - -ModelTrainingInstance::ModelTrainingInstance( - Allocator const &allocator, - LocalTrainingBacking const &local_training_backing, - LossAttrs const &loss_attrs, - OptimizerAttrs const &optimizer_attrs) - : allocator(allocator), training_backing(local_training_backing), - loss_attrs(loss_attrs), optimizer_attrs(optimizer_attrs) {} - -std::unordered_map> - ModelTrainingInstance::forward() { - - std::unordered_map> - per_layer_elapsed_time; - - for (layer_guid_t const &layer_guid : - topological_ordering(this->training_backing.training_computation_graph - .computation_graph)) { - std::optional elapsed_time = execute_forward( - this->training_backing.local_task_registry, - this->training_backing.local_tensor_backing, - this->training_backing.local_args_backing, - get_training_layer_plus_context( - this->training_backing.training_computation_graph, layer_guid), - this->allocator); - - per_layer_elapsed_time.insert({layer_guid, elapsed_time}); - } - - return per_layer_elapsed_time; -} - -std::unordered_map> - ModelTrainingInstance::backward() { - compute_loss(this->training_backing, this->loss_attrs, this->allocator); - - std::unordered_map> - per_layer_elapsed_time; - for (layer_guid_t const &layer_guid : reversed(topological_ordering( - this->training_backing.training_computation_graph - .computation_graph))) { - std::optional elapsed_time = execute_backward( - this->training_backing.local_task_registry, - this->training_backing.local_tensor_backing, - this->training_backing.local_args_backing, - get_training_layer_plus_context( - this->training_backing.training_computation_graph, layer_guid), - this->allocator); - per_layer_elapsed_time.insert({layer_guid, elapsed_time}); - } - return per_layer_elapsed_time; -} - -void ModelTrainingInstance::update() { - for (layer_guid_t const &layer_guid : - topological_ordering(this->training_backing.training_computation_graph - .computation_graph)) { - execute_update(this->training_backing, - layer_guid, - this->optimizer_attrs, - this->allocator); - } - this->optimizer_attrs = - get_optimizer_attrs_for_next_iter(this->optimizer_attrs); -} - -GenericTensorAccessorR ModelTrainingInstance::get_loss_tensor_accessor() const { - gradient_tensor_guid_t loss_tensor = get_gradient_tensor_guid_for_tensor_guid( - this->training_backing.training_computation_graph, - this->training_backing.training_computation_graph.logit_tensor); - GenericTensorAccessorW loss_tensor_backing = - this->training_backing.local_tensor_backing - .backing_for_training_tensor_map.at( - training_tensor_guid_t{loss_tensor}); - return read_only_accessor_from_write_accessor(loss_tensor_backing); -} - -} // namespace FlexFlow diff --git a/lib/local-execution/src/local-execution/operator_task_set.cc b/lib/local-execution/src/local-execution/operator_task_set.cc index 8dbc8791c6..a1b1d0817b 100644 --- a/lib/local-execution/src/local-execution/operator_task_set.cc +++ b/lib/local-execution/src/local-execution/operator_task_set.cc @@ -1,12 +1,12 @@ #include "local-execution/operator_task_set.h" -#include "local-execution/registered_task.h" +#include "local-execution/task_id_with_noop_default_t.h" #include "task-spec/task_signature_impl.h" #include "utils/bidict/algorithms/right_entries.h" #include "utils/containers/values.h" namespace FlexFlow { -bidict +bidict get_map_from_task_type_to_task(OperatorTaskSet const &op_task_set) { return { {OpTaskType::INIT, op_task_set.init_task}, @@ -15,21 +15,22 @@ bidict }; } -std::unordered_set +std::unordered_set get_all_tasks_in_task_set(OperatorTaskSet const &op_task_set) { return right_entries(get_map_from_task_type_to_task(op_task_set)); } -registered_task_t get_task_for_task_type(OperatorTaskSet const &op_task_set, - OpTaskType task_type) { +task_id_with_noop_default_t + get_task_for_task_type(OperatorTaskSet const &op_task_set, + OpTaskType task_type) { return get_map_from_task_type_to_task(op_task_set).at_l(task_type); } OperatorTaskSet get_task_set_for_operator(ComputationGraphOpAttrs const &attrs) { - registered_task_t init_task = make_noop_registered_task(); - registered_task_t fwd_task = make_noop_registered_task(); - registered_task_t bwd_task = make_noop_registered_task(); + task_id_with_noop_default_t init_task = make_default_noop_task(); + task_id_with_noop_default_t fwd_task = make_default_noop_task(); + task_id_with_noop_default_t bwd_task = make_default_noop_task(); std::vector task_ids = get_task_ids(attrs); @@ -37,24 +38,23 @@ OperatorTaskSet TaskSignatureAndImpl task_signature_and_impl = get_task_signature_and_impl_for_task_id(task_id); - TaskImplFunction task_impl_function = task_signature_and_impl.impl_function; OpTaskSignature task_signature = task_signature_and_impl.task_signature; switch (task_signature.type) { case OpTaskType::INIT: ASSERT(is_invocation_valid(task_signature, get_init_op_task_invocation(attrs))); - init_task = registered_task_t{task_id}; + init_task = task_id_with_noop_default_t{task_id}; break; case OpTaskType::FWD: ASSERT(is_invocation_valid(task_signature, get_forward_op_task_invocation(attrs))); - fwd_task = registered_task_t{task_id}; + fwd_task = task_id_with_noop_default_t{task_id}; break; case OpTaskType::BWD: ASSERT(is_invocation_valid(task_signature, get_backward_op_task_invocation(attrs))); - bwd_task = registered_task_t{task_id}; + bwd_task = task_id_with_noop_default_t{task_id}; break; default: PANIC("Unhandled OpTaskType", fmt::to_string(task_signature.type)); diff --git a/lib/local-execution/src/local-execution/registered_task.cc b/lib/local-execution/src/local-execution/registered_task.cc deleted file mode 100644 index 84b116273a..0000000000 --- a/lib/local-execution/src/local-execution/registered_task.cc +++ /dev/null @@ -1,9 +0,0 @@ -#include "local-execution/registered_task.h" - -namespace FlexFlow { - -registered_task_t make_noop_registered_task() { - return registered_task_t{std::monostate{}}; -} - -} // namespace FlexFlow diff --git a/lib/local-execution/src/local-execution/task_execution.cc b/lib/local-execution/src/local-execution/task_execution.cc new file mode 100644 index 0000000000..09276aa218 --- /dev/null +++ b/lib/local-execution/src/local-execution/task_execution.cc @@ -0,0 +1,24 @@ +#include "local-execution/task_execution.h" +#include "local-execution/local_task_argument_accessor.h" + +namespace FlexFlow { + +TaskArgumentAccessor make_task_argument_accessor_for_invocation( + DynamicNodeInvocation const &invocation, + Allocator &allocator, + ProfilingSettings const &profiling_settings, + DeviceType kernel_device_type, + PCGOperatorAttrs op_attrs, + std::optional const &loss_attrs, + std::optional const &per_device_op_state, + FFIterationConfig iteration_config, + std::optional const &optimizer_attrs) { + std::unordered_map < + + return TaskArgumentAccessor::create( + /*allocator=*/allocator, + /*tensor_slots_backing=*/ + ); +} + +} // namespace FlexFlow diff --git a/lib/local-execution/src/local-execution/task_id_with_noop_default_t.cc b/lib/local-execution/src/local-execution/task_id_with_noop_default_t.cc new file mode 100644 index 0000000000..15b2fe786b --- /dev/null +++ b/lib/local-execution/src/local-execution/task_id_with_noop_default_t.cc @@ -0,0 +1,9 @@ +#include "local-execution/task_id_with_noop_default_t.h" + +namespace FlexFlow { + +task_id_with_noop_default_t make_noop_registered_task() { + return task_id_with_noop_default_t{std::monostate{}}; +} + +} // namespace FlexFlow diff --git a/lib/local-execution/src/local-execution/tensor_allocation.cc b/lib/local-execution/src/local-execution/tensor_allocation.cc new file mode 100644 index 0000000000..16d6712616 --- /dev/null +++ b/lib/local-execution/src/local-execution/tensor_allocation.cc @@ -0,0 +1,85 @@ +#include "local-execution/tensor_allocation.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" +#include "utils/bidict/generate_bidict.h" +#include "utils/containers/all_are_true.h" +#include "utils/containers/contains_key.h" +#include "utils/containers/map_values.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/optional.h" + +namespace FlexFlow { + +bool no_tensors_are_allocated(DynamicOpenDataflowGraph const &g) { + return all_are_true( + transform(get_dynamic_values(g), [](DynamicValueAttrs const &v) -> bool { + return !v.accessor.has_value(); + })); +} + +bool all_tensors_are_allocated(DynamicOpenDataflowGraph const &g) { + return all_are_true( + transform(get_dynamic_values(g), [](DynamicValueAttrs const &v) -> bool { + return v.accessor.has_value(); + })); +} + +DynamicValueAttrs + perform_tensor_allocation_for_value(DynamicValueAttrs const &value, + Allocator &allocator) { + ASSERT(value.accessor == std::nullopt); + + TensorShape shape = + get_piece_shape(assert_unwrap(value.parallel_tensor_shape)); + + GenericTensorAccessorW accessor = allocator.allocate_tensor(shape); + + DynamicValueAttrs result = value; + result.accessor = accessor; + + return result; +} + +DynamicOpenDataflowGraph perform_tensor_allocation( + DynamicOpenDataflowGraph const &g, + std::unordered_map const + &preallocated, + Allocator &allocator) { + for (DynamicValueAttrs const &v : keys(preallocated)) { + ASSERT(v.accessor == std::nullopt); + } + + std::unordered_set all_values = + unordered_set_of(get_dynamic_values(g)); + + bidict unallocated_to_allocated = + generate_bidict( + all_values, [&](DynamicValueAttrs const &v) -> DynamicValueAttrs { + if (contains_key(preallocated, v)) { + DynamicValueAttrs result = v; + result.accessor = preallocated.at(v); + return result; + } else { + return perform_tensor_allocation_for_value(v, allocator); + } + }); + + return transform_dynamic_invocation_set( + g, [&](DynamicNodeInvocation const &i) -> DynamicNodeInvocation { + return DynamicNodeInvocation{ + /*inputs=*/map_values( + i.inputs, + [&](DynamicValueAttrs const &v) -> DynamicValueAttrs { + return unallocated_to_allocated.at_l(v); + }), + /*node_attrs=*/i.node_attrs, + /*outputs=*/ + map_values(i.outputs, + [&](DynamicValueAttrs const &v) -> DynamicValueAttrs { + return unallocated_to_allocated.at_l(v); + }), + }; + }); +} + +} // namespace FlexFlow diff --git a/lib/local-execution/src/local_task_argument_accessor.cc b/lib/local-execution/src/local_task_argument_accessor.cc deleted file mode 100644 index 207305a8db..0000000000 --- a/lib/local-execution/src/local_task_argument_accessor.cc +++ /dev/null @@ -1,68 +0,0 @@ -#include "local-execution/local_task_argument_accessor.h" -#include "utils/containers/contains_key.h" -#include "utils/containers/transform.h" -#include "utils/hash/pair.h" -#include "utils/overload.h" - -namespace FlexFlow { - -LocalTaskArgumentAccessor::LocalTaskArgumentAccessor( - Allocator const &allocator, - std::unordered_map const - &tensor_slots_backing, - std::unordered_map const &arg_slots_backing) - : allocator(allocator), tensor_slots_backing(tensor_slots_backing), - arg_slots_backing(arg_slots_backing){}; - -ConcreteArgSpec const & - LocalTaskArgumentAccessor::get_concrete_arg(slot_id_t name) const { - return this->arg_slots_backing.at(name); -} - -GenericTensorAccessor LocalTaskArgumentAccessor::get_tensor( - slot_id_t slot, Permissions priv, TensorType tensor_type) const { - tensor_sub_slot_id_t slot_tensor_type = - tensor_sub_slot_id_t{slot, tensor_type}; - GenericTensorAccessorW tensor_backing = - this->tensor_slots_backing.at(slot_tensor_type).require_single(); - if (priv == Permissions::RO) { - GenericTensorAccessorR readonly_tensor_backing = - read_only_accessor_from_write_accessor(tensor_backing); - return readonly_tensor_backing; - } else if (priv == Permissions::RW || priv == Permissions::WO) { - return tensor_backing; - } else { - PANIC(fmt::format("Unhandled privilege mode {}", priv)); - } -} - -VariadicGenericTensorAccessor LocalTaskArgumentAccessor::get_variadic_tensor( - slot_id_t slot, Permissions priv, TensorType tensor_type) const { - tensor_sub_slot_id_t slot_tensor_type = - tensor_sub_slot_id_t{slot, tensor_type}; - std::vector variadic_tensor_backing = - this->tensor_slots_backing.at(slot_tensor_type).require_variadic(); - if (priv == Permissions::RO) { - std::vector readonly_variadic_tensor_backing = {}; - for (GenericTensorAccessorW const &tensor_backing : - variadic_tensor_backing) { - readonly_variadic_tensor_backing.push_back( - read_only_accessor_from_write_accessor(tensor_backing)); - } - return readonly_variadic_tensor_backing; - } else if (priv == Permissions::RW || priv == Permissions::WO) { - return variadic_tensor_backing; - } else { - PANIC(fmt::format("Unhandled privilege mode {}", priv)); - } -} - -Allocator LocalTaskArgumentAccessor::get_allocator() const { - return this->allocator; -} - -size_t LocalTaskArgumentAccessor::get_device_idx() const { - return 0; -} - -} // namespace FlexFlow diff --git a/lib/local-execution/src/per_device_op_state.cc b/lib/local-execution/src/per_device_op_state.cc deleted file mode 100644 index a959f4a8c9..0000000000 --- a/lib/local-execution/src/per_device_op_state.cc +++ /dev/null @@ -1,12 +0,0 @@ -#include "task-spec/per_device_op_state.h" -#include "utils/overload.h" - -namespace FlexFlow { - -PerDeviceOpState get_device_state_from_device_specific( - DeviceSpecificDeviceStates const &device_specific, size_t device_idx) { - return device_specific.visit( - [&](auto const &x) { return PerDeviceOpState{*(x.get(device_idx))}; }); -} - -} // namespace FlexFlow diff --git a/lib/local-execution/src/task_binding.cc b/lib/local-execution/src/task_binding.cc deleted file mode 100644 index 3894fb8d34..0000000000 --- a/lib/local-execution/src/task_binding.cc +++ /dev/null @@ -1,111 +0,0 @@ -#include "task-spec/task_binding.h" -#include "pcg/tensor_guid_t.dtg.h" -#include "task-spec/training_tensor_guid_t.dtg.h" -#include "utils/containers/contains_key.h" -#include "utils/fmt/unordered_map.h" -#include "utils/hash/tuple.h" -#include "utils/hash/unordered_map.h" - -namespace FlexFlow { - -TaskBinding::TaskBinding() : tensor_bindings(), arg_bindings() {} - -TaskBinding::TaskBinding( - std::unordered_map const - &tensor_bindings, - std::unordered_map const &arg_bindings) - : tensor_bindings(tensor_bindings), arg_bindings(arg_bindings) {} - -void TaskBinding::bind(int name, forward_tensor_guid_t const &binding) { - this->bind(slot_id_t{name}, binding); -} - -void TaskBinding::bind(slot_id_t name, forward_tensor_guid_t const &binding) { - this->tensor_bindings.insert({tensor_sub_slot_id_t{name, TensorType::FORWARD}, - training_tensor_guid_t{binding}}); -} - -void TaskBinding::bind_grad(int name, gradient_tensor_guid_t const &binding) { - this->bind_grad(slot_id_t{name}, binding); -} - -void TaskBinding::bind_grad(slot_id_t name, - gradient_tensor_guid_t const &binding) { - this->tensor_bindings.insert( - {tensor_sub_slot_id_t{name, TensorType::GRADIENT}, - training_tensor_guid_t{binding}}); -} - -void TaskBinding::bind_optimizer(int name, - optimizer_tensor_guid_t const &binding) { - this->bind_optimizer(slot_id_t{name}, binding); -} - -void TaskBinding::bind_optimizer(slot_id_t name, - optimizer_tensor_guid_t const &binding) { - this->tensor_bindings.insert( - {tensor_sub_slot_id_t{name, TensorType::OPTIMIZER}, - training_tensor_guid_t{binding}}); -} - -void TaskBinding::bind_loss(int name, loss_tensor_guid_t const &binding) { - this->bind_loss(slot_id_t{name}, binding); -} - -void TaskBinding::bind_loss(slot_id_t name, loss_tensor_guid_t const &binding) { - this->tensor_bindings.insert({tensor_sub_slot_id_t{name, TensorType::LOSS}, - training_tensor_guid_t{binding}}); -} - -void TaskBinding::insert_arg_spec(slot_id_t name, TaskArgSpec const &arg_spec) { - assert(!contains_key(this->arg_bindings, name)); - this->arg_bindings.insert({name, arg_spec}); -} - -bool TaskBinding::operator==(TaskBinding const &other) const { - return this->tie() == other.tie(); -} - -bool TaskBinding::operator!=(TaskBinding const &other) const { - return this->tie() != other.tie(); -} - -std::tuple< - std::unordered_map const &, - std::unordered_map const &> - TaskBinding::tie() const { - return std::tie(this->tensor_bindings, this->arg_bindings); -} - -std::unordered_map const & - TaskBinding::get_tensor_bindings() const { - return this->tensor_bindings; -} - -std::unordered_map const & - TaskBinding::get_arg_bindings() const { - return this->arg_bindings; -} - -std::string format_as(TaskBinding const &x) { - std::ostringstream oss; - oss << " using namespace ::FlexFlow; diff --git a/lib/local-execution/test/src/local-execution/local_task_argument_accessor.cc b/lib/local-execution/test/src/local-execution/local_task_argument_accessor.cc index 482795b278..ff90abcde7 100644 --- a/lib/local-execution/test/src/local-execution/local_task_argument_accessor.cc +++ b/lib/local-execution/test/src/local-execution/local_task_argument_accessor.cc @@ -1,8 +1,8 @@ #include "local-execution/local_task_argument_accessor.h" -#include "doctest/doctest.h" #include "kernels/local_cpu_allocator.h" #include "task-spec/task_signature_impl.h" #include "utils/fmt/variant.h" +#include using namespace ::FlexFlow; @@ -36,24 +36,26 @@ TEST_SUITE(FF_TEST_SUITE) { VARIADIC_TENSORS, }; - std::unordered_map + std::unordered_map tensor_slots_backing = { { - tensor_sub_slot_id_t{slot_id_t{INPUT}, TensorType::FORWARD}, + training_tensor_slot_id_t{TensorSlotName::LHS_INPUT, + TrainingTensorType::FORWARD}, TensorSlotBacking{input}, }, { - tensor_sub_slot_id_t{slot_id_t{INPUT}, TensorType::GRADIENT}, + training_tensor_slot_id_t{TensorSlotName::LHS_INPUT, + TrainingTensorType::GRADIENT}, TensorSlotBacking{input_grad}, }, { - tensor_sub_slot_id_t{slot_id_t{VARIADIC_TENSORS}, - TensorType::FORWARD}, + training_tensor_slot_id_t{TensorSlotName::INPUT, + TrainingTensorType::FORWARD}, TensorSlotBacking{variadic_tensors}, }, { - tensor_sub_slot_id_t{slot_id_t{VARIADIC_TENSORS}, - TensorType::GRADIENT}, + training_tensor_slot_id_t{TensorSlotName::INPUT, + TrainingTensorType::GRADIENT}, TensorSlotBacking{variadic_tensors_grad}, }, }; @@ -62,113 +64,144 @@ TEST_SUITE(FF_TEST_SUITE) { /*allocator=*/allocator, /*tensor_slots_backing=*/tensor_slots_backing, /*arg_slots_backing=*/{}, + /*device_idx=*/0, }; SUBCASE("get_tensor") { - SUBCASE("get_tensor(slot_id_t, Permissions::RO, TensorType::FORWARD)") { + SUBCASE("get_tensor(TensorSlotName, Permissions::RO, " + "TrainingTensorType::FORWARD)") { GenericTensorAccessor correct = GenericTensorAccessor{ read_only_accessor_from_write_accessor(input)}; - GenericTensorAccessor result = acc.get_tensor( - slot_id_t{INPUT}, Permissions::RO, TensorType::FORWARD); + GenericTensorAccessor result = + acc.get_tensor(TensorSlotName::LHS_INPUT, + Permissions::RO, + TrainingTensorType::FORWARD); CHECK(correct == result); } - SUBCASE("get_tensor(slot_id_t, Permissions::RO, TensorType::GRADIENT)") { + SUBCASE("get_tensor(TensorSlotName, Permissions::RO, " + "TrainingTensorType::GRADIENT)") { GenericTensorAccessor correct = GenericTensorAccessor{ read_only_accessor_from_write_accessor(input_grad)}; - GenericTensorAccessor result = acc.get_tensor( - slot_id_t{INPUT}, Permissions::RO, TensorType::GRADIENT); + GenericTensorAccessor result = + acc.get_tensor(TensorSlotName::LHS_INPUT, + Permissions::RO, + TrainingTensorType::GRADIENT); CHECK(correct == result); } - SUBCASE("get_tensor(slot_id_t, Permissions::WO, TensorType::FORWARD)") { + SUBCASE("get_tensor(TensorSlotName, Permissions::WO, " + "TrainingTensorType::FORWARD)") { GenericTensorAccessor correct = GenericTensorAccessor{input}; - GenericTensorAccessor result = acc.get_tensor( - slot_id_t{INPUT}, Permissions::WO, TensorType::FORWARD); + GenericTensorAccessor result = + acc.get_tensor(TensorSlotName::LHS_INPUT, + Permissions::WO, + TrainingTensorType::FORWARD); CHECK(correct == result); } - SUBCASE("get_tensor(slot_id_t, Permissions::WO, TensorType::GRADIENT)") { + SUBCASE("get_tensor(TensorSlotName, Permissions::WO, " + "TrainingTensorType::GRADIENT)") { GenericTensorAccessor correct = GenericTensorAccessor{input_grad}; - GenericTensorAccessor result = acc.get_tensor( - slot_id_t{INPUT}, Permissions::WO, TensorType::GRADIENT); + GenericTensorAccessor result = + acc.get_tensor(TensorSlotName::LHS_INPUT, + Permissions::WO, + TrainingTensorType::GRADIENT); CHECK(correct == result); } - SUBCASE("get_tensor(slot_id_t, Permissions::RW, TensorType::FORWARD)") { + SUBCASE("get_tensor(TensorSlotName, Permissions::RW, " + "TrainingTensorType::FORWARD)") { GenericTensorAccessor correct = GenericTensorAccessor{input}; - GenericTensorAccessor result = acc.get_tensor( - slot_id_t{INPUT}, Permissions::RW, TensorType::FORWARD); + GenericTensorAccessor result = + acc.get_tensor(TensorSlotName::LHS_INPUT, + Permissions::RW, + TrainingTensorType::FORWARD); CHECK(correct == result); } - SUBCASE("get_tensor(slot_id_t, Permissions::RW, TensorType::GRADIENT)") { + SUBCASE("get_tensor(TensorSlotName, Permissions::RW, " + "TrainingTensorType::GRADIENT)") { GenericTensorAccessor correct = GenericTensorAccessor{input_grad}; - GenericTensorAccessor result = acc.get_tensor( - slot_id_t{INPUT}, Permissions::RW, TensorType::GRADIENT); + GenericTensorAccessor result = + acc.get_tensor(TensorSlotName::LHS_INPUT, + Permissions::RW, + TrainingTensorType::GRADIENT); CHECK(correct == result); } } SUBCASE("get_variadic_tensor") { - SUBCASE("get_variadic_tensor(slot_id_t, Permissions::RO, " - "TensorType::FORWARD)") { + SUBCASE("get_variadic_tensor(TensorSlotName, Permissions::RO, " + "TrainingTensorType::FORWARD)") { VariadicGenericTensorAccessor correct = VariadicGenericTensorAccessor{std::vector{ read_only_accessor_from_write_accessor(variadic_tensors.at(0)), read_only_accessor_from_write_accessor( variadic_tensors.at(1))}}; - VariadicGenericTensorAccessor result = acc.get_variadic_tensor( - slot_id_t{VARIADIC_TENSORS}, Permissions::RO, TensorType::FORWARD); + VariadicGenericTensorAccessor result = + acc.get_variadic_tensor(TensorSlotName::INPUT, + Permissions::RO, + TrainingTensorType::FORWARD); CHECK(result == correct); } - SUBCASE("get_variadic_tensor(slot_id_t, Permissions::RO, " - "TensorType::GRADIENT)") { + SUBCASE("get_variadic_tensor(TensorSlotName, Permissions::RO, " + "TrainingTensorType::GRADIENT)") { VariadicGenericTensorAccessor correct = VariadicGenericTensorAccessor{std::vector{ read_only_accessor_from_write_accessor( variadic_tensors_grad.at(0)), read_only_accessor_from_write_accessor( variadic_tensors_grad.at(1))}}; - VariadicGenericTensorAccessor result = acc.get_variadic_tensor( - slot_id_t{VARIADIC_TENSORS}, Permissions::RO, TensorType::GRADIENT); + VariadicGenericTensorAccessor result = + acc.get_variadic_tensor(TensorSlotName::INPUT, + Permissions::RO, + TrainingTensorType::GRADIENT); CHECK(result == correct); } - SUBCASE("get_variadic_tensor(slot_id_t, Permissions::WO, " - "TensorType::FORWARD)") { + SUBCASE("get_variadic_tensor(TensorSlotName, Permissions::WO, " + "TrainingTensorType::FORWARD)") { VariadicGenericTensorAccessor correct = VariadicGenericTensorAccessor{variadic_tensors}; - VariadicGenericTensorAccessor result = acc.get_variadic_tensor( - slot_id_t{VARIADIC_TENSORS}, Permissions::WO, TensorType::FORWARD); + VariadicGenericTensorAccessor result = + acc.get_variadic_tensor(TensorSlotName::INPUT, + Permissions::WO, + TrainingTensorType::FORWARD); CHECK(result == correct); } - SUBCASE("get_variadic_tensor(slot_id_t, Permissions::WO, " - "TensorType::GRADIENT)") { + SUBCASE("get_variadic_tensor(TensorSlotName, Permissions::WO, " + "TrainingTensorType::GRADIENT)") { VariadicGenericTensorAccessor correct = VariadicGenericTensorAccessor{variadic_tensors_grad}; - VariadicGenericTensorAccessor result = acc.get_variadic_tensor( - slot_id_t{VARIADIC_TENSORS}, Permissions::WO, TensorType::GRADIENT); + VariadicGenericTensorAccessor result = + acc.get_variadic_tensor(TensorSlotName::INPUT, + Permissions::WO, + TrainingTensorType::GRADIENT); CHECK(result == correct); } - SUBCASE("get_variadic_tensor(slot_id_t, Permissions::WO, " - "TensorType::FORWARD)") { + SUBCASE("get_variadic_tensor(TensorSlotName, Permissions::WO, " + "TrainingTensorType::FORWARD)") { VariadicGenericTensorAccessor correct = VariadicGenericTensorAccessor{variadic_tensors}; - VariadicGenericTensorAccessor result = acc.get_variadic_tensor( - slot_id_t{VARIADIC_TENSORS}, Permissions::RW, TensorType::FORWARD); + VariadicGenericTensorAccessor result = + acc.get_variadic_tensor(TensorSlotName::INPUT, + Permissions::RW, + TrainingTensorType::FORWARD); CHECK(result == correct); } - SUBCASE("get_variadic_tensor(slot_id_t, Permissions::WO, " - "TensorType::GRADIENT)") { + SUBCASE("get_variadic_tensor(TensorSlotName, Permissions::WO, " + "TrainingTensorType::GRADIENT)") { VariadicGenericTensorAccessor correct = VariadicGenericTensorAccessor{variadic_tensors_grad}; - VariadicGenericTensorAccessor result = acc.get_variadic_tensor( - slot_id_t{VARIADIC_TENSORS}, Permissions::RW, TensorType::GRADIENT); + VariadicGenericTensorAccessor result = + acc.get_variadic_tensor(TensorSlotName::INPUT, + Permissions::RW, + TrainingTensorType::GRADIENT); CHECK(result == correct); } } diff --git a/lib/local-execution/test/src/local-execution/local_task_registry.cc b/lib/local-execution/test/src/local-execution/local_task_registry.cc index 27cd74b2a6..5dc66c8ebc 100644 --- a/lib/local-execution/test/src/local-execution/local_task_registry.cc +++ b/lib/local-execution/test/src/local-execution/local_task_registry.cc @@ -206,8 +206,12 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("Partial task does not exist") { ComputationGraphOpAttrs bmm_attrs = ComputationGraphOpAttrs{ - BatchMatmulAttrs{/*a_seq_length_dim=*/10_n, - /*b_seq_length_dim=*/20_n}}; + BatchMatmulAttrs{ + /*a_seq_length_dim=*/10_p, + /*b_seq_length_dim=*/20_p, + }, + }; + LocalTaskRegistry task_registry = construct_local_task_registry_for_layers({ {layer_guid, LayerAttrs{bmm_attrs, std::nullopt}}, diff --git a/lib/local-execution/test/src/local-execution/local_tensor_backing.cc b/lib/local-execution/test/src/local-execution/local_tensor_backing.cc deleted file mode 100644 index 2f5bf493d6..0000000000 --- a/lib/local-execution/test/src/local-execution/local_tensor_backing.cc +++ /dev/null @@ -1,285 +0,0 @@ -#include "local-execution/local_tensor_backing.h" -#include "internal/test_utils.h" -#include "kernels/local_cpu_allocator.h" -#include "task-spec/gradient_tensor_source.h" -#include "task-spec/loss_tensor_source.h" -#include "task-spec/optimizer_tensor_source.h" -#include "test/utils/doctest/check_kv.h" -#include "test/utils/doctest/fmt/unordered_map.h" -#include "utils/containers/keys.h" -#include - -using namespace ::FlexFlow; - -bool is_shape_and_dtype_equal_for_tensor_backings( - LocalTensorBacking const &b1, LocalTensorBacking const &b2) { - - std::unordered_map m1 = - b1.backing_for_training_tensor_map; - std::unordered_map m2 = - b2.backing_for_training_tensor_map; - - if (keys(m1) == keys(m2)) { - for (std::pair const - &tensor_type_backing : m1) { - if (tensor_type_backing.second.shape == - m2.at(tensor_type_backing.first).shape) { - continue; - } else { - return false; - } - } - return true; - } else { - return false; - } -} - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("construct_local_tensor_backing") { - Allocator allocator = create_local_cpu_memory_allocator(); - - training_tensor_guid_t t1 = - training_tensor_guid_t{forward_tensor_guid_t{4}}; - training_tensor_guid_t t2 = - training_tensor_guid_t{gradient_tensor_guid_t{4}}; - training_tensor_guid_t t3 = - training_tensor_guid_t{gradient_tensor_guid_t{5}}; - training_tensor_guid_t t4 = - training_tensor_guid_t{gradient_tensor_guid_t{6}}; - - TensorShape tensor_shape_1 = TensorShape{ - TensorDims{FFOrdered{ - 4_p, - 5_p, - }}, - DataType::FLOAT, - }; - - TensorShape tensor_shape_2 = TensorShape{ - TensorDims{FFOrdered{ - 4_p, - 5_p, - }}, - DataType::FLOAT, - }; - - std::unordered_map - training_tensor_shapes = { - {t1, tensor_shape_1}, - {t2, tensor_shape_2}, - {t3, tensor_shape_1}, - }; - - GenericTensorAccessorW t3_accessor = - allocator.allocate_tensor(tensor_shape_2); - SUBCASE("allocates all non-preallocated tensors and does not re-allocate " - "the preallocated ones") { - std::unordered_map - preallocated_tensors = { - {t3, t3_accessor}, - }; - - LocalTensorBacking result = construct_local_tensor_backing( - /*training_tensor_shapes=*/training_tensor_shapes, - /*preallocated_tensors=*/preallocated_tensors, - /*allocator=*/allocator); - LocalTensorBacking correct = LocalTensorBacking{ - /*backing_for_training_tensor_map=*/{ - {t3, t3_accessor}, - {t1, allocator.allocate_tensor(tensor_shape_1)}, - {t2, allocator.allocate_tensor(tensor_shape_2)}, - }, - }; - - CHECK_MESSAGE( - is_shape_and_dtype_equal_for_tensor_backings(result, correct), - check_kv("result", fmt::to_string(result)), - check_kv("correct", fmt::to_string(correct))); - - CHECK(get_accessor_for_training_tensor(result, t3) == t3_accessor); - } - - SUBCASE("fails if a preallocated tensor is not in training_tensor_shapes") { - std::unordered_map - preallocated_tensors = { - {t4, t3_accessor}, - }; - - CHECK_THROWS(construct_local_tensor_backing( - /*training_tensor_shapes=*/training_tensor_shapes, - /*preallocated_tensors=*/preallocated_tensors, - /*allocator=*/allocator)); - } - } - - TEST_CASE("get_accessor_for_training_tensor") { - Allocator allocator = create_local_cpu_memory_allocator(); - - TensorShape tensor_shape = TensorShape{ - TensorDims{FFOrdered{ - 4_p, - 5_p, - }}, - DataType::FLOAT, - }; - - training_tensor_guid_t t1 = - training_tensor_guid_t{forward_tensor_guid_t{4}}; - training_tensor_guid_t t2 = - training_tensor_guid_t{gradient_tensor_guid_t{4}}; - - GenericTensorAccessorW t1_accessor = - allocator.allocate_tensor(tensor_shape); - GenericTensorAccessorW t2_accessor = - allocator.allocate_tensor(tensor_shape); - - LocalTensorBacking local_tensor_backing = LocalTensorBacking{ - /*backing_for_training_tensor_map=*/{{ - t1, - t1_accessor, - }, - { - t2, - t2_accessor, - }}, - }; - - SUBCASE("returns corresponding accessor if training tensor is present") { - GenericTensorAccessorW result = - get_accessor_for_training_tensor(local_tensor_backing, t1); - GenericTensorAccessorW correct = t1_accessor; - - CHECK(result == correct); - } - - SUBCASE("fails if the training tensor is not present") { - training_tensor_guid_t t3 = - training_tensor_guid_t{optimizer_tensor_guid_t{4}}; - training_tensor_guid_t t4 = - training_tensor_guid_t{forward_tensor_guid_t{3}}; - - CHECK_THROWS(get_accessor_for_training_tensor(local_tensor_backing, t3)); - CHECK_THROWS(get_accessor_for_training_tensor(local_tensor_backing, t4)); - } - } - - TEST_CASE("construct_tensor_slots_backing_for_binding") { - enum Slots { - TENSOR_SLOT_1, - TENSOR_SLOT_2, - TENSOR_SLOT_3, - ARG_SLOT, - }; - - Allocator allocator = create_local_cpu_memory_allocator(); - - TensorShape tensor_shape = TensorShape{ - TensorDims{FFOrdered{ - 4_p, - 5_p, - }}, - DataType::FLOAT, - }; - - training_tensor_guid_t t1 = - training_tensor_guid_t{forward_tensor_guid_t{4}}; - training_tensor_guid_t t2 = - training_tensor_guid_t{forward_tensor_guid_t{5}}; - training_tensor_guid_t t3 = - training_tensor_guid_t{forward_tensor_guid_t{6}}; - training_tensor_guid_t t4 = - training_tensor_guid_t{gradient_tensor_guid_t{5}}; - - GenericTensorAccessorW t1_accessor = - allocator.allocate_tensor(tensor_shape); - GenericTensorAccessorW t2_accessor = - allocator.allocate_tensor(tensor_shape); - GenericTensorAccessorW t3_accessor = - allocator.allocate_tensor(tensor_shape); - GenericTensorAccessorW t4_accessor = - allocator.allocate_tensor(tensor_shape); - - tensor_sub_slot_id_t tensor_slot_1_forward = tensor_sub_slot_id_t{ - slot_id_t{TENSOR_SLOT_1}, - TensorType::FORWARD, - }; - tensor_sub_slot_id_t tensor_slot_1_gradient = tensor_sub_slot_id_t{ - slot_id_t{TENSOR_SLOT_1}, - TensorType::GRADIENT, - }; - tensor_sub_slot_id_t tensor_slot_2_forward = tensor_sub_slot_id_t{ - slot_id_t{TENSOR_SLOT_2}, - TensorType::FORWARD, - }; - tensor_sub_slot_id_t tensor_slot_3_forward = tensor_sub_slot_id_t{ - slot_id_t{TENSOR_SLOT_3}, - TensorType::FORWARD, - }; - - LocalTensorBacking local_tensor_backing = LocalTensorBacking{ - /*backing_for_training_tensor_map=*/{{ - t1, - t1_accessor, - }, - { - t2, - t2_accessor, - }, - { - t3, - t3_accessor, - }, - { - t4, - t4_accessor, - }}, - }; - - TaskBinding task_binding = TaskBinding{ - /*tensor_bindings=*/{ - { - tensor_slot_1_forward, - t1, - }, - { - tensor_slot_2_forward, - t2, - }, - { - tensor_slot_1_gradient, - t4, - }, - }, - /*arg_bindings=*/ - { - { - slot_id_t{ARG_SLOT}, - TaskArgSpec{ - ConcreteArgSpec::create(4), - }, - }, - }, - }; - - std::unordered_map result = - construct_tensor_slots_backing_for_binding(local_tensor_backing, - task_binding); - std::unordered_map correct = { - { - tensor_slot_1_forward, - TensorSlotBacking{t1_accessor}, - }, - { - tensor_slot_2_forward, - TensorSlotBacking{t2_accessor}, - }, - { - tensor_slot_1_gradient, - TensorSlotBacking{t4_accessor}, - }, - }; - - CHECK(result == correct); - } -} diff --git a/lib/local-execution/test/src/local-execution/local_training_backing.cc b/lib/local-execution/test/src/local-execution/local_training_backing.cc index 5436dbdbb7..393cfab9dc 100644 --- a/lib/local-execution/test/src/local-execution/local_training_backing.cc +++ b/lib/local-execution/test/src/local-execution/local_training_backing.cc @@ -9,7 +9,7 @@ #include "task-spec/forward_tensor_source.h" #include "task-spec/gradient_tensor_source.h" #include "task-spec/optimizer_tensor_source.h" -#include "task-spec/runtime_arg_config.h" +#include "task-spec/runtime_task_invocation/runtime_arg_config.h" #include "task-spec/training_computation_graph.h" #include "utils/containers/get_only.h" #include diff --git a/lib/local-execution/test/src/local-execution/loss_functions.cc b/lib/local-execution/test/src/local-execution/loss_functions.cc index e5fffb980c..939bcec43d 100644 --- a/lib/local-execution/test/src/local-execution/loss_functions.cc +++ b/lib/local-execution/test/src/local-execution/loss_functions.cc @@ -1,4 +1,3 @@ -#include "doctest/doctest.h" #include "internal/test_utils.h" #include "kernels/local_cuda_allocator.h" #include "kernels/managed_ff_stream.h" @@ -12,9 +11,10 @@ #include "task-spec/gradient_tensor_source.h" #include "task-spec/loss_tensor_source.h" #include "task-spec/optimizer_tensor_source.h" -#include "task-spec/runtime_arg_config.h" +#include "task-spec/runtime_task_invocation/runtime_arg_config.h" #include "task-spec/training_computation_graph.h" #include "utils/containers/get_only.h" +#include using namespace ::FlexFlow; diff --git a/lib/local-execution/test/src/local-execution/tensor_allocation.cc b/lib/local-execution/test/src/local-execution/tensor_allocation.cc new file mode 100644 index 0000000000..e2c2869700 --- /dev/null +++ b/lib/local-execution/test/src/local-execution/tensor_allocation.cc @@ -0,0 +1,10 @@ +#include "local-execution/tensor_allocation.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("perform_tensor_allocation") { + CHECK_MESSAGE(false, "TODO: perform_tensor_allocation"); + } +} diff --git a/lib/local-execution/test/src/test_e2e.cc b/lib/local-execution/test/src/local-execution/test_e2e.cc similarity index 99% rename from lib/local-execution/test/src/test_e2e.cc rename to lib/local-execution/test/src/local-execution/test_e2e.cc index f8d34fc5ff..bc70195eef 100644 --- a/lib/local-execution/test/src/test_e2e.cc +++ b/lib/local-execution/test/src/local-execution/test_e2e.cc @@ -17,7 +17,7 @@ #include "task-spec/gradient_tensor_source.h" #include "task-spec/loss_tensor_source.h" #include "task-spec/optimizer_tensor_source.h" -#include "task-spec/runtime_arg_config.h" +#include "task-spec/runtime_task_invocation/runtime_arg_config.h" #include "task-spec/training_computation_graph.h" #include "test/utils/doctest/check_kv.h" #include "utils/containers/get_only.h" diff --git a/lib/local-pcg-execution/CMakeLists.txt b/lib/local-pcg-execution/CMakeLists.txt new file mode 100644 index 0000000000..5fadff777b --- /dev/null +++ b/lib/local-pcg-execution/CMakeLists.txt @@ -0,0 +1,21 @@ +ff_add_library( + NAME + local-pcg-execution + SRC_PATTERNS + src/*.cc + PUBLIC_INCLUDE + include/ + PRIVATE_INCLUDE + src/ + DEPS + op-attrs + utils + kernels + task-spec + local-execution + pcg + spdlog + compiler +) + +add_subdirectory(test) diff --git a/lib/local-pcg-execution/include/local-pcg-execution/execute_tasks_for_parallel_layer.h b/lib/local-pcg-execution/include/local-pcg-execution/execute_tasks_for_parallel_layer.h new file mode 100644 index 0000000000..e6c5945c77 --- /dev/null +++ b/lib/local-pcg-execution/include/local-pcg-execution/execute_tasks_for_parallel_layer.h @@ -0,0 +1,59 @@ +#ifndef _FLEXFLOW_LIB_LOCAL_PCG_EXECUTION_INCLUDE_LOCAL_PCG_EXECUTION_EXECUTE_TASKS_FOR_PARALLEL_LAYER_H +#define _FLEXFLOW_LIB_LOCAL_PCG_EXECUTION_INCLUDE_LOCAL_PCG_EXECUTION_EXECUTE_TASKS_FOR_PARALLEL_LAYER_H + +#include "compiler/mapped_operator_task_group.h" +#include "local-execution/local_atomic_tensor_backing.dtg.h" +#include "local-execution/local_ready_to_launch_task.dtg.h" +#include "local-execution/local_task_registry.dtg.h" +#include "local-pcg-execution/local_parallel_tensor_backing.dtg.h" +#include "local-pcg-execution/mapped_per_device_op_states_group.h" +#include "local-pcg-execution/mapped_runtime_task_group.h" +#include "local-pcg-execution/task_group_execution_times.dtg.h" +#include "task-spec/runtime_task_invocation/runtime_arg_config.dtg.h" +#include "task-spec/runtime_task_invocation/runtime_task_invocation.dtg.h" +#include "task-spec/symbolic/training_symbolic_computation_graph.dtg.h" + +namespace FlexFlow { + +std::unordered_map + prepare_parallel_runtime_task_invocations( + RuntimeTaskInvocation const &, + LocalParallelTensorBacking const &, + LocalAtomicTensorBacking const &, + Allocator &, + RuntimeArgConfig const &, + MappedRuntimeTaskGroup const &); + +std::optional + execute_init_for_parallel_layer(symbolic_layer_guid_t, + TrainingSymbolicComputationGraph const &, + LocalParallelTensorBacking const &, + LocalAtomicTensorBacking const &, + Allocator &, + LocalTaskRegistry const &, + RuntimeArgConfig const &, + MappedRuntimeTaskGroup const &); + +std::optional + execute_forward_for_parallel_layer(symbolic_layer_guid_t, + TrainingSymbolicComputationGraph const &, + LocalParallelTensorBacking const &, + LocalAtomicTensorBacking const &, + Allocator &, + LocalTaskRegistry const &, + RuntimeArgConfig const &, + MappedRuntimeTaskGroup const &); + +std::optional + execute_forward_for_parallel_layer(symbolic_layer_guid_t, + TrainingSymbolicComputationGraph const &, + LocalParallelTensorBacking const &, + LocalAtomicTensorBacking const &, + Allocator &, + LocalTaskRegistry const &, + RuntimeArgConfig const &, + MappedRuntimeTaskGroup const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/local-pcg-execution/include/local-pcg-execution/local_parallel_tensor_backing.dtg.toml b/lib/local-pcg-execution/include/local-pcg-execution/local_parallel_tensor_backing.dtg.toml new file mode 100644 index 0000000000..257e5ad4c0 --- /dev/null +++ b/lib/local-pcg-execution/include/local-pcg-execution/local_parallel_tensor_backing.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "LocalParallelTensorBacking" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "task-spec/symbolic_training_tensor_guid_t.dtg.h", + "local-pcg-execution/training_parallel_tensor_shard_group.dtg.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", + "utils/ord/unordered_map.h", +] + +[[fields]] +name = "parallel_tensor_map" +type = "std::unordered_map<::FlexFlow::symbolic_training_tensor_guid_t, ::FlexFlow::TrainingParallelTensorShardGroup>" diff --git a/lib/local-pcg-execution/include/local-pcg-execution/local_parallel_tensor_backing.h b/lib/local-pcg-execution/include/local-pcg-execution/local_parallel_tensor_backing.h new file mode 100644 index 0000000000..0af2502dc7 --- /dev/null +++ b/lib/local-pcg-execution/include/local-pcg-execution/local_parallel_tensor_backing.h @@ -0,0 +1,41 @@ +#ifndef _FLEXFLOW_LIB_LOCAL_PCG_EXECUTION_INCLUDE_LOCAL_PCG_EXECUTION_LOCAL_PARALLEL_TENSOR_BACKING_H +#define _FLEXFLOW_LIB_LOCAL_PCG_EXECUTION_INCLUDE_LOCAL_PCG_EXECUTION_LOCAL_PARALLEL_TENSOR_BACKING_H + +#include "kernels/allocation.h" +#include "local-execution/atomic_task_invocation.dtg.h" +#include "local-execution/tensor_slot_backing.dtg.h" +#include "local-pcg-execution/local_parallel_tensor_backing.dtg.h" +#include "local-pcg-execution/mapped_runtime_task_group.h" +#include "local-pcg-execution/parallel_tensor_accessors_w.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "pcg/machine_space_coordinate.dtg.h" +#include "task-spec/runtime_task_invocation/runtime_arg_config.dtg.h" +#include "task-spec/runtime_task_invocation/runtime_task_invocation.dtg.h" +#include "task-spec/task_argument_accessor/task_tensor_parameter.dtg.h" + +namespace FlexFlow { + +std::unordered_map + lower_parallel_runtime_task_invocation_to_atomic_task_invocation_group( + LocalParallelTensorBacking const &, + RuntimeTaskInvocation const &, + RuntimeArgConfig const &, + MappedRuntimeTaskGroup const &); + +AtomicTaskInvocation + lower_parallel_runtime_task_invocation_to_atomic_task_invocation( + LocalParallelTensorBacking const &, + RuntimeTaskInvocation const &, + RuntimeArgConfig const &, + MachineSpaceCoordinate const &, + RuntimeAtomicTaskShardBinding const &); + +// LocalParallelTensorBacking construct_local_parallel_tensor_backing( +// std::unordered_map +// const &training_ptensor_shapes, +// std::unordered_map const &preallocated_ptensors, Allocator &); + +} // namespace FlexFlow + +#endif diff --git a/lib/local-pcg-execution/include/local-pcg-execution/local_pcg_args_backing.dtg.toml b/lib/local-pcg-execution/include/local-pcg-execution/local_pcg_args_backing.dtg.toml new file mode 100644 index 0000000000..ad332327d8 --- /dev/null +++ b/lib/local-pcg-execution/include/local-pcg-execution/local_pcg_args_backing.dtg.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "LocalPcgArgsBacking" +type = "struct" +features = [] + +includes = [ + "task-spec/runtime_task_invocation/runtime_arg_config.dtg.h", + "task-spec/device_specific_device_states.dtg.h", + "local-pcg-execution/parallel_layer_instance_id_t.dtg.h", + "", + "", + "local-pcg-execution/mapped_per_device_op_states_group.h", +] + +[[fields]] +name = "runtime_arg_config" +type = "::FlexFlow::RuntimeArgConfig" + +[[fields]] +name = "per_device_op_states" +type = "std::unordered_map<::FlexFlow::symbolic_layer_guid_t, std::optional<::FlexFlow::MappedPerDeviceOpStatesGroup>>" diff --git a/lib/local-pcg-execution/include/local-pcg-execution/local_pcg_args_backing.h b/lib/local-pcg-execution/include/local-pcg-execution/local_pcg_args_backing.h new file mode 100644 index 0000000000..d755760ce6 --- /dev/null +++ b/lib/local-pcg-execution/include/local-pcg-execution/local_pcg_args_backing.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_LOCAL_PCG_EXECUTION_INCLUDE_LOCAL_PCG_EXECUTION_LOCAL_PCG_ARGS_BACKING_H +#define _FLEXFLOW_LIB_LOCAL_PCG_EXECUTION_INCLUDE_LOCAL_PCG_EXECUTION_LOCAL_PCG_ARGS_BACKING_H + +#include "local-pcg-execution/local_pcg_args_backing.dtg.h" +#include "pcg/machine_space_coordinate.dtg.h" +#include "task-spec/device_specific_per_device_op_state.dtg.h" +#include "task-spec/symbolic/symbolic_layer_guid_t.dtg.h" +#include +#include + +namespace FlexFlow { + +std::unordered_map> + get_op_states_for_machine_space_coord(LocalPcgArgsBacking const &, + MachineSpaceCoordinate const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/local-pcg-execution/include/local-pcg-execution/local_pcg_training_backing.dtg.toml b/lib/local-pcg-execution/include/local-pcg-execution/local_pcg_training_backing.dtg.toml new file mode 100644 index 0000000000..21b5afde73 --- /dev/null +++ b/lib/local-pcg-execution/include/local-pcg-execution/local_pcg_training_backing.dtg.toml @@ -0,0 +1,32 @@ +namespace = "FlexFlow" +name = "LocalPcgTrainingBacking" +type = "struct" +features = [] + +includes = [ + "task-spec/training_parallel_computation_graph.dtg.h", + "local-execution/local_task_registry.dtg.h", + "local-pcg-execution/local_parallel_tensor_backing.dtg.h", + "local-pcg-execution/local_pcg_args_backing.dtg.h", + "pcg/machine_compute_specification.dtg.h", +] + +[[fields]] +name = "training_pcg" +type = "::FlexFlow::TrainingParallelComputationGraph" + +[[fields]] +name = "local_task_registry" +type = "::FlexFlow::LocalTaskRegistry" + +[[fields]] +name = "local_parallel_tensor_backing" +type = "::FlexFlow::LocalParallelTensorBacking" + +[[fields]] +name = "local_parallel_args_backing" +type = "::FlexFlow::LocalPcgArgsBacking" + +[[fields]] +name = "machine_compute_specification" +type = "::FlexFlow::MachineComputeSpecification" diff --git a/lib/local-pcg-execution/include/local-pcg-execution/local_pcg_training_backing.h b/lib/local-pcg-execution/include/local-pcg-execution/local_pcg_training_backing.h new file mode 100644 index 0000000000..dc4b1ad350 --- /dev/null +++ b/lib/local-pcg-execution/include/local-pcg-execution/local_pcg_training_backing.h @@ -0,0 +1,41 @@ +#ifndef _FLEXFLOW_LIB_LOCAL_PCG_EXECUTION_INCLUDE_LOCAL_PCG_EXECUTION_LOCAL_PCG_TRAINING_BACKING_H +#define _FLEXFLOW_LIB_LOCAL_PCG_EXECUTION_INCLUDE_LOCAL_PCG_EXECUTION_LOCAL_PCG_TRAINING_BACKING_H + +#include "local-pcg-execution/local_pcg_training_backing.dtg.h" +#include "op-attrs/ops/loss_functions/loss_attrs.dtg.h" +#include "pcg/optimizer_attrs.dtg.h" +#include "task-spec/training_parallel_layer_plus_context.dtg.h" +#include "utils/units/milliseconds_t.h" + +namespace FlexFlow { + +LocalPcgTrainingBacking make_local_pcg_training_backing_for_pcg( + Allocator &allocator, + std::unordered_map const &preallocated_tensors, + TrainingParallelComputationGraph const &training_pcg, + RuntimeArgConfig const &runtime_arg_config, + OptimizerAttrs const &optimizer_attrs, + MachineComputeSpecification const &machine_compute_specification); + +std::optional> + execute_forward(LocalTaskRegistry const &, + LocalParallelTensorBacking const &, + LocalPcgArgsBacking const &, + TrainingParallelLayerPlusContext const &, + Allocator &); + +std::optional> execute_backward(); + +void compute_loss(LocalPcgTrainingBacking const &, + LossAttrs const &, + Allocator &); + +void execute_update(LocalPcgTrainingBacking const &, + parallel_layer_guid_t const &, + OptimizerAttrs const &, + Allocator &); + +} // namespace FlexFlow + +#endif diff --git a/lib/local-pcg-execution/include/local-pcg-execution/mapped_per_device_op_states_group.h b/lib/local-pcg-execution/include/local-pcg-execution/mapped_per_device_op_states_group.h new file mode 100644 index 0000000000..da4a954d93 --- /dev/null +++ b/lib/local-pcg-execution/include/local-pcg-execution/mapped_per_device_op_states_group.h @@ -0,0 +1,50 @@ +#ifndef _FLEXFLOW_LIB_LOCAL_PCG_EXECUTION_INCLUDE_LOCAL_PCG_EXECUTION_MAPPED_PER_DEVICE_OP_STATES_GROUP_H +#define _FLEXFLOW_LIB_LOCAL_PCG_EXECUTION_INCLUDE_LOCAL_PCG_EXECUTION_MAPPED_PER_DEVICE_OP_STATES_GROUP_H + +#include "compiler/mapped_operator_task_group.h" +#include "pcg/machine_space_coordinate.dtg.h" +#include "task-spec/device_specific_per_device_op_state.dtg.h" +#include "utils/bidict/bidict.h" + +namespace FlexFlow { + +struct MappedPerDeviceOpStatesGroup { + MappedPerDeviceOpStatesGroup() = delete; + + explicit MappedPerDeviceOpStatesGroup( + std::unordered_map const + &per_device_op_states); + + [[nodiscard]] bool operator==(MappedPerDeviceOpStatesGroup const &) const; + [[nodiscard]] bool operator!=(MappedPerDeviceOpStatesGroup const &) const; + + [[nodiscard]] std::unordered_map const & + get_per_device_op_states() const; + +private: + std::unordered_map + shard_bindings; + +private: + [[nodiscard]] std::tuple tie() const; + + friend struct ::std::hash; +}; + +std::string format_as(::FlexFlow::MappedPerDeviceOpStatesGroup const &); +std::ostream &operator<<(std::ostream &, + ::FlexFlow::MappedPerDeviceOpStatesGroup const &); + +} // namespace FlexFlow + +namespace std { + +template <> +struct hash<::FlexFlow::MappedPerDeviceOpStatesGroup> { + size_t operator()(::FlexFlow::MappedPerDeviceOpStatesGroup const &) const; +}; + +} // namespace std +#endif diff --git a/lib/local-pcg-execution/include/local-pcg-execution/mapped_runtime_task_group.h b/lib/local-pcg-execution/include/local-pcg-execution/mapped_runtime_task_group.h new file mode 100644 index 0000000000..550da0cafc --- /dev/null +++ b/lib/local-pcg-execution/include/local-pcg-execution/mapped_runtime_task_group.h @@ -0,0 +1,57 @@ +#ifndef _FLEXFLOW_LIB_LOCAL_PCG_EXECUTION_INCLUDE_LOCAL_PCG_EXECUTION_MAPPED_RUNTIME_TASK_GROUP_H +#define _FLEXFLOW_LIB_LOCAL_PCG_EXECUTION_INCLUDE_LOCAL_PCG_EXECUTION_MAPPED_RUNTIME_TASK_GROUP_H + +#include "compiler/mapped_operator_task_group.h" +#include "local-pcg-execution/runtime_atomic_task_shard_binding.dtg.h" +#include "pcg/machine_space_coordinate.dtg.h" +#include "task-spec/fwb_op_task_type.dtg.h" +#include "task-spec/symbolic/symbolic_layer_training_tensor_group_signature.dtg.h" +#include "utils/bidict/bidict.h" + +namespace FlexFlow { + +struct MappedRuntimeTaskGroup { + MappedRuntimeTaskGroup() = delete; + + explicit MappedRuntimeTaskGroup( + bidict const + &shard_bindings); + + [[nodiscard]] bool operator==(MappedRuntimeTaskGroup const &) const; + [[nodiscard]] bool operator!=(MappedRuntimeTaskGroup const &) const; + + [[nodiscard]] bidict const & + get_shard_bindings() const; + +private: + bidict shard_bindings; + +private: + [[nodiscard]] std::tuple tie() const; + + friend struct ::std::hash; +}; + +std::string format_as(::FlexFlow::MappedRuntimeTaskGroup const &); +std::ostream &operator<<(std::ostream &, + ::FlexFlow::MappedRuntimeTaskGroup const &); + +MappedRuntimeTaskGroup + lower_mapped_operator_task_group_to_mapped_runtime_task_group( + MappedOperatorTaskGroup const &, + SymbolicLayerTrainingTensorGroupSignature const &, + FwbOpTaskType); + +} // namespace FlexFlow + +namespace std { + +template <> +struct hash<::FlexFlow::MappedRuntimeTaskGroup> { + size_t operator()(::FlexFlow::MappedRuntimeTaskGroup const &) const; +}; + +} // namespace std + +#endif diff --git a/lib/local-pcg-execution/include/local-pcg-execution/parallel_forward_tensor_group.dtg.toml b/lib/local-pcg-execution/include/local-pcg-execution/parallel_forward_tensor_group.dtg.toml new file mode 100644 index 0000000000..d20d046c50 --- /dev/null +++ b/lib/local-pcg-execution/include/local-pcg-execution/parallel_forward_tensor_group.dtg.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "ParallelForwardTensorGroup" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", + "rapidcheck", +] + +includes = [ + "task-spec/forward_tensor_guid_t.dtg.h", + "op-attrs/parallel_tensor_space_coordinate.dtg.h", + "utils/bidict/bidict.h", +] + +[[fields]] +name = "forward_training_tensors_by_coord" +type = "::FlexFlow::bidict<::FlexFlow::ParallelTensorSpaceCoordinate, ::FlexFlow::forward_tensor_guid_t>" + diff --git a/lib/local-pcg-execution/include/local-pcg-execution/parallel_layer_instance_id_t.dtg.toml b/lib/local-pcg-execution/include/local-pcg-execution/parallel_layer_instance_id_t.dtg.toml new file mode 100644 index 0000000000..dcbc9d97ee --- /dev/null +++ b/lib/local-pcg-execution/include/local-pcg-execution/parallel_layer_instance_id_t.dtg.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "parallel_layer_instance_id_t" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "pcg/gpu_id_t.dtg.h", +] + +[[fields]] +name = "parallel_layer_guid" +type = "::FlexFlow::parallel_layer_guid_t" + +[[fields]] +name = "gpu_id" +type = "::FlexFlow::gpu_id_t" diff --git a/lib/local-pcg-execution/include/local-pcg-execution/parallel_loss_tensor_group.dtg.toml b/lib/local-pcg-execution/include/local-pcg-execution/parallel_loss_tensor_group.dtg.toml new file mode 100644 index 0000000000..6a2e2619b1 --- /dev/null +++ b/lib/local-pcg-execution/include/local-pcg-execution/parallel_loss_tensor_group.dtg.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "ParallelLossTensorGroup" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", + "rapidcheck", +] + +includes = [ + "task-spec/loss_tensor_guid_t.dtg.h", + "op-attrs/parallel_tensor_space_coordinate.dtg.h", + "utils/bidict/bidict.h", +] + +[[fields]] +name = "loss_training_tensors_by_coord" +type = "::FlexFlow::bidict<::FlexFlow::ParallelTensorSpaceCoordinate, ::FlexFlow::loss_tensor_guid_t>" + diff --git a/lib/local-pcg-execution/include/local-pcg-execution/parallel_model_training_instance.h b/lib/local-pcg-execution/include/local-pcg-execution/parallel_model_training_instance.h new file mode 100644 index 0000000000..8cfc261774 --- /dev/null +++ b/lib/local-pcg-execution/include/local-pcg-execution/parallel_model_training_instance.h @@ -0,0 +1,47 @@ +#ifndef _FLEXFLOW_LIB_LOCAL_PCG_EXECUTION_INCLUDE_LOCAL_PCG_EXECUTION_PARALLEL_MODEL_TRAINING_INSTANCE_H +#define _FLEXFLOW_LIB_LOCAL_PCG_EXECUTION_INCLUDE_LOCAL_PCG_EXECUTION_PARALLEL_MODEL_TRAINING_INSTANCE_H + +#include "compiler/mapped_parallel_computation_graph.dtg.h" +#include "kernels/allocation.h" +#include "local-execution/local_atomic_tensor_backing.dtg.h" +#include "local-execution/local_task_registry.dtg.h" +#include "local-pcg-execution/local_parallel_tensor_backing.dtg.h" +#include "local-pcg-execution/task_group_execution_times.dtg.h" +#include "op-attrs/ops/loss_functions/loss_attrs.dtg.h" +#include "pcg/optimizer_attrs.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "task-spec/runtime_task_invocation/runtime_arg_config.dtg.h" +#include "task-spec/symbolic/training_symbolic_computation_graph_from_pcg_conversion.dtg.h" + +namespace FlexFlow { + +struct ParallelModelTrainingInstance { + ParallelModelTrainingInstance(Allocator const &, + LossAttrs const &, + OptimizerAttrs const &); + +public: + std::unordered_map> + forward(); + std::unordered_map> + backward(); + void update(); + GenericTensorAccessorR get_loss_tensor_accessor() const; + +private: + Allocator allocator; + LossAttrs loss_attrs; + OptimizerAttrs optimizer_attrs; + TrainingSymbolicComputationGraphFromPcgConversion symbolic_cg; + MappedParallelComputationGraph mapped_pcg; + LocalParallelTensorBacking local_tensor_backing; + LocalAtomicTensorBacking local_atomic_tensor_backing; + LocalTaskRegistry local_task_registry; + RuntimeArgConfig runtime_arg_config; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/local-pcg-execution/include/local-pcg-execution/parallel_tensor_accessors_w.dtg.toml b/lib/local-pcg-execution/include/local-pcg-execution/parallel_tensor_accessors_w.dtg.toml new file mode 100644 index 0000000000..d75dda9f68 --- /dev/null +++ b/lib/local-pcg-execution/include/local-pcg-execution/parallel_tensor_accessors_w.dtg.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "ParallelTensorAccessorsW" +type = "struct" +features = [] + +includes = [ + "", + "op-attrs/parallel_tensor_space_coordinate.dtg.h", + "kernels/accessor.h", +] + +[[fields]] +name = "shard_map" +type = "std::unordered_map<::FlexFlow::ParallelTensorSpaceCoordinate, ::FlexFlow::GenericTensorAccessorW>" diff --git a/lib/local-pcg-execution/include/local-pcg-execution/runtime_atomic_task_shard_binding.dtg.toml b/lib/local-pcg-execution/include/local-pcg-execution/runtime_atomic_task_shard_binding.dtg.toml new file mode 100644 index 0000000000..6ec06d4d64 --- /dev/null +++ b/lib/local-pcg-execution/include/local-pcg-execution/runtime_atomic_task_shard_binding.dtg.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "RuntimeAtomicTaskShardBinding" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_space_coordinate.dtg.h", + "", + "task-spec/symbolic_training_tensor_guid_t.dtg.h", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/fmt/unordered_map.h", + "utils/ord/unordered_map.h", +] + +[[fields]] +name = "raw_binding" +type = "std::unordered_map<::FlexFlow::symbolic_training_tensor_guid_t, ::FlexFlow::ParallelTensorSpaceCoordinate>" diff --git a/lib/local-pcg-execution/include/local-pcg-execution/runtime_atomic_task_shard_binding.h b/lib/local-pcg-execution/include/local-pcg-execution/runtime_atomic_task_shard_binding.h new file mode 100644 index 0000000000..49631fd94b --- /dev/null +++ b/lib/local-pcg-execution/include/local-pcg-execution/runtime_atomic_task_shard_binding.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_LOCAL_PCG_EXECUTION_INCLUDE_LOCAL_PCG_EXECUTION_RUNTIME_ATOMIC_TASK_SHARD_BINDING_H +#define _FLEXFLOW_LIB_LOCAL_PCG_EXECUTION_INCLUDE_LOCAL_PCG_EXECUTION_RUNTIME_ATOMIC_TASK_SHARD_BINDING_H + +#include "compiler/operator_atomic_task_shard_binding.dtg.h" +#include "local-pcg-execution/runtime_atomic_task_shard_binding.dtg.h" +#include "task-spec/fwb_op_task_type.dtg.h" +#include "task-spec/symbolic/symbolic_layer_training_tensor_group_signature.dtg.h" + +namespace FlexFlow { + +RuntimeAtomicTaskShardBinding + lower_op_shard_binding_to_fwd_pass_runtime_shard_binding( + OperatorAtomicTaskShardBinding const &, + SymbolicLayerTrainingTensorGroupSignature const &); + +RuntimeAtomicTaskShardBinding + lower_op_shard_binding_to_bwd_pass_runtime_shard_binding( + OperatorAtomicTaskShardBinding const &, + SymbolicLayerTrainingTensorGroupSignature const &); + +RuntimeAtomicTaskShardBinding lower_op_shard_binding_to_runtime_shard_binding( + OperatorAtomicTaskShardBinding const &, + SymbolicLayerTrainingTensorGroupSignature const &, + FwbOpTaskType); + +} // namespace FlexFlow + +#endif diff --git a/lib/local-pcg-execution/include/local-pcg-execution/task_group_execution_times.dtg.toml b/lib/local-pcg-execution/include/local-pcg-execution/task_group_execution_times.dtg.toml new file mode 100644 index 0000000000..52fa1cbc00 --- /dev/null +++ b/lib/local-pcg-execution/include/local-pcg-execution/task_group_execution_times.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "TaskGroupExecutionTimes" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "", + "utils/units/milliseconds_t.h", + "pcg/machine_space_coordinate.dtg.h", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/fmt/unordered_map.h", + "utils/ord/unordered_map.h", +] + +[[fields]] +name = "execution_times" +type = "std::unordered_map<::FlexFlow::MachineSpaceCoordinate, ::FlexFlow::milliseconds_t>" diff --git a/lib/local-pcg-execution/include/local-pcg-execution/training_operator_task_signature.dtg.toml b/lib/local-pcg-execution/include/local-pcg-execution/training_operator_task_signature.dtg.toml new file mode 100644 index 0000000000..fc8f54715b --- /dev/null +++ b/lib/local-pcg-execution/include/local-pcg-execution/training_operator_task_signature.dtg.toml @@ -0,0 +1,34 @@ +namespace = "FlexFlow" +name = "TrainingOperatorTaskSignature" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", + "rapidcheck", +] + +includes = [ + "task-spec/training_tensor_guid_t.dtg.h", + "", +] + +src_includes = [ + "utils/hash/vector.h", + "utils/fmt/vector.h", + "utils/ord/vector.h", +] + +[[fields]] +name = "inputs" +type = "std::vector<::FlexFlow::training_tensor_guid_t>" + +[[fields]] +name = "weights" +type = "std::vector<::FlexFlow::training_tensor_guid_t>" + +[[fields]] +name = "outputs" +type = "std::vector<::FlexFlow::training_tensor_guid_t>" diff --git a/lib/local-pcg-execution/include/local-pcg-execution/training_parallel_layer_plus_context.dtg.toml b/lib/local-pcg-execution/include/local-pcg-execution/training_parallel_layer_plus_context.dtg.toml new file mode 100644 index 0000000000..fd4bbc6182 --- /dev/null +++ b/lib/local-pcg-execution/include/local-pcg-execution/training_parallel_layer_plus_context.dtg.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "TrainingParallelLayerPlusContext" +type = "struct" +features = [] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h", + "task-spec/training_parallel_tensor_group_with_attrs.dtg.h", +] + +[[fields]] +name = "parallel_layer_guid" +type = "::FlexFlow::parallel_layer_guid_t" + +[[fields]] +name = "parallel_layer_attrs" +type = "::FlexFlow::ParallelLayerAttrs" + +[[fields]] +name = "input_parallel_tensor_groups" +type = "std::vector<::FlexFlow::TrainingParallelTensorGroupWithAttrs>" + +[[fields]] +name = "weight_tensor_groups" +type = "std::vector<::FlexFlow::TrainingParallelTensorGroupWithAttrs>" + +[[fields]] +name = "output_tensor_groups" +type = "std::vector<::FlexFlow::TrainingParallelTensorGroupWithAttrs>" diff --git a/lib/local-pcg-execution/include/local-pcg-execution/training_parallel_tensor_shard_group.dtg.toml b/lib/local-pcg-execution/include/local-pcg-execution/training_parallel_tensor_shard_group.dtg.toml new file mode 100644 index 0000000000..e3958cf934 --- /dev/null +++ b/lib/local-pcg-execution/include/local-pcg-execution/training_parallel_tensor_shard_group.dtg.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "TrainingParallelTensorShardGroup" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", + "rapidcheck", +] + +includes = [ + "", + "op-attrs/parallel_tensor_space_coordinate.dtg.h", + "local-execution/atomic_training_tensor_guid_t.dtg.h", +] + +[[fields]] +name = "shard_map" +type = "std::unordered_map<::FlexFlow::ParallelTensorSpaceCoordinate, ::FlexFlow::atomic_training_tensor_guid_t>" diff --git a/lib/local-pcg-execution/src/local-pcg-execution/execute_tasks_for_parallel_layer.cc b/lib/local-pcg-execution/src/local-pcg-execution/execute_tasks_for_parallel_layer.cc new file mode 100644 index 0000000000..fc4562ff64 --- /dev/null +++ b/lib/local-pcg-execution/src/local-pcg-execution/execute_tasks_for_parallel_layer.cc @@ -0,0 +1,193 @@ +#include "local-pcg-execution/execute_tasks_for_parallel_layer.h" +#include "local-execution/local_atomic_tensor_backing.h" +#include "local-execution/local_task_registry.h" +#include "local-pcg-execution/local_parallel_tensor_backing.h" +#include "local-pcg-execution/task_group_execution_times.dtg.h" +#include "task-spec/fwb_op_task_type.h" +#include "utils/containers/all_of.h" +#include "utils/containers/flatmap.h" +#include "utils/containers/lift_optional_through_map.h" +#include "utils/containers/map_values.h" +#include "utils/containers/values.h" + +namespace FlexFlow { + +std::unordered_map + prepare_parallel_runtime_task_invocations( + RuntimeTaskInvocation const &runtime_task_invocation, + LocalParallelTensorBacking const ¶llel_tensor_backing, + LocalAtomicTensorBacking const &atomic_tensor_backing, + Allocator &allocator, + RuntimeArgConfig const &runtime_arg_config, + MappedRuntimeTaskGroup const &task_group) { + + std::unordered_map + atomic_task_invocations = + lower_parallel_runtime_task_invocation_to_atomic_task_invocation_group( + parallel_tensor_backing, + runtime_task_invocation, + runtime_arg_config, + task_group); + + return map_values( + atomic_task_invocations, + [&](AtomicTaskInvocation const &atomic_task_invocation) + -> LocalReadyToLaunchTask { + TaskArgumentAccessor task_arg_accessor = + get_task_arg_accessor_for_atomic_task_invocation( + atomic_tensor_backing, atomic_task_invocation, allocator); + + return LocalReadyToLaunchTask{ + atomic_task_invocation.task_id, + task_arg_accessor, + }; + }); +} + +std::optional execute_init_for_parallel_layer( + symbolic_layer_guid_t symbolic_layer_guid, + TrainingSymbolicComputationGraph const &g, + LocalParallelTensorBacking const ¶llel_tensor_backing, + LocalAtomicTensorBacking const &atomic_tensor_backing, + Allocator &allocator, + LocalTaskRegistry const &task_registry, + RuntimeArgConfig const &runtime_arg_config, + MappedRuntimeTaskGroup const &task_group) { + + SymbolicCgOpAttrsAndTrainingSignatureWithShapes attrs_and_signature = + get_attrs_and_signature_for_layer(g, symbolic_layer_guid); + + RuntimeTaskInvocation runtime_task_invocation = ({ + std::optional maybe_runtime_task_invocation = + get_init_runtime_task_invocation_for_layer(symbolic_layer_guid, + attrs_and_signature); + if (!maybe_runtime_task_invocation.has_value()) { + return std::nullopt; + } + maybe_runtime_task_invocation.value(); + }); + + std::unordered_map + prepared_tasks = + prepare_parallel_runtime_task_invocations(runtime_task_invocation, + parallel_tensor_backing, + atomic_tensor_backing, + allocator, + runtime_arg_config, + task_group); + + std::unordered_map> + op_state_by_shard = map_values( + prepared_tasks, + [&](LocalReadyToLaunchTask const &prepared_task) + -> std::optional { + return call_init_task_impl(task_registry, + prepared_task.task_id, + prepared_task.task_arg_accessor); + }); + + return transform( + lift_optional_through_map(op_state_by_shard), + [](std::unordered_map const &m) { + return MappedPerDeviceOpStatesGroup{m}; + }); +} + +static std::optional execute_fwb_for_parallel_layer( + symbolic_layer_guid_t symbolic_layer_guid, + TrainingSymbolicComputationGraph const &g, + LocalParallelTensorBacking const ¶llel_tensor_backing, + LocalAtomicTensorBacking const &atomic_tensor_backing, + Allocator &allocator, + LocalTaskRegistry const &task_registry, + RuntimeArgConfig const &runtime_arg_config, + MappedRuntimeTaskGroup const &task_group, + FwbOpTaskType fwb_task_type) { + + SymbolicCgOpAttrsAndTrainingSignatureWithShapes attrs_and_signature = + get_attrs_and_signature_for_layer(g, symbolic_layer_guid); + + OpTaskType op_task_type = + assert_unwrap(op_task_type_from_fwb_op_task_type(fwb_task_type)); + + RuntimeTaskInvocation runtime_task_invocation = ({ + std::optional maybe_runtime_task_invocation = + get_runtime_task_invocation_for_layer_and_type( + symbolic_layer_guid, attrs_and_signature, op_task_type); + if (!maybe_runtime_task_invocation.has_value()) { + return std::nullopt; + } + maybe_runtime_task_invocation.value(); + }); + + std::unordered_map + prepared_tasks = + prepare_parallel_runtime_task_invocations(runtime_task_invocation, + parallel_tensor_backing, + atomic_tensor_backing, + allocator, + runtime_arg_config, + task_group); + + std::unordered_map> + timing_by_shard = map_values( + prepared_tasks, + [&](LocalReadyToLaunchTask const &prepared_task) + -> std::optional { + return call_fwb_task_impl(task_registry, + prepared_task.task_id, + prepared_task.task_arg_accessor); + }); + + return transform( + lift_optional_through_map(timing_by_shard), + [](std::unordered_map const &m) { + return TaskGroupExecutionTimes{m}; + }); +} + +std::optional execute_forward_for_parallel_layer( + symbolic_layer_guid_t symbolic_layer_guid, + TrainingSymbolicComputationGraph const &g, + LocalParallelTensorBacking const ¶llel_tensor_backing, + LocalAtomicTensorBacking const &atomic_tensor_backing, + Allocator &allocator, + LocalTaskRegistry const &task_registry, + RuntimeArgConfig const &runtime_arg_config, + MappedRuntimeTaskGroup const &task_group) { + + return execute_fwb_for_parallel_layer(symbolic_layer_guid, + g, + parallel_tensor_backing, + atomic_tensor_backing, + allocator, + task_registry, + runtime_arg_config, + task_group, + FwbOpTaskType::FWD); +} + +std::optional execute_backward_for_parallel_layer( + symbolic_layer_guid_t symbolic_layer_guid, + TrainingSymbolicComputationGraph const &g, + LocalParallelTensorBacking const ¶llel_tensor_backing, + LocalAtomicTensorBacking const &atomic_tensor_backing, + Allocator &allocator, + LocalTaskRegistry const &task_registry, + RuntimeArgConfig const &runtime_arg_config, + MappedRuntimeTaskGroup const &task_group) { + + return execute_fwb_for_parallel_layer(symbolic_layer_guid, + g, + parallel_tensor_backing, + atomic_tensor_backing, + allocator, + task_registry, + runtime_arg_config, + task_group, + FwbOpTaskType::BWD); +} + +} // namespace FlexFlow diff --git a/lib/local-pcg-execution/src/local-pcg-execution/local_parallel_tensor_backing.cc b/lib/local-pcg-execution/src/local-pcg-execution/local_parallel_tensor_backing.cc new file mode 100644 index 0000000000..ead5349a9f --- /dev/null +++ b/lib/local-pcg-execution/src/local-pcg-execution/local_parallel_tensor_backing.cc @@ -0,0 +1,80 @@ +#include "local-pcg-execution/local_parallel_tensor_backing.h" +#include "local-pcg-execution/local_pcg_args_backing.dtg.h" +#include "local-pcg-execution/local_pcg_args_backing.h" +#include "local-pcg-execution/runtime_atomic_task_shard_binding.dtg.h" +#include "task-spec/device_specific_per_device_op_state.dtg.h" +#include "task-spec/lower_op_task_invocation_to_runtime_task_invocation.h" +#include "utils/containers/map_values.h" +#include "utils/containers/map_values2.h" +#include "utils/containers/try_at.h" + +namespace FlexFlow { + +std::unordered_map + lower_parallel_runtime_task_invocation_to_atomic_task_invocation_group( + LocalParallelTensorBacking const ¶llel_tensor_backing, + LocalPcgArgsBacking const ¶llel_args_backing, + RuntimeTaskInvocation const &runtime_task_invocation, + MappedRuntimeTaskGroup const &runtime_task_group) { + + std::unordered_map + shard_bindings = + runtime_task_group.get_shard_bindings().as_unordered_map(); + + return map_values2( + shard_bindings, + [&](MachineSpaceCoordinate const &machine_space_coord, + RuntimeAtomicTaskShardBinding const &shard_binding) + -> AtomicTaskInvocation { + return lower_parallel_runtime_task_invocation_to_atomic_task_invocation( + parallel_tensor_backing, + runtime_task_invocation, + parallel_args_backing.runtime_arg_config, + get_op_states_for_machine_space_coord(parallel_args_backing, + machine_space_coord), + machine_space_coord, + shard_binding); + }); +} + +AtomicTaskInvocation + lower_parallel_runtime_task_invocation_to_atomic_task_invocation( + LocalParallelTensorBacking const ¶llel_tensor_backing, + RuntimeTaskInvocation const &invocation, + RuntimeArgConfig const &runtime_arg_config, + std::unordered_map> const + &per_device_op_states, + MachineSpaceCoordinate const &machine_space_coord, + RuntimeAtomicTaskShardBinding const &shard_binding) { + + std::unordered_map + tensor_bindings = + map_values(invocation.binding.get_tensor_bindings(), + [&](symbolic_training_tensor_guid_t t) + -> atomic_training_tensor_guid_t { + return parallel_tensor_backing.parallel_tensor_map.at(t); + }); + + auto get_op_state_for_layer = [&](symbolic_layer_guid_t l) + -> std::optional { + return per_device_op_states.at(l); + }; + + std::unordered_map arg_bindings = + map_values(invocation.binding.get_arg_bindings(), + [&](RuntimeArgSpec const &arg_spec) -> ConcreteArgSpec { + return lower_runtime_arg_ref_spec_to_concrete_arg_spec( + arg_spec, runtime_arg_config, get_op_state_for_layer); + }); + + return AtomicTaskInvocation{ + invocation.task_id, + AtomicTaskBinding{ + tensor_bindings, + arg_bindings, + }, + }; +} + +} // namespace FlexFlow diff --git a/lib/local-pcg-execution/src/local-pcg-execution/local_pcg_args_backing.cc b/lib/local-pcg-execution/src/local-pcg-execution/local_pcg_args_backing.cc new file mode 100644 index 0000000000..2910683801 --- /dev/null +++ b/lib/local-pcg-execution/src/local-pcg-execution/local_pcg_args_backing.cc @@ -0,0 +1,20 @@ +#include "local-pcg-execution/local_pcg_args_backing.h" + +namespace FlexFlow { + +std::unordered_map> + get_op_states_for_machine_space_coord( + LocalPcgArgsBacking const &args_backing, + MachineSpaceCoordinate const &coord) { + + return map_values( + args_backing.per_device_op_states, + [&](std::optional const &m_g) { + return transform(m_g, [&](MappedPerDeviceOpStatesGroup const &g) { + return g.get_per_device_op_states().at_l(coord); + }); + }); +} + +} // namespace FlexFlow diff --git a/lib/local-pcg-execution/src/local-pcg-execution/local_pcg_training_backing.cc b/lib/local-pcg-execution/src/local-pcg-execution/local_pcg_training_backing.cc new file mode 100644 index 0000000000..d9649d9e85 --- /dev/null +++ b/lib/local-pcg-execution/src/local-pcg-execution/local_pcg_training_backing.cc @@ -0,0 +1,45 @@ +#include "local-pcg-execution/local_pcg_training_backing.h" +#include "local-execution/local_task_registry.h" + +namespace FlexFlow { + +LocalPcgTrainingBacking make_local_pcg_training_backing_for_pcg( + Allocator &allocator, + std::unordered_map const &preallocated_tensors, + TrainingParallelComputationGraph const &training_pcg, + RuntimeArgConfig const &runtime_arg_config, + OptimizerAttrs const &optimizer_attrs, + MachineComputeSpecification const &machine_compute_specification) { + + NOT_IMPLEMENTED(); +} + +std::optional> execute_forward( + LocalTaskRegistry const &local_task_registry, + LocalParallelTensorBacking const &, + LocalPcgArgsBacking const &, + TrainingParallelLayerPlusContext const &training_parallel_layer, + Allocator &) { + + NOT_IMPLEMENTED(); +} + +std::optional> execute_backward() { + NOT_IMPLEMENTED(); +} + +void compute_loss(LocalPcgTrainingBacking const &, + LossAttrs const &, + Allocator &) { + NOT_IMPLEMENTED(); +} + +void execute_update(LocalPcgTrainingBacking const &, + parallel_layer_guid_t const &, + OptimizerAttrs const &, + Allocator &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/local-pcg-execution/src/local-pcg-execution/mapped_per_device_op_states_group.cc b/lib/local-pcg-execution/src/local-pcg-execution/mapped_per_device_op_states_group.cc new file mode 100644 index 0000000000..b94f7378ac --- /dev/null +++ b/lib/local-pcg-execution/src/local-pcg-execution/mapped_per_device_op_states_group.cc @@ -0,0 +1,124 @@ +#include "local-pcg-execution/mapped_per_device_op_states_group.h" +#include "compiler/machine_mapping/machine_view.h" +#include "op-attrs/get_operator_task_space.h" +#include "op-attrs/operator_task_space.h" +#include "op-attrs/parallel_tensor_space_coordinate.h" +#include "utils/bidict/generate_bidict.h" +#include "utils/containers/are_all_distinct.h" +#include "utils/containers/require_all_same.h" +#include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" +#include "utils/hash/tuple.h" +#include "utils/nonnegative_int/num_elements.h" + +namespace FlexFlow { + +MappedPerDeviceOpStatesGroup::MappedPerDeviceOpStatesGroup( + bidict const + &per_device_op_states) + : per_device_op_states(per_device_op_states) { + auto check_arity = [&](TensorRole tensor_role) -> nonnegative_int { + std::unordered_set arities = transform( + shard_bindings.right_values(), + [&](OperatorAtomicTaskShardBinding const &s) -> nonnegative_int { + return num_elements(ptensor_space_coords_for_role(s, tensor_role)); + }); + + return require_all_same(arities).value_or(0_n); + }; + + nonnegative_int num_inputs = check_arity(TensorRole::INPUT); + nonnegative_int num_weights = check_arity(TensorRole::WEIGHT); + nonnegative_int num_outputs = check_arity(TensorRole::OUTPUT); + + std::unordered_set all_keys = + all_keys_for_signature_arities( + /*num_inputs=*/num_inputs, + /*num_weights=*/num_weights, + /*num_outputs=*/num_outputs); + + for (TaskSignatureTensorKey const &key : all_keys) { + std::vector signatures_for_key = + vector_of(shard_bindings.right_values()); + + std::vector coords_for_key = + transform(signatures_for_key, + [&](OperatorAtomicTaskShardBinding const &signature) { + return ptensor_space_coord_for_key(signature, key); + }); + + ASSERT(are_all_distinct(coords_for_key)); + + std::vector coord_dims_for_key = + transform(coords_for_key, [](ParallelTensorSpaceCoordinate const &c) { + return ptensor_coord_num_dims(c); + }); + + require_all_same(coord_dims_for_key); + } +} + +bool MappedPerDeviceOpStatesGroup::operator==( + MappedPerDeviceOpStatesGroup const &other) const { + return this->tie() == other.tie(); +} + +bool MappedPerDeviceOpStatesGroup::operator!=( + MappedPerDeviceOpStatesGroup const &other) const { + return this->tie() == other.tie(); +} + +std::tuple< + bidict const &> + MappedPerDeviceOpStatesGroup::tie() const { + + return std::tie(this->shard_bindings); +} + +bidict const & + MappedPerDeviceOpStatesGroup::get_shard_bindings() const { + return this->shard_bindings; +} + +std::string format_as(::FlexFlow::MappedPerDeviceOpStatesGroup const &m) { + return fmt::format("", + m.get_shard_bindings()); +} + +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::MappedPerDeviceOpStatesGroup const &x) { + return (s << fmt::to_string(x)); +} + +MappedPerDeviceOpStatesGroup mapped_operator_task_group_from_machine_view( + ComputationGraphOpAttrs const &op_attrs, + std::unordered_map const + &inputs_dim_degrees, + MachineView const &machine_view) { + + OperatorTaskSpace op_task_space = + get_operator_task_space(op_attrs, inputs_dim_degrees); + + return MappedPerDeviceOpStatesGroup{ + generate_bidict( + get_machine_space_coordinates(op_task_space, machine_view), + [&](MachineSpaceCoordinate const &machine_space_coord) { + return operator_atomic_task_shard_binding_from_machine_view( + op_attrs, + inputs_dim_degrees, + machine_view, + machine_space_coord); + }), + }; +} + +} // namespace FlexFlow + +namespace std { + +size_t hash<::FlexFlow::MappedPerDeviceOpStatesGroup>::operator()( + ::FlexFlow::MappedPerDeviceOpStatesGroup const &x) const { + return ::FlexFlow::get_std_hash(x.tie()); +} + +} // namespace std diff --git a/lib/local-pcg-execution/src/local-pcg-execution/mapped_runtime_task_group.cc b/lib/local-pcg-execution/src/local-pcg-execution/mapped_runtime_task_group.cc new file mode 100644 index 0000000000..f374412296 --- /dev/null +++ b/lib/local-pcg-execution/src/local-pcg-execution/mapped_runtime_task_group.cc @@ -0,0 +1,123 @@ +#include "local-pcg-execution/mapped_runtime_task_group.h" +#include "compiler/machine_mapping/machine_view.h" +#include "compiler/operator_atomic_task_shard_binding.h" +#include "local-pcg-execution/runtime_atomic_task_shard_binding.dtg.h" +#include "local-pcg-execution/runtime_atomic_task_shard_binding.h" +#include "op-attrs/get_operator_task_space.h" +#include "op-attrs/operator_task_space.h" +#include "op-attrs/parallel_tensor_space_coordinate.h" +#include "utils/bidict/algorithms/transform_values.h" +#include "utils/bidict/generate_bidict.h" +#include "utils/containers/are_all_distinct.h" +#include "utils/containers/require_all_same.h" +#include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" +#include "utils/hash/tuple.h" +#include "utils/nonnegative_int/num_elements.h" + +namespace FlexFlow { + +MappedRuntimeTaskGroup::MappedRuntimeTaskGroup( + bidict const + &shard_bindings) + : shard_bindings(shard_bindings) { + auto check_arity = [&](TensorRole tensor_role) -> nonnegative_int { + std::unordered_set arities = transform( + shard_bindings.right_values(), + [&](RuntimeAtomicTaskShardBinding const &s) -> nonnegative_int { + return num_elements(ptensor_space_coords_for_role(s, tensor_role)); + }); + + return require_all_same(arities).value_or(0_n); + }; + + nonnegative_int num_inputs = check_arity(TensorRole::INPUT); + nonnegative_int num_weights = check_arity(TensorRole::WEIGHT); + nonnegative_int num_outputs = check_arity(TensorRole::OUTPUT); + + std::unordered_set all_keys = + all_keys_for_signature_arities( + /*num_inputs=*/num_inputs, + /*num_weights=*/num_weights, + /*num_outputs=*/num_outputs); + + for (TaskSignatureTensorKey const &key : all_keys) { + std::vector signatures_for_key = + vector_of(shard_bindings.right_values()); + + std::vector coords_for_key = + transform(signatures_for_key, + [&](RuntimeAtomicTaskShardBinding const &signature) { + return ptensor_space_coord_for_key(signature, key); + }); + + ASSERT(are_all_distinct(coords_for_key)); + + std::vector coord_dims_for_key = + transform(coords_for_key, [](ParallelTensorSpaceCoordinate const &c) { + return ptensor_coord_num_dims(c); + }); + + require_all_same(coord_dims_for_key); + } +} + +bool MappedRuntimeTaskGroup::operator==( + MappedRuntimeTaskGroup const &other) const { + return this->tie() == other.tie(); +} + +bool MappedRuntimeTaskGroup::operator!=( + MappedRuntimeTaskGroup const &other) const { + return this->tie() == other.tie(); +} + +std::tuple< + bidict const &> + MappedRuntimeTaskGroup::tie() const { + + return std::tie(this->shard_bindings); +} + +bidict const & + MappedRuntimeTaskGroup::get_shard_bindings() const { + return this->shard_bindings; +} + +std::string format_as(::FlexFlow::MappedRuntimeTaskGroup const &m) { + return fmt::format("", + m.get_shard_bindings()); +} + +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::MappedRuntimeTaskGroup const &x) { + return (s << fmt::to_string(x)); +} + +MappedRuntimeTaskGroup + lower_mapped_operator_task_group_to_mapped_runtime_task_group( + MappedOperatorTaskGroup const &op_task_group, + SymbolicLayerTrainingTensorGroupSignature const + &symbolic_layer_signature, + FwbOpTaskType task_type) { + return MappedRuntimeTaskGroup{ + transform_values( + op_task_group.get_shard_bindings(), + [&](RuntimeAtomicTaskShardBinding const &op_shard_binding) + -> RuntimeAtomicTaskShardBinding { + return lower_op_shard_binding_to_runtime_shard_binding( + op_shard_binding, symbolic_layer_signature, task_type); + }), + }; +} + +} // namespace FlexFlow + +namespace std { + +size_t hash<::FlexFlow::MappedRuntimeTaskGroup>::operator()( + ::FlexFlow::MappedRuntimeTaskGroup const &x) const { + return ::FlexFlow::get_std_hash(x.tie()); +} + +} // namespace std diff --git a/lib/local-pcg-execution/src/local-pcg-execution/runtime_atomic_task_shard_binding.cc b/lib/local-pcg-execution/src/local-pcg-execution/runtime_atomic_task_shard_binding.cc new file mode 100644 index 0000000000..20924b1eed --- /dev/null +++ b/lib/local-pcg-execution/src/local-pcg-execution/runtime_atomic_task_shard_binding.cc @@ -0,0 +1,88 @@ +#include "local-pcg-execution/runtime_atomic_task_shard_binding.h" +#include "compiler/operator_atomic_task_shard_binding.h" +#include "op-attrs/tensor_role.dtg.h" +#include "task-spec/fwb_tensor_type.dtg.h" +#include "task-spec/symbolic/symbolic_layer_training_tensor_group_signature.h" +#include "utils/containers/map_from_keys_and_values.h" +#include "utils/containers/merge_disjoint_maps.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +static std::unordered_map + get_tensor_shard_binding_for_type( + SymbolicLayerTrainingTensorGroupSignature const &signature, + OperatorAtomicTaskShardBinding const &shard_binding, + TensorRole tensor_role, + FwbTensorType tensor_type) { + + std::vector keys = + get_training_tensors_for_role_and_type( + signature, tensor_role, tensor_type); + + std::vector pt_coords = + ptensor_space_coords_for_role(shard_binding, tensor_role); + + return map_from_keys_and_values( + /*keys=*/keys, + /*values=*/pt_coords); +}; + +RuntimeAtomicTaskShardBinding lower_op_shard_binding_to_runtime_shard_binding( + OperatorAtomicTaskShardBinding const &op_shard_binding, + SymbolicLayerTrainingTensorGroupSignature const &signature) { + + auto get_bindings = [&](TensorRole tensor_role, FwbTensorType tensor_type) { + return get_tensor_shard_binding_for_type( + signature, op_shard_binding, tensor_role, tensor_type); + }; + + return RuntimeAtomicTaskShardBinding{ + merge_disjoint_maps(std::vector{ + get_bindings(TensorRole::INPUT, FwbTensorType::FORWARD), + get_bindings(TensorRole::WEIGHT, FwbTensorType::FORWARD), + get_bindings(TensorRole::OUTPUT, FwbTensorType::FORWARD), + }), + }; +} + +RuntimeAtomicTaskShardBinding + lower_op_shard_binding_to_bwd_pass_runtime_shard_binding( + OperatorAtomicTaskShardBinding const &op_shard_binding, + SymbolicLayerTrainingTensorGroupSignature const &signature) { + + auto get_bindings = [&](TensorRole tensor_role, FwbTensorType tensor_type) { + return get_tensor_shard_binding_for_type( + signature, op_shard_binding, tensor_role, tensor_type); + }; + + return RuntimeAtomicTaskShardBinding{ + merge_disjoint_maps(std::vector{ + get_bindings(TensorRole::INPUT, FwbTensorType::FORWARD), + get_bindings(TensorRole::WEIGHT, FwbTensorType::FORWARD), + get_bindings(TensorRole::OUTPUT, FwbTensorType::FORWARD), + get_bindings(TensorRole::INPUT, FwbTensorType::GRADIENT), + get_bindings(TensorRole::WEIGHT, FwbTensorType::GRADIENT), + get_bindings(TensorRole::OUTPUT, FwbTensorType::GRADIENT), + }), + }; +} + +RuntimeAtomicTaskShardBinding lower_op_shard_binding_to_runtime_shard_binding( + OperatorAtomicTaskShardBinding const &shard_binding, + SymbolicLayerTrainingTensorGroupSignature const &signature, + FwbOpTaskType task_type) { + switch (task_type) { + case FwbOpTaskType::FWD: + return lower_op_shard_binding_to_fwd_pass_runtime_shard_binding( + shard_binding, signature); + case FwbOpTaskType::BWD: + return lower_op_shard_binding_to_bwd_pass_runtime_shard_binding( + shard_binding, signature); + default: + PANIC("Unhandled FwbOpTaskType", task_type); + } +} + +} // namespace FlexFlow diff --git a/lib/local-pcg-execution/test/CMakeLists.txt b/lib/local-pcg-execution/test/CMakeLists.txt new file mode 100644 index 0000000000..a7427fe351 --- /dev/null +++ b/lib/local-pcg-execution/test/CMakeLists.txt @@ -0,0 +1,16 @@ +ff_add_test_executable( + NAME + local-pcg-execution-tests + SRC_PATTERNS + src/*.cc + PRIVATE_INCLUDE + src/ + DEPS + doctest + utils-test-common + local-pcg-execution + kernels + op-attrs + task-spec +) + diff --git a/lib/local-pcg-execution/test/src/local-pcg-execution/local_pcg_training_backing.cc b/lib/local-pcg-execution/test/src/local-pcg-execution/local_pcg_training_backing.cc new file mode 100644 index 0000000000..a8cb61e63b --- /dev/null +++ b/lib/local-pcg-execution/test/src/local-pcg-execution/local_pcg_training_backing.cc @@ -0,0 +1,10 @@ +#include "local-pcg-execution/local_pcg_training_backing.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("LocalPcgTrainingBacking") { + CHECK_MESSAGE(false, "TODO: LocalPcgTrainingBacking"); + } +} diff --git a/lib/models/include/models/bert/bert_config.dtg.toml b/lib/models/include/models/bert/bert_config.dtg.toml new file mode 100644 index 0000000000..3ec803249f --- /dev/null +++ b/lib/models/include/models/bert/bert_config.dtg.toml @@ -0,0 +1,73 @@ +namespace = "FlexFlow" +name = "BertConfig" +type = "struct" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/activation.dtg.h", + "utils/positive_int/positive_int.h", +] + +[[fields]] +name = "vocab_size" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "hidden_size" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "num_encoder_layers" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "num_heads" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "dim_feedforward" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "hidden_act" +type = "::FlexFlow::Activation" + +[[fields]] +name = "hidden_dropout_prob" +type = "float" + +[[fields]] +name = "attention_probs_dropout_prob" +type = "float" + +[[fields]] +name = "initializer_range" +type = "float" + +[[fields]] +name = "layer_norm_eps" +type = "float" + +[[fields]] +name = "position_embedding_type" +type = "std::string" + +[[fields]] +name = "classifier_dropout" +type = "float" + +[[fields]] +name = "sequence_length" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "batch_size" +type = "::FlexFlow::positive_int" diff --git a/lib/models/include/models/bert/bert_config.struct.toml b/lib/models/include/models/bert/bert_config.struct.toml deleted file mode 100644 index de56a25710..0000000000 --- a/lib/models/include/models/bert/bert_config.struct.toml +++ /dev/null @@ -1,72 +0,0 @@ -namespace = "FlexFlow" -name = "BertConfig" - -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/activation.dtg.h", - "utils/positive_int/positive_int.h", -] - -[[fields]] -name = "vocab_size" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "hidden_size" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "num_encoder_layers" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "num_heads" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "dim_feedforward" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "hidden_act" -type = "::FlexFlow::Activation" - -[[fields]] -name = "hidden_dropout_prob" -type = "float" - -[[fields]] -name = "attention_probs_dropout_prob" -type = "float" - -[[fields]] -name = "initializer_range" -type = "float" - -[[fields]] -name = "layer_norm_eps" -type = "float" - -[[fields]] -name = "position_embedding_type" -type = "std::string" - -[[fields]] -name = "classifier_dropout" -type = "float" - -[[fields]] -name = "sequence_length" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "batch_size" -type = "::FlexFlow::positive_int" diff --git a/lib/models/include/models/candle_uno/candle_uno_config.dtg.toml b/lib/models/include/models/candle_uno/candle_uno_config.dtg.toml new file mode 100644 index 0000000000..c3b8ddf372 --- /dev/null +++ b/lib/models/include/models/candle_uno/candle_uno_config.dtg.toml @@ -0,0 +1,54 @@ +namespace = "FlexFlow" +name = "CandleUnoConfig" +type = "struct" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "", + "", + "", + "utils/positive_int/positive_int.h", +] + +src_includes = [ + "utils/fmt/vector.h", + "utils/fmt/map.h", + "utils/hash/vector.h", + "utils/hash/map.h", +] + +[[fields]] +name = "batch_size" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "dense_layers" +type = "std::vector<::FlexFlow::positive_int>" + +[[fields]] +name = "dense_feature_layers" +type = "std::vector<::FlexFlow::positive_int>" + +[[fields]] +name = "feature_shapes" +type = "std::map" + +[[fields]] +name = "input_features" +type = "std::map" + +[[fields]] +name = "dropout" +type = "float" + +[[fields]] +name = "residual" +type = "bool" diff --git a/lib/models/include/models/candle_uno/candle_uno_config.struct.toml b/lib/models/include/models/candle_uno/candle_uno_config.struct.toml deleted file mode 100644 index 135c58e1cc..0000000000 --- a/lib/models/include/models/candle_uno/candle_uno_config.struct.toml +++ /dev/null @@ -1,53 +0,0 @@ -namespace = "FlexFlow" -name = "CandleUnoConfig" - -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "", - "", - "", - "utils/positive_int/positive_int.h", -] - -src_includes = [ - "utils/fmt/vector.h", - "utils/fmt/map.h", - "utils/hash/vector.h", - "utils/hash/map.h", -] - -[[fields]] -name = "batch_size" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "dense_layers" -type = "std::vector<::FlexFlow::positive_int>" - -[[fields]] -name = "dense_feature_layers" -type = "std::vector<::FlexFlow::positive_int>" - -[[fields]] -name = "feature_shapes" -type = "std::map" - -[[fields]] -name = "input_features" -type = "std::map" - -[[fields]] -name = "dropout" -type = "float" - -[[fields]] -name = "residual" -type = "bool" diff --git a/lib/models/include/models/dlrm/dlrm_arch_interaction_op.dtg.toml b/lib/models/include/models/dlrm/dlrm_arch_interaction_op.dtg.toml new file mode 100644 index 0000000000..85018a248b --- /dev/null +++ b/lib/models/include/models/dlrm/dlrm_arch_interaction_op.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "DLRMArchInteractionOp" +type = "enum" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "DOT" + +[[values]] +name = "CAT" diff --git a/lib/models/include/models/dlrm/dlrm_arch_interaction_op.enum.toml b/lib/models/include/models/dlrm/dlrm_arch_interaction_op.enum.toml deleted file mode 100644 index 62410425da..0000000000 --- a/lib/models/include/models/dlrm/dlrm_arch_interaction_op.enum.toml +++ /dev/null @@ -1,14 +0,0 @@ -namespace = "FlexFlow" -name = "DLRMArchInteractionOp" -features = [ - "hash", - "json", - "rapidcheck", - "fmt", -] - -[[values]] -name = "DOT" - -[[values]] -name = "CAT" diff --git a/lib/models/include/models/dlrm/dlrm_config.dtg.toml b/lib/models/include/models/dlrm/dlrm_config.dtg.toml new file mode 100644 index 0000000000..70bff6a7ba --- /dev/null +++ b/lib/models/include/models/dlrm/dlrm_config.dtg.toml @@ -0,0 +1,56 @@ +namespace = "FlexFlow" +name = "DLRMConfig" +type = "struct" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "", + "", + "models/dlrm/dlrm_arch_interaction_op.dtg.h", + "utils/positive_int/positive_int.h", +] + +src_includes = [ + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "embedding_dim" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "embedding_bag_size" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "embedding_size" +type = "std::vector<::FlexFlow::positive_int>" + +[[fields]] +name = "dense_arch_layer_sizes" +type = "std::vector<::FlexFlow::positive_int>" + +[[fields]] +name = "over_arch_layer_sizes" +type = "std::vector<::FlexFlow::positive_int>" + +[[fields]] +name = "arch_interaction_op" +type = "::FlexFlow::DLRMArchInteractionOp" + +[[fields]] +name = "batch_size" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "seed" +type = "int" diff --git a/lib/models/include/models/dlrm/dlrm_config.struct.toml b/lib/models/include/models/dlrm/dlrm_config.struct.toml deleted file mode 100644 index 3cf43aed48..0000000000 --- a/lib/models/include/models/dlrm/dlrm_config.struct.toml +++ /dev/null @@ -1,55 +0,0 @@ -namespace = "FlexFlow" -name = "DLRMConfig" - -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "", - "", - "models/dlrm/dlrm_arch_interaction_op.dtg.h", - "utils/positive_int/positive_int.h", -] - -src_includes = [ - "utils/fmt/vector.h", - "utils/hash/vector.h", -] - -[[fields]] -name = "embedding_dim" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "embedding_bag_size" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "embedding_size" -type = "std::vector<::FlexFlow::positive_int>" - -[[fields]] -name = "dense_arch_layer_sizes" -type = "std::vector<::FlexFlow::positive_int>" - -[[fields]] -name = "over_arch_layer_sizes" -type = "std::vector<::FlexFlow::positive_int>" - -[[fields]] -name = "arch_interaction_op" -type = "::FlexFlow::DLRMArchInteractionOp" - -[[fields]] -name = "batch_size" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "seed" -type = "int" diff --git a/lib/models/include/models/inception_v3/inception_v3_config.dtg.toml b/lib/models/include/models/inception_v3/inception_v3_config.dtg.toml new file mode 100644 index 0000000000..3064ab4645 --- /dev/null +++ b/lib/models/include/models/inception_v3/inception_v3_config.dtg.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "InceptionV3Config" +type = "struct" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "utils/positive_int/positive_int.h", +] + +[[fields]] +name = "num_classes" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "batch_size" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "aux_logits" +type = "bool" diff --git a/lib/models/include/models/inception_v3/inception_v3_config.struct.toml b/lib/models/include/models/inception_v3/inception_v3_config.struct.toml deleted file mode 100644 index 0075783c87..0000000000 --- a/lib/models/include/models/inception_v3/inception_v3_config.struct.toml +++ /dev/null @@ -1,27 +0,0 @@ -namespace = "FlexFlow" -name = "InceptionV3Config" - -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "utils/positive_int/positive_int.h", -] - -[[fields]] -name = "num_classes" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "batch_size" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "aux_logits" -type = "bool" diff --git a/lib/models/include/models/inception_v3/inception_v3_output.dtg.toml b/lib/models/include/models/inception_v3/inception_v3_output.dtg.toml new file mode 100644 index 0000000000..f0233a658b --- /dev/null +++ b/lib/models/include/models/inception_v3/inception_v3_output.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "InceptionV3Output" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "pcg/tensor_guid_t.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/optional.h", +] + +[[fields]] +name = "standard_logits" +type = "::FlexFlow::tensor_guid_t" + +[[fields]] +name = "aux_logits" +type = "std::optional<::FlexFlow::tensor_guid_t>" diff --git a/lib/models/include/models/inception_v3/inception_v3_output.struct.toml b/lib/models/include/models/inception_v3/inception_v3_output.struct.toml deleted file mode 100644 index 066e6df02b..0000000000 --- a/lib/models/include/models/inception_v3/inception_v3_output.struct.toml +++ /dev/null @@ -1,25 +0,0 @@ -namespace = "FlexFlow" -name = "InceptionV3Output" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "pcg/tensor_guid_t.dtg.h", - "", -] - -src_includes = [ - "utils/fmt/optional.h", -] - -[[fields]] -name = "standard_logits" -type = "::FlexFlow::tensor_guid_t" - -[[fields]] -name = "aux_logits" -type = "std::optional<::FlexFlow::tensor_guid_t>" diff --git a/lib/models/include/models/transformer/transformer_config.dtg.toml b/lib/models/include/models/transformer/transformer_config.dtg.toml new file mode 100644 index 0000000000..aa8ce81274 --- /dev/null +++ b/lib/models/include/models/transformer/transformer_config.dtg.toml @@ -0,0 +1,55 @@ +namespace = "FlexFlow" +name = "TransformerConfig" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "utils/positive_int/positive_int.h", +] + +[[fields]] +name = "num_features" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "sequence_length" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "batch_size" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "dim_feedforward" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "num_heads" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "num_encoder_layers" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "num_decoder_layers" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "dropout" +type = "float" + +[[fields]] +name = "layer_norm_eps" +type = "float" + +[[fields]] +name = "vocab_size" +type = "::FlexFlow::positive_int" diff --git a/lib/models/include/models/transformer/transformer_config.struct.toml b/lib/models/include/models/transformer/transformer_config.struct.toml deleted file mode 100644 index 686491eff4..0000000000 --- a/lib/models/include/models/transformer/transformer_config.struct.toml +++ /dev/null @@ -1,54 +0,0 @@ -namespace = "FlexFlow" -name = "TransformerConfig" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "utils/positive_int/positive_int.h", -] - -[[fields]] -name = "num_features" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "sequence_length" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "batch_size" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "dim_feedforward" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "num_heads" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "num_encoder_layers" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "num_decoder_layers" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "dropout" -type = "float" - -[[fields]] -name = "layer_norm_eps" -type = "float" - -[[fields]] -name = "vocab_size" -type = "::FlexFlow::positive_int" diff --git a/lib/op-attrs/include/op-attrs/activation.dtg.toml b/lib/op-attrs/include/op-attrs/activation.dtg.toml new file mode 100644 index 0000000000..33e2b54e45 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/activation.dtg.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "Activation" +type = "enum" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "RELU" + +[[values]] +name = "SIGMOID" + +[[values]] +name = "TANH" + +[[values]] +name = "GELU" diff --git a/lib/op-attrs/include/op-attrs/activation.enum.toml b/lib/op-attrs/include/op-attrs/activation.enum.toml deleted file mode 100644 index 66119da9b1..0000000000 --- a/lib/op-attrs/include/op-attrs/activation.enum.toml +++ /dev/null @@ -1,20 +0,0 @@ -namespace = "FlexFlow" -name = "Activation" -features = [ - "hash", - "json", - "rapidcheck", - "fmt", -] - -[[values]] -name = "RELU" - -[[values]] -name = "SIGMOID" - -[[values]] -name = "TANH" - -[[values]] -name = "GELU" diff --git a/lib/op-attrs/include/op-attrs/aggregate_op.dtg.toml b/lib/op-attrs/include/op-attrs/aggregate_op.dtg.toml new file mode 100644 index 0000000000..1cf2e00e8e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/aggregate_op.dtg.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "AggregateOp" +type = "enum" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "SUM" + +[[values]] +name = "AVG" + diff --git a/lib/op-attrs/include/op-attrs/aggregate_op.enum.toml b/lib/op-attrs/include/op-attrs/aggregate_op.enum.toml deleted file mode 100644 index 09ee99915d..0000000000 --- a/lib/op-attrs/include/op-attrs/aggregate_op.enum.toml +++ /dev/null @@ -1,15 +0,0 @@ -namespace = "FlexFlow" -name = "AggregateOp" -features = [ - "hash", - "json", - "rapidcheck", - "fmt", -] - -[[values]] -name = "SUM" - -[[values]] -name = "AVG" - diff --git a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.dtg.toml new file mode 100644 index 0000000000..c8c646bd19 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.dtg.toml @@ -0,0 +1,144 @@ +namespace = "FlexFlow" +name = "ComputationGraphOpAttrs" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ops/attention_attrs.dtg.h", + "op-attrs/ops/batch_matmul_attrs.dtg.h", + "op-attrs/ops/batch_norm_attrs.dtg.h", + "op-attrs/ops/broadcast_attrs.dtg.h", + "op-attrs/ops/cast_attrs.dtg.h", + "op-attrs/ops/concat_attrs.dtg.h", + "op-attrs/ops/conv_2d_attrs.dtg.h", + "op-attrs/ops/dropout_attrs.dtg.h", + "op-attrs/ops/element_binary_attrs.dtg.h", + "op-attrs/ops/element_unary_attrs.dtg.h", + "op-attrs/ops/embedding_attrs.dtg.h", + "op-attrs/ops/flat_attrs.dtg.h", + "op-attrs/ops/gather_attrs.dtg.h", + "op-attrs/ops/input_attrs.dtg.h", + "op-attrs/ops/layer_norm_attrs.dtg.h", + "op-attrs/ops/linear_attrs.dtg.h", + "op-attrs/ops/noop_attrs.dtg.h", + "op-attrs/ops/pool_2d_attrs.dtg.h", + "op-attrs/ops/reduce_attrs.dtg.h", + "op-attrs/ops/reshape_attrs.dtg.h", + "op-attrs/ops/reverse_attrs.dtg.h", + "op-attrs/ops/softmax_attrs.dtg.h", + "op-attrs/ops/split_attrs.dtg.h", + "op-attrs/ops/topk_attrs.dtg.h", + "op-attrs/ops/transpose_attrs.dtg.h", + "op-attrs/ops/weight_attrs.dtg.h", +] + +[[values]] +type = "::FlexFlow::BatchMatmulAttrs" +key = "batch_matmul" + +[[values]] +type = "::FlexFlow::BatchNormAttrs" +key = "batch_norm" + +[[values]] +type = "::FlexFlow::BroadcastAttrs" +key = "broadcast" + +[[values]] +type = "::FlexFlow::CastAttrs" +key = "cast" + +[[values]] +type = "::FlexFlow::ConcatAttrs" +key = "concat" + +[[values]] +type = "::FlexFlow::Conv2DAttrs" +key = "conv2d" + +[[values]] +type = "::FlexFlow::DropoutAttrs" +key = "dropout" + +[[values]] +type = "::FlexFlow::ElementBinaryAttrs" +key = "element_binary" + +[[values]] +type = "::FlexFlow::ElementUnaryAttrs" +key = "element_unary" + +[[values]] +type = "::FlexFlow::EmbeddingAttrs" +key = "embedding" + +[[values]] +type = "::FlexFlow::FlatAttrs" +key = "flat" + +[[values]] +type = "::FlexFlow::GatherAttrs" +key = "gather" + +[[values]] +type = "::FlexFlow::InputAttrs" +key = "input" + +[[values]] +type = "::FlexFlow::LayerNormAttrs" +key = "layer_norm" + +[[values]] +type = "::FlexFlow::LinearAttrs" +key = "linear" + +[[values]] +type = "::FlexFlow::MultiHeadAttentionAttrs" +key = "multi_head_attention" + +[[values]] +type = "::FlexFlow::NoopAttrs" +key = "noop" + +[[values]] +type = "::FlexFlow::Pool2DAttrs" +key = "pool2d" + +[[values]] +type = "::FlexFlow::ReduceAttrs" +key = "reduce" + +[[values]] +type = "::FlexFlow::ReverseAttrs" +key = "reverse" + +[[values]] +type = "::FlexFlow::ReshapeAttrs" +key = "reshape" + +[[values]] +type = "::FlexFlow::SplitAttrs" +key = "split" + +[[values]] +type = "::FlexFlow::SoftmaxAttrs" +key = "softmax" + +[[values]] +type = "::FlexFlow::TopKAttrs" +key = "topk" + +[[values]] +type = "::FlexFlow::TransposeAttrs" +key = "transpose" + +[[values]] +type = "::FlexFlow::WeightAttrs" +key = "weight" diff --git a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h index 52e6e12a8c..fd0707aa2e 100644 --- a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h +++ b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h @@ -9,7 +9,7 @@ namespace FlexFlow { OperatorType get_op_type(ComputationGraphOpAttrs const &); RecordFormatter as_dot(ComputationGraphOpAttrs const &); -ComputationGraphOpAttrs +std::optional compgraph_op_attrs_from_pcg_op_attrs(PCGOperatorAttrs const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml deleted file mode 100644 index f1c5fe6b23..0000000000 --- a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml +++ /dev/null @@ -1,143 +0,0 @@ -namespace = "FlexFlow" -name = "ComputationGraphOpAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/ops/attention_attrs.dtg.h", - "op-attrs/ops/batch_matmul_attrs.dtg.h", - "op-attrs/ops/batch_norm_attrs.dtg.h", - "op-attrs/ops/broadcast_attrs.dtg.h", - "op-attrs/ops/cast_attrs.dtg.h", - "op-attrs/ops/concat_attrs.dtg.h", - "op-attrs/ops/conv_2d_attrs.dtg.h", - "op-attrs/ops/dropout_attrs.dtg.h", - "op-attrs/ops/element_binary_attrs.dtg.h", - "op-attrs/ops/element_unary_attrs.dtg.h", - "op-attrs/ops/embedding_attrs.dtg.h", - "op-attrs/ops/flat_attrs.dtg.h", - "op-attrs/ops/gather_attrs.dtg.h", - "op-attrs/ops/input_attrs.dtg.h", - "op-attrs/ops/layer_norm_attrs.dtg.h", - "op-attrs/ops/linear_attrs.dtg.h", - "op-attrs/ops/noop_attrs.dtg.h", - "op-attrs/ops/pool_2d_attrs.dtg.h", - "op-attrs/ops/reduce_attrs.dtg.h", - "op-attrs/ops/reshape_attrs.dtg.h", - "op-attrs/ops/reverse_attrs.dtg.h", - "op-attrs/ops/softmax_attrs.dtg.h", - "op-attrs/ops/split_attrs.dtg.h", - "op-attrs/ops/topk_attrs.dtg.h", - "op-attrs/ops/transpose_attrs.dtg.h", - "op-attrs/ops/weight_attrs.dtg.h", -] - -[[values]] -type = "::FlexFlow::BatchMatmulAttrs" -key = "batch_matmul" - -[[values]] -type = "::FlexFlow::BatchNormAttrs" -key = "batch_norm" - -[[values]] -type = "::FlexFlow::BroadcastAttrs" -key = "broadcast" - -[[values]] -type = "::FlexFlow::CastAttrs" -key = "cast" - -[[values]] -type = "::FlexFlow::ConcatAttrs" -key = "concat" - -[[values]] -type = "::FlexFlow::Conv2DAttrs" -key = "conv2d" - -[[values]] -type = "::FlexFlow::DropoutAttrs" -key = "dropout" - -[[values]] -type = "::FlexFlow::ElementBinaryAttrs" -key = "element_binary" - -[[values]] -type = "::FlexFlow::ElementUnaryAttrs" -key = "element_unary" - -[[values]] -type = "::FlexFlow::EmbeddingAttrs" -key = "embedding" - -[[values]] -type = "::FlexFlow::FlatAttrs" -key = "flat" - -[[values]] -type = "::FlexFlow::GatherAttrs" -key = "gather" - -[[values]] -type = "::FlexFlow::InputAttrs" -key = "input" - -[[values]] -type = "::FlexFlow::LayerNormAttrs" -key = "layer_norm" - -[[values]] -type = "::FlexFlow::LinearAttrs" -key = "linear" - -[[values]] -type = "::FlexFlow::MultiHeadAttentionAttrs" -key = "multi_head_attention" - -[[values]] -type = "::FlexFlow::NoopAttrs" -key = "noop" - -[[values]] -type = "::FlexFlow::Pool2DAttrs" -key = "pool2d" - -[[values]] -type = "::FlexFlow::ReduceAttrs" -key = "reduce" - -[[values]] -type = "::FlexFlow::ReverseAttrs" -key = "reverse" - -[[values]] -type = "::FlexFlow::ReshapeAttrs" -key = "reshape" - -[[values]] -type = "::FlexFlow::SplitAttrs" -key = "split" - -[[values]] -type = "::FlexFlow::SoftmaxAttrs" -key = "softmax" - -[[values]] -type = "::FlexFlow::TopKAttrs" -key = "topk" - -[[values]] -type = "::FlexFlow::TransposeAttrs" -key = "transpose" - -[[values]] -type = "::FlexFlow::WeightAttrs" -key = "weight" diff --git a/lib/op-attrs/include/op-attrs/datatype.dtg.toml b/lib/op-attrs/include/op-attrs/datatype.dtg.toml new file mode 100644 index 0000000000..792160be4a --- /dev/null +++ b/lib/op-attrs/include/op-attrs/datatype.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "DataType" +type = "enum" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "BOOL" + +[[values]] +name = "INT32" + +[[values]] +name = "INT64" + +[[values]] +name = "HALF" + +[[values]] +name = "FLOAT" + +[[values]] +name = "DOUBLE" diff --git a/lib/op-attrs/include/op-attrs/datatype.enum.toml b/lib/op-attrs/include/op-attrs/datatype.enum.toml deleted file mode 100644 index 15210cfe29..0000000000 --- a/lib/op-attrs/include/op-attrs/datatype.enum.toml +++ /dev/null @@ -1,26 +0,0 @@ -namespace = "FlexFlow" -name = "DataType" -features = [ - "hash", - "json", - "rapidcheck", - "fmt", -] - -[[values]] -name = "BOOL" - -[[values]] -name = "INT32" - -[[values]] -name = "INT64" - -[[values]] -name = "HALF" - -[[values]] -name = "FLOAT" - -[[values]] -name = "DOUBLE" diff --git a/lib/op-attrs/include/op-attrs/datatype_value.dtg.toml b/lib/op-attrs/include/op-attrs/datatype_value.dtg.toml new file mode 100644 index 0000000000..c72c3c60be --- /dev/null +++ b/lib/op-attrs/include/op-attrs/datatype_value.dtg.toml @@ -0,0 +1,39 @@ +namespace = "FlexFlow" +name = "DataTypeValue" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "utils/half.h", +] + +src_includes = [ + "utils/json/half.h", + "utils/rapidcheck/half.h", + "utils/fmt/half.h", +] + +[[values]] +type = "half" + +[[values]] +type = "float" + +[[values]] +type = "double" + +[[values]] +type = "int32_t" + +[[values]] +type = "int64_t" + +[[values]] +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/datatype_value.variant.toml b/lib/op-attrs/include/op-attrs/datatype_value.variant.toml deleted file mode 100644 index 4c867917b0..0000000000 --- a/lib/op-attrs/include/op-attrs/datatype_value.variant.toml +++ /dev/null @@ -1,38 +0,0 @@ -namespace = "FlexFlow" -name = "DataTypeValue" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "utils/half.h", -] - -src_includes = [ - "utils/json/half.h", - "utils/rapidcheck/half.h", - "utils/fmt/half.h", -] - -[[values]] -type = "half" - -[[values]] -type = "float" - -[[values]] -type = "double" - -[[values]] -type = "int32_t" - -[[values]] -type = "int64_t" - -[[values]] -type = "bool" diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h deleted file mode 100644 index 5c47745209..0000000000 --- a/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h +++ /dev/null @@ -1,200 +0,0 @@ -#ifndef _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_FF_STACK_VECTOR_H -#define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_FF_STACK_VECTOR_H - -#include "op-attrs/ff_dim_t.dtg.h" -#include "op-attrs/relative_ff_dim_t.dtg.h" -#include "utils/containers/range.h" -#include "utils/fmt/vector.h" -#include "utils/stack_vector/stack_vector.h" -#include - -namespace FlexFlow { - -template -struct DimOrdered { - DimOrdered() {} - - DimOrdered(std::initializer_list const &l) - : contents(l.begin(), l.end()) {} - - DimOrdered(std::vector const &contents) - : contents(contents.begin(), contents.end()) {} - - template - DimOrdered(It begin, It end) : contents(begin, end) {} - - template - DimOrdered(stack_vector const &contents) - : contents(contents.begin(), contents.end()) {} - - T const &at(Idx idx) const { - nonnegative_int raw = idx.value; - return this->contents.at(raw.unwrap_nonnegative()); - } - - T &at(Idx idx) { - nonnegative_int raw = idx.value; - return this->contents.at(raw.unwrap_nonnegative()); - } - - T const &operator[](Idx idx) const { - return this->at(idx); - } - - T &operator[](Idx idx) { - return this->at(idx); - } - - bool idx_is_valid(Idx const &idx) const { - nonnegative_int raw = idx.value; - return (raw < this->contents.size()); - } - - bool operator==(DimOrdered const &other) const { - return this->contents == other.contents; - } - - bool operator!=(DimOrdered const &other) const { - return this->contents != other.contents; - } - - using iterator = typename stack_vector::iterator; - using const_iterator = - typename stack_vector::const_iterator; - using reverse_iterator = - typename stack_vector::reverse_iterator; - using const_reverse_iterator = - typename stack_vector::const_reverse_iterator; - using value_type = T; - using pointer = value_type *; - using const_pointer = value_type const *; - using reference = value_type &; - using const_reference = value_type const &; - - iterator begin() { - return this->contents.begin(); - } - - const_iterator begin() const { - return this->cbegin(); - } - - const_iterator cbegin() const { - return this->contents.cbegin(); - } - - iterator end() { - return this->contents.end(); - } - - const_iterator end() const { - return this->cend(); - } - - const_iterator cend() const { - return this->contents.cend(); - } - - reverse_iterator rbegin() { - return this->contents.rbegin(); - } - - const_reverse_iterator rbegin() const { - return this->crbegin(); - } - - const_reverse_iterator crbegin() const { - return this->contents.crbegin(); - } - - reverse_iterator rend() { - return this->contents.rend(); - } - - const_reverse_iterator rend() const { - return this->crend(); - } - - const_reverse_iterator crend() const { - return this->contents.crend(); - } - - size_t size() const { - return this->contents.size(); - } - - size_t empty() const { - return this->contents.empty(); - } - - size_t num_dims() const { - return this->size(); - } - - friend struct ::std::hash; - -private: - stack_vector contents; -}; - -template -auto operator<(DimOrdered const &lhs, DimOrdered const &rhs) - -> std::enable_if_t, bool> { - return std::lexicographical_compare( - lhs.cbegin(), lhs.cend(), rhs.cbegin(), rhs.cend()); -} - -template -std::string format_as(DimOrdered const &v) { - std::vector as_vec(v.cbegin(), v.cend()); - return fmt::format("", as_vec); -} - -template -std::ostream &operator<<(std::ostream &s, DimOrdered const &v) { - return (s << fmt::to_string(v)); -} - -} // namespace FlexFlow - -namespace nlohmann { -template -struct adl_serializer<::FlexFlow::DimOrdered> { - static ::FlexFlow::DimOrdered from_json(nlohmann::json const &j) { - return {j.template get>()}; - } - - static void to_json(nlohmann::json &j, - ::FlexFlow::DimOrdered const &x) { - j = std::vector{x.cbegin(), x.cend()}; - } -}; -} // namespace nlohmann - -namespace std { - -template -struct hash<::FlexFlow::DimOrdered> { - size_t operator()(::FlexFlow::DimOrdered const &t) const { - static_assert(::FlexFlow::is_hashable::value, - "Elements must be hashable"); - - return get_std_hash(t.contents); - } -}; - -} // namespace std - -namespace rc { - -template -struct Arbitrary<::FlexFlow::DimOrdered> { - static Gen<::FlexFlow::DimOrdered> arbitrary() { - return gen::construct<::FlexFlow::DimOrdered>( - gen::arbitrary<::FlexFlow::stack_vector>()); - } -}; - -} // namespace rc - -#endif diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h deleted file mode 100644 index 76526447be..0000000000 --- a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h +++ /dev/null @@ -1,32 +0,0 @@ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_SLICE_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_SLICE_H - -#include "op-attrs/dim_ordered/dim_ordered.h" -#include "utils/containers/slice.h" -#include "utils/containers/transform.h" -#include "utils/containers/vector_of.h" -#include "utils/optional.h" - -namespace FlexFlow { - -template -DimOrdered nonoverloaded_slice(DimOrdered const &d, - std::optional const &start, - std::optional const &end) { - auto to_raw_idx = [](std::optional const &idx) -> std::optional { - return transform(idx, [](Idx const &i) { return i.value; }); - }; - - return DimOrdered{ - slice(vector_of(d), to_raw_idx(start), to_raw_idx(end))}; -} -template -DimOrdered slice(DimOrdered const &d, - std::optional const &start = std::nullopt, - std::optional const &end = std::nullopt) { - return ff_dim_t_nonoverloaded_slice(d, start, end); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/transform.h b/lib/op-attrs/include/op-attrs/dim_ordered/transform.h deleted file mode 100644 index 4fd3df0abb..0000000000 --- a/lib/op-attrs/include/op-attrs/dim_ordered/transform.h +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_TRANSFORM_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_TRANSFORM_H - -#include "op-attrs/dim_ordered/dim_ordered.h" -#include "utils/containers/vector_of.h" -#include "utils/containers/vector_transform.h" - -namespace FlexFlow { - -template -DimOrdered> - transform(DimOrdered const &d, F f) { - using Out = std::invoke_result_t; - - return DimOrdered{vector_transform(vector_of(d), f)}; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/zip.h b/lib/op-attrs/include/op-attrs/dim_ordered/zip.h deleted file mode 100644 index cc8b050f50..0000000000 --- a/lib/op-attrs/include/op-attrs/dim_ordered/zip.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ZIP_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ZIP_H - -#include "op-attrs/dim_ordered/dim_ordered.h" -#include "utils/containers/vector_of.h" -#include "utils/containers/zip.h" - -namespace FlexFlow { - -template -DimOrdered> zip(DimOrdered const &lhs, - DimOrdered const &rhs) { - return DimOrdered>{ - zip(vector_of(lhs), vector_of(rhs))}; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/op-attrs/include/op-attrs/ff_dim_t.dtg.toml b/lib/op-attrs/include/op-attrs/ff_dim_t.dtg.toml new file mode 100644 index 0000000000..fe009b7ddb --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ff_dim_t.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "ff_dim_t" +type = "struct" + +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "utils/nonnegative_int/nonnegative_int.h" +] + +[[fields]] +name = "value" +type = "::FlexFlow::nonnegative_int" diff --git a/lib/op-attrs/include/op-attrs/ff_dim_t.h b/lib/op-attrs/include/op-attrs/ff_dim_t.h index 0979201f67..1411886eee 100644 --- a/lib/op-attrs/include/op-attrs/ff_dim_t.h +++ b/lib/op-attrs/include/op-attrs/ff_dim_t.h @@ -11,6 +11,8 @@ relative_ff_dim_t relative_ff_dim_t_from_ff_dim_t(ff_dim_t ff_dim); ff_dim_t add_to_ff_dim(ff_dim_t ff_dim, int value); +std::vector ff_dim_range(nonnegative_int num_elements); + } // namespace FlexFlow namespace rc { diff --git a/lib/op-attrs/include/op-attrs/ff_dim_t.struct.toml b/lib/op-attrs/include/op-attrs/ff_dim_t.struct.toml deleted file mode 100644 index 38f51da4a1..0000000000 --- a/lib/op-attrs/include/op-attrs/ff_dim_t.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "ff_dim_t" - -features = [ - "eq", - "ord", - "hash", - "json", - "fmt", -] - -includes = [ - "utils/nonnegative_int/nonnegative_int.h" -] - -[[fields]] -name = "value" -type = "::FlexFlow::nonnegative_int" diff --git a/lib/op-attrs/include/op-attrs/ff_ordered/enumerate.h b/lib/op-attrs/include/op-attrs/ff_ordered/enumerate.h index bc8636615c..fbd828bf37 100644 --- a/lib/op-attrs/include/op-attrs/ff_ordered/enumerate.h +++ b/lib/op-attrs/include/op-attrs/ff_ordered/enumerate.h @@ -3,7 +3,7 @@ #include "op-attrs/ff_ordered/ff_ordered.h" #include "utils/bidict/bidict.h" -#include "utils/containers/count.h" +#include "utils/containers/range.h" namespace FlexFlow { @@ -18,7 +18,7 @@ namespace FlexFlow { template std::map enumerate(FFOrdered const &ff_ordered) { std::map result; - for (int raw_ff_dim : count(ff_ordered.size())) { + for (int raw_ff_dim : range(ff_ordered.size())) { ff_dim_t ff_dim = ff_dim_t{nonnegative_int{raw_ff_dim}}; result.insert({ff_dim, ff_ordered.at(ff_dim)}); } diff --git a/lib/op-attrs/include/op-attrs/ff_ordered/get_idxs.h b/lib/op-attrs/include/op-attrs/ff_ordered/get_idxs.h index 5ff390d3fe..1ee9f6b51c 100644 --- a/lib/op-attrs/include/op-attrs/ff_ordered/get_idxs.h +++ b/lib/op-attrs/include/op-attrs/ff_ordered/get_idxs.h @@ -3,14 +3,15 @@ #include "op-attrs/ff_dim_t.h" #include "op-attrs/ff_ordered/ff_ordered.h" -#include "utils/containers/count.h" +#include "utils/containers/range.h" +#include "utils/containers/set_of.h" #include "utils/containers/transform.h" namespace FlexFlow { template -std::vector get_idxs(FFOrdered const &d) { - return transform(count(d.size()), +std::set get_idxs(FFOrdered const &d) { + return transform(set_of(range(d.size())), [](int i) { return ff_dim_t{nonnegative_int{i}}; }); } diff --git a/lib/op-attrs/include/op-attrs/ff_ordered/map_from_ff_ordered.h b/lib/op-attrs/include/op-attrs/ff_ordered/map_from_ff_ordered.h new file mode 100644 index 0000000000..4a7e564c20 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ff_ordered/map_from_ff_ordered.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_ORDERED_MAP_FROM_FF_ORDERED_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_ORDERED_MAP_FROM_FF_ORDERED_H + +#include "op-attrs/ff_dim_t.h" +#include "op-attrs/ff_ordered/ff_ordered.h" +#include "utils/nonnegative_int/num_elements.h" + +namespace FlexFlow { + +template +std::unordered_map map_from_ff_ordered(FFOrdered const &m) { + std::unordered_map result; + + for (ff_dim_t d : ff_dim_range(num_elements(m))) { + result.insert({d, m.at(d)}); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/get_incoming_tensor_roles.h b/lib/op-attrs/include/op-attrs/get_incoming_tensor_roles.h index b395736773..0ad9a9c062 100644 --- a/lib/op-attrs/include/op-attrs/get_incoming_tensor_roles.h +++ b/lib/op-attrs/include/op-attrs/get_incoming_tensor_roles.h @@ -4,13 +4,14 @@ #include "op-attrs/computation_graph_op_attrs.dtg.h" #include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/pcg_operator_attrs.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" namespace FlexFlow { -std::vector - get_incoming_tensor_roles(ComputationGraphOpAttrs const &, int num_inputs); -std::vector - get_incoming_tensor_roles(PCGOperatorAttrs const &, int num_inputs); +std::unordered_map + get_incoming_tensor_roles(ComputationGraphOpAttrs const &); +std::unordered_map + get_incoming_tensor_roles(PCGOperatorAttrs const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/get_operator_space_to_parallel_tensor_space_mappings.h b/lib/op-attrs/include/op-attrs/get_operator_space_to_parallel_tensor_space_mappings.h new file mode 100644 index 0000000000..3a6b4732e6 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/get_operator_space_to_parallel_tensor_space_mappings.h @@ -0,0 +1,61 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_GET_OPERATOR_SPACE_TO_PARALLEL_TENSOR_SPACE_MAPPINGS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_GET_OPERATOR_SPACE_TO_PARALLEL_TENSOR_SPACE_MAPPINGS_H + +#include "op-attrs/computation_graph_op_attrs.dtg.h" +#include "op-attrs/incoming_tensor_role.dtg.h" +#include "op-attrs/num_ptensor_parallel_dims_t.h" +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.dtg.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" +#include "op-attrs/tensor_role.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" +#include + +namespace FlexFlow { + +std::unordered_map + get_operator_to_incoming_mappings( + ComputationGraphOpAttrs const &attrs, + std::unordered_map const + &inputs_degrees); + +std::unordered_map + get_operator_to_incoming_mappings_for_role( + ComputationGraphOpAttrs const &attrs, + std::unordered_map const + &inputs_degrees, + IncomingTensorRole role); + +std::unordered_map + get_operator_to_input_mappings( + ComputationGraphOpAttrs const &attrs, + std::unordered_map const + &inputs_degrees); + +std::unordered_map + get_operator_to_weight_mappings( + ComputationGraphOpAttrs const &attrs, + std::unordered_map const + &inputs_degrees); + +std::unordered_map + get_operator_to_output_mappings( + ComputationGraphOpAttrs const &attrs, + std::unordered_map const + &inputs_degrees); + +std::unordered_map + get_operator_to_ptensor_mappings_for_role( + ComputationGraphOpAttrs const &attrs, + std::unordered_map const + &inputs_degrees, + TensorRole role); + +std::unordered_map + get_operator_to_ptensor_mappings( + ComputationGraphOpAttrs const &attrs, + std::unordered_map const + &inputs_degrees); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/get_operator_task_space.h b/lib/op-attrs/include/op-attrs/get_operator_task_space.h new file mode 100644 index 0000000000..9ee4a8779a --- /dev/null +++ b/lib/op-attrs/include/op-attrs/get_operator_task_space.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_GET_OPERATOR_TASK_SPACE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_GET_OPERATOR_TASK_SPACE_H + +#include "op-attrs/computation_graph_op_attrs.dtg.h" +#include "op-attrs/operator_task_space.dtg.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" + +namespace FlexFlow { + +OperatorTaskSpace get_operator_task_space( + ComputationGraphOpAttrs const &attrs, + std::unordered_map const + &inputs_degrees); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/incoming_tensor_role.dtg.toml b/lib/op-attrs/include/op-attrs/incoming_tensor_role.dtg.toml new file mode 100644 index 0000000000..57931c8f1e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/incoming_tensor_role.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "IncomingTensorRole" +type = "enum" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "INPUT" + +[[values]] +name = "WEIGHT" diff --git a/lib/op-attrs/include/op-attrs/incoming_tensor_role.enum.toml b/lib/op-attrs/include/op-attrs/incoming_tensor_role.enum.toml deleted file mode 100644 index 427701c801..0000000000 --- a/lib/op-attrs/include/op-attrs/incoming_tensor_role.enum.toml +++ /dev/null @@ -1,14 +0,0 @@ -namespace = "FlexFlow" -name = "IncomingTensorRole" -features = [ - "hash", - "fmt", - "rapidcheck", - "json", -] - -[[values]] -name = "INPUT" - -[[values]] -name = "WEIGHT" diff --git a/lib/op-attrs/include/op-attrs/initializer_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/initializer_attrs.dtg.toml new file mode 100644 index 0000000000..6d5f03d657 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/initializer_attrs.dtg.toml @@ -0,0 +1,54 @@ +namespace = "FlexFlow" +name = "InitializerAttrs" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", + "rapidcheck", +] + +includes = [ + "op-attrs/initializers/glorot_uniform_attrs.dtg.h", + "op-attrs/initializers/glorot_normal_attrs.dtg.h", + "op-attrs/initializers/kaiming_normal_attrs.dtg.h", + "op-attrs/initializers/zero_initializer_attrs.dtg.h", + "op-attrs/initializers/uniform_initializer_attrs.h", + "op-attrs/initializers/norm_initializer_attrs.dtg.h", + "op-attrs/initializers/truncated_normal_initializer_attrs.dtg.h", + "op-attrs/initializers/constant_initializer_attrs.dtg.h", +] + +[[values]] +type = "::FlexFlow::GlorotUniformAttrs" +key = "glorot_uniform" + +[[values]] +type = "::FlexFlow::GlorotNormalAttrs" +key = "glorot_normal" + +[[values]] +type = "::FlexFlow::KaimingNormalAttrs" +key = "kaiming_normal" + +[[values]] +type = "::FlexFlow::ZeroInitializerAttrs" +key = "zero" + +[[values]] +type = "::FlexFlow::UniformInitializerAttrs" +key = "uniform" + +[[values]] +type = "::FlexFlow::NormInitializerAttrs" +key = "normal" + +[[values]] +type = "::FlexFlow::TruncatedNormalInitializerAttrs" +key = "truncated_normal" + +[[values]] +type = "::FlexFlow::ConstantInitializerAttrs" +key = "constant" diff --git a/lib/op-attrs/include/op-attrs/initializer_attrs.variant.toml b/lib/op-attrs/include/op-attrs/initializer_attrs.variant.toml deleted file mode 100644 index 108caf1203..0000000000 --- a/lib/op-attrs/include/op-attrs/initializer_attrs.variant.toml +++ /dev/null @@ -1,53 +0,0 @@ -namespace = "FlexFlow" -name = "InitializerAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "fmt", - "rapidcheck", -] - -includes = [ - "op-attrs/initializers/glorot_uniform_attrs.dtg.h", - "op-attrs/initializers/glorot_normal_attrs.dtg.h", - "op-attrs/initializers/kaiming_normal_attrs.dtg.h", - "op-attrs/initializers/zero_initializer_attrs.dtg.h", - "op-attrs/initializers/uniform_initializer_attrs.h", - "op-attrs/initializers/norm_initializer_attrs.dtg.h", - "op-attrs/initializers/truncated_normal_initializer_attrs.dtg.h", - "op-attrs/initializers/constant_initializer_attrs.dtg.h", -] - -[[values]] -type = "::FlexFlow::GlorotUniformAttrs" -key = "glorot_uniform" - -[[values]] -type = "::FlexFlow::GlorotNormalAttrs" -key = "glorot_normal" - -[[values]] -type = "::FlexFlow::KaimingNormalAttrs" -key = "kaiming_normal" - -[[values]] -type = "::FlexFlow::ZeroInitializerAttrs" -key = "zero" - -[[values]] -type = "::FlexFlow::UniformInitializerAttrs" -key = "uniform" - -[[values]] -type = "::FlexFlow::NormInitializerAttrs" -key = "normal" - -[[values]] -type = "::FlexFlow::TruncatedNormalInitializerAttrs" -key = "truncated_normal" - -[[values]] -type = "::FlexFlow::ConstantInitializerAttrs" -key = "constant" diff --git a/lib/op-attrs/include/op-attrs/initializers/constant_initializer_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/initializers/constant_initializer_attrs.dtg.toml new file mode 100644 index 0000000000..6762e14990 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/initializers/constant_initializer_attrs.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "ConstantInitializerAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/datatype_value.dtg.h", +] + +[[fields]] +name = "value" +type = "::FlexFlow::DataTypeValue" diff --git a/lib/op-attrs/include/op-attrs/initializers/constant_initializer_attrs.struct.toml b/lib/op-attrs/include/op-attrs/initializers/constant_initializer_attrs.struct.toml deleted file mode 100644 index 4e3c31bd36..0000000000 --- a/lib/op-attrs/include/op-attrs/initializers/constant_initializer_attrs.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "ConstantInitializerAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/datatype_value.dtg.h", -] - -[[fields]] -name = "value" -type = "::FlexFlow::DataTypeValue" diff --git a/lib/op-attrs/include/op-attrs/initializers/glorot_normal_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/initializers/glorot_normal_attrs.dtg.toml new file mode 100644 index 0000000000..fbe2c0bcba --- /dev/null +++ b/lib/op-attrs/include/op-attrs/initializers/glorot_normal_attrs.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "GlorotNormalAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "seed" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/initializers/glorot_normal_attrs.struct.toml b/lib/op-attrs/include/op-attrs/initializers/glorot_normal_attrs.struct.toml deleted file mode 100644 index fd0d8eb9be..0000000000 --- a/lib/op-attrs/include/op-attrs/initializers/glorot_normal_attrs.struct.toml +++ /dev/null @@ -1,14 +0,0 @@ -namespace = "FlexFlow" -name = "GlorotNormalAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -[[fields]] -name = "seed" -type = "int" diff --git a/lib/op-attrs/include/op-attrs/initializers/glorot_uniform_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/initializers/glorot_uniform_attrs.dtg.toml new file mode 100644 index 0000000000..6df50dc693 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/initializers/glorot_uniform_attrs.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "GlorotUniformAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "seed" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/initializers/glorot_uniform_attrs.struct.toml b/lib/op-attrs/include/op-attrs/initializers/glorot_uniform_attrs.struct.toml deleted file mode 100644 index de7f9141b0..0000000000 --- a/lib/op-attrs/include/op-attrs/initializers/glorot_uniform_attrs.struct.toml +++ /dev/null @@ -1,14 +0,0 @@ -namespace = "FlexFlow" -name = "GlorotUniformAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -[[fields]] -name = "seed" -type = "int" diff --git a/lib/op-attrs/include/op-attrs/initializers/kaiming_initializer_mode.dtg.toml b/lib/op-attrs/include/op-attrs/initializers/kaiming_initializer_mode.dtg.toml new file mode 100644 index 0000000000..7d8d24830a --- /dev/null +++ b/lib/op-attrs/include/op-attrs/initializers/kaiming_initializer_mode.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "KaimingInitializerMode" +type = "enum" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "FAN_IN" + +[[values]] +name = "FAN_OUT" diff --git a/lib/op-attrs/include/op-attrs/initializers/kaiming_initializer_mode.enum.toml b/lib/op-attrs/include/op-attrs/initializers/kaiming_initializer_mode.enum.toml deleted file mode 100644 index 46af896917..0000000000 --- a/lib/op-attrs/include/op-attrs/initializers/kaiming_initializer_mode.enum.toml +++ /dev/null @@ -1,14 +0,0 @@ -namespace = "FlexFlow" -name = "KaimingInitializerMode" -features = [ - "hash", - "fmt", - "rapidcheck", - "json", -] - -[[values]] -name = "FAN_IN" - -[[values]] -name = "FAN_OUT" diff --git a/lib/op-attrs/include/op-attrs/initializers/kaiming_initializer_nonlinearity.dtg.toml b/lib/op-attrs/include/op-attrs/initializers/kaiming_initializer_nonlinearity.dtg.toml new file mode 100644 index 0000000000..bd400fb875 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/initializers/kaiming_initializer_nonlinearity.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "KaimingInitializerNonlinearity" +type = "enum" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "RELU" + +[[values]] +name = "LEAKY_RELU" diff --git a/lib/op-attrs/include/op-attrs/initializers/kaiming_initializer_nonlinearity.enum.toml b/lib/op-attrs/include/op-attrs/initializers/kaiming_initializer_nonlinearity.enum.toml deleted file mode 100644 index 1a9aae9804..0000000000 --- a/lib/op-attrs/include/op-attrs/initializers/kaiming_initializer_nonlinearity.enum.toml +++ /dev/null @@ -1,14 +0,0 @@ -namespace = "FlexFlow" -name = "KaimingInitializerNonlinearity" -features = [ - "hash", - "fmt", - "rapidcheck", - "json", -] - -[[values]] -name = "RELU" - -[[values]] -name = "LEAKY_RELU" diff --git a/lib/op-attrs/include/op-attrs/initializers/kaiming_normal_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/initializers/kaiming_normal_attrs.dtg.toml new file mode 100644 index 0000000000..a6380f0829 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/initializers/kaiming_normal_attrs.dtg.toml @@ -0,0 +1,33 @@ +namespace = "FlexFlow" +name = "KaimingNormalAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/initializers/kaiming_initializer_mode.dtg.h", + "op-attrs/initializers/kaiming_initializer_nonlinearity.dtg.h", +] + +[[fields]] +name = "a" +type = "float" + +[[fields]] +name = "mode" +type = "::FlexFlow::KaimingInitializerMode" + +[[fields]] +name = "nonlinearity" +type = "::FlexFlow::KaimingInitializerNonlinearity" + +[[fields]] +name = "seed" +type = "int" + diff --git a/lib/op-attrs/include/op-attrs/initializers/kaiming_normal_attrs.struct.toml b/lib/op-attrs/include/op-attrs/initializers/kaiming_normal_attrs.struct.toml deleted file mode 100644 index d6b116d296..0000000000 --- a/lib/op-attrs/include/op-attrs/initializers/kaiming_normal_attrs.struct.toml +++ /dev/null @@ -1,32 +0,0 @@ -namespace = "FlexFlow" -name = "KaimingNormalAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/initializers/kaiming_initializer_mode.dtg.h", - "op-attrs/initializers/kaiming_initializer_nonlinearity.dtg.h", -] - -[[fields]] -name = "a" -type = "float" - -[[fields]] -name = "mode" -type = "::FlexFlow::KaimingInitializerMode" - -[[fields]] -name = "nonlinearity" -type = "::FlexFlow::KaimingInitializerNonlinearity" - -[[fields]] -name = "seed" -type = "int" - diff --git a/lib/op-attrs/include/op-attrs/initializers/norm_initializer_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/initializers/norm_initializer_attrs.dtg.toml new file mode 100644 index 0000000000..a66c60a92b --- /dev/null +++ b/lib/op-attrs/include/op-attrs/initializers/norm_initializer_attrs.dtg.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "NormInitializerAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "seed" +type = "int" + +[[fields]] +name = "mean" +type = "float" + +[[fields]] +name = "stddev" +type = "float" diff --git a/lib/op-attrs/include/op-attrs/initializers/norm_initializer_attrs.struct.toml b/lib/op-attrs/include/op-attrs/initializers/norm_initializer_attrs.struct.toml deleted file mode 100644 index ec138de63e..0000000000 --- a/lib/op-attrs/include/op-attrs/initializers/norm_initializer_attrs.struct.toml +++ /dev/null @@ -1,22 +0,0 @@ -namespace = "FlexFlow" -name = "NormInitializerAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -[[fields]] -name = "seed" -type = "int" - -[[fields]] -name = "mean" -type = "float" - -[[fields]] -name = "stddev" -type = "float" diff --git a/lib/op-attrs/include/op-attrs/initializers/truncated_normal_initializer_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/initializers/truncated_normal_initializer_attrs.dtg.toml new file mode 100644 index 0000000000..ed0f171e42 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/initializers/truncated_normal_initializer_attrs.dtg.toml @@ -0,0 +1,31 @@ +namespace = "FlexFlow" +name = "TruncatedNormalInitializerAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "seed" +type = "int" + +[[fields]] +name = "mean" +type = "float" + +[[fields]] +name = "stddev" +type = "float" + +[[fields]] +name = "min_cutoff" +type = "float" + +[[fields]] +name = "max_cutoff" +type = "float" diff --git a/lib/op-attrs/include/op-attrs/initializers/truncated_normal_initializer_attrs.struct.toml b/lib/op-attrs/include/op-attrs/initializers/truncated_normal_initializer_attrs.struct.toml deleted file mode 100644 index 9e4ec0272d..0000000000 --- a/lib/op-attrs/include/op-attrs/initializers/truncated_normal_initializer_attrs.struct.toml +++ /dev/null @@ -1,30 +0,0 @@ -namespace = "FlexFlow" -name = "TruncatedNormalInitializerAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -[[fields]] -name = "seed" -type = "int" - -[[fields]] -name = "mean" -type = "float" - -[[fields]] -name = "stddev" -type = "float" - -[[fields]] -name = "min_cutoff" -type = "float" - -[[fields]] -name = "max_cutoff" -type = "float" diff --git a/lib/op-attrs/include/op-attrs/initializers/uniform_initializer_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/initializers/uniform_initializer_attrs.dtg.toml new file mode 100644 index 0000000000..bc7b7c1196 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/initializers/uniform_initializer_attrs.dtg.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "UniformInitializerAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +[[fields]] +name = "seed" +type = "int" + +[[fields]] +name = "min_val" +type = "float" + +[[fields]] +name = "max_val" +type = "float" diff --git a/lib/op-attrs/include/op-attrs/initializers/uniform_initializer_attrs.struct.toml b/lib/op-attrs/include/op-attrs/initializers/uniform_initializer_attrs.struct.toml deleted file mode 100644 index 8ee67b9d4b..0000000000 --- a/lib/op-attrs/include/op-attrs/initializers/uniform_initializer_attrs.struct.toml +++ /dev/null @@ -1,21 +0,0 @@ -namespace = "FlexFlow" -name = "UniformInitializerAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "fmt", -] - -[[fields]] -name = "seed" -type = "int" - -[[fields]] -name = "min_val" -type = "float" - -[[fields]] -name = "max_val" -type = "float" diff --git a/lib/op-attrs/include/op-attrs/initializers/zero_initializer_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/initializers/zero_initializer_attrs.dtg.toml new file mode 100644 index 0000000000..50fd63bc63 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/initializers/zero_initializer_attrs.dtg.toml @@ -0,0 +1,12 @@ +namespace = "FlexFlow" +name = "ZeroInitializerAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] +fields = [] diff --git a/lib/op-attrs/include/op-attrs/initializers/zero_initializer_attrs.struct.toml b/lib/op-attrs/include/op-attrs/initializers/zero_initializer_attrs.struct.toml deleted file mode 100644 index db1b6238d5..0000000000 --- a/lib/op-attrs/include/op-attrs/initializers/zero_initializer_attrs.struct.toml +++ /dev/null @@ -1,11 +0,0 @@ -namespace = "FlexFlow" -name = "ZeroInitializerAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] -fields = [] diff --git a/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.dtg.toml new file mode 100644 index 0000000000..435a81196e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "L1RegularizerAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "lambda" +type = "float" diff --git a/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.struct.toml b/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.struct.toml deleted file mode 100644 index 60fabfb94a..0000000000 --- a/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.struct.toml +++ /dev/null @@ -1,14 +0,0 @@ -namespace = "FlexFlow" -name = "L1RegularizerAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -[[fields]] -name = "lambda" -type = "float" diff --git a/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.dtg.toml new file mode 100644 index 0000000000..5e7e5d16e9 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "L2RegularizerAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "lambda" +type = "float" diff --git a/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.struct.toml b/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.struct.toml deleted file mode 100644 index adce4397a4..0000000000 --- a/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.struct.toml +++ /dev/null @@ -1,14 +0,0 @@ -namespace = "FlexFlow" -name = "L2RegularizerAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -[[fields]] -name = "lambda" -type = "float" diff --git a/lib/op-attrs/include/op-attrs/num_ptensor_parallel_dims_t.h b/lib/op-attrs/include/op-attrs/num_ptensor_parallel_dims_t.h new file mode 100644 index 0000000000..3791d8b2ff --- /dev/null +++ b/lib/op-attrs/include/op-attrs/num_ptensor_parallel_dims_t.h @@ -0,0 +1,72 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_NUM_PTENSOR_PARALLEL_DIMS_T_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_NUM_PTENSOR_PARALLEL_DIMS_T_H + +#include "utils/nonnegative_int/nonnegative_int.h" +#include "utils/positive_int/positive_int.h" +#include +#include +#include +#include + +namespace FlexFlow { + +struct num_ptensor_parallel_dims_t { +public: + num_ptensor_parallel_dims_t() = delete; + explicit num_ptensor_parallel_dims_t(int); + explicit num_ptensor_parallel_dims_t(nonnegative_int); + explicit num_ptensor_parallel_dims_t(positive_int); + + bool operator<(num_ptensor_parallel_dims_t const &other) const; + bool operator==(num_ptensor_parallel_dims_t const &other) const; + bool operator>(num_ptensor_parallel_dims_t const &other) const; + bool operator<=(num_ptensor_parallel_dims_t const &other) const; + bool operator!=(num_ptensor_parallel_dims_t const &other) const; + bool operator>=(num_ptensor_parallel_dims_t const &other) const; + + int int_from_num_ptensor_parallel_dims() const; + nonnegative_int nonnegative_int_from_num_ptensor_parallel_dims() const; + positive_int positive_int_from_num_ptensor_parallel_dims() const; + +private: + int value; + +private: + void check_invariant() const; +}; + +std::ostream &operator<<(std::ostream &, num_ptensor_parallel_dims_t const &); +std::string format_as(num_ptensor_parallel_dims_t const &); + +} // namespace FlexFlow + +namespace nlohmann { + +template <> +struct adl_serializer<::FlexFlow::num_ptensor_parallel_dims_t> { + static ::FlexFlow::num_ptensor_parallel_dims_t from_json(json const &j); + static void to_json(json &j, ::FlexFlow::num_ptensor_parallel_dims_t t); +}; + +} // namespace nlohmann + +namespace rc { + +template <> +struct Arbitrary<::FlexFlow::num_ptensor_parallel_dims_t> { + static Gen<::FlexFlow::num_ptensor_parallel_dims_t> arbitrary(); +}; + +} // namespace rc + +namespace std { + +template <> +struct hash<::FlexFlow::num_ptensor_parallel_dims_t> { + size_t operator()( + ::FlexFlow::num_ptensor_parallel_dims_t const &) const noexcept; +}; + +} // namespace std + +#endif diff --git a/lib/op-attrs/include/op-attrs/num_ptensor_shard_dims_t.dtg.toml b/lib/op-attrs/include/op-attrs/num_ptensor_shard_dims_t.dtg.toml new file mode 100644 index 0000000000..45372cf7e8 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/num_ptensor_shard_dims_t.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "num_ptensor_shard_dims_t" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", +] + +doctstring = """\ +A wrapper type describing the number of shard dims (i.e., not including replia dims) in a parallel tensor, +to prevent accidentally confusing the number of shard dims and the total number of parallel dims. + +The conversion to/from @ref num_ptensor_parallel_dims_t is trivial, and provided by the +functions @ref num_ptensor_parallel_dims_from_shard_dims and @ref num_ptensor_shard_dims_from_parallel_dims. +""" + +includes = [ + "utils/nonnegative_int/nonnegative_int.h", +] + +[[fields]] +name = "value" +type = "::FlexFlow::nonnegative_int" diff --git a/lib/op-attrs/include/op-attrs/num_ptensor_shard_dims_t.h b/lib/op-attrs/include/op-attrs/num_ptensor_shard_dims_t.h new file mode 100644 index 0000000000..1cf4f94699 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/num_ptensor_shard_dims_t.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_NUM_PTENSOR_SHARD_DIMS_T_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_NUM_PTENSOR_SHARD_DIMS_T_H + +#include "op-attrs/num_ptensor_parallel_dims_t.h" +#include "op-attrs/num_ptensor_shard_dims_t.dtg.h" + +namespace FlexFlow { + +num_ptensor_parallel_dims_t + num_ptensor_parallel_dims_from_shard_dims(num_ptensor_shard_dims_t); +num_ptensor_shard_dims_t + num_ptensor_shard_dims_from_parallel_dims(num_ptensor_parallel_dims_t); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/num_tensor_dims_t.h b/lib/op-attrs/include/op-attrs/num_tensor_dims_t.h new file mode 100644 index 0000000000..2c2f2183f1 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/num_tensor_dims_t.h @@ -0,0 +1,81 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_NUM_TENSOR_DIMS_T_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_NUM_TENSOR_DIMS_T_H + +#include "op-attrs/ff_dim_t.dtg.h" +#include "op-attrs/num_ptensor_parallel_dims_t.h" +#include "op-attrs/num_ptensor_shard_dims_t.dtg.h" +#include "op-attrs/relative_ff_dim_t.dtg.h" + +namespace FlexFlow { + +struct num_tensor_dims_t { +public: + num_tensor_dims_t() = delete; + explicit num_tensor_dims_t(nonnegative_int); + + bool operator<(num_tensor_dims_t other) const; + bool operator==(num_tensor_dims_t other) const; + bool operator>(num_tensor_dims_t other) const; + bool operator<=(num_tensor_dims_t other) const; + bool operator!=(num_tensor_dims_t other) const; + bool operator>=(num_tensor_dims_t other) const; + + bool operator<(nonnegative_int other) const; + bool operator==(nonnegative_int other) const; + bool operator>(nonnegative_int other) const; + bool operator<=(nonnegative_int other) const; + bool operator!=(nonnegative_int other) const; + bool operator>=(nonnegative_int other) const; + + friend bool operator<(nonnegative_int lhs, num_tensor_dims_t rhs); + friend bool operator==(nonnegative_int lhs, num_tensor_dims_t rhs); + friend bool operator>(nonnegative_int lhs, num_tensor_dims_t rhs); + friend bool operator<=(nonnegative_int lhs, num_tensor_dims_t rhs); + friend bool operator!=(nonnegative_int lhs, num_tensor_dims_t rhs); + friend bool operator>=(nonnegative_int lhs, num_tensor_dims_t rhs); + + bool operator<(int other) const; + bool operator==(int other) const; + bool operator>(int other) const; + bool operator<=(int other) const; + bool operator!=(int other) const; + bool operator>=(int other) const; + + friend bool operator<(int lhs, num_tensor_dims_t rhs); + friend bool operator==(int lhs, num_tensor_dims_t rhs); + friend bool operator>(int lhs, num_tensor_dims_t rhs); + friend bool operator<=(int lhs, num_tensor_dims_t rhs); + friend bool operator!=(int lhs, num_tensor_dims_t rhs); + friend bool operator>=(int lhs, num_tensor_dims_t rhs); + + nonnegative_int nonnegative_int_from_num_tensor_dims() const; + int int_from_num_tensor_dims() const; + +private: + nonnegative_int value; + +private: + void check_invariant() const; +}; + +nonnegative_int format_as(num_tensor_dims_t); +std::ostream &operator<<(std::ostream &, num_tensor_dims_t); + +num_tensor_dims_t + num_tensor_dims_from_num_ptensor_shard_dims(num_ptensor_shard_dims_t); + +num_tensor_dims_t + num_tensor_dims_from_num_ptensor_parallel_dims(num_ptensor_parallel_dims_t); + +num_ptensor_shard_dims_t + num_ptensor_shard_dims_from_num_tensor_dims(num_tensor_dims_t); + +num_ptensor_parallel_dims_t + num_ptensor_parallel_dims_from_num_tensor_dims(num_tensor_dims_t); + +std::vector tensor_dims_range(num_tensor_dims_t); +std::vector relative_tensor_dims_range(num_tensor_dims_t); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/operator_attrs.h b/lib/op-attrs/include/op-attrs/operator_attrs.h deleted file mode 100644 index d94f7af4fb..0000000000 --- a/lib/op-attrs/include/op-attrs/operator_attrs.h +++ /dev/null @@ -1,50 +0,0 @@ -#ifndef _OPERATOR_PARAMS_H -#define _OPERATOR_PARAMS_H - -#include "op-attrs/ops/attention.h" -#include "op-attrs/ops/batch_matmul.h" -#include "op-attrs/ops/batch_norm.h" -#include "op-attrs/ops/broadcast.h" -#include "op-attrs/ops/cast.h" -#include "op-attrs/ops/combine.h" -#include "op-attrs/ops/concat.h" -#include "op-attrs/ops/conv_2d.h" -#include "op-attrs/ops/core.h" -#include "op-attrs/ops/dropout.h" -#include "op-attrs/ops/element_binary.h" -#include "op-attrs/ops/element_unary.h" -#include "op-attrs/ops/embedding.h" -#include "op-attrs/ops/flat.h" -#include "op-attrs/ops/gather.h" -#include "op-attrs/ops/input.h" -#include "op-attrs/ops/layer_norm.h" -#include "op-attrs/ops/linear.h" -#include "op-attrs/ops/noop.h" -#include "op-attrs/ops/pool_2d.h" -#include "op-attrs/ops/reduce.h" -#include "op-attrs/ops/reduction.h" -#include "op-attrs/ops/repartition.h" -#include "op-attrs/ops/replicate.h" -#include "op-attrs/ops/reshape.h" -#include "op-attrs/ops/reverse.h" -#include "op-attrs/ops/softmax.h" -#include "op-attrs/ops/split.h" -#include "op-attrs/ops/topk.h" -#include "op-attrs/ops/transpose.h" -#include "op-attrs/pcg_operator_attrs.dtg.h" -#include "utils/record_formatter.h" -#include "utils/variant.h" -#include - -namespace FlexFlow { - -std::vector get_output_shapes( - PCGOperatorAttrs const &op_params, - std::vector const &input_tensor_shapes); - -bool is_valid(PCGOperatorAttrs const &, - std::vector const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/op-attrs/include/op-attrs/operator_space_to_parallel_tensor_space_mapping.dtg.toml b/lib/op-attrs/include/op-attrs/operator_space_to_parallel_tensor_space_mapping.dtg.toml new file mode 100644 index 0000000000..d891fc797c --- /dev/null +++ b/lib/op-attrs/include/op-attrs/operator_space_to_parallel_tensor_space_mapping.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "OperatorSpaceToParallelTensorSpaceMapping" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/orthotope/dim_domain_mapping.h", + "op-attrs/operator_task_space_dim_idx_t.dtg.h", + "op-attrs/parallel_tensor_dim_idx_t.dtg.h", +] + +[[fields]] +name = "raw_mapping" +type = "::FlexFlow::DimDomainMapping<::FlexFlow::operator_task_space_dim_idx_t, ::FlexFlow::parallel_tensor_dim_idx_t>" diff --git a/lib/op-attrs/include/op-attrs/operator_space_to_parallel_tensor_space_mapping.h b/lib/op-attrs/include/op-attrs/operator_space_to_parallel_tensor_space_mapping.h new file mode 100644 index 0000000000..a45700f66e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/operator_space_to_parallel_tensor_space_mapping.h @@ -0,0 +1,52 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_SPACE_TO_PARALLEL_TENSOR_SPACE_MAPPING_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_SPACE_TO_PARALLEL_TENSOR_SPACE_MAPPING_H + +#include "op-attrs/num_ptensor_parallel_dims_t.h" +#include "op-attrs/num_ptensor_shard_dims_t.dtg.h" +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.dtg.h" +#include "op-attrs/operator_task_space.dtg.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" +#include "op-attrs/parallel_tensor_space_coordinate.dtg.h" +#include "op-attrs/parallel_tensor_space_to_parallel_tensor_space_mapping.dtg.h" +#include "op-attrs/task_space_coordinate.dtg.h" + +namespace FlexFlow { + +OperatorSpaceToParallelTensorSpaceMapping + empty_operator_space_to_ptensor_space_map(); + +OperatorTaskSpace get_operator_task_space_for_mapping( + OperatorSpaceToParallelTensorSpaceMapping const &); + +ParallelTensorDimDegrees get_parallel_tensor_space_for_mapping( + OperatorSpaceToParallelTensorSpaceMapping const &); + +OperatorSpaceToParallelTensorSpaceMapping get_identity_mapping( + OperatorTaskSpace const &operator_task_space, + ParallelTensorDimDegrees const ¶llel_tensor_dim_degrees); + +OperatorSpaceToParallelTensorSpaceMapping + operator_ptensor_space_mapping_from_projection( + DimProjection const &projection, + OperatorTaskSpace const &op_task_space, + ParallelTensorDimDegrees const ¶llel_tensor_dim_degrees); + +OperatorSpaceToParallelTensorSpaceMapping + operator_ptensor_space_mapping_from_composition( + OperatorSpaceToParallelTensorSpaceMapping const &op_to_pt1_mapping, + ParallelTensorSpaceToParallelTensorSpaceMapping const + &pt1_to_pt2_mapping); + +ParallelTensorSpaceCoordinate ptensor_coord_for_task_space_coord( + OperatorSpaceToParallelTensorSpaceMapping const &mapping, + TaskSpaceCoordinate const &task_space_coord, + num_ptensor_shard_dims_t num_shard_dims); + +TaskSpaceCoordinate task_space_coord_for_ptensor_coord( + OperatorSpaceToParallelTensorSpaceMapping const &mapping, + ParallelTensorSpaceCoordinate const &tensor_space_coordinate); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/operator_task_space.dtg.toml b/lib/op-attrs/include/op-attrs/operator_task_space.dtg.toml new file mode 100644 index 0000000000..baeafe7072 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/operator_task_space.dtg.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "OperatorTaskSpace" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "", + "utils/orthotope/minimal_orthotope.dtg.h", +] + +[[fields]] +name = "degrees" +type = "::FlexFlow::MinimalOrthotope" diff --git a/lib/op-attrs/include/op-attrs/operator_task_space.h b/lib/op-attrs/include/op-attrs/operator_task_space.h new file mode 100644 index 0000000000..426fbc1850 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/operator_task_space.h @@ -0,0 +1,50 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TASK_SPACE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TASK_SPACE_H + +#include "op-attrs/operator_task_space.dtg.h" +#include "op-attrs/operator_task_space_dim_idx_t.dtg.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" +#include "op-attrs/task_space_coordinate.dtg.h" +#include "utils/orthotope/dim_domain.dtg.h" +#include "utils/orthotope/dim_ordering.dtg.h" +#include "utils/orthotope/minimal_dim_domain.dtg.h" +#include + +namespace FlexFlow { + +OperatorTaskSpace trivial_op_task_space(); + +std::unordered_set + operator_task_space_get_dim_idxs(OperatorTaskSpace const &); + +std::unordered_set + get_task_space_coordinates(OperatorTaskSpace const &operator_task_space); + +bool operator_task_space_contains_coord(OperatorTaskSpace const &, + TaskSpaceCoordinate const &); + +TaskSpaceCoordinate get_task_space_maximum_coordinate( + OperatorTaskSpace const &operator_task_space); + +nonnegative_int + op_task_space_num_dims(OperatorTaskSpace const &operator_task_space); +positive_int num_tasks(OperatorTaskSpace const &operator_task_space); + +positive_int op_task_space_dim_size_for_idx(OperatorTaskSpace const &, + operator_task_space_dim_idx_t); + +MinimalDimDomain + minimal_dim_domain_from_operator_task_space(OperatorTaskSpace const &); + +OperatorTaskSpace operator_task_space_from_minimal_dim_domain( + MinimalDimDomain const &); + +DimOrdering + get_operator_task_space_dim_ordering(); + +OperatorTaskSpace get_operator_task_space_matching_parallel_tensor_dim_degrees( + ParallelTensorDimDegrees const &dim_degrees); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/operator_task_space_dim_idx_t.dtg.toml b/lib/op-attrs/include/op-attrs/operator_task_space_dim_idx_t.dtg.toml new file mode 100644 index 0000000000..bc67039121 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/operator_task_space_dim_idx_t.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "operator_task_space_dim_idx_t" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "utils/nonnegative_int/nonnegative_int.h", +] + +[[fields]] +name = "raw_idx" +type = "::FlexFlow::nonnegative_int" diff --git a/lib/op-attrs/include/op-attrs/operator_task_space_dim_idx_t.h b/lib/op-attrs/include/op-attrs/operator_task_space_dim_idx_t.h new file mode 100644 index 0000000000..30a6845734 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/operator_task_space_dim_idx_t.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TASK_SPACE_DIM_IDX_T_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TASK_SPACE_DIM_IDX_T_H + +#include "op-attrs/operator_task_space_dim_idx_t.dtg.h" +#include "utils/nonnegative_int/nonnegative_int.h" +#include + +namespace FlexFlow { + +std::set + operator_task_space_dim_idx_range(nonnegative_int end); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/operator_task_space_to_operator_task_space_mapping.dtg.toml b/lib/op-attrs/include/op-attrs/operator_task_space_to_operator_task_space_mapping.dtg.toml new file mode 100644 index 0000000000..63e65d0322 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/operator_task_space_to_operator_task_space_mapping.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "OperatorTaskSpaceToOperatorTaskSpaceMapping" +type = "struct" +features = [ + "eq", + "hash", + "fmt" +] + +includes = [ + "utils/orthotope/dim_domain_mapping.h", + "op-attrs/operator_task_space_dim_idx_t.dtg.h", +] + +[[fields]] +name = "raw_mapping" +type = "::FlexFlow::DimDomainMapping<::FlexFlow::operator_task_space_dim_idx_t, ::FlexFlow::operator_task_space_dim_idx_t>" diff --git a/lib/op-attrs/include/op-attrs/operator_task_space_to_operator_task_space_mapping.h b/lib/op-attrs/include/op-attrs/operator_task_space_to_operator_task_space_mapping.h new file mode 100644 index 0000000000..df37c1f945 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/operator_task_space_to_operator_task_space_mapping.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TASK_SPACE_TO_OPERATOR_TASK_SPACE_MAPPING_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TASK_SPACE_TO_OPERATOR_TASK_SPACE_MAPPING_H + +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.dtg.h" +#include "op-attrs/operator_task_space.dtg.h" +#include "op-attrs/operator_task_space_to_operator_task_space_mapping.dtg.h" +#include "op-attrs/task_space_coordinate.dtg.h" + +namespace FlexFlow { + +OperatorTaskSpaceToOperatorTaskSpaceMapping + op_to_op_identity_mapping(OperatorTaskSpace const &, + OperatorTaskSpace const &); + +OperatorTaskSpace op_mapping_get_src_space( + OperatorTaskSpaceToOperatorTaskSpaceMapping const &); + +OperatorTaskSpace op_mapping_get_dst_space( + OperatorTaskSpaceToOperatorTaskSpaceMapping const &); + +bidict op_to_op_get_coord_mapping( + OperatorTaskSpaceToOperatorTaskSpaceMapping const &); + +OperatorTaskSpaceToOperatorTaskSpaceMapping + op_to_op_mapping_from_composition_through_tensor( + OperatorSpaceToParallelTensorSpaceMapping const &src_to_tensor_mapping, + OperatorSpaceToParallelTensorSpaceMapping const &dst_to_tensor_mapping); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/operator_type.dtg.toml b/lib/op-attrs/include/op-attrs/operator_type.dtg.toml new file mode 100644 index 0000000000..0ef9b17cef --- /dev/null +++ b/lib/op-attrs/include/op-attrs/operator_type.dtg.toml @@ -0,0 +1,96 @@ +namespace = "FlexFlow" +name = "OperatorType" +type = "enum" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +values = [ + { name = "NOOP" }, + { name = "INPUT" }, + { name = "WEIGHT" }, + { name = "CONV2D" }, + { name = "DROPOUT" }, + { name = "LINEAR" }, + { name = "BATCHMATMUL" }, + { name = "POOL2D" }, + { name = "SCALAR_MULTIPLY" }, + { name = "SCALAR_ADD" }, + { name = "SCALAR_FLOOR_DIV" }, + { name = "SCALAR_TRUE_DIV" }, + { name = "SCALAR_SUB" }, + { name = "RELU" }, + { name = "IDENTITY" }, + { name = "SIGMOID" }, + { name = "TANH" }, + { name = "ELU" }, + { name = "FLAT" }, + { name = "SOFTMAX" }, + { name = "BATCHNORM" }, + { name = "CONCAT" }, + { name = "SPLIT" }, + { name = "EMBEDDING" }, + { name = "CACHE" }, + { name = "RESHAPE" }, + { name = "REVERSE" }, + { name = "TRANSPOSE" }, + { name = "EW_ADD" }, + { name = "EW_MUL" }, + { name = "MATMUL" }, + { name = "MUL" }, + { name = "ENLARGE" }, + { name = "SQUEEZE" }, + { name = "UNSQUEEZE" }, + { name = "EW_SUB" }, + { name = "EW_DIV" }, + { name = "EW_EQUAL" }, + { name = "EW_GREATER" }, + { name = "EW_LESS" }, + { name = "EW_MAX" }, + { name = "EW_MIN" }, + { name = "REDUCE_ARGMAX" }, + { name = "REDUCE_ARGMIN" }, + { name = "REDUCE_MAX" }, + { name = "REDUCE_MEAN" }, + { name = "REDUCE_MIN" }, + { name = "REDUCE_PROD" }, + { name = "REDUCE_SUM" }, + { name = "PAD" }, + { name = "SHAPE" }, + { name = "SIZE" }, + { name = "TOPK" }, + { name = "WHERE" }, + { name = "CEIL" }, + { name = "CAST" }, + { name = "EXP" }, + { name = "ROUND" }, + { name = "LOG" }, + { name = "LOGICAL_NOT" }, + { name = "SQRT" }, + { name = "SIN" }, + { name = "COS" }, + { name = "LEAKYRELU" }, + { name = "SLICE" }, + { name = "RESIZE" }, + { name = "PRELU" }, + { name = "GELU" }, + { name = "MULTIHEAD_ATTENTION" }, + { name = "FUSED" }, + { name = "RSQRT" }, + { name = "POW" }, + { name = "MEAN" }, + { name = "LAYERNORM" }, + { name = "GATHER" }, + { name = "BROADCAST" }, + { name = "REPARTITION" }, + { name = "COMBINE" }, + { name = "REPLICATE" }, + { name = "REDUCTION" }, + { name = "BATCH" }, + { name = "PIPELINE" }, + { name = "FUSED_PARALLEL" }, +] + diff --git a/lib/op-attrs/include/op-attrs/operator_type.enum.toml b/lib/op-attrs/include/op-attrs/operator_type.enum.toml deleted file mode 100644 index 8815d69dda..0000000000 --- a/lib/op-attrs/include/op-attrs/operator_type.enum.toml +++ /dev/null @@ -1,95 +0,0 @@ -namespace = "FlexFlow" -name = "OperatorType" -features = [ - "hash", - "json", - "rapidcheck", - "fmt", -] - -values = [ - { name = "NOOP" }, - { name = "INPUT" }, - { name = "WEIGHT" }, - { name = "CONV2D" }, - { name = "DROPOUT" }, - { name = "LINEAR" }, - { name = "BATCHMATMUL" }, - { name = "POOL2D" }, - { name = "SCALAR_MULTIPLY" }, - { name = "SCALAR_ADD" }, - { name = "SCALAR_FLOOR_DIV" }, - { name = "SCALAR_TRUE_DIV" }, - { name = "SCALAR_SUB" }, - { name = "RELU" }, - { name = "IDENTITY" }, - { name = "SIGMOID" }, - { name = "TANH" }, - { name = "ELU" }, - { name = "FLAT" }, - { name = "SOFTMAX" }, - { name = "BATCHNORM" }, - { name = "CONCAT" }, - { name = "SPLIT" }, - { name = "EMBEDDING" }, - { name = "CACHE" }, - { name = "RESHAPE" }, - { name = "REVERSE" }, - { name = "TRANSPOSE" }, - { name = "EW_ADD" }, - { name = "EW_MUL" }, - { name = "MATMUL" }, - { name = "MUL" }, - { name = "ENLARGE" }, - { name = "SQUEEZE" }, - { name = "UNSQUEEZE" }, - { name = "EW_SUB" }, - { name = "EW_DIV" }, - { name = "EW_EQUAL" }, - { name = "EW_GREATER" }, - { name = "EW_LESS" }, - { name = "EW_MAX" }, - { name = "EW_MIN" }, - { name = "REDUCE_ARGMAX" }, - { name = "REDUCE_ARGMIN" }, - { name = "REDUCE_MAX" }, - { name = "REDUCE_MEAN" }, - { name = "REDUCE_MIN" }, - { name = "REDUCE_PROD" }, - { name = "REDUCE_SUM" }, - { name = "PAD" }, - { name = "SHAPE" }, - { name = "SIZE" }, - { name = "TOPK" }, - { name = "WHERE" }, - { name = "CEIL" }, - { name = "CAST" }, - { name = "EXP" }, - { name = "ROUND" }, - { name = "LOG" }, - { name = "LOGICAL_NOT" }, - { name = "SQRT" }, - { name = "SIN" }, - { name = "COS" }, - { name = "LEAKYRELU" }, - { name = "SLICE" }, - { name = "RESIZE" }, - { name = "PRELU" }, - { name = "GELU" }, - { name = "MULTIHEAD_ATTENTION" }, - { name = "FUSED" }, - { name = "RSQRT" }, - { name = "POW" }, - { name = "MEAN" }, - { name = "LAYERNORM" }, - { name = "GATHER" }, - { name = "BROADCAST" }, - { name = "REPARTITION" }, - { name = "COMBINE" }, - { name = "REPLICATE" }, - { name = "REDUCTION" }, - { name = "BATCH" }, - { name = "PIPELINE" }, - { name = "FUSED_PARALLEL" }, -] - diff --git a/lib/op-attrs/include/op-attrs/ops/attention.h b/lib/op-attrs/include/op-attrs/ops/attention.h index 5ca237561f..fdd3f3775f 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention.h +++ b/lib/op-attrs/include/op-attrs/ops/attention.h @@ -1,14 +1,14 @@ -#ifndef _FLEXFLOW_ATTENTION_ATTRS_H -#define _FLEXFLOW_ATTENTION_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_H #include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/initializer_attrs.dtg.h" #include "op-attrs/ops/attention/multihead_attention_inputs.dtg.h" #include "op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h" #include "op-attrs/ops/attention_attrs.dtg.h" -#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" #include namespace FlexFlow { @@ -39,7 +39,7 @@ positive_int get_kvSeqLength(MultiHeadAttentionInputs const &); positive_int get_num_samples(MultiHeadAttentionParallelInputs const &); positive_int get_num_samples(MultiHeadAttentionInputs const &); -std::vector +std::unordered_map get_attention_incoming_tensor_roles(MultiHeadAttentionAttrs const &); tl::expected @@ -63,7 +63,7 @@ tl::expected TensorShape const &input_k, TensorShape const &input_v); -tl::expected, std::string> +tl::expected, std::string> get_weight_shapes(MultiHeadAttentionAttrs const &, TensorShape const &input_q, TensorShape const &input_k, @@ -106,24 +106,26 @@ tl::expected ParallelTensorShape const &input_k, ParallelTensorShape const &input_v); -tl::expected, std::string> +tl::expected, + std::string> get_weight_shapes(MultiHeadAttentionAttrs const &, ParallelTensorShape const &input_q, ParallelTensorShape const &input_k, ParallelTensorShape const &input_v); -tl::expected, std::string> get_initializers( - MultiHeadAttentionAttrs const &, - TensorShape const &input_q, - TensorShape const &input_k, - TensorShape const &input_v, - std::optional const &weights_initializer = std::nullopt, - std::optional const &input_bias_initializer = - std::nullopt, - std::optional const &output_bias_initializer = - std::nullopt); - -CHECK_VALID_OP_ATTR(MultiHeadAttentionAttrs); +tl::expected, std::string> + get_initializers( + MultiHeadAttentionAttrs const &, + TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v, + std::optional const &weights_initializer = + std::nullopt, + std::optional const &input_bias_initializer = + std::nullopt, + std::optional const &output_bias_initializer = + std::nullopt); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.dtg.toml new file mode 100644 index 0000000000..94fd5cc9a2 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.dtg.toml @@ -0,0 +1,40 @@ +namespace = "FlexFlow" +name = "MultiHeadAttentionInputs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/datatype.dtg.h", + "utils/positive_int/positive_int.h", +] + +[[fields]] +name = "batch_size" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "sequence_length" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "query_size" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "key_size" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "value_size" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "datatype" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.struct.toml b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.struct.toml deleted file mode 100644 index 8b9aefb67e..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.struct.toml +++ /dev/null @@ -1,39 +0,0 @@ -namespace = "FlexFlow" -name = "MultiHeadAttentionInputs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/datatype.dtg.h", - "utils/positive_int/positive_int.h", -] - -[[fields]] -name = "batch_size" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "sequence_length" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "query_size" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "key_size" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "value_size" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "datatype" -type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.toml new file mode 100644 index 0000000000..0f09a2a1c1 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.toml @@ -0,0 +1,47 @@ +namespace = "FlexFlow" +name = "MultiHeadAttentionParallelInputs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "", + "op-attrs/datatype.dtg.h", + "op-attrs/shard_parallel_dim.dtg.h", + "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h", + "op-attrs/parallel_tensor_shape/sum_degree.dtg.h", +] + +[[fields]] +name = "batch_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "sequence_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "query_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "key_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "value_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "discard_copy_degree" +type = "::FlexFlow::DiscardCopyDegree" + +[[fields]] +name = "datatype" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.struct.toml b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.struct.toml deleted file mode 100644 index b0636db353..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.struct.toml +++ /dev/null @@ -1,46 +0,0 @@ -namespace = "FlexFlow" -name = "MultiHeadAttentionParallelInputs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "", - "op-attrs/datatype.dtg.h", - "op-attrs/shard_parallel_dim.dtg.h", - "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h", - "op-attrs/parallel_tensor_shape/sum_degree.dtg.h", -] - -[[fields]] -name = "batch_dim" -type = "::FlexFlow::ShardParallelDim" - -[[fields]] -name = "sequence_dim" -type = "::FlexFlow::ShardParallelDim" - -[[fields]] -name = "query_dim" -type = "::FlexFlow::ShardParallelDim" - -[[fields]] -name = "key_dim" -type = "::FlexFlow::ShardParallelDim" - -[[fields]] -name = "value_dim" -type = "::FlexFlow::ShardParallelDim" - -[[fields]] -name = "discard_copy_degree" -type = "::FlexFlow::DiscardCopyDegree" - -[[fields]] -name = "datatype" -type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/attention_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/attention_attrs.dtg.toml new file mode 100644 index 0000000000..33824f587f --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention_attrs.dtg.toml @@ -0,0 +1,48 @@ +namespace = "FlexFlow" +name = "MultiHeadAttentionAttrs" +type = "struct" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "utils/positive_int/positive_int.h", +] + +[[fields]] +name = "embed_dim" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "num_heads" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "kdim" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "vdim" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "dropout" +type = "float" + +[[fields]] +name = "bias" +type = "bool" + +[[fields]] +name = "add_bias_kv" +type = "bool" + +[[fields]] +name = "add_zero_attn" +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/attention_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/attention_attrs.struct.toml deleted file mode 100644 index b9c6847cd6..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/attention_attrs.struct.toml +++ /dev/null @@ -1,47 +0,0 @@ -namespace = "FlexFlow" -name = "MultiHeadAttentionAttrs" - -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "utils/positive_int/positive_int.h", -] - -[[fields]] -name = "embed_dim" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "num_heads" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "kdim" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "vdim" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "dropout" -type = "float" - -[[fields]] -name = "bias" -type = "bool" - -[[fields]] -name = "add_bias_kv" -type = "bool" - -[[fields]] -name = "add_zero_attn" -type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index 333da4fa29..f17757ac85 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -2,15 +2,12 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_MATMUL_H #include "op-attrs/ops/batch_matmul_attrs.dtg.h" -#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" #include namespace FlexFlow { -CHECK_VALID_OP_ATTR(BatchMatmulAttrs); - bool is_valid(BatchMatmulAttrs const &, ParallelTensorShape const &, ParallelTensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/batch_matmul_attrs.dtg.toml new file mode 100644 index 0000000000..7a82d89e8d --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul_attrs.dtg.toml @@ -0,0 +1,31 @@ +namespace = "FlexFlow" +name = "BatchMatmulAttrs" +type = "struct" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "utils/positive_int/positive_int.h", + "", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", +] + +[[fields]] +name = "a_seq_length_dim" +type = "std::optional<::FlexFlow::positive_int>" + +[[fields]] +name = "b_seq_length_dim" +type = "std::optional<::FlexFlow::positive_int>" diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/batch_matmul_attrs.struct.toml deleted file mode 100644 index 394dfb5fcc..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul_attrs.struct.toml +++ /dev/null @@ -1,30 +0,0 @@ -namespace = "FlexFlow" -name = "BatchMatmulAttrs" - -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "utils/nonnegative_int/nonnegative_int.h", - "", -] - -src_includes = [ - "utils/fmt/optional.h", - "utils/json/optional.h", - "utils/rapidcheck/optional.h", -] - -[[fields]] -name = "a_seq_length_dim" -type = "std::optional<::FlexFlow::nonnegative_int>" - -[[fields]] -name = "b_seq_length_dim" -type = "std::optional<::FlexFlow::nonnegative_int>" diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm.h b/lib/op-attrs/include/op-attrs/ops/batch_norm.h index bcf6794f38..bbdb52cecc 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm.h @@ -4,14 +4,15 @@ #include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/initializer_attrs.dtg.h" #include "op-attrs/ops/batch_norm_attrs.dtg.h" -#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" +#include namespace FlexFlow { -std::vector +std::unordered_map get_batch_norm_incoming_tensor_roles(BatchNormAttrs const &); tl::expected get_output_shape(BatchNormAttrs const &, @@ -21,7 +22,7 @@ tl::expected tl::expected get_beta_weights_shape(BatchNormAttrs const &, TensorShape const &); -tl::expected, std::string> +tl::expected, std::string> get_weight_shapes(BatchNormAttrs const &attrs, TensorShape const &input_shape); @@ -35,7 +36,8 @@ tl::expected get_beta_weights_parallel_dim_degrees(BatchNormAttrs const &, ParallelTensorDimDegrees const &); -tl::expected, std::string> +tl::expected, + std::string> get_weight_parallel_dim_degrees( BatchNormAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees); @@ -48,7 +50,8 @@ tl::expected tl::expected get_beta_weights_shape(BatchNormAttrs const &, ParallelTensorShape const &); -tl::expected, std::string> +tl::expected, + std::string> get_weight_shapes(BatchNormAttrs const &attrs, ParallelTensorShape const &input_shape); @@ -58,11 +61,9 @@ tl::expected, std::string> * see * https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/batchnorm.py#L93-L97 */ -tl::expected, std::string> +tl::expected, std::string> get_initializers(BatchNormAttrs const &attrs); -CHECK_VALID_OP_ATTR(BatchNormAttrs); - } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.dtg.toml new file mode 100644 index 0000000000..a22fee4a64 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.dtg.toml @@ -0,0 +1,38 @@ +namespace = "FlexFlow" +name = "BatchNormAttrs" +type = "struct" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", +] + +[[fields]] +name = "relu" +type = "bool" + +[[fields]] +name = "affine" +type = "bool" + +[[fields]] +name = "eps" +type = "float" + +[[fields]] +name = "momentum" +type = "std::optional" diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml deleted file mode 100644 index fdc3bce1fe..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml +++ /dev/null @@ -1,37 +0,0 @@ -namespace = "FlexFlow" -name = "BatchNormAttrs" - -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "", -] - -src_includes = [ - "utils/fmt/optional.h", - "utils/json/optional.h", - "utils/rapidcheck/optional.h", -] - -[[fields]] -name = "relu" -type = "bool" - -[[fields]] -name = "affine" -type = "bool" - -[[fields]] -name = "eps" -type = "float" - -[[fields]] -name = "momentum" -type = "std::optional" diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.h b/lib/op-attrs/include/op-attrs/ops/broadcast.h index 4fd7d49234..9b6bd49418 100644 --- a/lib/op-attrs/include/op-attrs/ops/broadcast.h +++ b/lib/op-attrs/include/op-attrs/ops/broadcast.h @@ -2,15 +2,13 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BROADCAST_H #include "op-attrs/ops/broadcast_attrs.dtg.h" -#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" #include "utils/record_formatter.h" +#include namespace FlexFlow { -CHECK_VALID_OP_ATTR(BroadcastAttrs); - RecordFormatter as_dot(BroadcastAttrs const &); tl::expected get_output_shape(BroadcastAttrs const &, diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/broadcast_attrs.dtg.toml new file mode 100644 index 0000000000..6e1d4fcbd6 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/broadcast_attrs.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "BroadcastAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/tensor_dims.dtg.h", +] + +[[fields]] +name = "target_dims" +type = "::FlexFlow::TensorDims" diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/broadcast_attrs.struct.toml deleted file mode 100644 index 52e2ee66ca..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/broadcast_attrs.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "BroadcastAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/tensor_dims.dtg.h", -] - -[[fields]] -name = "target_dims" -type = "::FlexFlow::TensorDims" diff --git a/lib/op-attrs/include/op-attrs/ops/cast.h b/lib/op-attrs/include/op-attrs/ops/cast.h index 30818f046d..38a1e87a76 100644 --- a/lib/op-attrs/include/op-attrs/ops/cast.h +++ b/lib/op-attrs/include/op-attrs/ops/cast.h @@ -1,8 +1,7 @@ -#ifndef _FLEXFLOW_CAST_ATTRS_H -#define _FLEXFLOW_CAST_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CAST_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CAST_H #include "op-attrs/ops/cast_attrs.dtg.h" -#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" #include "utils/record_formatter.h" @@ -10,8 +9,6 @@ namespace FlexFlow { -CHECK_VALID_OP_ATTR(CastAttrs); - RecordFormatter as_dot(CastAttrs const &); tl::expected get_output_shape(CastAttrs const &, diff --git a/lib/op-attrs/include/op-attrs/ops/cast_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/cast_attrs.dtg.toml new file mode 100644 index 0000000000..a0b70d49f7 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/cast_attrs.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "CastAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/datatype.dtg.h" +] + +[[fields]] +name = "dtype" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml deleted file mode 100644 index 287861888c..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "CastAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/datatype.dtg.h" -] - -[[fields]] -name = "dtype" -type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/combine.h b/lib/op-attrs/include/op-attrs/ops/combine.h index d9ca314c2b..6839bc12e1 100644 --- a/lib/op-attrs/include/op-attrs/ops/combine.h +++ b/lib/op-attrs/include/op-attrs/ops/combine.h @@ -2,15 +2,12 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_COMBINE_H #include "op-attrs/ops/combine_attrs.dtg.h" -#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "utils/record_formatter.h" #include namespace FlexFlow { -CHECK_VALID_OP_ATTR(CombineAttrs); - RecordFormatter as_dot(CombineAttrs const &); tl::expected diff --git a/lib/op-attrs/include/op-attrs/ops/combine_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/combine_attrs.dtg.toml new file mode 100644 index 0000000000..08dc09f4be --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/combine_attrs.dtg.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "CombineAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h", + "utils/positive_int/positive_int.h", +] + +[[fields]] +name = "combine_dim" +type = "::FlexFlow::ff_dim_t" + +[[fields]] +name = "combine_degree" +type = "::FlexFlow::positive_int" diff --git a/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml deleted file mode 100644 index d80f853b00..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml +++ /dev/null @@ -1,24 +0,0 @@ -namespace = "FlexFlow" -name = "CombineAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/ff_dim_t.h", - "op-attrs/ff_dim_t.dtg.h", - "utils/positive_int/positive_int.h", -] - -[[fields]] -name = "combine_dim" -type = "::FlexFlow::ff_dim_t" - -[[fields]] -name = "combine_degree" -type = "::FlexFlow::positive_int" diff --git a/lib/op-attrs/include/op-attrs/ops/concat.h b/lib/op-attrs/include/op-attrs/ops/concat.h index f07f06df85..1647553b96 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat.h +++ b/lib/op-attrs/include/op-attrs/ops/concat.h @@ -2,14 +2,12 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONCAT_H #include "op-attrs/ops/concat_attrs.dtg.h" -#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" +#include namespace FlexFlow { -CHECK_VALID_OP_ATTR(ConcatAttrs); - tl::expected get_output_shape(ConcatAttrs const &, std::vector const &); tl::expected diff --git a/lib/op-attrs/include/op-attrs/ops/concat_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/concat_attrs.dtg.toml new file mode 100644 index 0000000000..c92fd91125 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/concat_attrs.dtg.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "ConcatAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h", + "utils/int_ge_two/int_ge_two.h", +] + +[[fields]] +name = "axis" +type = "::FlexFlow::ff_dim_t" + +[[fields]] +name = "num_inputs" +type = "::FlexFlow::int_ge_two" diff --git a/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml deleted file mode 100644 index f3c66d0416..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml +++ /dev/null @@ -1,19 +0,0 @@ -namespace = "FlexFlow" -name = "ConcatAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/ff_dim_t.h", - "op-attrs/ff_dim_t.dtg.h" -] - -[[fields]] -name = "axis" -type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d.h b/lib/op-attrs/include/op-attrs/ops/conv_2d.h index e4c7467de2..0f27b00406 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d.h @@ -1,18 +1,16 @@ -#ifndef _FLEXFLOW_CONV_2D_ATTRS_H -#define _FLEXFLOW_CONV_2D_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_H #include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/initializer_attrs.dtg.h" #include "op-attrs/ops/conv_2d_attrs.dtg.h" -#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" +#include "op-attrs/tensor_slot_name.dtg.h" namespace FlexFlow { -CHECK_VALID_OP_ATTR(Conv2DAttrs); - -std::vector +std::unordered_map get_conv2d_incoming_tensor_roles(Conv2DAttrs const &); TensorShape get_kernel_shape(Conv2DAttrs const &attrs, @@ -21,8 +19,8 @@ TensorShape get_bias_shape(Conv2DAttrs const &attrs, TensorShape const &input); TensorShape get_output_shape(Conv2DAttrs const &attrs, TensorShape const &input); -std::vector get_weight_shapes(Conv2DAttrs const &attrs, - TensorShape const &input_shape); +std::unordered_map + get_weight_shapes(Conv2DAttrs const &attrs, TensorShape const &input_shape); ParallelTensorShape get_kernel_shape(Conv2DAttrs const &attrs, ParallelTensorShape const &input_shape); @@ -31,11 +29,11 @@ ParallelTensorShape get_bias_shape(Conv2DAttrs const &attrs, ParallelTensorShape get_output_shape(Conv2DAttrs const &attrs, ParallelTensorShape const &input_shape); -std::vector +std::unordered_map get_weight_shapes(Conv2DAttrs const &attrs, ParallelTensorShape const &input_shape); -std::vector get_initializers( +std::unordered_map get_initializers( Conv2DAttrs const &attrs, TensorShape const &input_shape, std::optional kernel_initializer = std::nullopt, diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.toml b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.toml new file mode 100644 index 0000000000..729fe9f26b --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.toml @@ -0,0 +1,37 @@ +namespace = "FlexFlow" +name = "Conv2DInputShape" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", + "rapidcheck", + "json", +] + +includes = [ + "", + "op-attrs/datatype.dtg.h", + "utils/positive_int/positive_int.h", +] + +[[fields]] +name = "num_samples" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "num_channels" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "height" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "width" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "datatype" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.struct.toml b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.struct.toml deleted file mode 100644 index b81acbfadd..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.struct.toml +++ /dev/null @@ -1,36 +0,0 @@ -namespace = "FlexFlow" -name = "Conv2DInputShape" -features = [ - "eq", - "ord", - "hash", - "fmt", - "rapidcheck", - "json", -] - -includes = [ - "", - "op-attrs/datatype.dtg.h", - "utils/positive_int/positive_int.h", -] - -[[fields]] -name = "num_samples" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "num_channels" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "height" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "width" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "datatype" -type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.toml b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.toml new file mode 100644 index 0000000000..365e07be29 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.toml @@ -0,0 +1,45 @@ +namespace = "FlexFlow" +name = "Conv2DParallelInputShape" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/shard_parallel_dim.dtg.h", + "op-attrs/datatype.dtg.h", + "utils/positive_int/positive_int.h", +] + +[[fields]] +name = "sample_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "channel_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "height_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "width_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "sum_reduction_degree" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "discard_copy_reduction_degree" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "datatype" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.struct.toml b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.struct.toml deleted file mode 100644 index 668c61168b..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.struct.toml +++ /dev/null @@ -1,44 +0,0 @@ -namespace = "FlexFlow" -name = "Conv2DParallelInputShape" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/shard_parallel_dim.dtg.h", - "op-attrs/datatype.dtg.h", - "utils/positive_int/positive_int.h", -] - -[[fields]] -name = "sample_dim" -type = "::FlexFlow::ShardParallelDim" - -[[fields]] -name = "channel_dim" -type = "::FlexFlow::ShardParallelDim" - -[[fields]] -name = "height_dim" -type = "::FlexFlow::ShardParallelDim" - -[[fields]] -name = "width_dim" -type = "::FlexFlow::ShardParallelDim" - -[[fields]] -name = "sum_reduction_degree" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "discard_copy_reduction_degree" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "datatype" -type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.dtg.toml new file mode 100644 index 0000000000..ee3f97aaa0 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.dtg.toml @@ -0,0 +1,37 @@ +namespace = "FlexFlow" +name = "Conv2DAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "", + "op-attrs/activation.dtg.h", + "utils/nonnegative_int/nonnegative_int.h", + "utils/positive_int/positive_int.h", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", +] + +fields = [ + { name = "out_channels", type = "::FlexFlow::positive_int" }, + { name = "kernel_h", type = "::FlexFlow::positive_int" }, + { name = "kernel_w", type = "::FlexFlow::positive_int" }, + { name = "stride_h", type = "::FlexFlow::positive_int" }, + { name = "stride_w", type = "::FlexFlow::positive_int" }, + { name = "padding_h", type = "::FlexFlow::nonnegative_int" }, + { name = "padding_w", type = "::FlexFlow::nonnegative_int" }, + { name = "groups", type = "::FlexFlow::positive_int" }, + { name = "activation", type = "std::optional<::FlexFlow::Activation>" }, + { name = "use_bias", type = "bool" }, +] diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml deleted file mode 100644 index 469ce6570e..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml +++ /dev/null @@ -1,36 +0,0 @@ -namespace = "FlexFlow" -name = "Conv2DAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "", - "op-attrs/activation.dtg.h", - "utils/nonnegative_int/nonnegative_int.h", - "utils/positive_int/positive_int.h", -] - -src_includes = [ - "utils/fmt/optional.h", - "utils/json/optional.h", - "utils/rapidcheck/optional.h", -] - -fields = [ - { name = "out_channels", type = "::FlexFlow::positive_int" }, - { name = "kernel_h", type = "::FlexFlow::positive_int" }, - { name = "kernel_w", type = "::FlexFlow::positive_int" }, - { name = "stride_h", type = "::FlexFlow::positive_int" }, - { name = "stride_w", type = "::FlexFlow::positive_int" }, - { name = "padding_h", type = "::FlexFlow::nonnegative_int" }, - { name = "padding_w", type = "::FlexFlow::nonnegative_int" }, - { name = "groups", type = "::FlexFlow::positive_int" }, - { name = "activation", type = "std::optional<::FlexFlow::Activation>" }, - { name = "use_bias", type = "bool" }, -] diff --git a/lib/op-attrs/include/op-attrs/ops/core.h b/lib/op-attrs/include/op-attrs/ops/core.h deleted file mode 100644 index 611b53def5..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/core.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_OPS_CORE_H -#define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_OPS_CORE_H - -#include "utils/type_traits.h" - -namespace FlexFlow { - -#define CHECK_VALID_OP_ATTR(TYPENAME) CHECK_WELL_BEHAVED_VALUE_TYPE(TYPENAME) - -template -using is_valid_opattr = is_well_behaved_value_type; - -} // namespace FlexFlow - -#endif diff --git a/lib/op-attrs/include/op-attrs/ops/dropout.h b/lib/op-attrs/include/op-attrs/ops/dropout.h index 86e5db4d77..d5f3ae0c0d 100644 --- a/lib/op-attrs/include/op-attrs/ops/dropout.h +++ b/lib/op-attrs/include/op-attrs/ops/dropout.h @@ -1,10 +1,10 @@ -#ifndef _FLEXFLOW_DROPOUT_ATTRS_H -#define _FLEXFLOW_DROPOUT_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_DROPOUT_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_DROPOUT_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/dropout_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" +#include namespace FlexFlow { @@ -12,8 +12,6 @@ TensorShape get_output_shape(DropoutAttrs const &, TensorShape const &); tl::expected get_output_shape(DropoutAttrs const &, ParallelTensorShape const &); -CHECK_VALID_OP_ATTR(DropoutAttrs); - } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/dropout_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/dropout_attrs.dtg.toml new file mode 100644 index 0000000000..91f41572a6 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/dropout_attrs.dtg.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "DropoutAttrs" +type = "struct" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "rate" +type = "float" + +[[fields]] +name = "seed" +type = "unsigned long long" diff --git a/lib/op-attrs/include/op-attrs/ops/dropout_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/dropout_attrs.struct.toml deleted file mode 100644 index 8731e0780b..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/dropout_attrs.struct.toml +++ /dev/null @@ -1,19 +0,0 @@ -namespace = "FlexFlow" -name = "DropoutAttrs" - -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -[[fields]] -name = "rate" -type = "float" - -[[fields]] -name = "seed" -type = "unsigned long long" diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary.h b/lib/op-attrs/include/op-attrs/ops/element_binary.h index d51c3a3afa..1fc08e1322 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_binary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_binary.h @@ -1,21 +1,45 @@ -#ifndef _FLEXFLOW_ELEMENT_BINARY_ATTRS_H -#define _FLEXFLOW_ELEMENT_BINARY_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_BINARY_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_BINARY_H -#include "op-attrs/ops/core.h" +#include "op-attrs/num_ptensor_parallel_dims_t.h" +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.dtg.h" +#include "op-attrs/operator_task_space.dtg.h" #include "op-attrs/ops/element_binary_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" -#include namespace FlexFlow { -tl::expected get_output_shape( - ElementBinaryAttrs const &, TensorShape const &, TensorShape const &); -tl::expected - get_output_shape(ElementBinaryAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); +TensorShape get_output_shape(ElementBinaryAttrs const &, + TensorShape const &, + TensorShape const &); +ParallelTensorShape get_output_shape(ElementBinaryAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &); -CHECK_VALID_OP_ATTR(ElementBinaryAttrs); +ParallelTensorDimDegrees get_output_parallel_dim_degrees( + ElementBinaryAttrs const &attrs, + ParallelTensorDimDegrees const &lhs_input_degrees, + ParallelTensorDimDegrees const &rhs_input_degrees); + +OperatorTaskSpace + get_operator_task_space(ElementBinaryAttrs const &attrs, + ParallelTensorDimDegrees const &lhs_input_degrees, + ParallelTensorDimDegrees const &rhs_input_degrees); + +OperatorSpaceToParallelTensorSpaceMapping get_operator_to_lhs_input_mapping( + ElementBinaryAttrs const &attrs, + ParallelTensorDimDegrees const &lhs_input_degrees, + ParallelTensorDimDegrees const &rhs_input_degrees); + +OperatorSpaceToParallelTensorSpaceMapping get_operator_to_rhs_input_mapping( + ElementBinaryAttrs const &attrs, + ParallelTensorDimDegrees const &lhs_input_degrees, + ParallelTensorDimDegrees const &rhs_input_degrees); + +OperatorSpaceToParallelTensorSpaceMapping get_operator_to_output_mapping( + ElementBinaryAttrs const &attrs, + ParallelTensorDimDegrees const &lhs_input_degrees, + ParallelTensorDimDegrees const &rhs_input_degrees); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.dtg.toml new file mode 100644 index 0000000000..eb0a195eea --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.dtg.toml @@ -0,0 +1,33 @@ +namespace = "FlexFlow" +name = "ElementBinaryAttrs" +type = "struct" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/operator_type.h", + "op-attrs/datatype.h", +] + +[[fields]] +name = "type" +type = "::FlexFlow::OperatorType" + +[[fields]] +name = "compute_type" +type = "::FlexFlow::DataType" + +[[fields]] +name = "should_broadcast_lhs" +type = "bool" + +[[fields]] +name = "should_broadcast_rhs" +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml deleted file mode 100644 index d167c67aed..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml +++ /dev/null @@ -1,32 +0,0 @@ -namespace = "FlexFlow" -name = "ElementBinaryAttrs" - -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/operator_type.h", - "op-attrs/datatype.h", -] - -[[fields]] -name = "type" -type = "::FlexFlow::OperatorType" - -[[fields]] -name = "compute_type" -type = "::FlexFlow::DataType" - -[[fields]] -name = "should_broadcast_lhs" -type = "bool" - -[[fields]] -name = "should_broadcast_rhs" -type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index 1a965b2c51..004f127276 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -1,8 +1,10 @@ -#ifndef _FLEXFLOW_ELEMENTARY_UNARY_ATTRS_H -#define _FLEXFLOW_ELEMENTARY_UNARY_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_UNARY_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_UNARY_H -#include "op-attrs/ops/core.h" +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.dtg.h" +#include "op-attrs/operator_task_space.dtg.h" #include "op-attrs/ops/element_unary_attrs.dtg.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" #include @@ -11,12 +13,25 @@ namespace FlexFlow { ElementUnaryAttrs make_relu_attrs(); -tl::expected - get_output_shape(ElementUnaryAttrs const &, TensorShape const &); -tl::expected - get_output_shape(ElementUnaryAttrs const &, ParallelTensorShape const &); +TensorShape get_output_shape(ElementUnaryAttrs const &, TensorShape const &); +ParallelTensorShape get_output_shape(ElementUnaryAttrs const &, + ParallelTensorShape const &); -CHECK_VALID_OP_ATTR(ElementUnaryAttrs); +ParallelTensorDimDegrees get_output_parallel_dim_degrees( + ElementUnaryAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees); + +OperatorTaskSpace + get_operator_task_space(ElementUnaryAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees); + +OperatorSpaceToParallelTensorSpaceMapping get_operator_to_input_mapping( + ElementUnaryAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees); + +OperatorSpaceToParallelTensorSpaceMapping get_operator_to_output_mapping( + ElementUnaryAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.dtg.toml new file mode 100644 index 0000000000..9f2ff2010f --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.dtg.toml @@ -0,0 +1,31 @@ +namespace = "FlexFlow" +name = "ElementUnaryAttrs" +type = "struct" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/operator_type.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", +] + +[[fields]] +name = "op_type" +type = "::FlexFlow::OperatorType" + +[[fields]] +name = "scalar" +type = "std::optional" diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml deleted file mode 100644 index 403bb87592..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml +++ /dev/null @@ -1,30 +0,0 @@ -namespace = "FlexFlow" -name = "ElementUnaryAttrs" - -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/operator_type.dtg.h", - "", -] - -src_includes = [ - "utils/fmt/optional.h", - "utils/json/optional.h", - "utils/rapidcheck/optional.h", -] - -[[fields]] -name = "op_type" -type = "::FlexFlow::OperatorType" - -[[fields]] -name = "scalar" -type = "std::optional" diff --git a/lib/op-attrs/include/op-attrs/ops/embedding.h b/lib/op-attrs/include/op-attrs/ops/embedding.h index d44adf5f54..ff4aecae98 100644 --- a/lib/op-attrs/include/op-attrs/ops/embedding.h +++ b/lib/op-attrs/include/op-attrs/ops/embedding.h @@ -1,18 +1,16 @@ -#ifndef _FLEXFLOW_EMBEDDING_ATTRS_H -#define _FLEXFLOW_EMBEDDING_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_EMBEDDING_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_EMBEDDING_H #include "op-attrs/initializer_attrs.dtg.h" -#include "op-attrs/ops/core.h" #include "op-attrs/ops/embedding_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" +#include "op-attrs/tensor_slot_name.dtg.h" #include "utils/record_formatter.h" #include namespace FlexFlow { -CHECK_VALID_OP_ATTR(EmbeddingAttrs); - RecordFormatter as_dot(EmbeddingAttrs const &); tl::expected get_output_shape(EmbeddingAttrs const &, @@ -31,7 +29,7 @@ tl::expected * see * https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/sparse.py#L180-L182 */ -std::vector get_initializers( +std::unordered_map get_initializers( EmbeddingAttrs const &, std::optional const &initializer_attrs = std::nullopt); diff --git a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.dtg.toml new file mode 100644 index 0000000000..bb3ef3b709 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.dtg.toml @@ -0,0 +1,40 @@ +namespace = "FlexFlow" +name = "EmbeddingAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/aggregate_op.dtg.h", + "op-attrs/datatype.dtg.h", + "utils/positive_int/positive_int.h", + "", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", +] + +[[fields]] +name = "num_entries" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "out_channels" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "aggr" +type = "std::optional<::FlexFlow::AggregateOp>" + +[[fields]] +name = "data_type" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml deleted file mode 100644 index 07f82883db..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml +++ /dev/null @@ -1,39 +0,0 @@ -namespace = "FlexFlow" -name = "EmbeddingAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/aggregate_op.dtg.h", - "op-attrs/datatype.dtg.h", - "utils/positive_int/positive_int.h", - "", -] - -src_includes = [ - "utils/fmt/optional.h", - "utils/json/optional.h", - "utils/rapidcheck/optional.h", -] - -[[fields]] -name = "num_entries" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "out_channels" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "aggr" -type = "std::optional<::FlexFlow::AggregateOp>" - -[[fields]] -name = "data_type" -type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/flat.h b/lib/op-attrs/include/op-attrs/ops/flat.h index 710cbdb44b..ac03d7c7a9 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat.h +++ b/lib/op-attrs/include/op-attrs/ops/flat.h @@ -1,7 +1,6 @@ -#ifndef _FLEXFLOW_FLAT_ATTRS_H -#define _FLEXFLOW_FLAT_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_FLAT_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_FLAT_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/flat_attrs.dtg.h" #include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" @@ -9,14 +8,12 @@ namespace FlexFlow { -CHECK_VALID_OP_ATTR(FlatAttrs); - TensorShape get_output_shape(FlatAttrs const &, TensorShape const &); -tl::expected +ParallelTensorDimDegrees get_output_parallel_dim_degrees(FlatAttrs const &, ParallelTensorDimDegrees const &); -tl::expected - get_output_shape(FlatAttrs const &, ParallelTensorShape const &); +ParallelTensorShape get_output_shape(FlatAttrs const &, + ParallelTensorShape const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/flat_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/flat_attrs.dtg.toml new file mode 100644 index 0000000000..ce9dccc713 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/flat_attrs.dtg.toml @@ -0,0 +1,31 @@ +namespace = "FlexFlow" +name = "FlatAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "", + "op-attrs/ff_dim_t.dtg.h", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", + "op-attrs/ff_dim_t.h", +] + +[[fields]] +name = "start_dim" +type = "::FlexFlow::ff_dim_t" + +[[fields]] +name = "end_dim" +type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml deleted file mode 100644 index 301df8bca4..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml +++ /dev/null @@ -1,30 +0,0 @@ -namespace = "FlexFlow" -name = "FlatAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "", - "op-attrs/ff_dim_t.dtg.h", -] - -src_includes = [ - "utils/fmt/optional.h", - "utils/json/optional.h", - "utils/rapidcheck/optional.h", - "op-attrs/ff_dim_t.h", -] - -[[fields]] -name = "start_dim" -type = "::FlexFlow::ff_dim_t" - -[[fields]] -name = "end_dim" -type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/gather.h b/lib/op-attrs/include/op-attrs/ops/gather.h index 42efd13b60..3b67b9130b 100644 --- a/lib/op-attrs/include/op-attrs/ops/gather.h +++ b/lib/op-attrs/include/op-attrs/ops/gather.h @@ -1,14 +1,11 @@ -#ifndef _FLEXFLOW_GATHER_ATTRS_H -#define _FLEXFLOW_GATHER_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_GATHER_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_GATHER_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/gather_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { -CHECK_VALID_OP_ATTR(GatherAttrs); - TensorShape get_output_shape(GatherAttrs const &, TensorShape const &input, TensorShape const &index); diff --git a/lib/op-attrs/include/op-attrs/ops/gather_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/gather_attrs.dtg.toml new file mode 100644 index 0000000000..bb16939e24 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/gather_attrs.dtg.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "GatherAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim_t.dtg.h", +] + +src_includes = [ + "op-attrs/ff_dim_t.h", +] + +[[fields]] +name = "dim" +type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml deleted file mode 100644 index f76c7c683f..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml +++ /dev/null @@ -1,22 +0,0 @@ -namespace = "FlexFlow" -name = "GatherAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/ff_dim_t.dtg.h", -] - -src_includes = [ - "op-attrs/ff_dim_t.h", -] - -[[fields]] -name = "dim" -type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/input.h b/lib/op-attrs/include/op-attrs/ops/input.h index fe92c77a52..f2887421f5 100644 --- a/lib/op-attrs/include/op-attrs/ops/input.h +++ b/lib/op-attrs/include/op-attrs/ops/input.h @@ -1,18 +1,22 @@ #ifndef _FLEXFLOW_OP_ATTRS_OPS_OP_ATTRS_INPUT_H #define _FLEXFLOW_OP_ATTRS_OPS_OP_ATTRS_INPUT_H -#include "op-attrs/ops/core.h" +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.dtg.h" +#include "op-attrs/operator_task_space.dtg.h" #include "op-attrs/ops/input_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { -CHECK_VALID_OP_ATTR(InputAttrs); - TensorShape get_output_shape(InputAttrs const &); ParallelTensorShape get_output_parallel_tensor_shape(InputAttrs const &); +OperatorTaskSpace get_operator_task_space(InputAttrs const &); + +OperatorSpaceToParallelTensorSpaceMapping + get_operator_to_output_mapping(InputAttrs const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/input_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/input_attrs.dtg.toml new file mode 100644 index 0000000000..5aaebe4cb4 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/input_attrs.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "InputAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/tensor_shape.dtg.h", +] + +[[fields]] +name = "tensor_shape" +type = "::FlexFlow::TensorShape" diff --git a/lib/op-attrs/include/op-attrs/ops/input_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/input_attrs.struct.toml deleted file mode 100644 index 8965ef18fa..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/input_attrs.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "InputAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/tensor_shape.dtg.h", -] - -[[fields]] -name = "tensor_shape" -type = "::FlexFlow::TensorShape" diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm.h b/lib/op-attrs/include/op-attrs/ops/layer_norm.h index 4dcbeb665e..00c1ad9b12 100644 --- a/lib/op-attrs/include/op-attrs/ops/layer_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm.h @@ -1,16 +1,17 @@ -#ifndef _FLEXFLOW_OP_META_OPS_LAYER_NORM_ATTRS_H -#define _FLEXFLOW_OP_META_OPS_LAYER_NORM_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LAYER_NORM_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LAYER_NORM_H #include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/initializer_attrs.dtg.h" -#include "op-attrs/ops/core.h" #include "op-attrs/ops/layer_norm_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" +#include namespace FlexFlow { -std::vector +std::unordered_map get_layer_norm_incoming_tensor_roles(LayerNormAttrs const &); tl::expected get_output_shape(LayerNormAttrs const &, @@ -20,7 +21,7 @@ tl::expected tl::expected get_beta_weights_shape(LayerNormAttrs const &, TensorShape const &); -tl::expected, std::string> +tl::expected, std::string> get_weight_shapes(LayerNormAttrs const &attrs, TensorShape const &input_shape); @@ -32,7 +33,8 @@ tl::expected tl::expected get_beta_weights_shape(LayerNormAttrs const &, ParallelTensorShape const &); -tl::expected, std::string> +tl::expected, + std::string> get_weight_shapes(LayerNormAttrs const &attrs, ParallelTensorShape const &input_shape); @@ -42,9 +44,8 @@ tl::expected, std::string> * see * https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/normalization.py#L210-L214 */ -std::vector get_initializers(LayerNormAttrs const &attrs); - -CHECK_VALID_OP_ATTR(LayerNormAttrs); +std::unordered_map + get_initializers(LayerNormAttrs const &attrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.dtg.toml new file mode 100644 index 0000000000..7faf066013 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.dtg.toml @@ -0,0 +1,34 @@ +namespace = "FlexFlow" +name = "LayerNormAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim_t.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/set.h", + "utils/hash/set.h", + "op-attrs/ff_dim_t.h", +] + +[[fields]] +name = "axes" +type = "std::set<::FlexFlow::ff_dim_t>" + +[[fields]] +name = "elementwise_affine" +type = "bool" + +[[fields]] +name = "eps" +type = "float" diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml deleted file mode 100644 index 12e29d8a60..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml +++ /dev/null @@ -1,33 +0,0 @@ -namespace = "FlexFlow" -name = "LayerNormAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/ff_dim_t.dtg.h", - "", -] - -src_includes = [ - "utils/fmt/set.h", - "utils/hash/set.h", - "op-attrs/ff_dim_t.h", -] - -[[fields]] -name = "axes" -type = "std::set<::FlexFlow::ff_dim_t>" - -[[fields]] -name = "elementwise_affine" -type = "bool" - -[[fields]] -name = "eps" -type = "float" diff --git a/lib/op-attrs/include/op-attrs/ops/linear.h b/lib/op-attrs/include/op-attrs/ops/linear.h index 107f772e03..fb44e5cb4f 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear.h +++ b/lib/op-attrs/include/op-attrs/ops/linear.h @@ -1,22 +1,25 @@ -#ifndef _FLEXFLOW_LINEAR_ATTRS_H -#define _FLEXFLOW_LINEAR_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LINEAR_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LINEAR_H #include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/initializer_attrs.dtg.h" -#include "op-attrs/ops/core.h" +#include "op-attrs/num_ptensor_parallel_dims_t.h" +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.dtg.h" +#include "op-attrs/operator_task_space.dtg.h" #include "op-attrs/ops/linear_attrs.dtg.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/parallel_tensor_space_to_parallel_tensor_space_mapping.dtg.h" #include "op-attrs/tensor_shape.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" #include "utils/record_formatter.h" #include namespace FlexFlow { -std::vector +std::unordered_map get_linear_incoming_tensor_roles(LinearAttrs const &); -CHECK_VALID_OP_ATTR(LinearAttrs); - RecordFormatter as_dot(LinearAttrs const &); tl::expected @@ -26,9 +29,19 @@ tl::expected get_bias_shape(LinearAttrs const &attrs, tl::expected get_output_shape(LinearAttrs const &attrs, TensorShape const &input); -tl::expected, std::string> +tl::expected, std::string> get_weight_shapes(LinearAttrs const &attrs, TensorShape const &input_shape); +ParallelTensorDimDegrees + get_projection_parallel_dim_degrees(LinearAttrs const &attrs, + ParallelTensorDimDegrees const &input); +ParallelTensorDimDegrees + get_bias_parallel_dim_degrees(LinearAttrs const &attrs, + ParallelTensorDimDegrees const &input); +ParallelTensorDimDegrees + get_output_parallel_dim_degrees(LinearAttrs const &attrs, + ParallelTensorDimDegrees const &input); + tl::expected get_projection_shape(LinearAttrs const &attrs, ParallelTensorShape const &input); @@ -38,16 +51,33 @@ tl::expected get_output_shape(LinearAttrs const &attrs, ParallelTensorShape const &input); -tl::expected, std::string> +tl::expected, + std::string> get_weight_shapes(LinearAttrs const &attrs, ParallelTensorShape const &input_shape); -tl::expected, std::string> get_initializers( - LinearAttrs const &, - TensorShape const &input_shape, - std::optional const &projection_initializer = - std::nullopt, - std::optional const &kernel_initializer = std::nullopt); +tl::expected, std::string> + get_initializers(LinearAttrs const &, + TensorShape const &input_shape, + std::optional const + &projection_initializer = std::nullopt, + std::optional const &kernel_initializer = + std::nullopt); + +OperatorTaskSpace + get_operator_task_space(LinearAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees); + +OperatorSpaceToParallelTensorSpaceMapping get_operator_to_input_mapping( + LinearAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees); +OperatorSpaceToParallelTensorSpaceMapping get_operator_to_projection_mapping( + LinearAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees); +OperatorSpaceToParallelTensorSpaceMapping + get_operator_to_bias_mapping(LinearAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees); + +OperatorSpaceToParallelTensorSpaceMapping get_operator_to_output_mapping( + LinearAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/linear_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/linear_attrs.dtg.toml new file mode 100644 index 0000000000..9c8e0587c6 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/linear_attrs.dtg.toml @@ -0,0 +1,45 @@ +namespace = "FlexFlow" +name = "LinearAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/datatype.dtg.h", + "op-attrs/activation.dtg.h", + "op-attrs/regularizer_attrs.dtg.h", + "", + "utils/positive_int/positive_int.h", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", +] + +[[fields]] +name = "out_channels" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "use_bias" +type = "bool" + +[[fields]] +name = "data_type" +type = "::FlexFlow::DataType" + +[[fields]] +name = "activation" +type = "std::optional<::FlexFlow::Activation>" + +[[fields]] +name = "regularizer" +type = "std::optional<::FlexFlow::RegularizerAttrs>" diff --git a/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml deleted file mode 100644 index 23513482d3..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml +++ /dev/null @@ -1,44 +0,0 @@ -namespace = "FlexFlow" -name = "LinearAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/datatype.dtg.h", - "op-attrs/activation.dtg.h", - "op-attrs/regularizer_attrs.dtg.h", - "", - "utils/positive_int/positive_int.h", -] - -src_includes = [ - "utils/fmt/optional.h", - "utils/json/optional.h", - "utils/rapidcheck/optional.h", -] - -[[fields]] -name = "out_channels" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "use_bias" -type = "bool" - -[[fields]] -name = "data_type" -type = "::FlexFlow::DataType" - -[[fields]] -name = "activation" -type = "std::optional<::FlexFlow::Activation>" - -[[fields]] -name = "regularizer" -type = "std::optional<::FlexFlow::RegularizerAttrs>" diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions.h b/lib/op-attrs/include/op-attrs/ops/loss_functions.h index 657f8d91dc..c19d7f9e87 100644 --- a/lib/op-attrs/include/op-attrs/ops/loss_functions.h +++ b/lib/op-attrs/include/op-attrs/ops/loss_functions.h @@ -1,7 +1,6 @@ #ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LOSS_FUNCTIONS_H #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LOSS_FUNCTIONS_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/loss_functions/loss_attrs.dtg.h" #include "op-attrs/ops/loss_functions/loss_function.dtg.h" #include "op-attrs/ops/loss_functions/nonconfigurable_loss_attrs.dtg.h" diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_attrs.dtg.toml new file mode 100644 index 0000000000..c5f9bb7874 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_attrs.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "LossAttrs" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", + "rapidcheck", +] + +includes = [ + "op-attrs/ops/loss_functions/sparse_categorical_cross_entropy_loss_attrs.dtg.h", + "op-attrs/ops/loss_functions/nonconfigurable_loss_attrs.dtg.h", +] + +[[values]] +type = "::FlexFlow::SparseCategoricalCrossEntropyLossAttrs" +key = "sparse_categorical_cross_entropy_loss" + +[[values]] +type = "::FlexFlow::NonconfigurableLossAttrs" +key = "nonconfigurable_loss_attrs" diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_attrs.variant.toml b/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_attrs.variant.toml deleted file mode 100644 index 943760d949..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_attrs.variant.toml +++ /dev/null @@ -1,23 +0,0 @@ -namespace = "FlexFlow" -name = "LossAttrs" -features = [ - "eq", - "ord", - "hash", - "fmt", - "json", - "rapidcheck", -] - -includes = [ - "op-attrs/ops/loss_functions/sparse_categorical_cross_entropy_loss_attrs.dtg.h", - "op-attrs/ops/loss_functions/nonconfigurable_loss_attrs.dtg.h", -] - -[[values]] -type = "::FlexFlow::SparseCategoricalCrossEntropyLossAttrs" -key = "sparse_categorical_cross_entropy_loss" - -[[values]] -type = "::FlexFlow::NonconfigurableLossAttrs" -key = "nonconfigurable_loss_attrs" diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_function.dtg.toml b/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_function.dtg.toml new file mode 100644 index 0000000000..98e6a37fa6 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_function.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "LossFunction" +type = "enum" +features = [ + "fmt", + "hash", + "rapidcheck", + "json", +] + +[[values]] +name = "CATEGORICAL_CROSSENTROPY" + +[[values]] +name = "SPARSE_CATEGORICAL_CROSSENTROPY" + +[[values]] +name = "MEAN_SQUARED_ERROR_AVG_REDUCE" + +[[values]] +name = "MEAN_SQUARED_ERROR_SUM_REDUCE" + +[[values]] +name = "IDENTITY" diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_function.enum.toml b/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_function.enum.toml deleted file mode 100644 index 9658202a45..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_function.enum.toml +++ /dev/null @@ -1,23 +0,0 @@ -namespace = "FlexFlow" -name = "LossFunction" -features = [ - "fmt", - "hash", - "rapidcheck", - "json", -] - -[[values]] -name = "CATEGORICAL_CROSSENTROPY" - -[[values]] -name = "SPARSE_CATEGORICAL_CROSSENTROPY" - -[[values]] -name = "MEAN_SQUARED_ERROR_AVG_REDUCE" - -[[values]] -name = "MEAN_SQUARED_ERROR_SUM_REDUCE" - -[[values]] -name = "IDENTITY" diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions/nonconfigurable_loss_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/loss_functions/nonconfigurable_loss_attrs.dtg.toml new file mode 100644 index 0000000000..3a1c785232 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/loss_functions/nonconfigurable_loss_attrs.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "NonconfigurableLossAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ops/loss_functions/loss_function.dtg.h" +] + +[[fields]] +name = "loss_type" +type = "::FlexFlow::LossFunction" diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions/nonconfigurable_loss_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/loss_functions/nonconfigurable_loss_attrs.struct.toml deleted file mode 100644 index 3fe7ac86c5..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/loss_functions/nonconfigurable_loss_attrs.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "NonconfigurableLossAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/ops/loss_functions/loss_function.dtg.h" -] - -[[fields]] -name = "loss_type" -type = "::FlexFlow::LossFunction" diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions/sparse_categorical_cross_entropy_loss_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/loss_functions/sparse_categorical_cross_entropy_loss_attrs.dtg.toml new file mode 100644 index 0000000000..5f0184b0be --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/loss_functions/sparse_categorical_cross_entropy_loss_attrs.dtg.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "SparseCategoricalCrossEntropyLossAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[fields]] +# for aggregate_spec: More predictions than labels +name = "replace_labels" +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions/sparse_categorical_cross_entropy_loss_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/loss_functions/sparse_categorical_cross_entropy_loss_attrs.struct.toml deleted file mode 100644 index c50b432ba2..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/loss_functions/sparse_categorical_cross_entropy_loss_attrs.struct.toml +++ /dev/null @@ -1,15 +0,0 @@ -namespace = "FlexFlow" -name = "SparseCategoricalCrossEntropyLossAttrs" -features = [ - "eq", - "ord", - "hash", - "fmt", - "rapidcheck", - "json", -] - -[[fields]] -# for aggregate_spec: More predictions than labels -name = "replace_labels" -type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/noop.h b/lib/op-attrs/include/op-attrs/ops/noop.h index 2c61dff886..8c8e191132 100644 --- a/lib/op-attrs/include/op-attrs/ops/noop.h +++ b/lib/op-attrs/include/op-attrs/ops/noop.h @@ -1,15 +1,12 @@ -#ifndef _FLEXFLOW_OP_ATTRS_OPS_NOOP_H -#define _FLEXFLOW_OP_ATTRS_OPS_NOOP_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_NOOP_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_NOOP_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/noop_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { -CHECK_VALID_OP_ATTR(NoopAttrs); - TensorShape get_output_shape(NoopAttrs const &, TensorShape const &); ParallelTensorShape get_output_shape(NoopAttrs const &, ParallelTensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/ops/noop_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/noop_attrs.dtg.toml new file mode 100644 index 0000000000..e487165def --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/noop_attrs.dtg.toml @@ -0,0 +1,12 @@ +namespace = "FlexFlow" +name = "NoopAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] +fields = [] diff --git a/lib/op-attrs/include/op-attrs/ops/noop_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/noop_attrs.struct.toml deleted file mode 100644 index 3d9202093c..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/noop_attrs.struct.toml +++ /dev/null @@ -1,11 +0,0 @@ -namespace = "FlexFlow" -name = "NoopAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] -fields = [] diff --git a/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.dtg.toml new file mode 100644 index 0000000000..be0bee2179 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "ParallelMultiHeadAttentionInputs" +type = "struct" +features = [ + "eq", + # "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_shape.h" +] + +[[fields]] +name = "query" +type = "::FlexFlow::ParallelTensorShape" + +[[fields]] +name = "key" +type = "::FlexFlow::ParallelTensorShape" + +[[fields]] +name = "value" +type = "::FlexFlow::ParallelTensorShape" diff --git a/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml deleted file mode 100644 index 4809ee998a..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml +++ /dev/null @@ -1,26 +0,0 @@ -namespace = "FlexFlow" -name = "ParallelMultiHeadAttentionInputs" -features = [ - "eq", - # "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/parallel_tensor_shape.h" -] - -[[fields]] -name = "query" -type = "::FlexFlow::ParallelTensorShape" - -[[fields]] -name = "key" -type = "::FlexFlow::ParallelTensorShape" - -[[fields]] -name = "value" -type = "::FlexFlow::ParallelTensorShape" diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d.h b/lib/op-attrs/include/op-attrs/ops/pool_2d.h index 368250c957..016e632b33 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d.h @@ -1,16 +1,14 @@ -#ifndef _FLEXFLOW_POOL_2D_ATTRS_H -#define _FLEXFLOW_POOL_2D_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_POOL_2D_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_POOL_2D_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/pool_2d_attrs.dtg.h" #include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" +#include namespace FlexFlow { -CHECK_VALID_OP_ATTR(Pool2DAttrs); - tl::expected make_adaptive_pool2d_attrs(TensorDims const &input_dims, positive_int output_h, diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.dtg.toml new file mode 100644 index 0000000000..6751b6c956 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.dtg.toml @@ -0,0 +1,57 @@ +namespace = "FlexFlow" +name = "Pool2DAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/pool_op.dtg.h", + "op-attrs/activation.dtg.h", + "", + "utils/nonnegative_int/nonnegative_int.h", + "utils/positive_int/positive_int.h", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", +] + +[[fields]] +name = "kernel_h" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "kernel_w" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "stride_h" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "stride_w" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "padding_h" +type = "::FlexFlow::nonnegative_int" + +[[fields]] +name = "padding_w" +type = "::FlexFlow::nonnegative_int" + +[[fields]] +name = "pool_type" +type = "::FlexFlow::PoolOp" + +[[fields]] +name = "activation" +type = "std::optional<::FlexFlow::Activation>" diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml deleted file mode 100644 index d0005eee19..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml +++ /dev/null @@ -1,56 +0,0 @@ -namespace = "FlexFlow" -name = "Pool2DAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/pool_op.dtg.h", - "op-attrs/activation.dtg.h", - "", - "utils/nonnegative_int/nonnegative_int.h", - "utils/positive_int/positive_int.h", -] - -src_includes = [ - "utils/fmt/optional.h", - "utils/json/optional.h", - "utils/rapidcheck/optional.h", -] - -[[fields]] -name = "kernel_h" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "kernel_w" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "stride_h" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "stride_w" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "padding_h" -type = "::FlexFlow::nonnegative_int" - -[[fields]] -name = "padding_w" -type = "::FlexFlow::nonnegative_int" - -[[fields]] -name = "pool_type" -type = "::FlexFlow::PoolOp" - -[[fields]] -name = "activation" -type = "std::optional<::FlexFlow::Activation>" diff --git a/lib/op-attrs/include/op-attrs/ops/reduce.h b/lib/op-attrs/include/op-attrs/ops/reduce.h index 04e44b4161..5595ab9df5 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduce.h +++ b/lib/op-attrs/include/op-attrs/ops/reduce.h @@ -1,14 +1,11 @@ -#ifndef _FLEXFLOW_OP_META_OPS_REDUCE_ATTRS_H -#define _FLEXFLOW_OP_META_OPS_REDUCE_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCE_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/reduce_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -CHECK_VALID_OP_ATTR(ReduceAttrs); - ParallelTensorShape get_output_shape(ReduceAttrs const &, ParallelTensorShape const &input_shape); diff --git a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.dtg.toml new file mode 100644 index 0000000000..be55c50a19 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.dtg.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "ReduceAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/operator_type.dtg.h", + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h", + "utils/stack_vector/stack_vector.h", +] + +[[fields]] +name = "axes" +type = "::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>" + +[[fields]] +name = "op_type" +type = "::FlexFlow::OperatorType" + +[[fields]] +name = "keepdims" +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml deleted file mode 100644 index 607bee3000..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml +++ /dev/null @@ -1,29 +0,0 @@ -namespace = "FlexFlow" -name = "ReduceAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/operator_type.dtg.h", - "op-attrs/ff_dim_t.h", - "op-attrs/ff_dim_t.dtg.h", - "utils/stack_vector/stack_vector.h", -] - -[[fields]] -name = "axes" -type = "::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>" - -[[fields]] -name = "op_type" -type = "::FlexFlow::OperatorType" - -[[fields]] -name = "keepdims" -type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/reduction.h b/lib/op-attrs/include/op-attrs/ops/reduction.h index e8b2483cd5..b107178744 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduction.h +++ b/lib/op-attrs/include/op-attrs/ops/reduction.h @@ -1,7 +1,6 @@ -#ifndef _FLEXFLOW_REDUCTION_ATTRS_H -#define _FLEXFLOW_REDUCTION_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCTION_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCTION_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/reduction_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "utils/record_formatter.h" @@ -9,8 +8,6 @@ namespace FlexFlow { -CHECK_VALID_OP_ATTR(ReductionAttrs); - RecordFormatter as_dot(ReductionAttrs const &); tl::expected diff --git a/lib/op-attrs/include/op-attrs/ops/reduction_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.dtg.toml new file mode 100644 index 0000000000..0b9357c859 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "ReductionAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "utils/positive_int/positive_int.h", +] + +[[fields]] +name = "reduction_degree" +type = "::FlexFlow::positive_int" diff --git a/lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml deleted file mode 100644 index 1ae2dcdc75..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "ReductionAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "utils/positive_int/positive_int.h", -] - -[[fields]] -name = "reduction_degree" -type = "::FlexFlow::positive_int" diff --git a/lib/op-attrs/include/op-attrs/ops/repartition.h b/lib/op-attrs/include/op-attrs/ops/repartition.h index b67486ed35..7733bc6989 100644 --- a/lib/op-attrs/include/op-attrs/ops/repartition.h +++ b/lib/op-attrs/include/op-attrs/ops/repartition.h @@ -1,7 +1,6 @@ -#ifndef _FLEXFLOW_PARTITION_ATTRS_H -#define _FLEXFLOW_PARTITION_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPARTITION_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPARTITION_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/repartition_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "utils/record_formatter.h" @@ -9,8 +8,6 @@ namespace FlexFlow { -CHECK_VALID_OP_ATTR(RepartitionAttrs); - RecordFormatter as_dot(RepartitionAttrs const &); tl::expected diff --git a/lib/op-attrs/include/op-attrs/ops/repartition_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.dtg.toml new file mode 100644 index 0000000000..7881949b79 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.dtg.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "RepartitionAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h", + "utils/positive_int/positive_int.h", +] + +[[fields]] +name = "repartition_dim" +type = "::FlexFlow::ff_dim_t" + +[[fields]] +name = "repartition_degree" +type = "::FlexFlow::positive_int" diff --git a/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml deleted file mode 100644 index 9f08a13fcf..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml +++ /dev/null @@ -1,24 +0,0 @@ -namespace = "FlexFlow" -name = "RepartitionAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/ff_dim_t.h", - "op-attrs/ff_dim_t.dtg.h", - "utils/positive_int/positive_int.h", -] - -[[fields]] -name = "repartition_dim" -type = "::FlexFlow::ff_dim_t" - -[[fields]] -name = "repartition_degree" -type = "::FlexFlow::positive_int" diff --git a/lib/op-attrs/include/op-attrs/ops/replicate.h b/lib/op-attrs/include/op-attrs/ops/replicate.h index 10a4636d27..6a6ecd3d1e 100644 --- a/lib/op-attrs/include/op-attrs/ops/replicate.h +++ b/lib/op-attrs/include/op-attrs/ops/replicate.h @@ -1,15 +1,12 @@ -#ifndef _FLEXFLOW_REPLICATE_ATTRS_H -#define _FLEXFLOW_REPLICATE_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPLICATE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPLICATE_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/replicate_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "utils/record_formatter.h" namespace FlexFlow { -CHECK_VALID_OP_ATTR(ReplicateAttrs); - RecordFormatter as_dot(ReplicateAttrs const &); ParallelTensorShape get_output_shape(ReplicateAttrs const &attrs, diff --git a/lib/op-attrs/include/op-attrs/ops/replicate_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.dtg.toml new file mode 100644 index 0000000000..5c9f3fd38b --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "ReplicateAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "utils/positive_int/positive_int.h", +] + +[[fields]] +name = "replicate_degree" +type = "::FlexFlow::positive_int" diff --git a/lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml deleted file mode 100644 index 739f0edfb4..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "ReplicateAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "utils/positive_int/positive_int.h", -] - -[[fields]] -name = "replicate_degree" -type = "::FlexFlow::positive_int" diff --git a/lib/op-attrs/include/op-attrs/ops/reshape.h b/lib/op-attrs/include/op-attrs/ops/reshape.h index e87ca5c750..c7b8863ed6 100644 --- a/lib/op-attrs/include/op-attrs/ops/reshape.h +++ b/lib/op-attrs/include/op-attrs/ops/reshape.h @@ -1,14 +1,11 @@ -#ifndef _FLEXFLOW_RESHAPE_ATTRS_H -#define _FLEXFLOW_RESHAPE_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_RESHAPE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_RESHAPE_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/reshape_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -CHECK_VALID_OP_ATTR(ReshapeAttrs); - TensorShape get_output_shape(ReshapeAttrs const &attrs, TensorShape const &input_shape); ParallelTensorShape get_output_shape(ReshapeAttrs const &attrs, diff --git a/lib/op-attrs/include/op-attrs/ops/reshape_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/reshape_attrs.dtg.toml new file mode 100644 index 0000000000..1d86893a07 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reshape_attrs.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "ReshapeAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/tensor_shape.dtg.h", +] + +[[fields]] +name = "shape" +type = "::FlexFlow::TensorShape" diff --git a/lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml deleted file mode 100644 index 69ac761859..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "ReshapeAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/tensor_shape.dtg.h", -] - -[[fields]] -name = "shape" -type = "::FlexFlow::TensorShape" diff --git a/lib/op-attrs/include/op-attrs/ops/reverse.h b/lib/op-attrs/include/op-attrs/ops/reverse.h index 023e714c20..7b8ea7cbe5 100644 --- a/lib/op-attrs/include/op-attrs/ops/reverse.h +++ b/lib/op-attrs/include/op-attrs/ops/reverse.h @@ -1,15 +1,12 @@ #ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_H #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/reverse_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { -CHECK_VALID_OP_ATTR(ReverseAttrs); - TensorShape get_output_shape(ReverseAttrs const &, TensorShape const &); ParallelTensorShape get_output_shape(ReverseAttrs const &attrs, ParallelTensorShape const &input_shape); diff --git a/lib/op-attrs/include/op-attrs/ops/reverse_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.dtg.toml new file mode 100644 index 0000000000..8bf58ea7ee --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.dtg.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "ReverseAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h", +] + +[[fields]] +name = "axis" +type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml deleted file mode 100644 index 2577ac1398..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml +++ /dev/null @@ -1,19 +0,0 @@ -namespace = "FlexFlow" -name = "ReverseAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/ff_dim_t.h", - "op-attrs/ff_dim_t.dtg.h", -] - -[[fields]] -name = "axis" -type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/softmax.h b/lib/op-attrs/include/op-attrs/ops/softmax.h index 6eacc66b78..63bd7f1736 100644 --- a/lib/op-attrs/include/op-attrs/ops/softmax.h +++ b/lib/op-attrs/include/op-attrs/ops/softmax.h @@ -1,15 +1,13 @@ -#ifndef _FLEXFLOW_SOFTMAX_ATTRS_H -#define _FLEXFLOW_SOFTMAX_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SOFTMAX_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SOFTMAX_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/softmax_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" +#include namespace FlexFlow { -CHECK_VALID_OP_ATTR(SoftmaxAttrs); - tl::expected get_output_shape(SoftmaxAttrs const &attrs, TensorShape const &input_shape); tl::expected diff --git a/lib/op-attrs/include/op-attrs/ops/softmax_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.dtg.toml new file mode 100644 index 0000000000..bd7f7ba786 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.dtg.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "SoftmaxAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h", +] + +[[fields]] +name = "dim" +type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml deleted file mode 100644 index 49172f44b0..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml +++ /dev/null @@ -1,19 +0,0 @@ -namespace = "FlexFlow" -name = "SoftmaxAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/ff_dim_t.h", - "op-attrs/ff_dim_t.dtg.h", -] - -[[fields]] -name = "dim" -type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/split.h b/lib/op-attrs/include/op-attrs/ops/split.h index e6a08d6e77..b29a591b1b 100644 --- a/lib/op-attrs/include/op-attrs/ops/split.h +++ b/lib/op-attrs/include/op-attrs/ops/split.h @@ -1,7 +1,6 @@ -#ifndef _FLEXFLOW_SPLIT_ATTRS_H -#define _FLEXFLOW_SPLIT_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SPLIT_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SPLIT_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/split_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" @@ -9,8 +8,6 @@ namespace FlexFlow { -CHECK_VALID_OP_ATTR(SplitAttrs); - std::vector get_output_shapes(SplitAttrs const &, TensorShape const &); std::vector diff --git a/lib/op-attrs/include/op-attrs/ops/split_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/split_attrs.dtg.toml new file mode 100644 index 0000000000..5653a4f873 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/split_attrs.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "SplitAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "utils/stack_vector/stack_vector.h", + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h", + "utils/nonnegative_int/nonnegative_int.h", +] + +[[fields]] +name = "splits" +type = "::FlexFlow::stack_vector<::FlexFlow::nonnegative_int, MAX_NUM_OUTPUTS>" + +[[fields]] +name = "axis" +type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml deleted file mode 100644 index 7ce1ad7e34..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml +++ /dev/null @@ -1,25 +0,0 @@ -namespace = "FlexFlow" -name = "SplitAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "utils/stack_vector/stack_vector.h", - "op-attrs/ff_dim_t.h", - "op-attrs/ff_dim_t.dtg.h", - "utils/nonnegative_int/nonnegative_int.h", -] - -[[fields]] -name = "splits" -type = "::FlexFlow::stack_vector<::FlexFlow::nonnegative_int, MAX_NUM_OUTPUTS>" - -[[fields]] -name = "axis" -type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/topk.h b/lib/op-attrs/include/op-attrs/ops/topk.h index d6de90903a..cf28d0f8e9 100644 --- a/lib/op-attrs/include/op-attrs/ops/topk.h +++ b/lib/op-attrs/include/op-attrs/ops/topk.h @@ -1,15 +1,12 @@ -#ifndef _FLEXFLOW_TOPK_ATTRS_H -#define _FLEXFLOW_TOPK_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TOPK_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TOPK_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/topk_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { -CHECK_VALID_OP_ATTR(TopKAttrs); - TensorShape get_output_shape(TopKAttrs const &, TensorShape const &); ParallelTensorShape get_output_shape(TopKAttrs const &attrs, ParallelTensorShape const &input_shape); diff --git a/lib/op-attrs/include/op-attrs/ops/topk_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/topk_attrs.dtg.toml new file mode 100644 index 0000000000..86f40380af --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/topk_attrs.dtg.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "TopKAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "utils/positive_int/positive_int.h", +] + +[[fields]] +name = "k" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "sorted" +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/topk_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/topk_attrs.struct.toml deleted file mode 100644 index 8feaff4dc0..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/topk_attrs.struct.toml +++ /dev/null @@ -1,22 +0,0 @@ -namespace = "FlexFlow" -name = "TopKAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "utils/positive_int/positive_int.h", -] - -[[fields]] -name = "k" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "sorted" -type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/transpose.h b/lib/op-attrs/include/op-attrs/ops/transpose.h index 6de83ee414..f79b8c4225 100644 --- a/lib/op-attrs/include/op-attrs/ops/transpose.h +++ b/lib/op-attrs/include/op-attrs/ops/transpose.h @@ -1,19 +1,34 @@ -#ifndef _FLEXFLOW_OP_META_OPS_TRANSPOSE_ATTRS_H -#define _FLEXFLOW_OP_META_OPS_TRANSPOSE_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TRANSPOSE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TRANSPOSE_H -#include "op-attrs/ops/core.h" +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.dtg.h" +#include "op-attrs/operator_task_space.dtg.h" #include "op-attrs/ops/transpose_attrs.dtg.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { -CHECK_VALID_OP_ATTR(TransposeAttrs); - TensorShape get_output_shape(TransposeAttrs const &, TensorShape const &); + +ParallelTensorDimDegrees + get_output_parallel_dim_degrees(TransposeAttrs const &, + ParallelTensorDimDegrees const &); + ParallelTensorShape get_output_shape(TransposeAttrs const &, ParallelTensorShape const &); +OperatorTaskSpace + get_operator_task_space(TransposeAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees); + +OperatorSpaceToParallelTensorSpaceMapping get_operator_to_input_mapping( + TransposeAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees); + +OperatorSpaceToParallelTensorSpaceMapping get_operator_to_output_mapping( + TransposeAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.dtg.toml new file mode 100644 index 0000000000..ac48150b5d --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "TransposeAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/tensor_dim_permutation.h", +] + +[[fields]] +name = "permutation" +type = "::FlexFlow::TensorDimPermutation" diff --git a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml deleted file mode 100644 index 50756f095b..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml +++ /dev/null @@ -1,20 +0,0 @@ -namespace = "FlexFlow" -name = "TransposeAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/ff_dim_t.h", - "op-attrs/ff_dim_t.dtg.h", - "op-attrs/ff_ordered/ff_ordered.h", -] - -[[fields]] -name = "perm" -type = "::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t>" diff --git a/lib/op-attrs/include/op-attrs/ops/weight.h b/lib/op-attrs/include/op-attrs/ops/weight.h index 66eb0064ed..3d488ef24c 100644 --- a/lib/op-attrs/include/op-attrs/ops/weight.h +++ b/lib/op-attrs/include/op-attrs/ops/weight.h @@ -1,7 +1,8 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_WEIGHT_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_WEIGHT_H -#include "op-attrs/ops/core.h" +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.dtg.h" +#include "op-attrs/operator_task_space.dtg.h" #include "op-attrs/ops/weight_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" @@ -9,13 +10,16 @@ namespace FlexFlow { -CHECK_VALID_OP_ATTR(WeightAttrs); - RecordFormatter as_dot(WeightAttrs const &); TensorShape get_output_shape(WeightAttrs const &); ParallelTensorShape get_output_parallel_tensor_shape(WeightAttrs const &); +OperatorTaskSpace get_operator_task_space(WeightAttrs const &); + +OperatorSpaceToParallelTensorSpaceMapping + get_operator_to_output_mapping(WeightAttrs const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/weight_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/ops/weight_attrs.dtg.toml new file mode 100644 index 0000000000..af75b19f6e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/weight_attrs.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "WeightAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/tensor_shape.dtg.h", + "op-attrs/initializer_attrs.dtg.h", +] + +[[fields]] +name = "tensor_shape" +type = "::FlexFlow::TensorShape" + +[[fields]] +name = "initializer" +type = "::FlexFlow::InitializerAttrs" diff --git a/lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml deleted file mode 100644 index 2f62143ce6..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml +++ /dev/null @@ -1,23 +0,0 @@ -namespace = "FlexFlow" -name = "WeightAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/tensor_shape.dtg.h", - "op-attrs/initializer_attrs.dtg.h", -] - -[[fields]] -name = "tensor_shape" -type = "::FlexFlow::TensorShape" - -[[fields]] -name = "initializer" -type = "::FlexFlow::InitializerAttrs" diff --git a/lib/op-attrs/include/op-attrs/parallel_dim.dtg.toml b/lib/op-attrs/include/op-attrs/parallel_dim.dtg.toml new file mode 100644 index 0000000000..230f45048d --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_dim.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "ParallelDim" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/shard_parallel_dim.dtg.h", + "op-attrs/replica_parallel_dim.dtg.h", +] + +[[values]] +type = "::FlexFlow::ShardParallelDim" +key = "shard_dim" + +[[values]] +type = "::FlexFlow::ReplicaParallelDim" +key = "replica_dim" diff --git a/lib/op-attrs/include/op-attrs/parallel_dim.variant.toml b/lib/op-attrs/include/op-attrs/parallel_dim.variant.toml deleted file mode 100644 index e27e6509fe..0000000000 --- a/lib/op-attrs/include/op-attrs/parallel_dim.variant.toml +++ /dev/null @@ -1,23 +0,0 @@ -namespace = "FlexFlow" -name = "ParallelDim" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/shard_parallel_dim.dtg.h", - "op-attrs/replica_parallel_dim.dtg.h", -] - -[[values]] -type = "::FlexFlow::ShardParallelDim" -key = "shard_dim" - -[[values]] -type = "::FlexFlow::ReplicaParallelDim" -key = "replica_dim" diff --git a/lib/op-attrs/include/op-attrs/parallel_op_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/parallel_op_attrs.dtg.toml new file mode 100644 index 0000000000..63b948a8bd --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_op_attrs.dtg.toml @@ -0,0 +1,35 @@ +namespace = "FlexFlow" +name = "ParallelOpAttrs" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ops/combine_attrs.dtg.h", + "op-attrs/ops/reduction_attrs.dtg.h", + "op-attrs/ops/repartition_attrs.dtg.h", + "op-attrs/ops/replicate_attrs.dtg.h", +] + +[[values]] +type = "::FlexFlow::CombineAttrs" +key = "combine_distributed" + +[[values]] +type = "::FlexFlow::ReductionAttrs" +key = "reduce_distributed" + +[[values]] +type = "::FlexFlow::RepartitionAttrs" +key = "partition_distributed" + +[[values]] +type = "::FlexFlow::ReplicateAttrs" +key = "replicate_distributed" + diff --git a/lib/op-attrs/include/op-attrs/parallel_op_attrs.variant.toml b/lib/op-attrs/include/op-attrs/parallel_op_attrs.variant.toml deleted file mode 100644 index f1631a41f2..0000000000 --- a/lib/op-attrs/include/op-attrs/parallel_op_attrs.variant.toml +++ /dev/null @@ -1,34 +0,0 @@ -namespace = "FlexFlow" -name = "ParallelOpAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/ops/combine_attrs.dtg.h", - "op-attrs/ops/reduction_attrs.dtg.h", - "op-attrs/ops/repartition_attrs.dtg.h", - "op-attrs/ops/replicate_attrs.dtg.h", -] - -[[values]] -type = "::FlexFlow::CombineAttrs" -key = "combine_distributed" - -[[values]] -type = "::FlexFlow::ReductionAttrs" -key = "reduce_distributed" - -[[values]] -type = "::FlexFlow::RepartitionAttrs" -key = "partition_distributed" - -[[values]] -type = "::FlexFlow::ReplicateAttrs" -key = "replicate_distributed" - diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.dtg.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.dtg.toml new file mode 100644 index 0000000000..a0ad167338 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.dtg.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "ParallelTensorDimDegrees" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", + "rapidcheck", +] + +includes = [ + "op-attrs/parallel_tensor_shape/sum_degree.dtg.h", + "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h", + "op-attrs/ff_ordered/ff_ordered.h", + "utils/positive_int/positive_int.h", +] + +[[fields]] +name = "sum_degree" +type = "::FlexFlow::SumDegree" + +[[fields]] +name = "discard_copy_degree" +type = "::FlexFlow::DiscardCopyDegree" + +[[fields]] +name = "shard_degrees" +type = "::FlexFlow::FFOrdered<::FlexFlow::positive_int>" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.h new file mode 100644 index 0000000000..5582bf6e07 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.h @@ -0,0 +1,48 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIM_DEGREES_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIM_DEGREES_H + +#include "op-attrs/num_ptensor_shard_dims_t.dtg.h" +#include "op-attrs/num_tensor_dims_t.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" +#include "op-attrs/parallel_tensor_dim_idx_t.dtg.h" +#include "op-attrs/parallel_tensor_space_coordinate.dtg.h" +#include "utils/orthotope/dim_domain.dtg.h" +#include "utils/orthotope/minimal_dim_domain.dtg.h" + +namespace FlexFlow { + +num_ptensor_shard_dims_t + get_ptensor_dim_degrees_num_shard_dims(ParallelTensorDimDegrees const &); +num_tensor_dims_t + get_ptensor_dim_degrees_num_tensor_dims(ParallelTensorDimDegrees const &); + +std::unordered_set + get_parallel_tensor_dim_indices(ParallelTensorDimDegrees const &); + +std::set get_nontrivial_parallel_tensor_dim_indices( + ParallelTensorDimDegrees const &); + +positive_int + get_degree_for_parallel_tensor_dim_idx(ParallelTensorDimDegrees const &, + parallel_tensor_dim_idx_t const &); + +std::unordered_map + get_parallel_tensor_degree_map(ParallelTensorDimDegrees const &); + +std::unordered_set + get_parallel_tensor_space_coordinates(ParallelTensorDimDegrees const &); + +DimDomain + dim_domain_from_parallel_tensor_dim_degrees( + ParallelTensorDimDegrees const &); + +ParallelTensorDimDegrees parallel_tensor_dim_degrees_from_dim_domain( + DimDomain const &); + +MinimalDimDomain + minimal_dim_domain_from_parallel_tensor_dim_degrees( + ParallelTensorDimDegrees const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.struct.toml deleted file mode 100644 index e25627f709..0000000000 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.struct.toml +++ /dev/null @@ -1,29 +0,0 @@ -namespace = "FlexFlow" -name = "ParallelTensorDimDegrees" -features = [ - "eq", - "ord", - "hash", - "json", - "fmt", - "rapidcheck", -] - -includes = [ - "op-attrs/parallel_tensor_shape/sum_degree.dtg.h", - "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h", - "op-attrs/ff_ordered/ff_ordered.h", - "utils/positive_int/positive_int.h", -] - -[[fields]] -name = "sum_degree" -type = "::FlexFlow::SumDegree" - -[[fields]] -name = "discard_copy_degree" -type = "::FlexFlow::DiscardCopyDegree" - -[[fields]] -name = "shard_degrees" -type = "::FlexFlow::FFOrdered<::FlexFlow::positive_int>" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.dtg.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.dtg.toml new file mode 100644 index 0000000000..e546801977 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.dtg.toml @@ -0,0 +1,37 @@ +namespace = "FlexFlow" +name = "parallel_tensor_dim_idx_t" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", + "rapidcheck", +] + +docstring = """\ +@brief Index type for the dimensions of @ref ParallelTensorDimDegrees. + +@ref parallel_tensor_dim_idx_t is to @ref ParallelTensorDimDegrees as +@ref operator_task_space_dim_idx_t is to @ref OperatorTaskSpace as +@ref MachineSpecificationDimension is to @ref MachineComputeSpecification. +""" + +includes = [ + "op-attrs/ff_dim_t.dtg.h", + "op-attrs/replica_type.dtg.h", +] + +src_includes = [ + "op-attrs/ff_dim_t.h", +] + +[[values]] +type = "::FlexFlow::ff_dim_t" +key = "shard_dim" + + +[[values]] +type = "::FlexFlow::ReplicaType" +key = "replica_dim" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.h new file mode 100644 index 0000000000..1311471e09 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIM_IDX_T_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIM_IDX_T_H + +#include "op-attrs/num_ptensor_shard_dims_t.dtg.h" +#include "op-attrs/parallel_tensor_dim_idx_t.dtg.h" +#include "utils/orthotope/dim_ordering.dtg.h" + +namespace FlexFlow { + +parallel_tensor_dim_idx_t sum_dim_idx(); +parallel_tensor_dim_idx_t discard_copy_dim_idx(); +parallel_tensor_dim_idx_t shard_dim_idx(ff_dim_t); + +bool is_dim_idx_for_reduction_dimension(parallel_tensor_dim_idx_t); + +std::set + dim_idxs_for_num_shard_dims(num_ptensor_shard_dims_t num_shard_dims); + +DimOrdering get_parallel_tensor_dim_ordering(); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.variant.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.variant.toml deleted file mode 100644 index 7e7356a5e7..0000000000 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.variant.toml +++ /dev/null @@ -1,20 +0,0 @@ -namespace = "FlexFlow" -name = "parallel_tensor_dim_idx_t" -features = [ - "eq", - "ord", - "hash", - "json", - "fmt", -] - -includes = [ - "op-attrs/ff_dim_t.dtg.h", - "op-attrs/replica_type.dtg.h", -] - -[[values]] -type = "::FlexFlow::ff_dim_t" - -[[values]] -type = "::FlexFlow::ReplicaType" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.toml new file mode 100644 index 0000000000..33e2e29db1 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "ParallelTensorDims" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_ordered/ff_ordered.h", + "op-attrs/shard_parallel_dim.dtg.h", + "op-attrs/replica_parallel_dim_set.dtg.h", + "", + "utils/fmt/unordered_map.h", + "utils/fmt/pair.h", +] + +[[fields]] +name = "shard_dims" +type = "::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim>" + +[[fields]] +name = "replica_dims" +type = "::FlexFlow::ReplicaParallelDimSet" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h index 435a962963..9e71785013 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_H #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_H +#include "op-attrs/num_ptensor_shard_dims_t.dtg.h" #include "op-attrs/parallel_dim.h" #include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_dims.dtg.h" @@ -13,7 +14,7 @@ FFOrdered ff_ordered_shard_degrees(ParallelTensorDims const &); std::unordered_set replica_dims(ParallelTensorDims const &); /* size_t get_volume(ParallelTensorDims const &); */ -nonnegative_int num_shard_dims(ParallelTensorDims const &); +num_ptensor_shard_dims_t num_shard_dims(ParallelTensorDims const &); ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorDims const &); diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml deleted file mode 100644 index d2f8758377..0000000000 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml +++ /dev/null @@ -1,27 +0,0 @@ -namespace = "FlexFlow" -name = "ParallelTensorDims" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/ff_ordered/ff_ordered.h", - "op-attrs/shard_parallel_dim.dtg.h", - "op-attrs/replica_parallel_dim_set.dtg.h", - "", - "utils/fmt/unordered_map.h", - "utils/fmt/pair.h", -] - -[[fields]] -name = "shard_dims" -type = "::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim>" - -[[fields]] -name = "replica_dims" -type = "::FlexFlow::ReplicaParallelDimSet" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.toml new file mode 100644 index 0000000000..89a61fcedc --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "ParallelTensorShape" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_dims.dtg.h", + "op-attrs/datatype.dtg.h", +] + +[[fields]] +name = "dims" +type = "::FlexFlow::ParallelTensorDims" + +[[fields]] +name = "data_type" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index e366f99b8e..e23ae33cbf 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -2,6 +2,7 @@ #define _OP_META_PARALLEL_TENSOR_SHAPE_H #include "op-attrs/ff_dim_t.h" +#include "op-attrs/num_ptensor_shard_dims_t.dtg.h" #include "op-attrs/parallel_dim.h" #include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_dim_idx_t.dtg.h" @@ -12,7 +13,7 @@ namespace FlexFlow { -nonnegative_int num_shard_dims(ParallelTensorShape const &); +num_ptensor_shard_dims_t num_shard_dims(ParallelTensorShape const &); ShardParallelDim shard_dim_at_idx(ParallelTensorShape const &, relative_ff_dim_t); ShardParallelDim &shard_dim_at_idx(ParallelTensorShape &, relative_ff_dim_t); @@ -34,9 +35,12 @@ ParallelTensorShape lift_to_parallel_with_degrees(TensorShape const &, ParallelTensorDimDegrees const &); +TensorShape get_piece_shape(ParallelTensorShape const &); +num_bytes_t get_piece_size_in_bytes(ParallelTensorShape const &); + std::unordered_set replica_dims(ParallelTensorShape const &); -TensorShape get_piece_shape(ParallelTensorShape const &); + positive_int get_num_replica_dims(ParallelTensorShape const &); positive_int get_num_replicas(ParallelTensorShape const &); @@ -48,7 +52,6 @@ positive_int get_total_parallel_degree(ParallelTensorShape const &); bool is_valid(ParallelTensorShape const &); TensorShape require_not_parallel(ParallelTensorShape const &); -TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &); std::vector get_tensor_shapes_unsafe(std::vector const &); diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml deleted file mode 100644 index 806af55cba..0000000000 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml +++ /dev/null @@ -1,23 +0,0 @@ -namespace = "FlexFlow" -name = "ParallelTensorShape" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/parallel_tensor_dims.dtg.h", - "op-attrs/datatype.dtg.h", -] - -[[fields]] -name = "dims" -type = "::FlexFlow::ParallelTensorDims" - -[[fields]] -name = "data_type" -type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.toml new file mode 100644 index 0000000000..4f0213c76e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "DiscardCopyDegree" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "utils/positive_int/positive_int.h", +] + +[[fields]] +name = "value" +type = "::FlexFlow::positive_int" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.struct.toml deleted file mode 100644 index d60495bc3a..0000000000 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "DiscardCopyDegree" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "utils/positive_int/positive_int.h", -] - -[[fields]] -name = "value" -type = "::FlexFlow::positive_int" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.dtg.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.dtg.toml new file mode 100644 index 0000000000..ec7db0b438 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "SumDegree" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "utils/positive_int/positive_int.h", +] + +[[fields]] +name = "value" +type = "::FlexFlow::positive_int" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.struct.toml deleted file mode 100644 index f16586c4c9..0000000000 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "SumDegree" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "utils/positive_int/positive_int.h", -] - -[[fields]] -name = "value" -type = "::FlexFlow::positive_int" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_space_coordinate.dtg.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_space_coordinate.dtg.toml new file mode 100644 index 0000000000..c0db511565 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_space_coordinate.dtg.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "ParallelTensorSpaceCoordinate" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", + "rapidcheck", +] + +includes = [ + "op-attrs/ff_ordered/ff_ordered.h", + "utils/nonnegative_int/nonnegative_int.h", +] + +[[fields]] +name = "sum_component" +type = "::FlexFlow::nonnegative_int" + +[[fields]] +name = "discard_copy_component" +type = "::FlexFlow::nonnegative_int" + +[[fields]] +name = "shard_components" +type = "::FlexFlow::FFOrdered<::FlexFlow::nonnegative_int>" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_space_coordinate.h b/lib/op-attrs/include/op-attrs/parallel_tensor_space_coordinate.h new file mode 100644 index 0000000000..3fd684c5ef --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_space_coordinate.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SPACE_COORDINATE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SPACE_COORDINATE_H + +#include "op-attrs/num_ptensor_parallel_dims_t.h" +#include "op-attrs/num_ptensor_shard_dims_t.dtg.h" +#include "op-attrs/parallel_tensor_dim_idx_t.dtg.h" +#include "op-attrs/parallel_tensor_space_coordinate.dtg.h" +#include "utils/orthotope/dim_coord.dtg.h" + +namespace FlexFlow { + +num_ptensor_parallel_dims_t + ptensor_coord_num_dims(ParallelTensorSpaceCoordinate const &); +num_ptensor_shard_dims_t + ptensor_coord_num_shard_dims(ParallelTensorSpaceCoordinate const &); + +std::unordered_set + get_dim_idxs_in_ptensor_space_coord(ParallelTensorSpaceCoordinate const &); + +nonnegative_int ptensor_coord_component_for_ptensor_dim_idx( + ParallelTensorSpaceCoordinate const &, parallel_tensor_dim_idx_t); + +ParallelTensorSpaceCoordinate parallel_tensor_space_coord_from_map( + std::unordered_map const &); + +ParallelTensorSpaceCoordinate parallel_tensor_space_coord_from_dim_coord( + DimCoord const &); + +DimCoord dim_coord_from_parallel_tensor_space_coord( + ParallelTensorSpaceCoordinate const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_space_to_parallel_tensor_space_mapping.dtg.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_space_to_parallel_tensor_space_mapping.dtg.toml new file mode 100644 index 0000000000..b10140bbac --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_space_to_parallel_tensor_space_mapping.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "ParallelTensorSpaceToParallelTensorSpaceMapping" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/orthotope/dim_domain_mapping.h", + "op-attrs/parallel_tensor_dim_idx_t.dtg.h", +] + +[[fields]] +name = "raw_mapping" +type = "::FlexFlow::DimDomainMapping<::FlexFlow::parallel_tensor_dim_idx_t, ::FlexFlow::parallel_tensor_dim_idx_t>" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_space_to_parallel_tensor_space_mapping.h b/lib/op-attrs/include/op-attrs/parallel_tensor_space_to_parallel_tensor_space_mapping.h new file mode 100644 index 0000000000..34568efea6 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_space_to_parallel_tensor_space_mapping.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SPACE_MAPPING_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SPACE_MAPPING_H + +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" +#include "op-attrs/parallel_tensor_dim_idx_t.dtg.h" +#include "op-attrs/parallel_tensor_space_to_parallel_tensor_space_mapping.dtg.h" +#include "utils/orthotope/dim_projection.dtg.h" + +namespace FlexFlow { + +ParallelTensorSpaceToParallelTensorSpaceMapping + parallel_tensor_space_mapping_from_projection( + DimProjection const &projection, + ParallelTensorDimDegrees const &l_degrees, + ParallelTensorDimDegrees const &r_degrees); + +ParallelTensorSpaceToParallelTensorSpaceMapping + invert_parallel_tensor_space_mapping( + ParallelTensorSpaceToParallelTensorSpaceMapping const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.toml new file mode 100644 index 0000000000..88a65f75c5 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.toml @@ -0,0 +1,164 @@ +namespace = "FlexFlow" +name = "PCGOperatorAttrs" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ops/attention_attrs.dtg.h", + "op-attrs/ops/batch_matmul_attrs.dtg.h", + "op-attrs/ops/batch_norm_attrs.dtg.h", + "op-attrs/ops/broadcast_attrs.dtg.h", + "op-attrs/ops/cast_attrs.dtg.h", + "op-attrs/ops/combine_attrs.dtg.h", + "op-attrs/ops/concat_attrs.dtg.h", + "op-attrs/ops/conv_2d_attrs.dtg.h", + "op-attrs/ops/dropout_attrs.dtg.h", + "op-attrs/ops/element_binary_attrs.dtg.h", + "op-attrs/ops/element_unary_attrs.dtg.h", + "op-attrs/ops/embedding_attrs.dtg.h", + "op-attrs/ops/flat_attrs.dtg.h", + "op-attrs/ops/gather_attrs.dtg.h", + "op-attrs/ops/input_attrs.dtg.h", + "op-attrs/ops/layer_norm_attrs.dtg.h", + "op-attrs/ops/linear_attrs.dtg.h", + "op-attrs/ops/noop_attrs.dtg.h", + "op-attrs/ops/pool_2d_attrs.dtg.h", + "op-attrs/ops/reduce_attrs.dtg.h", + "op-attrs/ops/reduction_attrs.dtg.h", + "op-attrs/ops/repartition_attrs.dtg.h", + "op-attrs/ops/replicate_attrs.dtg.h", + "op-attrs/ops/reshape_attrs.dtg.h", + "op-attrs/ops/reverse_attrs.dtg.h", + "op-attrs/ops/softmax_attrs.dtg.h", + "op-attrs/ops/split_attrs.dtg.h", + "op-attrs/ops/topk_attrs.dtg.h", + "op-attrs/ops/transpose_attrs.dtg.h", + "op-attrs/ops/weight_attrs.dtg.h", +] + +[[values]] +type = "::FlexFlow::BatchMatmulAttrs" +key = "batch_matmul" + +[[values]] +type = "::FlexFlow::BatchNormAttrs" +key = "batch_norm" + +[[values]] +type = "::FlexFlow::BroadcastAttrs" +key = "broadcast" + +[[values]] +type = "::FlexFlow::CastAttrs" +key = "cast" + +[[values]] +type = "::FlexFlow::CombineAttrs" +key = "combine_distributed" + +[[values]] +type = "::FlexFlow::ConcatAttrs" +key = "concat" + +[[values]] +type = "::FlexFlow::Conv2DAttrs" +key = "conv2d" + +[[values]] +type = "::FlexFlow::DropoutAttrs" +key = "dropout" + +[[values]] +type = "::FlexFlow::ElementBinaryAttrs" +key = "element_binary" + +[[values]] +type = "::FlexFlow::ElementUnaryAttrs" +key = "element_unary" + +[[values]] +type = "::FlexFlow::EmbeddingAttrs" +key = "embedding" + +[[values]] +type = "::FlexFlow::FlatAttrs" +key = "flat" + +[[values]] +type = "::FlexFlow::GatherAttrs" +key = "gather" + +[[values]] +type = "::FlexFlow::InputAttrs" +key = "input" + +[[values]] +type = "::FlexFlow::LayerNormAttrs" +key = "layer_norm" + +[[values]] +type = "::FlexFlow::LinearAttrs" +key = "linear" + +[[values]] +type = "::FlexFlow::MultiHeadAttentionAttrs" +key = "multi_head_attention" + +[[values]] +type = "::FlexFlow::NoopAttrs" +key = "noop" + +[[values]] +type = "::FlexFlow::Pool2DAttrs" +key = "pool2d" + +[[values]] +type = "::FlexFlow::ReduceAttrs" +key = "reduce" + +[[values]] +type = "::FlexFlow::ReductionAttrs" +key = "reduce_distributed" + +[[values]] +type = "::FlexFlow::RepartitionAttrs" +key = "partition_distributed" + +[[values]] +type = "::FlexFlow::ReplicateAttrs" +key = "replicate_distributed" + +[[values]] +type = "::FlexFlow::ReverseAttrs" +key = "reverse" + +[[values]] +type = "::FlexFlow::ReshapeAttrs" +key = "reshape" + +[[values]] +type = "::FlexFlow::SplitAttrs" +key = "split" + +[[values]] +type = "::FlexFlow::SoftmaxAttrs" +key = "softmax" + +[[values]] +type = "::FlexFlow::TopKAttrs" +key = "topk" + +[[values]] +type = "::FlexFlow::TransposeAttrs" +key = "transpose" + +[[values]] +type = "::FlexFlow::WeightAttrs" +key = "weight" diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml deleted file mode 100644 index fdd11ac11f..0000000000 --- a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml +++ /dev/null @@ -1,163 +0,0 @@ -namespace = "FlexFlow" -name = "PCGOperatorAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/ops/attention_attrs.dtg.h", - "op-attrs/ops/batch_matmul_attrs.dtg.h", - "op-attrs/ops/batch_norm_attrs.dtg.h", - "op-attrs/ops/broadcast_attrs.dtg.h", - "op-attrs/ops/cast_attrs.dtg.h", - "op-attrs/ops/combine_attrs.dtg.h", - "op-attrs/ops/concat_attrs.dtg.h", - "op-attrs/ops/conv_2d_attrs.dtg.h", - "op-attrs/ops/dropout_attrs.dtg.h", - "op-attrs/ops/element_binary_attrs.dtg.h", - "op-attrs/ops/element_unary_attrs.dtg.h", - "op-attrs/ops/embedding_attrs.dtg.h", - "op-attrs/ops/flat_attrs.dtg.h", - "op-attrs/ops/gather_attrs.dtg.h", - "op-attrs/ops/input_attrs.dtg.h", - "op-attrs/ops/layer_norm_attrs.dtg.h", - "op-attrs/ops/linear_attrs.dtg.h", - "op-attrs/ops/noop_attrs.dtg.h", - "op-attrs/ops/pool_2d_attrs.dtg.h", - "op-attrs/ops/reduce_attrs.dtg.h", - "op-attrs/ops/reduction_attrs.dtg.h", - "op-attrs/ops/repartition_attrs.dtg.h", - "op-attrs/ops/replicate_attrs.dtg.h", - "op-attrs/ops/reshape_attrs.dtg.h", - "op-attrs/ops/reverse_attrs.dtg.h", - "op-attrs/ops/softmax_attrs.dtg.h", - "op-attrs/ops/split_attrs.dtg.h", - "op-attrs/ops/topk_attrs.dtg.h", - "op-attrs/ops/transpose_attrs.dtg.h", - "op-attrs/ops/weight_attrs.dtg.h", -] - -[[values]] -type = "::FlexFlow::BatchMatmulAttrs" -key = "batch_matmul" - -[[values]] -type = "::FlexFlow::BatchNormAttrs" -key = "batch_norm" - -[[values]] -type = "::FlexFlow::BroadcastAttrs" -key = "broadcast" - -[[values]] -type = "::FlexFlow::CastAttrs" -key = "cast" - -[[values]] -type = "::FlexFlow::CombineAttrs" -key = "combine_distributed" - -[[values]] -type = "::FlexFlow::ConcatAttrs" -key = "concat" - -[[values]] -type = "::FlexFlow::Conv2DAttrs" -key = "conv2d" - -[[values]] -type = "::FlexFlow::DropoutAttrs" -key = "dropout" - -[[values]] -type = "::FlexFlow::ElementBinaryAttrs" -key = "element_binary" - -[[values]] -type = "::FlexFlow::ElementUnaryAttrs" -key = "element_unary" - -[[values]] -type = "::FlexFlow::EmbeddingAttrs" -key = "embedding" - -[[values]] -type = "::FlexFlow::FlatAttrs" -key = "flat" - -[[values]] -type = "::FlexFlow::GatherAttrs" -key = "gather" - -[[values]] -type = "::FlexFlow::InputAttrs" -key = "input" - -[[values]] -type = "::FlexFlow::LayerNormAttrs" -key = "layer_norm" - -[[values]] -type = "::FlexFlow::LinearAttrs" -key = "linear" - -[[values]] -type = "::FlexFlow::MultiHeadAttentionAttrs" -key = "multi_head_attention" - -[[values]] -type = "::FlexFlow::NoopAttrs" -key = "noop" - -[[values]] -type = "::FlexFlow::Pool2DAttrs" -key = "pool2d" - -[[values]] -type = "::FlexFlow::ReduceAttrs" -key = "reduce" - -[[values]] -type = "::FlexFlow::ReductionAttrs" -key = "reduce_distributed" - -[[values]] -type = "::FlexFlow::RepartitionAttrs" -key = "partition_distributed" - -[[values]] -type = "::FlexFlow::ReplicateAttrs" -key = "replicate_distributed" - -[[values]] -type = "::FlexFlow::ReverseAttrs" -key = "reverse" - -[[values]] -type = "::FlexFlow::ReshapeAttrs" -key = "reshape" - -[[values]] -type = "::FlexFlow::SplitAttrs" -key = "split" - -[[values]] -type = "::FlexFlow::SoftmaxAttrs" -key = "softmax" - -[[values]] -type = "::FlexFlow::TopKAttrs" -key = "topk" - -[[values]] -type = "::FlexFlow::TransposeAttrs" -key = "transpose" - -[[values]] -type = "::FlexFlow::WeightAttrs" -key = "weight" diff --git a/lib/op-attrs/include/op-attrs/pool_op.dtg.toml b/lib/op-attrs/include/op-attrs/pool_op.dtg.toml new file mode 100644 index 0000000000..13c9dd8c06 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/pool_op.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "PoolOp" +type = "enum" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "MAX" + +[[values]] +name = "AVG" diff --git a/lib/op-attrs/include/op-attrs/pool_op.enum.toml b/lib/op-attrs/include/op-attrs/pool_op.enum.toml deleted file mode 100644 index 88f4dfea19..0000000000 --- a/lib/op-attrs/include/op-attrs/pool_op.enum.toml +++ /dev/null @@ -1,14 +0,0 @@ -namespace = "FlexFlow" -name = "PoolOp" -features = [ - "hash", - "json", - "rapidcheck", - "fmt", -] - -[[values]] -name = "MAX" - -[[values]] -name = "AVG" diff --git a/lib/op-attrs/include/op-attrs/regularizer_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/regularizer_attrs.dtg.toml new file mode 100644 index 0000000000..0a2e686103 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/regularizer_attrs.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "RegularizerAttrs" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/l1_regularizer_attrs.dtg.h", + "op-attrs/l2_regularizer_attrs.dtg.h", +] + +[[values]] +type = "::FlexFlow::L1RegularizerAttrs" +key = "l1" + +[[values]] +type = "::FlexFlow::L2RegularizerAttrs" +key = "l2" diff --git a/lib/op-attrs/include/op-attrs/regularizer_attrs.variant.toml b/lib/op-attrs/include/op-attrs/regularizer_attrs.variant.toml deleted file mode 100644 index d650c7f6a9..0000000000 --- a/lib/op-attrs/include/op-attrs/regularizer_attrs.variant.toml +++ /dev/null @@ -1,23 +0,0 @@ -namespace = "FlexFlow" -name = "RegularizerAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/l1_regularizer_attrs.dtg.h", - "op-attrs/l2_regularizer_attrs.dtg.h", -] - -[[values]] -type = "::FlexFlow::L1RegularizerAttrs" -key = "l1" - -[[values]] -type = "::FlexFlow::L2RegularizerAttrs" -key = "l2" diff --git a/lib/op-attrs/include/op-attrs/relative_ff_dim_t.dtg.toml b/lib/op-attrs/include/op-attrs/relative_ff_dim_t.dtg.toml new file mode 100644 index 0000000000..b8d20ac25b --- /dev/null +++ b/lib/op-attrs/include/op-attrs/relative_ff_dim_t.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "relative_ff_dim_t" +type = "struct" + +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +[[fields]] +name = "value" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/relative_ff_dim_t.h b/lib/op-attrs/include/op-attrs/relative_ff_dim_t.h index 5205b1ead8..9ca3ce0afb 100644 --- a/lib/op-attrs/include/op-attrs/relative_ff_dim_t.h +++ b/lib/op-attrs/include/op-attrs/relative_ff_dim_t.h @@ -2,12 +2,13 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_RELATIVE_FF_DIM_T_H #include "op-attrs/ff_dim_t.dtg.h" +#include "op-attrs/num_tensor_dims_t.h" #include "op-attrs/relative_ff_dim_t.dtg.h" -#include "rapidcheck.h" +#include namespace FlexFlow { ff_dim_t ff_dim_t_from_relative_ff_dim_t(relative_ff_dim_t ff_dim, - nonnegative_int input_dim); + num_tensor_dims_t input_dim); } // namespace FlexFlow namespace rc { diff --git a/lib/op-attrs/include/op-attrs/relative_ff_dim_t.struct.toml b/lib/op-attrs/include/op-attrs/relative_ff_dim_t.struct.toml deleted file mode 100644 index a93b649052..0000000000 --- a/lib/op-attrs/include/op-attrs/relative_ff_dim_t.struct.toml +++ /dev/null @@ -1,14 +0,0 @@ -namespace = "FlexFlow" -name = "relative_ff_dim_t" - -features = [ - "eq", - "ord", - "hash", - "json", - "fmt", -] - -[[fields]] -name = "value" -type = "int" diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim.dtg.toml b/lib/op-attrs/include/op-attrs/replica_parallel_dim.dtg.toml new file mode 100644 index 0000000000..5e7ec34b5f --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_parallel_dim.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "ReplicaParallelDim" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/replica_type.dtg.h", + "utils/positive_int/positive_int.h", +] + +[[fields]] +name = "degree" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "replica_type" +type = "::FlexFlow::ReplicaType" diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim.struct.toml b/lib/op-attrs/include/op-attrs/replica_parallel_dim.struct.toml deleted file mode 100644 index ac4c2563dc..0000000000 --- a/lib/op-attrs/include/op-attrs/replica_parallel_dim.struct.toml +++ /dev/null @@ -1,23 +0,0 @@ -namespace = "FlexFlow" -name = "ReplicaParallelDim" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/replica_type.dtg.h", - "utils/positive_int/positive_int.h", -] - -[[fields]] -name = "degree" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "replica_type" -type = "::FlexFlow::ReplicaType" diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.dtg.toml b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.dtg.toml new file mode 100644 index 0000000000..04a17e8290 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "ReplicaParallelDimSet" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_shape/sum_degree.dtg.h", + "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h", +] + +[[fields]] +name = "sum_degree" +type = "::FlexFlow::SumDegree" + +[[fields]] +name = "discard_copy_degree" +type = "::FlexFlow::DiscardCopyDegree" diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.struct.toml b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.struct.toml deleted file mode 100644 index 66f50bee9f..0000000000 --- a/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.struct.toml +++ /dev/null @@ -1,23 +0,0 @@ -namespace = "FlexFlow" -name = "ReplicaParallelDimSet" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/parallel_tensor_shape/sum_degree.dtg.h", - "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h", -] - -[[fields]] -name = "sum_degree" -type = "::FlexFlow::SumDegree" - -[[fields]] -name = "discard_copy_degree" -type = "::FlexFlow::DiscardCopyDegree" diff --git a/lib/op-attrs/include/op-attrs/replica_type.dtg.toml b/lib/op-attrs/include/op-attrs/replica_type.dtg.toml new file mode 100644 index 0000000000..aa23feb774 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_type.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "ReplicaType" +type = "enum" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "DISCARD_COPY" + +[[values]] +name = "SUM" diff --git a/lib/op-attrs/include/op-attrs/replica_type.enum.toml b/lib/op-attrs/include/op-attrs/replica_type.enum.toml deleted file mode 100644 index 0c0eb5e3ab..0000000000 --- a/lib/op-attrs/include/op-attrs/replica_type.enum.toml +++ /dev/null @@ -1,14 +0,0 @@ -namespace = "FlexFlow" -name = "ReplicaType" -features = [ - "hash", - "json", - "rapidcheck", - "fmt", -] - -[[values]] -name = "SUM" - -[[values]] -name = "DISCARD_COPY" diff --git a/lib/op-attrs/include/op-attrs/shape_inference.h b/lib/op-attrs/include/op-attrs/shape_inference.h index 8c679f442a..14184fac92 100644 --- a/lib/op-attrs/include/op-attrs/shape_inference.h +++ b/lib/op-attrs/include/op-attrs/shape_inference.h @@ -1,28 +1,31 @@ -#ifndef _FLEXFLOW_INCLUDE_OP_ATTRS_SHAPE_INFERENCE_H -#define _FLEXFLOW_INCLUDE_OP_ATTRS_SHAPE_INFERENCE_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_SHAPE_INFERENCE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_SHAPE_INFERENCE_H #include "op-attrs/computation_graph_op_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/pcg_operator_attrs.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" #include namespace FlexFlow { -std::vector - get_output_shapes(ComputationGraphOpAttrs const &, - std::vector const &input_shapes); +std::unordered_map get_output_shapes( + ComputationGraphOpAttrs const &, + std::unordered_map const &input_shapes); -std::vector - get_weight_shapes(ComputationGraphOpAttrs const &, - std::vector const &input_shapes); +std::unordered_map get_weight_shapes( + ComputationGraphOpAttrs const &, + std::unordered_map const &input_shapes); -std::vector - get_output_shapes(PCGOperatorAttrs const &, - std::vector const &input_shapes); +std::unordered_map get_output_shapes( + PCGOperatorAttrs const &, + std::unordered_map const + &input_shapes); -std::vector - get_weight_shapes(PCGOperatorAttrs const &, - std::vector const &input_shapes); +std::unordered_map get_weight_shapes( + PCGOperatorAttrs const &, + std::unordered_map const + &input_shapes); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/shard_parallel_dim.dtg.toml b/lib/op-attrs/include/op-attrs/shard_parallel_dim.dtg.toml new file mode 100644 index 0000000000..8c98611cfe --- /dev/null +++ b/lib/op-attrs/include/op-attrs/shard_parallel_dim.dtg.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "ShardParallelDim" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "utils/positive_int/positive_int.h", +] + +[[fields]] +name = "size" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "degree" +type = "::FlexFlow::positive_int" diff --git a/lib/op-attrs/include/op-attrs/shard_parallel_dim.struct.toml b/lib/op-attrs/include/op-attrs/shard_parallel_dim.struct.toml deleted file mode 100644 index a11897070f..0000000000 --- a/lib/op-attrs/include/op-attrs/shard_parallel_dim.struct.toml +++ /dev/null @@ -1,22 +0,0 @@ -namespace = "FlexFlow" -name = "ShardParallelDim" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "utils/positive_int/positive_int.h", -] - -[[fields]] -name = "size" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "degree" -type = "::FlexFlow::positive_int" diff --git a/lib/op-attrs/include/op-attrs/task_space_coordinate.dtg.toml b/lib/op-attrs/include/op-attrs/task_space_coordinate.dtg.toml new file mode 100644 index 0000000000..5ca3bfe809 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/task_space_coordinate.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "TaskSpaceCoordinate" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "utils/orthotope/orthotope_coord.dtg.h", +] + +[[fields]] +name = "orthotope_coord" +type = "::FlexFlow::OrthotopeCoord" diff --git a/lib/op-attrs/include/op-attrs/task_space_coordinate.h b/lib/op-attrs/include/op-attrs/task_space_coordinate.h new file mode 100644 index 0000000000..6337747675 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/task_space_coordinate.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TASK_SPACE_COORDINATE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TASK_SPACE_COORDINATE_H + +#include "op-attrs/operator_task_space_dim_idx_t.dtg.h" +#include "op-attrs/task_space_coordinate.dtg.h" +#include "utils/orthotope/dim_coord.dtg.h" + +namespace FlexFlow { + +nonnegative_int task_space_coord_num_dims(TaskSpaceCoordinate const &); + +TaskSpaceCoordinate + make_task_space_coordinate(std::vector const &); + +TaskSpaceCoordinate task_space_coordinate_from_dim_coord( + DimCoord const &); + +DimCoord + dim_coord_from_task_space_coordinate(TaskSpaceCoordinate const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/tensor_dim_permutation.h b/lib/op-attrs/include/op-attrs/tensor_dim_permutation.h new file mode 100644 index 0000000000..cc85b7f746 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/tensor_dim_permutation.h @@ -0,0 +1,100 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIM_PERMUTATION_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIM_PERMUTATION_H + +#include "op-attrs/ff_dim_t.dtg.h" +#include "op-attrs/ff_ordered/ff_ordered.h" +#include "op-attrs/num_tensor_dims_t.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_dims.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" +#include "utils/bidict/bidict.h" + +namespace FlexFlow { + +struct TensorDimPermutation { + TensorDimPermutation() = delete; + + TensorDimPermutation(bidict const &); + + bool operator==(TensorDimPermutation const &) const; + bool operator!=(TensorDimPermutation const &) const; + + bool operator<(TensorDimPermutation const &) const; + bool operator>(TensorDimPermutation const &) const; + bool operator<=(TensorDimPermutation const &) const; + bool operator>=(TensorDimPermutation const &) const; + + ff_dim_t at_l(ff_dim_t) const; + ff_dim_t at_r(ff_dim_t) const; + + num_tensor_dims_t num_tensor_dims() const; + + bidict const &as_bidict() const; + +private: + bidict raw; + +private: + std::tuple tie() const; + + friend struct std::hash; +}; + +bidict format_as(TensorDimPermutation const &); +std::ostream &operator<<(std::ostream &, TensorDimPermutation const &); + +TensorDimPermutation + compose_tensor_dim_permutations(TensorDimPermutation const &, + TensorDimPermutation const &); + +TensorDimPermutation + invert_tensor_dim_permutation(TensorDimPermutation const &); + +TensorDims permute_tensor_dims(TensorDimPermutation const &, + TensorDims const &); + +TensorShape permute_tensor_shape(TensorDimPermutation const &, + TensorShape const &); + +ParallelTensorDimDegrees + permute_parallel_tensor_dim_degrees(TensorDimPermutation const &, + ParallelTensorDimDegrees const &); + +ParallelTensorDims permute_parallel_tensor_dims(TensorDimPermutation const &, + ParallelTensorDims const &); + +ParallelTensorShape permute_parallel_tensor_shape(TensorDimPermutation const &, + ParallelTensorShape const &); + +} // namespace FlexFlow + +namespace nlohmann { + +template <> +struct adl_serializer<::FlexFlow::TensorDimPermutation> { + static ::FlexFlow::TensorDimPermutation from_json(json const &); + static void to_json(json &, ::FlexFlow::TensorDimPermutation const &); +}; + +} // namespace nlohmann + +namespace rc { + +template <> +struct Arbitrary<::FlexFlow::TensorDimPermutation> { + static Gen<::FlexFlow::TensorDimPermutation> arbitrary(); +}; + +} // namespace rc + +namespace std { + +template <> +struct hash<::FlexFlow::TensorDimPermutation> { + size_t operator()(::FlexFlow::TensorDimPermutation const &) const; +}; + +} // namespace std + +#endif diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.dtg.toml b/lib/op-attrs/include/op-attrs/tensor_dims.dtg.toml new file mode 100644 index 0000000000..73656fe3c4 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/tensor_dims.dtg.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "TensorDims" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_ordered/ff_ordered.h", + "utils/positive_int/positive_int.h", +] + +[[fields]] +name = "ff_ordered" +type = "::FlexFlow::FFOrdered<::FlexFlow::positive_int>" diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.h b/lib/op-attrs/include/op-attrs/tensor_dims.h index 0f5b987944..e0c8aa2dc6 100644 --- a/lib/op-attrs/include/op-attrs/tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/tensor_dims.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_H +#include "op-attrs/num_tensor_dims_t.h" #include "op-attrs/parallel_tensor_dims.dtg.h" #include "op-attrs/tensor_dims.dtg.h" #include "op-attrs/tensor_dims_coord.dtg.h" @@ -12,7 +13,7 @@ FFOrdered const &ff_ordered(TensorDims const &); bool tensor_dims_has_dim(TensorDims const &, ff_dim_t); -nonnegative_int get_num_dims(TensorDims const &); +num_tensor_dims_t get_num_dims(TensorDims const &); positive_int dim_at_idx(TensorDims const &, relative_ff_dim_t); positive_int &dim_at_idx(TensorDims &, relative_ff_dim_t); diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.struct.toml b/lib/op-attrs/include/op-attrs/tensor_dims.struct.toml deleted file mode 100644 index a1039798c9..0000000000 --- a/lib/op-attrs/include/op-attrs/tensor_dims.struct.toml +++ /dev/null @@ -1,19 +0,0 @@ -namespace = "FlexFlow" -name = "TensorDims" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/ff_ordered/ff_ordered.h", - "utils/positive_int/positive_int.h", -] - -[[fields]] -name = "ff_ordered" -type = "::FlexFlow::FFOrdered<::FlexFlow::positive_int>" diff --git a/lib/op-attrs/include/op-attrs/tensor_dims_coord.dtg.toml b/lib/op-attrs/include/op-attrs/tensor_dims_coord.dtg.toml new file mode 100644 index 0000000000..64e39aa21a --- /dev/null +++ b/lib/op-attrs/include/op-attrs/tensor_dims_coord.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "TensorDimsCoord" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "op-attrs/ff_ordered/ff_ordered.h", + "utils/nonnegative_int/nonnegative_int.h", +] + +[[fields]] +name = "ff_ordered" +type = "::FlexFlow::FFOrdered<::FlexFlow::nonnegative_int>" diff --git a/lib/op-attrs/include/op-attrs/tensor_dims_coord.struct.toml b/lib/op-attrs/include/op-attrs/tensor_dims_coord.struct.toml deleted file mode 100644 index 53f4405389..0000000000 --- a/lib/op-attrs/include/op-attrs/tensor_dims_coord.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "TensorDimsCoord" -features = [ - "eq", - "ord", - "hash", - "json", - "fmt", -] - -includes = [ - "op-attrs/ff_ordered/ff_ordered.h", - "utils/nonnegative_int/nonnegative_int.h", -] - -[[fields]] -name = "ff_ordered" -type = "::FlexFlow::FFOrdered<::FlexFlow::nonnegative_int>" diff --git a/lib/op-attrs/include/op-attrs/tensor_role.dtg.toml b/lib/op-attrs/include/op-attrs/tensor_role.dtg.toml new file mode 100644 index 0000000000..e95ab1e280 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/tensor_role.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "TensorRole" +type = "enum" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "INPUT" + +[[values]] +name = "WEIGHT" + +[[values]] +name = "OUTPUT" diff --git a/lib/op-attrs/include/op-attrs/tensor_shape.dtg.toml b/lib/op-attrs/include/op-attrs/tensor_shape.dtg.toml new file mode 100644 index 0000000000..0f18c56da5 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/tensor_shape.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "TensorShape" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/tensor_dims.dtg.h", + "op-attrs/datatype.dtg.h", +] + +[[fields]] +name = "dims" +type = "::FlexFlow::TensorDims" + +[[fields]] +name = "data_type" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/tensor_shape.struct.toml b/lib/op-attrs/include/op-attrs/tensor_shape.struct.toml deleted file mode 100644 index 901c3b9e60..0000000000 --- a/lib/op-attrs/include/op-attrs/tensor_shape.struct.toml +++ /dev/null @@ -1,23 +0,0 @@ -namespace = "FlexFlow" -name = "TensorShape" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/tensor_dims.dtg.h", - "op-attrs/datatype.dtg.h", -] - -[[fields]] -name = "dims" -type = "::FlexFlow::TensorDims" - -[[fields]] -name = "data_type" -type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/tensor_slot_name.dtg.toml b/lib/op-attrs/include/op-attrs/tensor_slot_name.dtg.toml new file mode 100644 index 0000000000..2021305dbd --- /dev/null +++ b/lib/op-attrs/include/op-attrs/tensor_slot_name.dtg.toml @@ -0,0 +1,102 @@ +namespace = "FlexFlow" +name = "TensorSlotName" +type = "enum" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "INPUT" + +[[values]] +name = "OUTPUT" + +[[values]] +name = "FILTER" + +[[values]] +name = "BIAS" + +[[values]] +name = "INPUT_BIAS" + +[[values]] +name = "OUTPUT_BIAS" + +[[values]] +name = "LHS_INPUT" + +[[values]] +name = "RHS_INPUT" + +[[values]] +name = "INDEX" + +[[values]] +name = "GAMMA" + +[[values]] +name = "BETA" + +[[values]] +name = "WEIGHT" + +[[values]] +name = "QUERY" + +[[values]] +name = "KEY" + +[[values]] +name = "VALUE" + +[[values]] +name = "LOGIT" + +[[values]] +name = "INPUT_0" + +[[values]] +name = "INPUT_1" + +[[values]] +name = "INPUT_2" + +[[values]] +name = "INPUT_3" + +[[values]] +name = "INPUT_4" + +[[values]] +name = "INPUT_5" + +[[values]] +name = "INPUT_6" + +[[values]] +name = "INPUT_7" + +[[values]] +name = "OUTPUT_0" + +[[values]] +name = "OUTPUT_1" + +[[values]] +name = "OUTPUT_2" + +[[values]] +name = "OUTPUT_3" + +[[values]] +name = "OUTPUT_4" + +[[values]] +name = "OUTPUT_5" + +[[values]] +name = "SCALE" diff --git a/lib/op-attrs/include/op-attrs/tensor_slot_name.h b/lib/op-attrs/include/op-attrs/tensor_slot_name.h new file mode 100644 index 0000000000..0122bbc90f --- /dev/null +++ b/lib/op-attrs/include/op-attrs/tensor_slot_name.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_SLOT_NAME_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_SLOT_NAME_H + +#include "op-attrs/tensor_slot_name.dtg.h" + +namespace FlexFlow { + +std::vector get_variadic_inputs_slot_name_sequence(); +std::vector get_variadic_outputs_slot_name_sequence(); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc index f65a0f5f08..9d1a9f68d4 100644 --- a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc +++ b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc @@ -31,20 +31,14 @@ RecordFormatter as_dot(ComputationGraphOpAttrs const &attrs) { return result; } -ComputationGraphOpAttrs +std::optional compgraph_op_attrs_from_pcg_op_attrs(PCGOperatorAttrs const &op) { - auto fail_on_parallel_op = [](auto const &attrs) -> ComputationGraphOpAttrs { - throw mk_runtime_error( - fmt::format("Encountered parallel operator in " - "compgraph_op_attrs_from_pcg_op_attrs: {}", - attrs)); - }; - - return op.visit(overload{ - [&](CombineAttrs const &attrs) { return fail_on_parallel_op(attrs); }, - [&](ReductionAttrs const &attrs) { return fail_on_parallel_op(attrs); }, - [&](RepartitionAttrs const &attrs) { return fail_on_parallel_op(attrs); }, - [&](ReplicateAttrs const &attrs) { return fail_on_parallel_op(attrs); }, + + return op.visit>(overload{ + [&](CombineAttrs const &attrs) { return std::nullopt; }, + [&](ReductionAttrs const &attrs) { return std::nullopt; }, + [&](RepartitionAttrs const &attrs) { return std::nullopt; }, + [&](ReplicateAttrs const &attrs) { return std::nullopt; }, [](auto const &attrs) { return ComputationGraphOpAttrs{attrs}; }, }); } diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/slice.cc b/lib/op-attrs/src/op-attrs/dim_ordered/slice.cc deleted file mode 100644 index 8c3dbd7bbc..0000000000 --- a/lib/op-attrs/src/op-attrs/dim_ordered/slice.cc +++ /dev/null @@ -1 +0,0 @@ -#include "op-attrs/dim_ordered/slice.h" diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/transform.cc b/lib/op-attrs/src/op-attrs/dim_ordered/transform.cc deleted file mode 100644 index 73683eba94..0000000000 --- a/lib/op-attrs/src/op-attrs/dim_ordered/transform.cc +++ /dev/null @@ -1 +0,0 @@ -#include "op-attrs/dim_ordered/transform.h" diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/zip.cc b/lib/op-attrs/src/op-attrs/dim_ordered/zip.cc deleted file mode 100644 index 208fc4a719..0000000000 --- a/lib/op-attrs/src/op-attrs/dim_ordered/zip.cc +++ /dev/null @@ -1 +0,0 @@ -#include "op-attrs/dim_ordered/zip.h" diff --git a/lib/op-attrs/src/op-attrs/ff_dim_t.cc b/lib/op-attrs/src/op-attrs/ff_dim_t.cc index 63c783d909..0d3de74735 100644 --- a/lib/op-attrs/src/op-attrs/ff_dim_t.cc +++ b/lib/op-attrs/src/op-attrs/ff_dim_t.cc @@ -1,4 +1,6 @@ #include "op-attrs/ff_dim_t.h" +#include "utils/containers/transform.h" +#include "utils/nonnegative_int/nonnegative_range.h" namespace FlexFlow { @@ -11,6 +13,11 @@ ff_dim_t add_to_ff_dim(ff_dim_t ff_dim, int value) { return ff_dim_t{nonnegative_int{ff_dim.value.unwrap_nonnegative() + value}}; } +std::vector ff_dim_range(nonnegative_int num_elements) { + return transform(nonnegative_range(num_elements), + [](nonnegative_int idx) { return ff_dim_t{idx}; }); +} + } // namespace FlexFlow namespace rc { diff --git a/lib/op-attrs/src/op-attrs/ff_ordered/concat.cc b/lib/op-attrs/src/op-attrs/ff_ordered/concat.cc new file mode 100644 index 0000000000..3fa2eb8dba --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ff_ordered/concat.cc @@ -0,0 +1,12 @@ +#include "op-attrs/ff_ordered/concat.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T = value_type<0>; + +template FFOrdered concat(FFOrdered const &, FFOrdered const &); + +template FFOrdered concat(std::vector> const &); + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ff_ordered/ff_ordered_of.cc b/lib/op-attrs/src/op-attrs/ff_ordered/ff_ordered_of.cc new file mode 100644 index 0000000000..0e0e8711d6 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ff_ordered/ff_ordered_of.cc @@ -0,0 +1,12 @@ +#include "op-attrs/ff_ordered/ff_ordered_of.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T = value_type<0>; + +template FFOrdered ff_ordered_of(std::vector const &); + +template FFOrdered ff_ordered_of(std::unordered_set const &); + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ff_ordered/get_idxs.cc b/lib/op-attrs/src/op-attrs/ff_ordered/get_idxs.cc index 3da15bebba..7b93643735 100644 --- a/lib/op-attrs/src/op-attrs/ff_ordered/get_idxs.cc +++ b/lib/op-attrs/src/op-attrs/ff_ordered/get_idxs.cc @@ -5,6 +5,6 @@ namespace FlexFlow { using T = value_type<0>; -template std::vector get_idxs(FFOrdered const &); +template std::set get_idxs(FFOrdered const &); } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ff_ordered/map_from_ff_ordered.cc b/lib/op-attrs/src/op-attrs/ff_ordered/map_from_ff_ordered.cc new file mode 100644 index 0000000000..8c4e6c2e37 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ff_ordered/map_from_ff_ordered.cc @@ -0,0 +1,11 @@ +#include "op-attrs/ff_ordered/map_from_ff_ordered.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T = value_type<0>; + +template std::unordered_map + map_from_ff_ordered(FFOrdered const &); + +} // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ff_ordered/zip_with.cc b/lib/op-attrs/src/op-attrs/ff_ordered/zip_with.cc similarity index 100% rename from lib/op-attrs/include/op-attrs/ff_ordered/zip_with.cc rename to lib/op-attrs/src/op-attrs/ff_ordered/zip_with.cc diff --git a/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc b/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc index 21efc26466..eec9ae869c 100644 --- a/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc +++ b/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc @@ -5,100 +5,163 @@ #include "op-attrs/ops/layer_norm.h" #include "op-attrs/ops/linear.h" #include "op-attrs/pcg_operator_attrs.h" +#include "op-attrs/tensor_slot_name.h" #include "utils/overload.h" namespace FlexFlow { -std::vector get_incoming_tensor_roles( - ComputationGraphOpAttrs const &comp_graph_op_attrs, int num_incoming) { +std::unordered_map + get_incoming_tensor_roles( + ComputationGraphOpAttrs const &comp_graph_op_attrs) { return get_incoming_tensor_roles( - pcg_op_attrs_from_compgraph_op_attrs(comp_graph_op_attrs), num_incoming); + pcg_op_attrs_from_compgraph_op_attrs(comp_graph_op_attrs)); } -std::vector - get_incoming_tensor_roles(PCGOperatorAttrs const &pcg_op_attrs, - int num_incoming) { - return pcg_op_attrs.visit>(overload{ - [](BatchMatmulAttrs const &) { - return std::vector{IncomingTensorRole::INPUT, - IncomingTensorRole::INPUT}; - }, - [](BatchNormAttrs const &attrs) { - return get_batch_norm_incoming_tensor_roles(attrs); - }, - [](BroadcastAttrs const &) { - return std::vector{IncomingTensorRole::INPUT}; - }, - [](CastAttrs const &) { return std::vector{IncomingTensorRole::INPUT}; }, - [](CombineAttrs const &) { - return std::vector{IncomingTensorRole::INPUT}; - }, - [&](ConcatAttrs const &) { - return std::vector(num_incoming, IncomingTensorRole::INPUT); - }, - [](Conv2DAttrs const &attrs) { - return get_conv2d_incoming_tensor_roles(attrs); - }, - [](DropoutAttrs const &) { - return std::vector{IncomingTensorRole::INPUT}; - }, - [](ElementBinaryAttrs const &) { - return std::vector{IncomingTensorRole::INPUT, - IncomingTensorRole::INPUT}; - }, - [](ElementUnaryAttrs const &) { - return std::vector{IncomingTensorRole::INPUT}; - }, - [](EmbeddingAttrs const &) { - return std::vector{IncomingTensorRole::INPUT, - IncomingTensorRole::WEIGHT}; - }, - [](FlatAttrs const &) { return std::vector{IncomingTensorRole::INPUT}; }, - [](GatherAttrs const &) { - return std::vector{IncomingTensorRole::INPUT}; - }, - [](InputAttrs const &) { return std::vector{}; }, - [](LayerNormAttrs const &attrs) { - return get_layer_norm_incoming_tensor_roles(attrs); - }, - [](LinearAttrs const &attrs) { - return get_linear_incoming_tensor_roles(attrs); - }, - [](MultiHeadAttentionAttrs const &attrs) { - return get_attention_incoming_tensor_roles(attrs); - }, - [](NoopAttrs const &) { return std::vector{IncomingTensorRole::INPUT}; }, - [](Pool2DAttrs const &) { - return std::vector{IncomingTensorRole::INPUT}; - }, - [](ReduceAttrs const &) { - return std::vector{IncomingTensorRole::INPUT}; - }, - [](ReductionAttrs const &) { - return std::vector{IncomingTensorRole::INPUT}; - }, - [](RepartitionAttrs const &) { - return std::vector{IncomingTensorRole::INPUT}; - }, - [](ReplicateAttrs const &) { - return std::vector{IncomingTensorRole::INPUT}; - }, - [](ReverseAttrs const &) { - return std::vector{IncomingTensorRole::INPUT}; - }, - [](ReshapeAttrs const &) { - return std::vector{IncomingTensorRole::INPUT}; - }, - [](SplitAttrs const &) { return std::vector{IncomingTensorRole::INPUT}; }, - [](SoftmaxAttrs const &) { - return std::vector{IncomingTensorRole::INPUT}; - }, - [](TopKAttrs const &) { return std::vector{IncomingTensorRole::INPUT}; }, - [](TransposeAttrs const &) { - return std::vector{IncomingTensorRole::INPUT}; - }, - [](WeightAttrs const &) { return std::vector{}; }, - }); +std::unordered_map + get_incoming_tensor_roles(PCGOperatorAttrs const &pcg_op_attrs) { + return pcg_op_attrs + .visit>(overload{ + [](BatchMatmulAttrs const &) { + return std::unordered_map{ + {TensorSlotName::LHS_INPUT, IncomingTensorRole::INPUT}, + {TensorSlotName::RHS_INPUT, IncomingTensorRole::INPUT}, + }; + }, + [](BatchNormAttrs const &attrs) { + return get_batch_norm_incoming_tensor_roles(attrs); + }, + [](BroadcastAttrs const &) { + return std::unordered_map{ + {TensorSlotName::INPUT, IncomingTensorRole::INPUT}, + }; + }, + [](CastAttrs const &) { + return std::unordered_map{ + {TensorSlotName::INPUT, IncomingTensorRole::INPUT}, + }; + }, + [](CombineAttrs const &) { + return std::unordered_map{ + {TensorSlotName::INPUT, IncomingTensorRole::INPUT}, + }; + }, + [&](ConcatAttrs const &) { + return generate_map(get_variadic_inputs_slot_name_sequence(), + [](TensorSlotName) -> IncomingTensorRole { + return IncomingTensorRole::INPUT; + }); + }, + [](Conv2DAttrs const &attrs) { + return get_conv2d_incoming_tensor_roles(attrs); + }, + [](DropoutAttrs const &) { + return std::unordered_map{ + {TensorSlotName::INPUT, IncomingTensorRole::INPUT}, + }; + }, + [](ElementBinaryAttrs const &) { + return std::unordered_map{ + {TensorSlotName::LHS_INPUT, IncomingTensorRole::INPUT}, + {TensorSlotName::RHS_INPUT, IncomingTensorRole::INPUT}, + }; + }, + [](ElementUnaryAttrs const &) { + return std::unordered_map{ + {TensorSlotName::INPUT, IncomingTensorRole::INPUT}, + }; + }, + [](EmbeddingAttrs const &) { + return std::unordered_map{ + {TensorSlotName::INPUT, IncomingTensorRole::INPUT}, + {TensorSlotName::WEIGHT, IncomingTensorRole::WEIGHT}, + }; + }, + [](FlatAttrs const &) { + return std::unordered_map{ + {TensorSlotName::INPUT, IncomingTensorRole::INPUT}, + }; + }, + [](GatherAttrs const &) { + return std::unordered_map{ + {TensorSlotName::INPUT, IncomingTensorRole::INPUT}, + }; + }, + [](InputAttrs const &) { + return std::unordered_map{}; + }, + [](LayerNormAttrs const &attrs) { + return get_layer_norm_incoming_tensor_roles(attrs); + }, + [](LinearAttrs const &attrs) { + return get_linear_incoming_tensor_roles(attrs); + }, + [](MultiHeadAttentionAttrs const &attrs) { + return get_attention_incoming_tensor_roles(attrs); + }, + [](NoopAttrs const &) { + return std::unordered_map{ + {TensorSlotName::INPUT, IncomingTensorRole::INPUT}, + }; + }, + [](Pool2DAttrs const &) { + return std::unordered_map{ + {TensorSlotName::INPUT, IncomingTensorRole::INPUT}, + }; + }, + [](ReduceAttrs const &) { + return std::unordered_map{ + {TensorSlotName::INPUT, IncomingTensorRole::INPUT}, + }; + }, + [](ReductionAttrs const &) { + return std::unordered_map{ + {TensorSlotName::INPUT, IncomingTensorRole::INPUT}, + }; + }, + [](RepartitionAttrs const &) { + return std::unordered_map{ + {TensorSlotName::INPUT, IncomingTensorRole::INPUT}, + }; + }, + [](ReplicateAttrs const &) { + return std::unordered_map{ + {TensorSlotName::INPUT, IncomingTensorRole::INPUT}, + }; + }, + [](ReverseAttrs const &) { + return std::unordered_map{ + {TensorSlotName::INPUT, IncomingTensorRole::INPUT}, + }; + }, + [](ReshapeAttrs const &) { + return std::unordered_map{ + {TensorSlotName::INPUT, IncomingTensorRole::INPUT}, + }; + }, + [](SplitAttrs const &) { + return std::unordered_map{ + {TensorSlotName::INPUT, IncomingTensorRole::INPUT}, + }; + }, + [](SoftmaxAttrs const &) { + return std::unordered_map{ + {TensorSlotName::INPUT, IncomingTensorRole::INPUT}, + }; + }, + [](TopKAttrs const &) { + return std::unordered_map{ + {TensorSlotName::INPUT, IncomingTensorRole::INPUT}, + }; + }, + [](TransposeAttrs const &) { + return std::unordered_map{ + {TensorSlotName::INPUT, IncomingTensorRole::INPUT}, + }; + }, + [](WeightAttrs const &) { + return std::unordered_map{}; + }, + }); } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/get_operator_space_to_parallel_tensor_space_mappings.cc b/lib/op-attrs/src/op-attrs/get_operator_space_to_parallel_tensor_space_mappings.cc new file mode 100644 index 0000000000..618eb533ff --- /dev/null +++ b/lib/op-attrs/src/op-attrs/get_operator_space_to_parallel_tensor_space_mappings.cc @@ -0,0 +1,291 @@ +#include "op-attrs/get_operator_space_to_parallel_tensor_space_mappings.h" +#include "op-attrs/get_incoming_tensor_roles.h" +#include "op-attrs/ops/element_binary.h" +#include "op-attrs/ops/element_unary.h" +#include "op-attrs/ops/input.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/transpose.h" +#include "op-attrs/ops/weight.h" +#include "utils/containers/filtrans.h" +#include "utils/containers/get_only.h" +#include "utils/containers/merge_disjoint_maps.h" +#include "utils/containers/require_only_key.h" +#include "utils/containers/require_two_keys.h" +#include "utils/containers/zip_values_strict.h" +#include "utils/overload.h" + +namespace FlexFlow { + +std::unordered_map + get_operator_to_incoming_mappings( + ComputationGraphOpAttrs const &comp_graph_op_attrs, + std::unordered_map const + &inputs_degrees) { + return comp_graph_op_attrs.visit< + std::unordered_map>(overload{ + [&](ElementBinaryAttrs const &attrs) + -> std::unordered_map { + ASSERT(inputs_degrees.size() == 2); + + ParallelTensorDimDegrees lhs_degrees = + inputs_degrees.at(TensorSlotName::LHS_INPUT); + ParallelTensorDimDegrees rhs_degrees = + inputs_degrees.at(TensorSlotName::RHS_INPUT); + + return { + { + TensorSlotName::LHS_INPUT, + get_operator_to_lhs_input_mapping( + attrs, lhs_degrees, rhs_degrees), + }, + { + TensorSlotName::RHS_INPUT, + get_operator_to_rhs_input_mapping( + attrs, lhs_degrees, rhs_degrees), + }, + }; + }, + [&](ElementUnaryAttrs const &attrs) + -> std::unordered_map { + ParallelTensorDimDegrees input_degrees = + require_only_key(inputs_degrees, TensorSlotName::INPUT); + + return { + { + TensorSlotName::INPUT, + get_operator_to_input_mapping(attrs, input_degrees), + }, + }; + }, + [&](InputAttrs const &) { + ASSERT(inputs_degrees.size() == 0); + + return std::unordered_map{}; + }, + [&](LinearAttrs const &attrs) + -> std::unordered_map { + ParallelTensorDimDegrees input_degrees = + require_only_key(inputs_degrees, TensorSlotName::INPUT); + + std::unordered_map + result = { + {TensorSlotName::INPUT, + get_operator_to_input_mapping(attrs, input_degrees)}, + {TensorSlotName::WEIGHT, + get_operator_to_projection_mapping(attrs, input_degrees)}, + }; + + if (attrs.use_bias) { + result.insert({TensorSlotName::BIAS, + get_operator_to_bias_mapping(attrs, input_degrees)}); + }; + + return result; + }, + [&](TransposeAttrs const &attrs) + -> std::unordered_map { + ParallelTensorDimDegrees input_degrees = + require_only_key(inputs_degrees, TensorSlotName::INPUT); + + return { + { + TensorSlotName::INPUT, + get_operator_to_input_mapping(attrs, input_degrees), + }, + }; + }, + [&](WeightAttrs const &) { + ASSERT(inputs_degrees.size() == 0); + + return std::unordered_map{}; + }, + [](auto const &attrs) + -> std::unordered_map { + PANIC("Missing implmentation of get_operator_to_input_mappings", attrs); + }, + }); +} + +std::unordered_map + get_operator_to_incoming_mappings_for_role( + ComputationGraphOpAttrs const &attrs, + std::unordered_map const + &inputs_degrees, + IncomingTensorRole incoming_tensor_role) { + + std::unordered_map + incoming_mappings = + get_operator_to_incoming_mappings(attrs, inputs_degrees); + + std::unordered_map incoming_tensor_roles = + get_incoming_tensor_roles(attrs); + + return filtermap_values( + zip_values_strict(incoming_mappings, incoming_tensor_roles), + [&](std::pair const &p) + -> std::optional { + auto const &[mapping, role] = p; + + if (role == incoming_tensor_role) { + return mapping; + } else { + return std::nullopt; + } + }); +} + +std::unordered_map + get_operator_to_input_mappings( + ComputationGraphOpAttrs const &attrs, + std::unordered_map const + &inputs_degrees) { + return get_operator_to_incoming_mappings_for_role( + attrs, inputs_degrees, IncomingTensorRole::INPUT); +} + +std::unordered_map + get_operator_to_weight_mappings( + ComputationGraphOpAttrs const &attrs, + std::unordered_map const + &inputs_degrees) { + + return get_operator_to_incoming_mappings_for_role( + attrs, inputs_degrees, IncomingTensorRole::WEIGHT); +} + +std::unordered_map + get_operator_to_output_mappings( + ComputationGraphOpAttrs const &comp_graph_op_attrs, + std::unordered_map const + &inputs_degrees) { + + return comp_graph_op_attrs.visit< + std::unordered_map>(overload{ + [&](ElementBinaryAttrs const &attrs) + -> std::unordered_map { + auto [lhs_degrees, rhs_degrees] = + require_two_keys(inputs_degrees, + TensorSlotName::LHS_INPUT, + TensorSlotName::RHS_INPUT); + + return { + { + TensorSlotName::OUTPUT, + get_operator_to_output_mapping(attrs, lhs_degrees, rhs_degrees), + }, + }; + }, + [&](ElementUnaryAttrs const &attrs) + -> std::unordered_map { + ParallelTensorDimDegrees input_degrees = + require_only_key(inputs_degrees, TensorSlotName::INPUT); + + return { + { + TensorSlotName::OUTPUT, + get_operator_to_output_mapping(attrs, input_degrees), + }, + }; + }, + [&](LinearAttrs const &attrs) + -> std::unordered_map { + ParallelTensorDimDegrees input_degrees = + require_only_key(inputs_degrees, TensorSlotName::INPUT); + + return { + { + TensorSlotName::OUTPUT, + get_operator_to_output_mapping(attrs, input_degrees), + }, + }; + }, + [&](InputAttrs const &attrs) + -> std::unordered_map { + ASSERT(inputs_degrees.size() == 0); + + return { + { + TensorSlotName::OUTPUT, + get_operator_to_output_mapping(attrs), + }, + }; + }, + [&](TransposeAttrs const &attrs) + -> std::unordered_map { + ParallelTensorDimDegrees input_degrees = + require_only_key(inputs_degrees, TensorSlotName::INPUT); + + return { + { + TensorSlotName::OUTPUT, + get_operator_to_output_mapping(attrs, input_degrees), + }, + }; + }, + [&](WeightAttrs const &attrs) + -> std::unordered_map { + ASSERT(inputs_degrees.size() == 0); + + return { + { + TensorSlotName::OUTPUT, + get_operator_to_output_mapping(attrs), + }, + }; + }, + [](auto const &attrs) + -> std::unordered_map { + PANIC("Missing implmentation of get_operator_to_input_mappings", attrs); + }, + }); +} + +std::unordered_map + get_operator_to_ptensor_mappings_for_role( + ComputationGraphOpAttrs const &attrs, + std::unordered_map const + &inputs_degrees, + TensorRole role) { + switch (role) { + case TensorRole::INPUT: + return get_operator_to_input_mappings(attrs, inputs_degrees); + case TensorRole::WEIGHT: + return get_operator_to_weight_mappings(attrs, inputs_degrees); + case TensorRole::OUTPUT: + return get_operator_to_output_mappings(attrs, inputs_degrees); + default: + PANIC("Unhandled TensorRole", role); + } +} + +std::unordered_map + get_operator_to_ptensor_mappings( + ComputationGraphOpAttrs const &attrs, + std::unordered_map const + &inputs_degrees) { + return merge_disjoint_maps(std::vector{ + get_operator_to_input_mappings(attrs, inputs_degrees), + get_operator_to_weight_mappings(attrs, inputs_degrees), + get_operator_to_output_mappings(attrs, inputs_degrees), + }); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/get_operator_task_space.cc b/lib/op-attrs/src/op-attrs/get_operator_task_space.cc new file mode 100644 index 0000000000..40ec0b964c --- /dev/null +++ b/lib/op-attrs/src/op-attrs/get_operator_task_space.cc @@ -0,0 +1,65 @@ +#include "op-attrs/get_operator_task_space.h" +#include "op-attrs/ops/element_binary.h" +#include "op-attrs/ops/element_unary.h" +#include "op-attrs/ops/input.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/transpose.h" +#include "op-attrs/ops/weight.h" +#include "utils/containers/get_only.h" +#include "utils/containers/require_only_key.h" +#include "utils/containers/require_two_keys.h" +#include "utils/overload.h" +#include + +namespace FlexFlow { + +OperatorTaskSpace get_operator_task_space( + ComputationGraphOpAttrs const &attrs, + std::unordered_map const + &inputs_degrees) { + return attrs.visit(overload{ + [&](ElementUnaryAttrs const &attrs) { + ParallelTensorDimDegrees input = + require_only_key(inputs_degrees, TensorSlotName::INPUT); + + return get_operator_task_space(attrs, input); + }, + [&](ElementBinaryAttrs const &attrs) { + auto [lhs, rhs] = require_two_keys(inputs_degrees, + TensorSlotName::LHS_INPUT, + TensorSlotName::RHS_INPUT); + + return get_operator_task_space( + /*attrs=*/attrs, + /*lhs_input_degrees=*/lhs, + /*rhs_input_degrees=*/rhs); + }, + [&](LinearAttrs const &attrs) { + ParallelTensorDimDegrees input = + require_only_key(inputs_degrees, TensorSlotName::INPUT); + + return get_operator_task_space(attrs, input); + }, + [&](InputAttrs const &attrs) { + ASSERT(inputs_degrees.size() == 0); + + return get_operator_task_space(attrs); + }, + [&](TransposeAttrs const &attrs) { + ParallelTensorDimDegrees input = + require_only_key(inputs_degrees, TensorSlotName::INPUT); + + return get_operator_task_space(attrs, input); + }, + [&](WeightAttrs const &attrs) { + ASSERT(inputs_degrees.size() == 0); + + return get_operator_task_space(attrs); + }, + [](auto const &attrs) -> OperatorTaskSpace { + PANIC("Missing implmentation of get_operator_task_space", attrs); + }, + }); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/num_ptensor_parallel_dims_t.cc b/lib/op-attrs/src/op-attrs/num_ptensor_parallel_dims_t.cc new file mode 100644 index 0000000000..db8fb12ff2 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/num_ptensor_parallel_dims_t.cc @@ -0,0 +1,112 @@ +#include "op-attrs/num_ptensor_parallel_dims_t.h" +#include "utils/hash-utils.h" +#include +#include +#include + +namespace FlexFlow { + +num_ptensor_parallel_dims_t::num_ptensor_parallel_dims_t(int value) + : value(value) { + this->check_invariant(); +} + +num_ptensor_parallel_dims_t::num_ptensor_parallel_dims_t(nonnegative_int value) + : value(value.unwrap_nonnegative()) {} + +num_ptensor_parallel_dims_t::num_ptensor_parallel_dims_t(positive_int value) + : value(value.int_from_positive_int()) {} + +bool num_ptensor_parallel_dims_t::operator<( + num_ptensor_parallel_dims_t const &other) const { + return this->value < other.value; +} + +bool num_ptensor_parallel_dims_t::operator==( + num_ptensor_parallel_dims_t const &other) const { + return this->value == other.value; +} + +bool num_ptensor_parallel_dims_t::operator>( + num_ptensor_parallel_dims_t const &other) const { + return this->value > other.value; +} + +bool num_ptensor_parallel_dims_t::operator<=( + num_ptensor_parallel_dims_t const &other) const { + return this->value <= other.value; +} + +bool num_ptensor_parallel_dims_t::operator!=( + num_ptensor_parallel_dims_t const &other) const { + return this->value != other.value; +} + +bool num_ptensor_parallel_dims_t::operator>=( + num_ptensor_parallel_dims_t const &other) const { + return this->value >= other.value; +} + +int num_ptensor_parallel_dims_t::int_from_num_ptensor_parallel_dims() const { + return this->value; +} + +nonnegative_int num_ptensor_parallel_dims_t:: + nonnegative_int_from_num_ptensor_parallel_dims() const { + return nonnegative_int{this->value}; +} + +positive_int + num_ptensor_parallel_dims_t::positive_int_from_num_ptensor_parallel_dims() + const { + return positive_int{this->value}; +} + +void num_ptensor_parallel_dims_t::check_invariant() const { + ASSERT(this->value >= 2); + ASSERT(this->value <= MAX_TENSOR_DIM + 2); +} + +std::ostream &operator<<(std::ostream &s, + num_ptensor_parallel_dims_t const &m) { + return (s << fmt::to_string(m)); +} + +std::string format_as(num_ptensor_parallel_dims_t const &m) { + return fmt::format("{} parallel dims", + m.int_from_num_ptensor_parallel_dims()); +} + +} // namespace FlexFlow + +namespace nlohmann { +::FlexFlow::num_ptensor_parallel_dims_t + adl_serializer<::FlexFlow::num_ptensor_parallel_dims_t>::from_json( + json const &j) { + return ::FlexFlow::num_ptensor_parallel_dims_t{j.template get()}; +} + +void adl_serializer<::FlexFlow::num_ptensor_parallel_dims_t>::to_json( + json &j, ::FlexFlow::num_ptensor_parallel_dims_t t) { + j = t.int_from_num_ptensor_parallel_dims(); +} +} // namespace nlohmann + +namespace rc { + +Gen<::FlexFlow::num_ptensor_parallel_dims_t> + Arbitrary<::FlexFlow::num_ptensor_parallel_dims_t>::arbitrary() { + return gen::construct<::FlexFlow::num_ptensor_parallel_dims_t>( + gen::arbitrary()); +} + +} // namespace rc + +namespace std { + +size_t hash<::FlexFlow::num_ptensor_parallel_dims_t>::operator()( + ::FlexFlow::num_ptensor_parallel_dims_t const &m) const noexcept { + return ::FlexFlow::get_std_hash(m.int_from_num_ptensor_parallel_dims()); +} + +} // namespace std diff --git a/lib/op-attrs/src/op-attrs/num_ptensor_shard_dims_t.cc b/lib/op-attrs/src/op-attrs/num_ptensor_shard_dims_t.cc new file mode 100644 index 0000000000..00e3137eee --- /dev/null +++ b/lib/op-attrs/src/op-attrs/num_ptensor_shard_dims_t.cc @@ -0,0 +1,18 @@ +#include "op-attrs/num_ptensor_shard_dims_t.h" + +namespace FlexFlow { + +num_ptensor_parallel_dims_t num_ptensor_parallel_dims_from_shard_dims( + num_ptensor_shard_dims_t num_shard_dims) { + return num_ptensor_parallel_dims_t{num_shard_dims.value + 2_p}; +} + +num_ptensor_shard_dims_t num_ptensor_shard_dims_from_parallel_dims( + num_ptensor_parallel_dims_t num_parallel_dims) { + return num_ptensor_shard_dims_t{ + nonnegative_int{num_parallel_dims.int_from_num_ptensor_parallel_dims() - + 2}, + }; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/num_tensor_dims_t.cc b/lib/op-attrs/src/op-attrs/num_tensor_dims_t.cc new file mode 100644 index 0000000000..50843a944f --- /dev/null +++ b/lib/op-attrs/src/op-attrs/num_tensor_dims_t.cc @@ -0,0 +1,192 @@ +#include "op-attrs/num_tensor_dims_t.h" +#include "op-attrs/num_ptensor_shard_dims_t.h" +#include "utils/containers/transform.h" +#include "utils/nonnegative_int/nonnegative_range.h" +#include + +namespace FlexFlow { + +num_tensor_dims_t::num_tensor_dims_t(nonnegative_int value_) : value(value_) { + ASSERT(this->value <= MAX_TENSOR_DIM); +} + +bool num_tensor_dims_t::operator<(num_tensor_dims_t other) const { + return this->value < other.value; +} + +bool num_tensor_dims_t::operator==(num_tensor_dims_t other) const { + return this->value == other.value; +} + +bool num_tensor_dims_t::operator>(num_tensor_dims_t other) const { + return this->value > other.value; +} + +bool num_tensor_dims_t::operator<=(num_tensor_dims_t other) const { + return this->value <= other.value; +} + +bool num_tensor_dims_t::operator!=(num_tensor_dims_t other) const { + return this->value != other.value; +} + +bool num_tensor_dims_t::operator>=(num_tensor_dims_t other) const { + return this->value >= other.value; +} + +bool num_tensor_dims_t::operator<(nonnegative_int other) const { + return this->value < other; +} + +bool num_tensor_dims_t::operator==(nonnegative_int other) const { + return this->value == other; +} + +bool num_tensor_dims_t::operator>(nonnegative_int other) const { + return this->value > other; +} + +bool num_tensor_dims_t::operator<=(nonnegative_int other) const { + return this->value <= other; +} + +bool num_tensor_dims_t::operator!=(nonnegative_int other) const { + return this->value != other; +} + +bool num_tensor_dims_t::operator>=(nonnegative_int other) const { + return this->value >= other; +} + +bool operator<(nonnegative_int lhs, num_tensor_dims_t rhs) { + return lhs < rhs.value; +} + +bool operator==(nonnegative_int lhs, num_tensor_dims_t rhs) { + return lhs == rhs.value; +} + +bool operator>(nonnegative_int lhs, num_tensor_dims_t rhs) { + return lhs > rhs.value; +} + +bool operator<=(nonnegative_int lhs, num_tensor_dims_t rhs) { + return lhs <= rhs.value; +} + +bool operator!=(nonnegative_int lhs, num_tensor_dims_t rhs) { + return lhs != rhs.value; +} + +bool operator>=(nonnegative_int lhs, num_tensor_dims_t rhs) { + return lhs >= rhs.value; +} + +bool num_tensor_dims_t::operator<(int other) const { + return this->value < other; +} + +bool num_tensor_dims_t::operator==(int other) const { + return this->value == other; +} + +bool num_tensor_dims_t::operator>(int other) const { + return this->value > other; +} + +bool num_tensor_dims_t::operator<=(int other) const { + return this->value <= other; +} + +bool num_tensor_dims_t::operator!=(int other) const { + return this->value != other; +} + +bool num_tensor_dims_t::operator>=(int other) const { + return this->value >= other; +} + +bool operator<(int lhs, num_tensor_dims_t rhs) { + return lhs < rhs.value; +} + +bool operator==(int lhs, num_tensor_dims_t rhs) { + return lhs == rhs.value; +} + +bool operator>(int lhs, num_tensor_dims_t rhs) { + return lhs > rhs.value; +} + +bool operator<=(int lhs, num_tensor_dims_t rhs) { + return lhs <= rhs.value; +} + +bool operator!=(int lhs, num_tensor_dims_t rhs) { + return lhs != rhs.value; +} + +bool operator>=(int lhs, num_tensor_dims_t rhs) { + return lhs >= rhs.value; +} + +nonnegative_int + num_tensor_dims_t::nonnegative_int_from_num_tensor_dims() const { + return this->value; +} + +int num_tensor_dims_t::int_from_num_tensor_dims() const { + return this->value.unwrap_nonnegative(); +} + +void num_tensor_dims_t::check_invariant() const { + ASSERT(this->value <= MAX_TENSOR_DIM); +} + +nonnegative_int format_as(num_tensor_dims_t num_tensor_dims) { + return num_tensor_dims.nonnegative_int_from_num_tensor_dims(); +} + +std::ostream &operator<<(std::ostream &s, num_tensor_dims_t num_tensor_dims) { + return (s << fmt::to_string(num_tensor_dims)); +} + +num_tensor_dims_t num_tensor_dims_from_num_ptensor_shard_dims( + num_ptensor_shard_dims_t num_ptensor_shard_dims) { + return num_tensor_dims_t{num_ptensor_shard_dims.value}; +} + +num_tensor_dims_t num_tensor_dims_from_num_ptensor_parallel_dims( + num_ptensor_parallel_dims_t num_ptensor_parallel_dims) { + return num_tensor_dims_from_num_ptensor_shard_dims( + num_ptensor_shard_dims_from_parallel_dims(num_ptensor_parallel_dims)); +} + +num_ptensor_shard_dims_t num_ptensor_shard_dims_from_num_tensor_dims( + num_tensor_dims_t num_tensor_dims) { + return num_ptensor_shard_dims_t{ + num_tensor_dims.nonnegative_int_from_num_tensor_dims()}; +} + +num_ptensor_parallel_dims_t num_ptensor_parallel_dims_from_num_tensor_dims( + num_tensor_dims_t num_tensor_dims) { + return num_ptensor_parallel_dims_from_shard_dims( + num_ptensor_shard_dims_from_num_tensor_dims(num_tensor_dims)); +} + +std::vector tensor_dims_range(num_tensor_dims_t num_tensor_dims) { + return transform( + nonnegative_range(num_tensor_dims.nonnegative_int_from_num_tensor_dims()), + [](nonnegative_int idx) { return ff_dim_t{idx}; }); +} + +std::vector + relative_tensor_dims_range(num_tensor_dims_t num_tensor_dims) { + return transform( + nonnegative_range(num_tensor_dims.nonnegative_int_from_num_tensor_dims()), + [](nonnegative_int idx) { + return relative_ff_dim_t{idx.unwrap_nonnegative()}; + }); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/operator_space_to_parallel_tensor_space_mapping.cc b/lib/op-attrs/src/op-attrs/operator_space_to_parallel_tensor_space_mapping.cc new file mode 100644 index 0000000000..c88043f6ce --- /dev/null +++ b/lib/op-attrs/src/op-attrs/operator_space_to_parallel_tensor_space_mapping.cc @@ -0,0 +1,134 @@ +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.h" +#include "op-attrs/num_ptensor_shard_dims_t.h" +#include "op-attrs/operator_task_space.h" +#include "op-attrs/parallel_tensor_dim_degrees.h" +#include "op-attrs/parallel_tensor_dim_idx_t.h" +#include "op-attrs/parallel_tensor_space_coordinate.h" +#include "op-attrs/task_space_coordinate.h" +#include "utils/bidict/algorithms/bidict_from_keys_and_values.h" +#include "utils/containers/set_of.h" +#include "utils/containers/transform.h" +#include "utils/nonnegative_int/nonnegative_range.h" +#include "utils/nonnegative_int/num_elements.h" +#include "utils/nonnegative_int/range.h" +#include "utils/orthotope/dim_projection.h" +#include "utils/orthotope/minimal_dim_domain.h" +#include "utils/orthotope/minimal_dim_domain_mapping.h" + +namespace FlexFlow { + +OperatorSpaceToParallelTensorSpaceMapping + empty_operator_space_to_ptensor_space_map() { + + return OperatorSpaceToParallelTensorSpaceMapping{ + empty_dim_domain_mapping(), + }; +} + +OperatorTaskSpace get_operator_task_space_for_mapping( + OperatorSpaceToParallelTensorSpaceMapping const &mapping) { + + return operator_task_space_from_minimal_dim_domain( + require_dim_domain_is_minimal(mapping.raw_mapping.l_domain)); +} + +ParallelTensorDimDegrees get_parallel_tensor_space_for_mapping( + OperatorSpaceToParallelTensorSpaceMapping const &mapping) { + + return parallel_tensor_dim_degrees_from_dim_domain( + mapping.raw_mapping.r_domain); +} + +OperatorSpaceToParallelTensorSpaceMapping get_identity_mapping( + OperatorTaskSpace const &operator_task_space, + ParallelTensorDimDegrees const ¶llel_tensor_dim_degrees) { + + MinimalDimDomain pt_minimal_dim_domain = + minimal_dim_domain_from_parallel_tensor_dim_degrees( + parallel_tensor_dim_degrees); + + ASSERT(op_task_space_num_dims(operator_task_space) == + minimal_dim_domain_num_dims(pt_minimal_dim_domain)); + + std::vector op_minimal_domain_dims = + sorted_by(operator_task_space_get_dim_idxs(operator_task_space), + get_operator_task_space_dim_ordering().lt); + + std::vector pt_minimal_domain_dims = + sorted_by(get_minimal_domain_dims(pt_minimal_dim_domain), + get_parallel_tensor_dim_ordering().lt); + + bidict projection = + bidict_from_keys_and_values(op_minimal_domain_dims, + pt_minimal_domain_dims); + + return operator_ptensor_space_mapping_from_projection( + DimProjection{EqProjection{projection}}, + operator_task_space, + parallel_tensor_dim_degrees); +} + +OperatorSpaceToParallelTensorSpaceMapping + operator_ptensor_space_mapping_from_projection( + DimProjection const &projection, + OperatorTaskSpace const &operator_task_space, + ParallelTensorDimDegrees const ¶llel_tensor_dim_degrees) { + + return OperatorSpaceToParallelTensorSpaceMapping{ + dim_domain_mapping_from_projection( + /*projection=*/projection, + /*l_domain=*/ + lift_minimal_dim_domain( + minimal_dim_domain_from_operator_task_space(operator_task_space)), + /*r_domain=*/ + lift_minimal_dim_domain( + minimal_dim_domain_from_parallel_tensor_dim_degrees( + parallel_tensor_dim_degrees)), + /*l_dim_ordering=*/get_operator_task_space_dim_ordering(), + /*r_dim_ordering=*/get_parallel_tensor_dim_ordering()), + }; +} + +OperatorSpaceToParallelTensorSpaceMapping + operator_ptensor_space_mapping_from_composition( + OperatorSpaceToParallelTensorSpaceMapping const &op_to_pt1_mapping, + ParallelTensorSpaceToParallelTensorSpaceMapping const + &pt1_to_pt2_mapping) { + + return OperatorSpaceToParallelTensorSpaceMapping{ + compose_dim_domain_mappings_through_minimal( + op_to_pt1_mapping.raw_mapping, pt1_to_pt2_mapping.raw_mapping), + }; +} + +ParallelTensorSpaceCoordinate ptensor_coord_for_task_space_coord( + OperatorSpaceToParallelTensorSpaceMapping const &mapping, + TaskSpaceCoordinate const &task_space_coordinate, + num_ptensor_shard_dims_t num_dims) { + + std::unordered_set ptensor_dim_idxs = + unordered_set_of(dim_idxs_for_num_shard_dims(num_dims)); + + DimCoord mapped_dim_coord = + mapping.raw_mapping.at_l( + dim_coord_from_task_space_coordinate(task_space_coordinate)); + + DimCoord lifted_dim_coord = + lift_dim_coord(mapped_dim_coord, ptensor_dim_idxs); + + return parallel_tensor_space_coord_from_dim_coord(lifted_dim_coord); +} + +TaskSpaceCoordinate task_space_coord_for_ptensor_coord( + OperatorSpaceToParallelTensorSpaceMapping const &mapping, + ParallelTensorSpaceCoordinate const &ptensor_space_coord) { + + DimCoord dim_coord = mapping.raw_mapping.at_r( + dim_coord_from_parallel_tensor_space_coord(ptensor_space_coord)); + + return task_space_coordinate_from_dim_coord(dim_coord); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/operator_task_space.cc b/lib/op-attrs/src/op-attrs/operator_task_space.cc new file mode 100644 index 0000000000..98f7525564 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/operator_task_space.cc @@ -0,0 +1,109 @@ +#include "op-attrs/operator_task_space.h" +#include "op-attrs/operator_task_space.dtg.h" +#include "op-attrs/operator_task_space_dim_idx_t.h" +#include "op-attrs/parallel_tensor_dim_degrees.h" +#include "op-attrs/parallel_tensor_dim_idx_t.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "utils/containers/cartesian_product.h" +#include "utils/containers/extend.h" +#include "utils/containers/maximum.h" +#include "utils/containers/product.h" +#include "utils/containers/range.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/vector_of.h" +#include "utils/fmt/unordered_set.h" +#include "utils/nonnegative_int/nonnegative_range.h" +#include "utils/nonnegative_int/num_elements.h" +#include "utils/orthotope/dim_domain.h" +#include "utils/orthotope/dim_ordering.h" +#include "utils/orthotope/minimal_dim_domain.h" +#include "utils/orthotope/minimal_orthotope.h" +#include "utils/orthotope/orthotope.dtg.h" +#include "utils/orthotope/orthotope.h" + +namespace FlexFlow { + +OperatorTaskSpace trivial_op_task_space() { + return OperatorTaskSpace{MinimalOrthotope{{}}}; +} + +std::unordered_set + operator_task_space_get_dim_idxs(OperatorTaskSpace const &op_task_space) { + return get_minimal_domain_dims( + minimal_dim_domain_from_operator_task_space(op_task_space)); +} + +std::unordered_set + get_task_space_coordinates(OperatorTaskSpace const &task) { + + std::vector> coordinate_ranges = + transform(task.degrees.dims, [&](int_ge_two num_points) { + return nonnegative_range(num_points.nonnegative_int_from_int_ge_two()); + }); + + std::unordered_set> raw_coordinates = + unordered_set_of(cartesian_product(coordinate_ranges)); + std::unordered_set task_space_coordinates = + transform(raw_coordinates, [](std::vector const &point) { + return TaskSpaceCoordinate{OrthotopeCoord{point}}; + }); + return task_space_coordinates; +} + +bool operator_task_space_contains_coord(OperatorTaskSpace const &task_space, + TaskSpaceCoordinate const &coord) { + return contains(get_task_space_coordinates(task_space), coord); +} + +TaskSpaceCoordinate + get_task_space_maximum_coordinate(OperatorTaskSpace const &task) { + return maximum(get_task_space_coordinates(task)); +} + +nonnegative_int op_task_space_num_dims(OperatorTaskSpace const &op_task_space) { + return minimal_orthotope_get_num_dims(op_task_space.degrees); +} + +positive_int num_tasks(OperatorTaskSpace const &op_task_space) { + return minimal_orthotope_get_volume(op_task_space.degrees); +} + +MinimalDimDomain + minimal_dim_domain_from_operator_task_space( + OperatorTaskSpace const &operator_task_space) { + + MinimalOrthotope minimal_orthotope = operator_task_space.degrees; + + return minimal_dim_domain_from_minimal_orthotope( + minimal_orthotope, + unordered_set_of(operator_task_space_dim_idx_range( + minimal_orthotope_get_num_dims(minimal_orthotope))), + get_operator_task_space_dim_ordering()); +} + +OperatorTaskSpace operator_task_space_from_minimal_dim_domain( + MinimalDimDomain const &minimal_dim_domain) { + + return OperatorTaskSpace{ + minimal_orthotope_from_minimal_dim_domain( + minimal_dim_domain, get_operator_task_space_dim_ordering()), + }; +} + +DimOrdering + get_operator_task_space_dim_ordering() { + return make_default_dim_ordering(); +} + +OperatorTaskSpace get_operator_task_space_matching_parallel_tensor_dim_degrees( + ParallelTensorDimDegrees const &dim_degrees) { + return OperatorTaskSpace{ + minimal_orthotope_from_minimal_dim_domain( + minimal_dim_domain_from_parallel_tensor_dim_degrees(dim_degrees), + get_parallel_tensor_dim_ordering()), + }; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/operator_task_space_dim_idx_t.cc b/lib/op-attrs/src/op-attrs/operator_task_space_dim_idx_t.cc new file mode 100644 index 0000000000..c78afc502f --- /dev/null +++ b/lib/op-attrs/src/op-attrs/operator_task_space_dim_idx_t.cc @@ -0,0 +1,15 @@ +#include "op-attrs/operator_task_space_dim_idx_t.h" +#include "utils/containers/set_of.h" +#include "utils/containers/transform.h" +#include "utils/nonnegative_int/range.h" + +namespace FlexFlow { + +std::set + operator_task_space_dim_idx_range(nonnegative_int end) { + return transform(set_of(range(end)), [](nonnegative_int raw_idx) { + return operator_task_space_dim_idx_t{raw_idx}; + }); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/operator_task_space_to_operator_task_space_mapping.cc b/lib/op-attrs/src/op-attrs/operator_task_space_to_operator_task_space_mapping.cc new file mode 100644 index 0000000000..605578acdc --- /dev/null +++ b/lib/op-attrs/src/op-attrs/operator_task_space_to_operator_task_space_mapping.cc @@ -0,0 +1,61 @@ +#include "op-attrs/operator_task_space_to_operator_task_space_mapping.h" +#include "op-attrs/operator_task_space.h" +#include "op-attrs/task_space_coordinate.h" +#include "utils/bidict/algorithms/transform_keys.h" +#include "utils/bidict/algorithms/transform_values.h" +#include "utils/orthotope/minimal_dim_domain.h" +#include "utils/orthotope/minimal_dim_domain_mapping.h" + +namespace FlexFlow { + +OperatorTaskSpaceToOperatorTaskSpaceMapping + op_to_op_identity_mapping(OperatorTaskSpace const &src_space, + OperatorTaskSpace const &dst_space) { + + return OperatorTaskSpaceToOperatorTaskSpaceMapping{ + dim_domain_mapping_identity_map( + /*l_domain=*/lift_minimal_dim_domain( + minimal_dim_domain_from_operator_task_space(src_space)), + /*r_domain=*/ + lift_minimal_dim_domain( + minimal_dim_domain_from_operator_task_space(dst_space)), + /*l_dim_ordering=*/get_operator_task_space_dim_ordering(), + /*r_dim_ordering=*/get_operator_task_space_dim_ordering()), + }; +} + +OperatorTaskSpace op_mapping_get_src_space( + OperatorTaskSpaceToOperatorTaskSpaceMapping const &mapping) { + + return operator_task_space_from_minimal_dim_domain( + require_dim_domain_is_minimal(mapping.raw_mapping.l_domain)); +} + +OperatorTaskSpace op_mapping_get_dst_space( + OperatorTaskSpaceToOperatorTaskSpaceMapping const &mapping) { + + return operator_task_space_from_minimal_dim_domain( + require_dim_domain_is_minimal(mapping.raw_mapping.r_domain)); +} + +bidict op_to_op_get_coord_mapping( + OperatorTaskSpaceToOperatorTaskSpaceMapping const &mapping) { + return transform_values(transform_keys(mapping.raw_mapping.coord_mapping, + task_space_coordinate_from_dim_coord), + task_space_coordinate_from_dim_coord); +} + +OperatorTaskSpaceToOperatorTaskSpaceMapping + op_to_op_mapping_from_composition_through_tensor( + OperatorSpaceToParallelTensorSpaceMapping const &src_to_tensor_mapping, + OperatorSpaceToParallelTensorSpaceMapping const + &dst_to_tensor_mapping) { + + return OperatorTaskSpaceToOperatorTaskSpaceMapping{ + compose_dim_domain_mappings_through_minimal( + src_to_tensor_mapping.raw_mapping, + invert_dim_domain_mapping(dst_to_tensor_mapping.raw_mapping)), + }; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/attention.cc b/lib/op-attrs/src/op-attrs/ops/attention.cc index cc6ef8cfac..816c2787cd 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention.cc @@ -5,8 +5,10 @@ #include "op-attrs/tensor_dims.h" #include "op-attrs/tensor_shape.h" #include "utils/containers/extend.h" +#include "utils/exception.h" #include "utils/expected.h" #include "utils/integer_conversions.h" +#include namespace FlexFlow { @@ -95,27 +97,26 @@ positive_int get_num_samples(MultiHeadAttentionInputs const &inputs) { } static void check_attrs(MultiHeadAttentionAttrs const &attrs) { - if (attrs.add_bias_kv) { - throw mk_runtime_error("add_bias_kv is not yet supported. If you need this " - "functionality, please create an issue."); - } + ASSERT(!attrs.add_bias_kv, + "add_bias_kv is not yet supported. If you need this " + "functionality, please create an issue."); } -std::vector +std::unordered_map get_attention_incoming_tensor_roles(MultiHeadAttentionAttrs const &attrs) { check_attrs(attrs); - std::vector roles = std::vector{ - IncomingTensorRole::INPUT, - IncomingTensorRole::INPUT, - IncomingTensorRole::INPUT, - IncomingTensorRole::WEIGHT, + std::unordered_map roles = { + {TensorSlotName::QUERY, IncomingTensorRole::INPUT}, + {TensorSlotName::KEY, IncomingTensorRole::INPUT}, + {TensorSlotName::VALUE, IncomingTensorRole::INPUT}, + {TensorSlotName::WEIGHT, IncomingTensorRole::WEIGHT}, }; if (attrs.bias) { - extend(roles, - std::vector{IncomingTensorRole::WEIGHT, IncomingTensorRole::WEIGHT}); + roles[TensorSlotName::INPUT_BIAS] = IncomingTensorRole::WEIGHT; + roles[TensorSlotName::OUTPUT_BIAS] = IncomingTensorRole::WEIGHT; } return roles; @@ -232,21 +233,29 @@ tl::expected }; } -tl::expected, std::string> +tl::expected, std::string> get_weight_shapes(MultiHeadAttentionAttrs const &attrs, TensorShape const &input_q, TensorShape const &input_k, TensorShape const &input_v) { - std::vector weight_shapes = { - PROPAGATE_ERR(get_weights_shape(attrs, input_q, input_k, input_v)), + std::unordered_map weight_shapes = { + { + TensorSlotName::WEIGHT, + PROPAGATE_ERR(get_weights_shape(attrs, input_q, input_k, input_v)), + }, }; if (attrs.bias) { - weight_shapes.push_back( - PROPAGATE_ERR(get_input_bias_shape(attrs, input_q, input_k, input_v))); - weight_shapes.push_back( - PROPAGATE_ERR(get_output_bias_shape(attrs, input_q, input_k, input_v))); + weight_shapes.insert({ + TensorSlotName::INPUT_BIAS, + PROPAGATE_ERR(get_input_bias_shape(attrs, input_q, input_k, input_v)), + }); + + weight_shapes.insert({ + TensorSlotName::OUTPUT_BIAS, + PROPAGATE_ERR(get_output_bias_shape(attrs, input_q, input_k, input_v)), + }); } return weight_shapes; @@ -407,34 +416,44 @@ positive_int get_oSize(TensorShape const &) { NOT_IMPLEMENTED(); } -tl::expected, std::string> +tl::expected, + std::string> get_weight_shapes(MultiHeadAttentionAttrs const &attrs, ParallelTensorShape const &input_q, ParallelTensorShape const &input_k, ParallelTensorShape const &input_v) { - std::vector weight_shapes = { - PROPAGATE_ERR(get_weights_shape(attrs, input_q, input_k, input_v)), + std::unordered_map weight_shapes = { + { + TensorSlotName::WEIGHT, + PROPAGATE_ERR(get_weights_shape(attrs, input_q, input_k, input_v)), + }, }; if (attrs.bias) { - weight_shapes.push_back( - PROPAGATE_ERR(get_input_bias_shape(attrs, input_q, input_k, input_v))); - weight_shapes.push_back( - PROPAGATE_ERR(get_output_bias_shape(attrs, input_q, input_k, input_v))); + weight_shapes.insert({ + TensorSlotName::INPUT_BIAS, + PROPAGATE_ERR(get_input_bias_shape(attrs, input_q, input_k, input_v)), + }); + + weight_shapes.insert({ + TensorSlotName::OUTPUT_BIAS, + PROPAGATE_ERR(get_output_bias_shape(attrs, input_q, input_k, input_v)), + }); } return weight_shapes; } -tl::expected, std::string> get_initializers( - MultiHeadAttentionAttrs const &attrs, - TensorShape const &input_q, - TensorShape const &input_k, - TensorShape const &input_v, - std::optional const &maybe_weights_initializer, - std::optional const &maybe_input_bias_initializer, - std::optional const &maybe_output_bias_initializer) { +tl::expected, std::string> + get_initializers( + MultiHeadAttentionAttrs const &attrs, + TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v, + std::optional const &maybe_weights_initializer, + std::optional const &maybe_input_bias_initializer, + std::optional const &maybe_output_bias_initializer) { check_attrs(attrs); if (!attrs.bias && maybe_input_bias_initializer.has_value()) { @@ -473,14 +492,14 @@ tl::expected, std::string> get_initializers( maybe_output_bias_initializer.value_or(default_output_bias_initializer); if (attrs.bias) { - return std::vector{ - weights_initializer, - input_bias_initializer, - output_bias_initializer, + return std::unordered_map{ + {TensorSlotName::WEIGHT, weights_initializer}, + {TensorSlotName::INPUT_BIAS, input_bias_initializer}, + {TensorSlotName::OUTPUT_BIAS, output_bias_initializer}, }; } else { - return std::vector{ - weights_initializer, + return std::unordered_map{ + {TensorSlotName::WEIGHT, weights_initializer}, }; } } diff --git a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc index 3225f1aef2..f9d00dc523 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc @@ -23,19 +23,19 @@ tl::expected unpar_parse_result.error())); } - if (num_shard_dims(input_q) != 3) { + if (num_shard_dims(input_q).value != 3) { return tl::unexpected( fmt::format("Query input has incorrect number of dims: {} != {}", num_shard_dims(input_q), 3)); } - if (num_shard_dims(input_k) != 3) { + if (num_shard_dims(input_k).value != 3) { return tl::unexpected( fmt::format("Key input has incorrect number of dims: {} != {}", num_shard_dims(input_k), 3)); } - if (num_shard_dims(input_v) != 3) { + if (num_shard_dims(input_v).value != 3) { return tl::unexpected( fmt::format("Value input has incorrect number of dims: {} != {}", num_shard_dims(input_v), diff --git a/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc b/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc index 3c76561d17..8fb34dc191 100644 --- a/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc +++ b/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc @@ -1,6 +1,7 @@ #include "op-attrs/ops/batch_matmul.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_dims.h" +#include "utils/exception.h" namespace FlexFlow { @@ -35,10 +36,12 @@ tl::expected get_output_shape(BatchMatmulAttrs const &attrs, TensorShape const &input_lhs, TensorShape const &input_rhs) { - // If input_lhs is a (b×n×m) tensor, - // input_rhs is a (b×m×p) tensor, - // out will be a (b×n×p) tensor. - // https://pytorch.org/docs/stable/generated/torch.bmm.html + /** + * If input_lhs is a (b×n×m) tensor, + * input_rhs is a (b×m×p) tensor, + * out will be a (b×n×p) tensor. + * https://pytorch.org/docs/stable/generated/torch.bmm.html + */ if (get_num_dims(input_lhs.dims) != 3) { return tl::unexpected( @@ -91,13 +94,13 @@ tl::expected get_output_shape(BatchMatmulAttrs const &attrs, ParallelTensorShape const &input_lhs, ParallelTensorShape const &input_rhs) { - if (num_shard_dims(input_lhs) != 3) { + if (num_shard_dims(input_lhs).value != 3) { return tl::unexpected( fmt::format("LHS input has incorrect number of shard dims: {} != {}", num_shard_dims(input_lhs), 3)); } - if (num_shard_dims(input_rhs) != 3) { + if (num_shard_dims(input_rhs).value != 3) { return tl::unexpected( fmt::format("RHS input has incorrect number of shard dims: {} != {}", num_shard_dims(input_rhs), diff --git a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc index cfe5bafaba..5d451c617d 100644 --- a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc @@ -10,13 +10,18 @@ namespace FlexFlow { -std::vector +std::unordered_map get_batch_norm_incoming_tensor_roles(BatchNormAttrs const &attrs) { - std::vector result = {IncomingTensorRole::INPUT}; + std::unordered_map result = { + { + TensorSlotName::INPUT, + IncomingTensorRole::INPUT, + }, + }; if (attrs.affine) { - extend(result, - std::vector{IncomingTensorRole::WEIGHT, IncomingTensorRole::WEIGHT}); + result[TensorSlotName::GAMMA] = IncomingTensorRole::WEIGHT; + result[TensorSlotName::BETA] = IncomingTensorRole::WEIGHT; } return result; @@ -91,7 +96,7 @@ tl::expected return get_gamma_weights_shape(attrs, input_shape); } -tl::expected, std::string> +tl::expected, std::string> get_weight_shapes(BatchNormAttrs const &attrs, TensorShape const &input_shape) { @@ -100,9 +105,15 @@ tl::expected, std::string> TensorShape beta_shape = PROPAGATE_ERR(get_beta_weights_shape(attrs, input_shape)); - return std::vector{ - gamma_shape, - beta_shape, + return std::unordered_map{ + { + TensorSlotName::GAMMA, + gamma_shape, + }, + { + TensorSlotName::BETA, + beta_shape, + }, }; } @@ -199,7 +210,8 @@ tl::expected return get_gamma_weights_parallel_dim_degrees(attrs, input_degrees); } -tl::expected, std::string> +tl::expected, + std::string> get_weight_parallel_dim_degrees( BatchNormAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees) { @@ -209,9 +221,15 @@ tl::expected, std::string> ParallelTensorDimDegrees beta_degrees = PROPAGATE_ERR( get_beta_weights_parallel_dim_degrees(attrs, input_degrees)); - return std::vector{ - gamma_degrees, - beta_degrees, + return std::unordered_map{ + { + TensorSlotName::GAMMA, + gamma_degrees, + }, + { + TensorSlotName::BETA, + beta_degrees, + }, }; } @@ -292,7 +310,8 @@ tl::expected return lift_to_parallel_with_degrees(unpar, degrees); } -tl::expected, std::string> +tl::expected, + std::string> get_weight_shapes(BatchNormAttrs const &attrs, ParallelTensorShape const &input_shape) { @@ -301,13 +320,19 @@ tl::expected, std::string> ParallelTensorShape beta_shape = PROPAGATE_ERR(get_beta_weights_shape(attrs, input_shape)); - return std::vector{ - gamma_shape, - beta_shape, + return std::unordered_map{ + { + TensorSlotName::GAMMA, + gamma_shape, + }, + { + TensorSlotName::BETA, + beta_shape, + }, }; } -tl::expected, std::string> +tl::expected, std::string> get_initializers(BatchNormAttrs const &attrs) { if (attrs.affine) { InitializerAttrs gamma_initializer = @@ -316,9 +341,18 @@ tl::expected, std::string> InitializerAttrs beta_initializer = InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; - return std::vector{gamma_initializer, beta_initializer}; + return std::unordered_map{ + { + TensorSlotName::GAMMA, + gamma_initializer, + }, + { + TensorSlotName::BETA, + beta_initializer, + }, + }; } else { - return std::vector{}; + return std::unordered_map{}; } } diff --git a/lib/op-attrs/src/op-attrs/ops/broadcast.cc b/lib/op-attrs/src/op-attrs/ops/broadcast.cc index d84a9ee46e..927d4fd913 100644 --- a/lib/op-attrs/src/op-attrs/ops/broadcast.cc +++ b/lib/op-attrs/src/op-attrs/ops/broadcast.cc @@ -1,5 +1,7 @@ #include "op-attrs/ops/broadcast.h" +#include "op-attrs/num_tensor_dims_t.h" #include "op-attrs/tensor_dims.h" +#include "utils/exception.h" #include "utils/record_formatter.h" namespace FlexFlow { @@ -13,9 +15,9 @@ RecordFormatter as_dot(BroadcastAttrs const &attrs) { return rr; }; - for (int i = 0; i < get_num_dims(attrs.target_dims); i++) { - r << kv(fmt::format("target_dims[{}]", i), - dim_at_idx(attrs.target_dims, relative_ff_dim_t{i})); + for (ff_dim_t dim_idx : tensor_dims_range(get_num_dims(attrs.target_dims))) { + r << kv(fmt::format("target_dims[{}]", dim_idx.value), + dim_at_idx(attrs.target_dims, dim_idx)); } return r; diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc index 6ff1b8a06e..b50446a693 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc @@ -4,18 +4,19 @@ #include "op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.h" #include "utils/fmt/optional.h" #include "utils/integer_conversions.h" +#include namespace FlexFlow { -std::vector +std::unordered_map get_conv2d_incoming_tensor_roles(Conv2DAttrs const &attrs) { - std::vector result = { - IncomingTensorRole::INPUT, - IncomingTensorRole::WEIGHT, + std::unordered_map result = { + {TensorSlotName::INPUT, IncomingTensorRole::INPUT}, + {TensorSlotName::FILTER, IncomingTensorRole::WEIGHT}, }; if (attrs.use_bias) { - result.push_back(IncomingTensorRole::WEIGHT); + result[TensorSlotName::BIAS] = IncomingTensorRole::WEIGHT; } return result; @@ -88,14 +89,21 @@ TensorShape get_output_shape(Conv2DAttrs const &attrs, input.datatype}; } -std::vector get_weight_shapes(Conv2DAttrs const &attrs, - TensorShape const &input_shape) { - std::vector weight_shapes = { - get_kernel_shape(attrs, input_shape), +std::unordered_map + get_weight_shapes(Conv2DAttrs const &attrs, + TensorShape const &input_shape) { + std::unordered_map weight_shapes = { + { + TensorSlotName::FILTER, + get_kernel_shape(attrs, input_shape), + }, }; if (attrs.use_bias) { - weight_shapes.push_back(get_bias_shape(attrs, input_shape)); + weight_shapes.insert({ + TensorSlotName::BIAS, + get_bias_shape(attrs, input_shape), + }); } return weight_shapes; @@ -172,15 +180,21 @@ ParallelTensorShape get_output_shape(Conv2DAttrs const &attrs, unpar, sum_degree, discard_copy_degree, shard_degrees); } -std::vector +std::unordered_map get_weight_shapes(Conv2DAttrs const &attrs, ParallelTensorShape const &input_shape) { - std::vector weight_shapes = { - get_kernel_shape(attrs, input_shape), + std::unordered_map weight_shapes = { + { + TensorSlotName::FILTER, + get_kernel_shape(attrs, input_shape), + }, }; if (attrs.use_bias) { - weight_shapes.push_back(get_bias_shape(attrs, input_shape)); + weight_shapes.insert({ + TensorSlotName::BIAS, + get_bias_shape(attrs, input_shape), + }); } return weight_shapes; @@ -192,18 +206,12 @@ std::vector * see * https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L178-L187 */ -std::vector +std::unordered_map get_initializers(Conv2DAttrs const &attrs, TensorShape const &input_shape, std::optional maybe_kernel_initializer, std::optional maybe_bias_initializer) { - if (!attrs.use_bias && maybe_bias_initializer.has_value()) { - throw mk_runtime_error(fmt::format( - "Unexpectedly received bias initializer while use_bias=false: {}", - maybe_bias_initializer)); - } - TensorShape kernel_shape = get_kernel_shape(attrs, input_shape); InitializerAttrs kernel_default_initializer = @@ -233,9 +241,14 @@ std::vector maybe_bias_initializer.value_or(bias_default_initializer); if (attrs.use_bias) { - return {kernel_initializer, bias_initializer}; + return { + {TensorSlotName::FILTER, kernel_initializer}, + {TensorSlotName::BIAS, bias_initializer}, + }; } else { - return {kernel_initializer}; + return { + {TensorSlotName::FILTER, kernel_initializer}, + }; } } diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc index 79bb14f2b2..8038c1a4c5 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc @@ -1,6 +1,8 @@ #include "op-attrs/ops/conv_2d/conv_2d_input_shape.h" +#include "op-attrs/num_tensor_dims_t.h" #include "op-attrs/tensor_dims.h" #include "op-attrs/tensor_shape.h" +#include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc index 8143353b2d..e08bd4bec2 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc @@ -5,7 +5,7 @@ namespace FlexFlow { Conv2DParallelInputShape parse_parallel_input_shape(ParallelTensorShape const &input) { - assert(num_shard_dims(input) == 4); + assert(num_shard_dims(input).value == 4); ShardParallelDim sample_dim = shard_dim_at_idx(input, relative_ff_dim_t{0}); ShardParallelDim channel_dim = shard_dim_at_idx(input, relative_ff_dim_t{1}); diff --git a/lib/op-attrs/src/op-attrs/ops/element_binary.cc b/lib/op-attrs/src/op-attrs/ops/element_binary.cc index 16957a036c..27b73f00d0 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_binary.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_binary.cc @@ -1,55 +1,64 @@ #include "op-attrs/ops/element_binary.h" +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.h" +#include "op-attrs/operator_task_space.h" +#include "utils/containers/require_same.h" +#include "utils/exception.h" namespace FlexFlow { -tl::expected - get_output_shape(ElementBinaryAttrs const &attrs, - TensorShape const &input_lhs, - TensorShape const &input_rhs) { - assert(!(attrs.should_broadcast_lhs && attrs.should_broadcast_rhs)); +TensorShape get_output_shape(ElementBinaryAttrs const &attrs, + TensorShape const &input_lhs, + TensorShape const &input_rhs) { + ASSERT(!attrs.should_broadcast_lhs && !attrs.should_broadcast_rhs, + "ElementBinary broadcasting is currently not supported. " + "Contact @lockshaw if you want this feature implemented."); if (attrs.should_broadcast_lhs) { NOT_IMPLEMENTED(); } else if (attrs.should_broadcast_rhs) { NOT_IMPLEMENTED(); } else { - if (input_lhs != input_rhs) { - return tl::unexpected(fmt::format( - "Expected input shapes to match, but receieved LHS ({}) != RHS ({})", - input_lhs, - input_rhs)); - } + ASSERT(input_lhs == input_rhs, "Expected input shapes to match"); return input_lhs; } } -tl::expected - get_output_shape(ElementBinaryAttrs const &attrs, - ParallelTensorShape const &input_lhs, - ParallelTensorShape const &input_rhs) { - assert(!(attrs.should_broadcast_lhs && attrs.should_broadcast_rhs)); +ParallelTensorShape get_output_shape(ElementBinaryAttrs const &attrs, + ParallelTensorShape const &input_lhs, + ParallelTensorShape const &input_rhs) { + TensorShape output_shape = get_output_shape( + attrs, get_reduced_shape(input_lhs), get_reduced_shape(input_rhs)); + + ParallelTensorDimDegrees output_degrees = get_output_parallel_dim_degrees( + attrs, get_parallel_degrees(input_lhs), get_parallel_degrees(input_rhs)); + + return lift_to_parallel_with_degrees(output_shape, output_degrees); +} + +ParallelTensorDimDegrees get_output_parallel_dim_degrees( + ElementBinaryAttrs const &attrs, + ParallelTensorDimDegrees const &lhs_input_degrees, + ParallelTensorDimDegrees const &rhs_input_degrees) { + ASSERT(!attrs.should_broadcast_lhs && !attrs.should_broadcast_rhs, + "ElementBinary broadcasting is currently not supported. " + "Contact @lockshaw if you want this feature implemented."); + + ASSERT(lhs_input_degrees == rhs_input_degrees); if (attrs.should_broadcast_lhs) { NOT_IMPLEMENTED(); } else if (attrs.should_broadcast_rhs) { NOT_IMPLEMENTED(); } else { - if (input_lhs != input_rhs) { - return tl::unexpected(fmt::format( - "Expected input shapes to match, but receieved LHS ({}) != RHS ({})", - input_lhs, - input_rhs)); - } + ASSERT(lhs_input_degrees == rhs_input_degrees, + "Expected input degrees to match"); switch (attrs.type) { case OperatorType::EW_ADD: { - if (get_discard_copy_degree(input_lhs) != 1) { - return tl::unexpected( - fmt::format("Elementwise Add expected discard copy degree of " - "inputs to be 1, but receieved {}", - get_discard_copy_degree(input_lhs))); - } + ASSERT( + lhs_input_degrees.discard_copy_degree.value == 1, + "Elementwise Add expected discard copy degree of inputs to be 1"); break; } @@ -64,12 +73,56 @@ tl::expected case OperatorType::EW_MIN: NOT_IMPLEMENTED(); default: - return tl::unexpected(fmt::format( - "Unexpected element-wise binary operator {}", attrs.type)); + PANIC("Unexpected element-wise binary operator", attrs.type); } - return input_lhs; + return lhs_input_degrees; } } +OperatorTaskSpace + get_operator_task_space(ElementBinaryAttrs const &attrs, + ParallelTensorDimDegrees const &lhs_input_degrees, + ParallelTensorDimDegrees const &rhs_input_degrees) { + + ParallelTensorDimDegrees output_degrees = get_output_parallel_dim_degrees( + attrs, lhs_input_degrees, rhs_input_degrees); + + return get_operator_task_space_matching_parallel_tensor_dim_degrees( + output_degrees); +} + +OperatorSpaceToParallelTensorSpaceMapping get_operator_to_lhs_input_mapping( + ElementBinaryAttrs const &attrs, + ParallelTensorDimDegrees const &lhs_input_degrees, + ParallelTensorDimDegrees const &rhs_input_degrees) { + + return get_identity_mapping( + get_operator_task_space(attrs, lhs_input_degrees, rhs_input_degrees), + lhs_input_degrees); +} + +OperatorSpaceToParallelTensorSpaceMapping get_operator_to_rhs_input_mapping( + ElementBinaryAttrs const &attrs, + ParallelTensorDimDegrees const &lhs_input_degrees, + ParallelTensorDimDegrees const &rhs_input_degrees) { + + return get_identity_mapping( + get_operator_task_space(attrs, lhs_input_degrees, rhs_input_degrees), + rhs_input_degrees); +} + +OperatorSpaceToParallelTensorSpaceMapping get_operator_to_output_mapping( + ElementBinaryAttrs const &attrs, + ParallelTensorDimDegrees const &lhs_input_degrees, + ParallelTensorDimDegrees const &rhs_input_degrees) { + + ParallelTensorDimDegrees output_dim_degrees = get_output_parallel_dim_degrees( + attrs, lhs_input_degrees, rhs_input_degrees); + + return get_identity_mapping( + get_operator_task_space(attrs, lhs_input_degrees, rhs_input_degrees), + output_dim_degrees); +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/element_unary.cc b/lib/op-attrs/src/op-attrs/ops/element_unary.cc index fd65e1f5c9..9d02923689 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_unary.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_unary.cc @@ -1,5 +1,10 @@ #include "op-attrs/ops/element_unary.h" +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.h" +#include "op-attrs/operator_task_space.h" +#include "op-attrs/parallel_tensor_dim_degrees.h" +#include "op-attrs/parallel_tensor_dim_idx_t.h" #include "op-attrs/parallel_tensor_shape.h" +#include "utils/orthotope/minimal_dim_domain.h" namespace FlexFlow { @@ -10,28 +15,58 @@ ElementUnaryAttrs make_relu_attrs() { }; } -tl::expected - get_output_shape(ElementUnaryAttrs const &attrs, - TensorShape const &input_shape) { +TensorShape get_output_shape(ElementUnaryAttrs const &attrs, + TensorShape const &input_shape) { return input_shape; } -tl::expected - get_output_shape(ElementUnaryAttrs const &attrs, - ParallelTensorShape const &input_shape) { - if (get_sum_degree(input_shape) != 1) { - return tl::unexpected( - fmt::format("Expected sum degree 1, but receieved sum degree {}", - get_sum_degree(input_shape))); - } - - if (get_discard_copy_degree(input_shape) != 1) { - return tl::unexpected(fmt::format( - "Expected discard copy degree 1, but received discartd copy degree {}", - get_discard_copy_degree(input_shape))); - } +ParallelTensorShape get_output_shape(ElementUnaryAttrs const &attrs, + ParallelTensorShape const &input_shape) { + TensorShape output_shape = + get_output_shape(attrs, get_reduced_shape(input_shape)); - return input_shape; + ParallelTensorDimDegrees output_degrees = + get_output_parallel_dim_degrees(attrs, get_parallel_degrees(input_shape)); + + return lift_to_parallel_with_degrees(output_shape, output_degrees); +} + +ParallelTensorDimDegrees get_output_parallel_dim_degrees( + ElementUnaryAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + ASSERT(input_degrees.sum_degree.value == 1); + ASSERT(input_degrees.discard_copy_degree.value == 1); + + return input_degrees; +} + +OperatorTaskSpace + get_operator_task_space(ElementUnaryAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + ParallelTensorDimDegrees output_degrees = + get_output_parallel_dim_degrees(attrs, input_degrees); + + return get_operator_task_space_matching_parallel_tensor_dim_degrees( + output_degrees); +} + +OperatorSpaceToParallelTensorSpaceMapping get_operator_to_input_mapping( + ElementUnaryAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + + return get_identity_mapping(get_operator_task_space(attrs, input_degrees), + input_degrees); +} + +OperatorSpaceToParallelTensorSpaceMapping get_operator_to_output_mapping( + ElementUnaryAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + + ParallelTensorDimDegrees output_degrees = + get_output_parallel_dim_degrees(attrs, input_degrees); + + return get_identity_mapping(get_operator_task_space(attrs, input_degrees), + output_degrees); } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/embedding.cc b/lib/op-attrs/src/op-attrs/ops/embedding.cc index e0e1a44b3b..451468ba28 100644 --- a/lib/op-attrs/src/op-attrs/ops/embedding.cc +++ b/lib/op-attrs/src/op-attrs/ops/embedding.cc @@ -130,7 +130,7 @@ tl::expected unpar, sum_degree, discard_copy_degree, shard_degrees); } -std::vector get_initializers( +std::unordered_map get_initializers( EmbeddingAttrs const &, std::optional const &maybe_initializer_attrs) { InitializerAttrs default_initializer_attrs = InitializerAttrs{ @@ -141,7 +141,12 @@ std::vector get_initializers( }, }; - return {maybe_initializer_attrs.value_or(default_initializer_attrs)}; + return { + { + TensorSlotName::WEIGHT, + maybe_initializer_attrs.value_or(default_initializer_attrs), + }, + }; } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/flat.cc b/lib/op-attrs/src/op-attrs/ops/flat.cc index 14180cecf8..5469380c05 100644 --- a/lib/op-attrs/src/op-attrs/ops/flat.cc +++ b/lib/op-attrs/src/op-attrs/ops/flat.cc @@ -34,9 +34,8 @@ TensorShape get_output_shape(FlatAttrs const &attrs, }; } -tl::expected - get_output_parallel_dim_degrees( - FlatAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees) { +ParallelTensorDimDegrees get_output_parallel_dim_degrees( + FlatAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees) { FFOrdered flattened_dim_degrees = slice(input_degrees.shard_degrees, attrs.start_dim, attrs.end_dim); @@ -44,14 +43,10 @@ tl::expected return input_degrees; } - if (any_of(flattened_dim_degrees, - [](positive_int degree) { return degree != 1; })) { - return tl::unexpected( - fmt::format("get_output_parallel_dim_degrees for {} expected all shard " - "degrees of flattened dimensions to be 1, but received {}", - attrs, - input_degrees)); - } + ASSERT(any_of(flattened_dim_degrees, + [](positive_int degree) { return degree != 1; }), + "get_output_parallel_dim_degrees for {} expected all shard degrees of " + "flattened dimensions to be 1"); return ParallelTensorDimDegrees{ /*sum_degree=*/input_degrees.sum_degree, @@ -65,20 +60,12 @@ tl::expected }; } -tl::expected - get_output_shape(FlatAttrs const &attrs, - ParallelTensorShape const &input_shape) { +ParallelTensorShape get_output_shape(FlatAttrs const &attrs, + ParallelTensorShape const &input_shape) { TensorShape unpar = get_output_shape(attrs, get_reduced_shape(input_shape)); - ParallelTensorDimDegrees degrees = ({ - tl::expected returned = - get_output_parallel_dim_degrees(attrs, - get_parallel_degrees(input_shape)); - if (!returned.has_value()) { - return tl::unexpected(returned.error()); - } - returned.value(); - }); + ParallelTensorDimDegrees degrees = + get_output_parallel_dim_degrees(attrs, get_parallel_degrees(input_shape)); return lift_to_parallel_with_degrees(unpar, degrees); } diff --git a/lib/op-attrs/src/op-attrs/ops/gather.cc b/lib/op-attrs/src/op-attrs/ops/gather.cc index 4b1053aee1..2c5a4bbdc0 100644 --- a/lib/op-attrs/src/op-attrs/ops/gather.cc +++ b/lib/op-attrs/src/op-attrs/ops/gather.cc @@ -1,4 +1,5 @@ #include "op-attrs/ops/gather.h" +#include "utils/exception.h" namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/input.cc b/lib/op-attrs/src/op-attrs/ops/input.cc index d1f68584b9..1c0bdf0740 100644 --- a/lib/op-attrs/src/op-attrs/ops/input.cc +++ b/lib/op-attrs/src/op-attrs/ops/input.cc @@ -1,4 +1,6 @@ #include "op-attrs/ops/input.h" +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.h" +#include "op-attrs/operator_task_space.h" #include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { @@ -11,4 +13,14 @@ ParallelTensorShape get_output_parallel_tensor_shape(InputAttrs const &attrs) { return lift_to_parallel(attrs.tensor_shape); } +OperatorTaskSpace get_operator_task_space(InputAttrs const &) { + return trivial_op_task_space(); +} + +OperatorSpaceToParallelTensorSpaceMapping + get_operator_to_output_mapping(InputAttrs const &attrs) { + + return empty_operator_space_to_ptensor_space_map(); +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc index e0db1cdfe7..81aa2d8a52 100644 --- a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc @@ -9,18 +9,21 @@ #include "utils/containers/contains.h" #include "utils/containers/extend.h" #include "utils/containers/filter.h" +#include "utils/containers/vector_of.h" #include "utils/expected.h" #include "utils/fmt/set.h" namespace FlexFlow { -std::vector +std::unordered_map get_layer_norm_incoming_tensor_roles(LayerNormAttrs const &attrs) { - std::vector result = {IncomingTensorRole::INPUT}; + std::unordered_map result = { + {TensorSlotName::INPUT, IncomingTensorRole::INPUT}, + }; if (attrs.elementwise_affine) { - extend(result, - std::vector{IncomingTensorRole::WEIGHT, IncomingTensorRole::WEIGHT}); + result[TensorSlotName::GAMMA] = IncomingTensorRole::WEIGHT; + result[TensorSlotName::BETA] = IncomingTensorRole::WEIGHT; } return result; @@ -72,7 +75,7 @@ tl::expected } std::vector non_layer_norm_dim_idxs = filter( - get_idxs(input_shape.dims.ff_ordered), + vector_of(get_idxs(input_shape.dims.ff_ordered)), [&](ff_dim_t const &dim_idx) { return !contains(attrs.axes, dim_idx); }); std::vector raw_weight_dims = transform(non_layer_norm_dim_idxs, [&](ff_dim_t const &dim_idx) { @@ -97,7 +100,7 @@ tl::expected return get_gamma_weights_shape(attrs, input_shape); } -tl::expected, std::string> +tl::expected, std::string> get_weight_shapes(LayerNormAttrs const &attrs, TensorShape const &input_shape) { @@ -106,9 +109,15 @@ tl::expected, std::string> TensorShape beta_shape = PROPAGATE_ERR(get_beta_weights_shape(attrs, input_shape)); - return std::vector{ - gamma_shape, - beta_shape, + return std::unordered_map{ + { + TensorSlotName::GAMMA, + gamma_shape, + }, + { + TensorSlotName::BETA, + beta_shape, + }, }; } @@ -180,7 +189,7 @@ tl::expected } std::vector non_layer_norm_dim_idxs = filter( - get_idxs(input_shape.dims.shard_dims), + vector_of(get_idxs(input_shape.dims.shard_dims)), [&](ff_dim_t const &dim_idx) { return !contains(attrs.axes, dim_idx); }); std::vector raw_weight_shard_dims = transform(non_layer_norm_dim_idxs, [&](ff_dim_t const &dim_idx) { @@ -211,7 +220,8 @@ tl::expected return get_gamma_weights_shape(attrs, input_shape); } -tl::expected, std::string> +tl::expected, + std::string> get_weight_shapes(LayerNormAttrs const &attrs, ParallelTensorShape const &input_shape) { @@ -220,13 +230,20 @@ tl::expected, std::string> ParallelTensorShape beta_shape = PROPAGATE_ERR(get_beta_weights_shape(attrs, input_shape)); - return std::vector{ - gamma_shape, - beta_shape, + return std::unordered_map{ + { + TensorSlotName::GAMMA, + gamma_shape, + }, + { + TensorSlotName::BETA, + beta_shape, + }, }; } -std::vector get_initializers(LayerNormAttrs const &attrs) { +std::unordered_map + get_initializers(LayerNormAttrs const &attrs) { if (attrs.elementwise_affine) { InitializerAttrs gamma_initializer = InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{1}}}}; @@ -234,7 +251,10 @@ std::vector get_initializers(LayerNormAttrs const &attrs) { InitializerAttrs beta_initializer = InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; - return {gamma_initializer, beta_initializer}; + return { + {TensorSlotName::GAMMA, gamma_initializer}, + {TensorSlotName::BETA, beta_initializer}, + }; } else { return {}; } diff --git a/lib/op-attrs/src/op-attrs/ops/linear.cc b/lib/op-attrs/src/op-attrs/ops/linear.cc index 37f504f873..2518df77e4 100644 --- a/lib/op-attrs/src/op-attrs/ops/linear.cc +++ b/lib/op-attrs/src/op-attrs/ops/linear.cc @@ -2,25 +2,39 @@ #include "op-attrs/ff_ordered/slice.h" #include "op-attrs/ff_ordered/transform.h" #include "op-attrs/initializers/kaiming_initializer_mode.h" +#include "op-attrs/num_ptensor_shard_dims_t.h" +#include "op-attrs/num_tensor_dims_t.h" +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.h" +#include "op-attrs/operator_task_space.h" +#include "op-attrs/parallel_tensor_dim_degrees.h" +#include "op-attrs/parallel_tensor_dim_idx_t.h" #include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/parallel_tensor_space_to_parallel_tensor_space_mapping.h" +#include "op-attrs/relative_ff_dim_t.h" #include "op-attrs/tensor_dims.h" #include "op-attrs/tensor_shape.h" #include "utils/containers/product.h" +#include "utils/containers/unordered_set_of.h" #include "utils/expected.h" #include "utils/fmt/optional.h" #include "utils/integer_conversions.h" +#include "utils/orthotope/dim_projection.h" +#include "utils/orthotope/down_projection.h" +#include "utils/orthotope/eq_projection.h" +#include "utils/orthotope/minimal_dim_domain_mapping.h" +#include "utils/orthotope/up_projection.h" namespace FlexFlow { -std::vector +std::unordered_map get_linear_incoming_tensor_roles(LinearAttrs const &attrs) { - std::vector result = { - IncomingTensorRole::INPUT, - IncomingTensorRole::WEIGHT, + std::unordered_map result = { + {TensorSlotName::INPUT, IncomingTensorRole::INPUT}, + {TensorSlotName::WEIGHT, IncomingTensorRole::WEIGHT}, }; if (attrs.use_bias) { - result.push_back(IncomingTensorRole::WEIGHT); + result[TensorSlotName::BIAS] = IncomingTensorRole::WEIGHT; } return result; @@ -74,16 +88,22 @@ tl::expected return output_shape; } -tl::expected, std::string> +tl::expected, std::string> get_weight_shapes(LinearAttrs const &attrs, TensorShape const &input_shape) { - std::vector weight_shapes = { - PROPAGATE_ERR(get_projection_shape(attrs, input_shape)), + std::unordered_map weight_shapes = { + { + TensorSlotName::WEIGHT, + PROPAGATE_ERR(get_projection_shape(attrs, input_shape)), + }, }; if (attrs.use_bias) { - weight_shapes.push_back(PROPAGATE_ERR(get_bias_shape(attrs, input_shape))); + weight_shapes.insert({ + TensorSlotName::BIAS, + PROPAGATE_ERR(get_bias_shape(attrs, input_shape)), + }); } return weight_shapes; @@ -101,18 +121,10 @@ tl::expected result_unpar.value(); }); - SumDegree sum_degree = SumDegree{1_p}; - DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{ - get_sum_degree(input) * product(slice(ff_ordered_shard_degrees(input), - relative_ff_dim_t{0}, - relative_ff_dim_t{-1}))}; - FFOrdered shard_degrees = FFOrdered{ - get_discard_copy_degree(input), - shard_dim_at_idx(input, relative_ff_dim_t{-1}).degree, - }; + ParallelTensorDimDegrees projection_degrees = + get_projection_parallel_dim_degrees(attrs, get_parallel_degrees(input)); - return lift_to_parallel_with_degrees( - unpar, sum_degree, discard_copy_degree, shard_degrees); + return lift_to_parallel_with_degrees(unpar, projection_degrees); } tl::expected @@ -126,18 +138,10 @@ tl::expected result_unpar.value(); }); - SumDegree sum_degree = - SumDegree{get_sum_degree(input) * - shard_dim_at_idx(input, relative_ff_dim_t{-1}).degree}; - DiscardCopyDegree discard_copy_degree = - DiscardCopyDegree{product(slice(ff_ordered_shard_degrees(input), - relative_ff_dim_t{0}, - relative_ff_dim_t{-1}))}; - FFOrdered shard_degrees = - FFOrdered{get_discard_copy_degree(input)}; + ParallelTensorDimDegrees bias_degrees = + get_bias_parallel_dim_degrees(attrs, get_parallel_degrees(input)); - return lift_to_parallel_with_degrees( - unpar, sum_degree, discard_copy_degree, shard_degrees); + return lift_to_parallel_with_degrees(unpar, bias_degrees); } tl::expected @@ -152,27 +156,84 @@ tl::expected result_unpar.value(); }); - SumDegree sum_degree = - SumDegree{get_sum_degree(input) * - shard_dim_at_idx(input, relative_ff_dim_t{-1}).degree}; + ParallelTensorDimDegrees output_degrees = + get_output_parallel_dim_degrees(attrs, get_parallel_degrees(input)); + + return lift_to_parallel_with_degrees(unpar, output_degrees); +} + +ParallelTensorDimDegrees + get_projection_parallel_dim_degrees(LinearAttrs const &attrs, + ParallelTensorDimDegrees const &input) { + SumDegree sum_degree = SumDegree{1_p}; + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{ + input.sum_degree.value * product(slice(input.shard_degrees, + relative_ff_dim_t{0}, + relative_ff_dim_t{-1}))}; + FFOrdered shard_degrees = FFOrdered{ + input.discard_copy_degree.value, + input.shard_degrees.at(relative_ff_dim_t{-1}), + }; + + return ParallelTensorDimDegrees{ + /*sum_degree=*/sum_degree, + /*discard_copy_degree=*/discard_copy_degree, + /*shard_degrees=*/shard_degrees, + }; +} + +ParallelTensorDimDegrees + get_bias_parallel_dim_degrees(LinearAttrs const &attrs, + ParallelTensorDimDegrees const &input) { + + SumDegree sum_degree = SumDegree{ + input.sum_degree.value * input.shard_degrees.at(relative_ff_dim_t{-1}), + }; + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{product( + slice(input.shard_degrees, relative_ff_dim_t{0}, relative_ff_dim_t{-1}))}; + FFOrdered shard_degrees = + FFOrdered{input.discard_copy_degree.value}; + + return ParallelTensorDimDegrees{ + /*sum_degree=*/sum_degree, + /*discard_copy_degree=*/discard_copy_degree, + /*shard_degrees=*/shard_degrees, + }; +} + +ParallelTensorDimDegrees + get_output_parallel_dim_degrees(LinearAttrs const &attrs, + ParallelTensorDimDegrees const &input) { + SumDegree sum_degree = SumDegree{ + input.sum_degree.value * input.shard_degrees.at(relative_ff_dim_t{-1}), + }; + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{1_p}; - FFOrdered shard_degrees = ff_ordered_shard_degrees(input); - shard_degrees.at(relative_ff_dim_t{-1}) = get_discard_copy_degree(input); + FFOrdered shard_degrees = input.shard_degrees; + shard_degrees.at(relative_ff_dim_t{-1}) = input.discard_copy_degree.value; - return lift_to_parallel_with_degrees( - unpar, sum_degree, discard_copy_degree, shard_degrees); + return ParallelTensorDimDegrees{ + /*sum_degree=*/sum_degree, + /*discard_copy_degree=*/discard_copy_degree, + /*shard_degrees=*/shard_degrees, + }; } -tl::expected, std::string> +tl::expected, + std::string> get_weight_shapes(LinearAttrs const &attrs, ParallelTensorShape const &input_shape) { - std::vector weight_shapes = { - PROPAGATE_ERR(get_projection_shape(attrs, input_shape)), + std::unordered_map weight_shapes = { + { + TensorSlotName::WEIGHT, + PROPAGATE_ERR(get_projection_shape(attrs, input_shape)), + }, }; if (attrs.use_bias) { - weight_shapes.push_back(PROPAGATE_ERR(get_bias_shape(attrs, input_shape))); + weight_shapes.insert({TensorSlotName::BIAS, + PROPAGATE_ERR(get_bias_shape(attrs, input_shape))}); } return weight_shapes; @@ -184,11 +245,12 @@ tl::expected, std::string> * see * https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/linear.py#L114-L122 */ -tl::expected, std::string> get_initializers( - LinearAttrs const &attrs, - TensorShape const &input_shape, - std::optional const &maybe_projection_initializer, - std::optional const &maybe_bias_initializer) { +tl::expected, std::string> + get_initializers( + LinearAttrs const &attrs, + TensorShape const &input_shape, + std::optional const &maybe_projection_initializer, + std::optional const &maybe_bias_initializer) { if (!attrs.use_bias && maybe_bias_initializer.has_value()) { return tl::unexpected( @@ -227,10 +289,217 @@ tl::expected, std::string> get_initializers( maybe_bias_initializer.value_or(bias_default_initializer); if (attrs.use_bias) { - return std::vector{projection_initializer, bias_initializer}; + return std::unordered_map{ + {TensorSlotName::WEIGHT, projection_initializer}, + {TensorSlotName::BIAS, bias_initializer}, + }; } else { - return std::vector{projection_initializer}; + return std::unordered_map{ + {TensorSlotName::WEIGHT, projection_initializer}, + }; } } +OperatorTaskSpace + get_operator_task_space(LinearAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + + ParallelTensorDimDegrees output_degrees = + get_output_parallel_dim_degrees(attrs, input_degrees); + + return get_operator_task_space_matching_parallel_tensor_dim_degrees( + output_degrees); +} + +static ParallelTensorSpaceToParallelTensorSpaceMapping + get_input_to_output_mapping(LinearAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + + num_tensor_dims_t input_num_dims = + get_ptensor_dim_degrees_num_tensor_dims(input_degrees); + + DownProjection + inp_to_out = make_empty_down_projection(); + + ff_dim_t input_channel_dim = + ff_dim_t_from_relative_ff_dim_t(relative_ff_dim_t{-1}, input_num_dims); + + num_tensor_dims_t output_num_dims = input_num_dims; + ff_dim_t output_channel_dim = + ff_dim_t_from_relative_ff_dim_t(relative_ff_dim_t{-1}, output_num_dims); + + project_dims(inp_to_out, + /*from=*/{sum_dim_idx(), shard_dim_idx(input_channel_dim)}, + /*onto=*/sum_dim_idx()); + project_dims(inp_to_out, + /*from=*/{discard_copy_dim_idx()}, + /*onto=*/shard_dim_idx(output_channel_dim)); + + for (ff_dim_t const &idx : slice(tensor_dims_range(input_num_dims), 0, -1)) { + project_dims(inp_to_out, + /*from=*/{shard_dim_idx(idx)}, + /*onto=*/shard_dim_idx(idx)); + } + + ParallelTensorDimDegrees output_degrees = + get_output_parallel_dim_degrees(attrs, input_degrees); + + return parallel_tensor_space_mapping_from_projection( + DimProjection{inp_to_out}, input_degrees, output_degrees); +} + +static ParallelTensorSpaceToParallelTensorSpaceMapping + get_input_to_projection_mapping( + LinearAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + + num_ptensor_shard_dims_t input_num_shard_dims = + get_ptensor_dim_degrees_num_shard_dims(input_degrees); + + DownProjection + inp_to_proj = make_empty_down_projection(); + + parallel_tensor_dim_idx_t input_channel_dim = parallel_tensor_dim_idx_t{ + ff_dim_t{ + nonnegative_int{ + input_num_shard_dims.value.unwrap_nonnegative() - 1, + }, + }, + }; + + { + std::unordered_set dims_from = + unordered_set_of(dim_idxs_for_num_shard_dims(input_num_shard_dims)); + dims_from.insert(sum_dim_idx()); + dims_from.erase(input_channel_dim); + dims_from.erase(discard_copy_dim_idx()); + + project_dims(inp_to_proj, + /*from=*/dims_from, + /*onto=*/discard_copy_dim_idx()); + } + + parallel_tensor_dim_idx_t projection_in_channel_dim = + parallel_tensor_dim_idx_t{ff_dim_t{0_n}}; + + parallel_tensor_dim_idx_t projection_out_channel_dim = + parallel_tensor_dim_idx_t{ff_dim_t{1_n}}; + + project_dims(inp_to_proj, + /*from=*/{discard_copy_dim_idx()}, + /*onto=*/projection_out_channel_dim); + + project_dims(inp_to_proj, + /*from=*/{input_channel_dim}, + /*onto=*/projection_in_channel_dim); + + ParallelTensorDimDegrees projection_degrees = + get_projection_parallel_dim_degrees(attrs, input_degrees); + + return parallel_tensor_space_mapping_from_projection( + DimProjection{inp_to_proj}, input_degrees, projection_degrees); +} + +static ParallelTensorSpaceToParallelTensorSpaceMapping + get_input_to_bias_mapping(LinearAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + ASSERT(attrs.use_bias); + + num_ptensor_shard_dims_t input_num_shard_dims = + get_ptensor_dim_degrees_num_shard_dims(input_degrees); + + ParallelTensorDimDegrees bias_degrees = + get_bias_parallel_dim_degrees(attrs, input_degrees); + + DownProjection + inp_to_bias = make_empty_down_projection(); + + parallel_tensor_dim_idx_t input_channel_dim = parallel_tensor_dim_idx_t{ + ff_dim_t{ + nonnegative_int{ + input_num_shard_dims.value.unwrap_nonnegative() - 1, + }, + }, + }; + + { + std::unordered_set dims_from = + unordered_set_of(dim_idxs_for_num_shard_dims(input_num_shard_dims)); + dims_from.erase(input_channel_dim); + dims_from.erase(sum_dim_idx()); + + project_dims(inp_to_bias, + /*from=*/dims_from, + /*onto=*/discard_copy_dim_idx()); + } + + parallel_tensor_dim_idx_t bias_out_channel_dim = + parallel_tensor_dim_idx_t{ff_dim_t{0_n}}; + + project_dims(inp_to_bias, + /*from=*/ + { + sum_dim_idx(), + input_channel_dim, + }, + /*onto=*/sum_dim_idx()); + + DimDomain l_domain = + dim_domain_from_parallel_tensor_dim_degrees(input_degrees); + DimDomain r_domain = + dim_domain_from_parallel_tensor_dim_degrees(bias_degrees); + + return parallel_tensor_space_mapping_from_projection( + DimProjection{inp_to_bias}, input_degrees, bias_degrees); +} + +OperatorSpaceToParallelTensorSpaceMapping get_operator_to_projection_mapping( + LinearAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees) { + + return operator_ptensor_space_mapping_from_composition( + get_operator_to_input_mapping(attrs, input_degrees), + get_input_to_projection_mapping(attrs, input_degrees)); +} + +OperatorSpaceToParallelTensorSpaceMapping get_operator_to_input_mapping( + LinearAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees) { + + DimDomainMapping + inp_to_out = + get_input_to_output_mapping(attrs, input_degrees).raw_mapping; + + DimDomainMapping + op_to_out = + get_operator_to_output_mapping(attrs, input_degrees).raw_mapping; + + DimDomainMapping + op_to_inp = compose_dim_domain_mappings_through_minimal( + op_to_out, invert_dim_domain_mapping(inp_to_out)); + + return OperatorSpaceToParallelTensorSpaceMapping{ + op_to_inp, + }; +} + +OperatorSpaceToParallelTensorSpaceMapping get_operator_to_bias_mapping( + LinearAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees) { + + return operator_ptensor_space_mapping_from_composition( + get_operator_to_input_mapping(attrs, input_degrees), + get_input_to_bias_mapping(attrs, input_degrees)); +} + +OperatorSpaceToParallelTensorSpaceMapping get_operator_to_output_mapping( + LinearAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees) { + + ParallelTensorDimDegrees output_degrees = + get_output_parallel_dim_degrees(attrs, input_degrees); + + return get_identity_mapping(get_operator_task_space(attrs, input_degrees), + output_degrees); +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc index ee75340ed0..289895a87b 100644 --- a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc @@ -47,13 +47,15 @@ tl::expected input_dims)); } - // Note that for some reason the stack overflow post linked above states that - // `kernel_size = ind - (outd-1)*stride`, but some simplification yields - // `kernel_size` = `ind - (outd - 1)*stride` - // = `ind - (outd - 1) * (ind / outd)` - // = `ind - ind + (ind /outd)` - // = `ind / outd` - // = `stride` + /** + * Note that for some reason the stack overflow post linked above states that + * `kernel_size = ind - (outd-1)*stride`, but some simplification yields + * `kernel_size` = `ind - (outd - 1)*stride` + * = `ind - (outd - 1) * (ind / outd)` + * = `ind - ind + (ind /outd)` + * = `ind / outd` + * = `stride` + */ positive_int kernel_h = positive_int{input_h / output_h}; positive_int kernel_w = positive_int{input_w / output_w}; diff --git a/lib/op-attrs/src/op-attrs/ops/reduce.cc b/lib/op-attrs/src/op-attrs/ops/reduce.cc index 2a8bf06ecf..e5474ae124 100644 --- a/lib/op-attrs/src/op-attrs/ops/reduce.cc +++ b/lib/op-attrs/src/op-attrs/ops/reduce.cc @@ -1,4 +1,5 @@ #include "op-attrs/ops/reduce.h" +#include "utils/exception.h" namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/reshape.cc b/lib/op-attrs/src/op-attrs/ops/reshape.cc index 6216ad8c6c..d8ea92d540 100644 --- a/lib/op-attrs/src/op-attrs/ops/reshape.cc +++ b/lib/op-attrs/src/op-attrs/ops/reshape.cc @@ -1,4 +1,5 @@ #include "op-attrs/ops/reshape.h" +#include "utils/exception.h" namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/reverse.cc b/lib/op-attrs/src/op-attrs/ops/reverse.cc index c38d7e4782..3a063d1af9 100644 --- a/lib/op-attrs/src/op-attrs/ops/reverse.cc +++ b/lib/op-attrs/src/op-attrs/ops/reverse.cc @@ -1,4 +1,5 @@ #include "op-attrs/ops/reverse.h" +#include "utils/exception.h" namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/split.cc b/lib/op-attrs/src/op-attrs/ops/split.cc index a9fe691584..ed737b18d9 100644 --- a/lib/op-attrs/src/op-attrs/ops/split.cc +++ b/lib/op-attrs/src/op-attrs/ops/split.cc @@ -1,4 +1,5 @@ #include "op-attrs/ops/split.h" +#include "utils/exception.h" namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/topk.cc b/lib/op-attrs/src/op-attrs/ops/topk.cc index 7a6868340b..179d6cfdd3 100644 --- a/lib/op-attrs/src/op-attrs/ops/topk.cc +++ b/lib/op-attrs/src/op-attrs/ops/topk.cc @@ -1,4 +1,5 @@ #include "op-attrs/ops/topk.h" +#include "utils/exception.h" namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/transpose.cc b/lib/op-attrs/src/op-attrs/ops/transpose.cc index 50e6fb35f5..a13a5d1724 100644 --- a/lib/op-attrs/src/op-attrs/ops/transpose.cc +++ b/lib/op-attrs/src/op-attrs/ops/transpose.cc @@ -1,14 +1,94 @@ #include "op-attrs/ops/transpose.h" +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.h" +#include "op-attrs/operator_task_space.h" +#include "op-attrs/parallel_tensor_dim_idx_t.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/parallel_tensor_space_to_parallel_tensor_space_mapping.dtg.h" +#include "op-attrs/parallel_tensor_space_to_parallel_tensor_space_mapping.h" +#include "utils/bidict/algorithms/transform_keys.h" +#include "utils/bidict/algorithms/transform_values.h" namespace FlexFlow { -TensorShape get_output_shape(TransposeAttrs const &, TensorShape const &) { - NOT_IMPLEMENTED(); +TensorShape get_output_shape(TransposeAttrs const &attrs, + TensorShape const &input_shape) { + return permute_tensor_shape(attrs.permutation, input_shape); } -ParallelTensorShape get_output_shape(TransposeAttrs const &op_attrs, +ParallelTensorDimDegrees get_output_parallel_dim_degrees( + TransposeAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + return permute_parallel_tensor_dim_degrees(attrs.permutation, input_degrees); +} + +ParallelTensorShape get_output_shape(TransposeAttrs const &attrs, ParallelTensorShape const &input_shape) { - NOT_IMPLEMENTED(); + TensorShape output_shape = + get_output_shape(attrs, get_reduced_shape(input_shape)); + + ParallelTensorDimDegrees output_degrees = + get_output_parallel_dim_degrees(attrs, get_parallel_degrees(input_shape)); + + return lift_to_parallel_with_degrees(output_shape, output_degrees); +} + +OperatorTaskSpace + get_operator_task_space(TransposeAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + ParallelTensorDimDegrees output_degrees = + get_output_parallel_dim_degrees(attrs, input_degrees); + + return get_operator_task_space_matching_parallel_tensor_dim_degrees( + output_degrees); +} + +static ParallelTensorSpaceToParallelTensorSpaceMapping + get_input_to_output_mapping(TransposeAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + auto ff_dim_to_pt_dim = [](ff_dim_t d) -> parallel_tensor_dim_idx_t { + return parallel_tensor_dim_idx_t{d}; + }; + + EqProjection + inp_to_out = EqProjection{ + transform_keys( + transform_values(attrs.permutation.as_bidict(), ff_dim_to_pt_dim), + ff_dim_to_pt_dim), + }; + + project_dims(inp_to_out, sum_dim_idx(), sum_dim_idx()); + project_dims(inp_to_out, discard_copy_dim_idx(), discard_copy_dim_idx()); + + ParallelTensorDimDegrees output_degrees = + get_output_parallel_dim_degrees(attrs, input_degrees); + + return parallel_tensor_space_mapping_from_projection( + DimProjection{inp_to_out}, input_degrees, output_degrees); +} + +OperatorSpaceToParallelTensorSpaceMapping get_operator_to_input_mapping( + TransposeAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + ParallelTensorSpaceToParallelTensorSpaceMapping inp_to_out = + get_input_to_output_mapping(attrs, input_degrees); + + ParallelTensorSpaceToParallelTensorSpaceMapping out_to_inp = + invert_parallel_tensor_space_mapping(inp_to_out); + + OperatorSpaceToParallelTensorSpaceMapping op_to_out = + get_operator_to_output_mapping(attrs, input_degrees); + + return operator_ptensor_space_mapping_from_composition(op_to_out, out_to_inp); +} + +OperatorSpaceToParallelTensorSpaceMapping get_operator_to_output_mapping( + TransposeAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + ParallelTensorDimDegrees output_degrees = + get_output_parallel_dim_degrees(attrs, input_degrees); + + return get_identity_mapping(get_operator_task_space(attrs, input_degrees), + output_degrees); } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/weight.cc b/lib/op-attrs/src/op-attrs/ops/weight.cc index 710529af0a..ba63eca6ee 100644 --- a/lib/op-attrs/src/op-attrs/ops/weight.cc +++ b/lib/op-attrs/src/op-attrs/ops/weight.cc @@ -1,4 +1,6 @@ #include "op-attrs/ops/weight.h" +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.h" +#include "op-attrs/operator_task_space.h" #include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { @@ -23,4 +25,14 @@ ParallelTensorShape get_output_parallel_tensor_shape(WeightAttrs const &attrs) { return lift_to_parallel(attrs.tensor_shape); } +OperatorTaskSpace get_operator_task_space(WeightAttrs const &) { + return trivial_op_task_space(); +} + +OperatorSpaceToParallelTensorSpaceMapping + get_operator_to_output_mapping(WeightAttrs const &attrs) { + + return empty_operator_space_to_ptensor_space_map(); +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_op_attrs.cc b/lib/op-attrs/src/op-attrs/parallel_op_attrs.cc index c458d4149d..ccacd9bc3e 100644 --- a/lib/op-attrs/src/op-attrs/parallel_op_attrs.cc +++ b/lib/op-attrs/src/op-attrs/parallel_op_attrs.cc @@ -3,6 +3,7 @@ #include "op-attrs/ops/reduction.h" #include "op-attrs/ops/repartition.h" #include "op-attrs/ops/replicate.h" +#include "utils/exception.h" #include "utils/overload.h" namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc new file mode 100644 index 0000000000..51d7968033 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc @@ -0,0 +1,170 @@ +#include "op-attrs/parallel_tensor_dim_degrees.h" +#include "op-attrs/ff_ordered/ff_ordered_from_map.h" +#include "op-attrs/ff_ordered/get_idxs.h" +#include "op-attrs/num_tensor_dims_t.h" +#include "op-attrs/parallel_tensor_dim_idx_t.dtg.h" +#include "op-attrs/parallel_tensor_dim_idx_t.h" +#include "op-attrs/parallel_tensor_space_coordinate.h" +#include "utils/containers/binary_merge_disjoint_maps.h" +#include "utils/containers/filtermap_keys.h" +#include "utils/containers/filtrans.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/get_all_assignments.h" +#include "utils/containers/map_keys.h" +#include "utils/containers/map_values.h" +#include "utils/containers/range.h" +#include "utils/containers/set_union.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/nonnegative_int/nonnegative_range.h" +#include "utils/nonnegative_int/num_elements.h" +#include "utils/orthotope/minimal_dim_domain.h" + +namespace FlexFlow { + +num_ptensor_shard_dims_t get_ptensor_dim_degrees_num_shard_dims( + ParallelTensorDimDegrees const °rees) { + return num_ptensor_shard_dims_t{ + num_elements(degrees.shard_degrees), + }; +} + +num_tensor_dims_t get_ptensor_dim_degrees_num_tensor_dims( + ParallelTensorDimDegrees const °rees) { + return num_tensor_dims_from_num_ptensor_shard_dims( + get_ptensor_dim_degrees_num_shard_dims(degrees)); +} + +std::unordered_set + get_parallel_tensor_dim_indices(ParallelTensorDimDegrees const °rees) { + + std::unordered_set result = + unordered_set_of(dim_idxs_for_num_shard_dims( + get_ptensor_dim_degrees_num_shard_dims(degrees))); + result.insert(sum_dim_idx()); + result.insert(discard_copy_dim_idx()); + return result; +} + +std::set get_nontrivial_parallel_tensor_dim_indices( + ParallelTensorDimDegrees const °rees) { + std::set nontrivial_replica_dims; + + if (degrees.sum_degree.value > 1) { + nontrivial_replica_dims.insert(parallel_tensor_dim_idx_t{ReplicaType::SUM}); + } + + if (degrees.discard_copy_degree.value > 1) { + nontrivial_replica_dims.insert( + parallel_tensor_dim_idx_t{ReplicaType::DISCARD_COPY}); + } + + std::set nontrivial_shard_dims = filtrans( + get_idxs(degrees.shard_degrees), + [&](ff_dim_t const &dim) -> std::optional { + if (degrees.shard_degrees.at(dim) > 1) { + return parallel_tensor_dim_idx_t{dim}; + } else { + return std::nullopt; + } + }); + + return set_union(nontrivial_replica_dims, nontrivial_shard_dims); +} + +positive_int get_degree_for_parallel_tensor_dim_idx( + ParallelTensorDimDegrees const &dim_degrees, + parallel_tensor_dim_idx_t const &idx) { + if (idx == sum_dim_idx()) { + return dim_degrees.sum_degree.value; + } else if (idx == discard_copy_dim_idx()) { + return dim_degrees.discard_copy_degree.value; + } else { + return dim_degrees.shard_degrees.at(idx.require_shard_dim()); + } +} + +std::unordered_map + get_parallel_tensor_degree_map(ParallelTensorDimDegrees const °rees) { + + std::unordered_map + replica_dim_degrees = { + {parallel_tensor_dim_idx_t{ReplicaType::SUM}, + degrees.sum_degree.value}, + {parallel_tensor_dim_idx_t{ReplicaType::DISCARD_COPY}, + degrees.discard_copy_degree.value}, + }; + + std::unordered_map shard_dim_degrees = + generate_map(get_idxs(degrees.shard_degrees), [&](ff_dim_t const &dim) { + return degrees.shard_degrees.at(dim); + }); + + return binary_merge_disjoint_maps( + /*lhs=*/replica_dim_degrees, + /*rhs=*/map_keys(shard_dim_degrees, [](ff_dim_t const &dim) { + return parallel_tensor_dim_idx_t{dim}; + })); +} + +std::unordered_set + get_parallel_tensor_space_coordinates( + ParallelTensorDimDegrees const °rees) { + + std::unordered_map degree_map = + get_parallel_tensor_degree_map(degrees); + + std::unordered_map> + possible_per_dim_coords = map_values(degree_map, [](positive_int degree) { + return unordered_set_of(nonnegative_range(degree)); + }); + + return transform( + get_all_assignments(possible_per_dim_coords), + [](std::unordered_map const + &m) { return parallel_tensor_space_coord_from_map(m); }); +} + +DimDomain + dim_domain_from_parallel_tensor_dim_degrees( + ParallelTensorDimDegrees const &dim_degrees) { + + return DimDomain{ + generate_map(get_parallel_tensor_dim_indices(dim_degrees), + [&](parallel_tensor_dim_idx_t idx) { + return get_degree_for_parallel_tensor_dim_idx(dim_degrees, + idx); + }), + }; +} + +ParallelTensorDimDegrees parallel_tensor_dim_degrees_from_dim_domain( + DimDomain const &dim_domain) { + + std::unordered_map shard_dims = + filtermap_keys(dim_domain.dims, [](parallel_tensor_dim_idx_t dim_idx) { + return dim_idx.try_require_shard_dim(); + }); + + return ParallelTensorDimDegrees{ + /*sum_degree=*/SumDegree{ + dim_domain.dims.at(sum_dim_idx()), + }, + /*discard_copy_degree=*/ + DiscardCopyDegree{ + dim_domain.dims.at(discard_copy_dim_idx()), + }, + /*shard_degres=*/ff_ordered_from_map(shard_dims), + }; +} + +MinimalDimDomain + minimal_dim_domain_from_parallel_tensor_dim_degrees( + ParallelTensorDimDegrees const &dim_degrees) { + + return minimal_dim_domain_from_dim_domain( + dim_domain_from_parallel_tensor_dim_degrees(dim_degrees)); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dim_idx_t.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dim_idx_t.cc new file mode 100644 index 0000000000..39ca693c58 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dim_idx_t.cc @@ -0,0 +1,52 @@ +#include "op-attrs/parallel_tensor_dim_idx_t.h" +#include "op-attrs/ff_dim_t.h" +#include "utils/containers/set_of.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +parallel_tensor_dim_idx_t sum_dim_idx() { + return parallel_tensor_dim_idx_t{ReplicaType::SUM}; +} + +parallel_tensor_dim_idx_t discard_copy_dim_idx() { + return parallel_tensor_dim_idx_t{ReplicaType::DISCARD_COPY}; +} + +parallel_tensor_dim_idx_t shard_dim_idx(ff_dim_t idx) { + return parallel_tensor_dim_idx_t{idx}; +} + +bool is_dim_idx_for_reduction_dimension(parallel_tensor_dim_idx_t dim_idx) { + return (dim_idx == sum_dim_idx()) || (dim_idx == discard_copy_dim_idx()); +} + +std::set + dim_idxs_for_num_shard_dims(num_ptensor_shard_dims_t num_shard_dims) { + std::set result = + transform(set_of(ff_dim_range(num_shard_dims.value)), shard_dim_idx); + result.insert(sum_dim_idx()); + result.insert(discard_copy_dim_idx()); + + return result; +} + +DimOrdering get_parallel_tensor_dim_ordering() { + + return DimOrdering{ + /*lt=*/[](parallel_tensor_dim_idx_t lhs, + parallel_tensor_dim_idx_t rhs) -> bool { + if (lhs.is_shard_dim() && rhs.is_shard_dim()) { + return lhs.require_shard_dim() < rhs.require_shard_dim(); + } else if (lhs.is_shard_dim() && !rhs.is_shard_dim()) { + return false; + } else if (!lhs.is_shard_dim() && rhs.is_shard_dim()) { + return true; + } else { + return lhs.require_replica_dim() > rhs.require_replica_dim(); + } + }, + }; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc index 1c77bc6ca8..71419e4a57 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -29,8 +29,10 @@ std::unordered_set return get_replica_dims(d.replica_dims); } -nonnegative_int num_shard_dims(ParallelTensorDims const &dims) { - return num_elements(dims.shard_dims); +num_ptensor_shard_dims_t num_shard_dims(ParallelTensorDims const &dims) { + return num_ptensor_shard_dims_t{ + num_elements(dims.shard_dims), + }; } ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorDims const &d) { @@ -42,8 +44,9 @@ ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorDims const &d) { } ParallelTensorDims lift_to_parallel(TensorDims const &dims) { - std::vector shard_degrees = - repeat_element(/*num_times=*/get_num_dims(dims), /*element=*/1_p); + std::vector shard_degrees = repeat_element( + /*num_times=*/get_num_dims(dims).nonnegative_int_from_num_tensor_dims(), + /*element=*/1_p); return lift_to_parallel_with_degrees(dims, SumDegree{1_p}, DiscardCopyDegree{1_p}, diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index 1b8f6f1dfa..91d3d0b1aa 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -5,13 +5,15 @@ #include "utils/containers/product.h" #include "utils/containers/range.h" #include "utils/containers/transform.h" +#include "utils/exception.h" #include "utils/hash-utils.h" #include "utils/nonnegative_int/nonnegative_range.h" #include "utils/overload.h" +#include namespace FlexFlow { -nonnegative_int num_shard_dims(ParallelTensorShape const &s) { +num_ptensor_shard_dims_t num_shard_dims(ParallelTensorShape const &s) { return num_shard_dims(s.dims); } @@ -97,25 +99,22 @@ ParallelTensorShape TensorShape require_not_parallel(ParallelTensorShape const &s) { positive_int total_degree = get_total_parallel_degree(s); - if (total_degree != 1_p) { - throw mk_runtime_error( - fmt::format("Error: require_not_parallel received a parallel tensor " - "shape with parallel degree {}: {}", - total_degree, - s)); - } + ASSERT(total_degree != 1_p, + "Error: require_not_parallel received a parallel tensor shape with " + "non-zero parallel degree", + s); return get_reduced_shape(s); } -TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &) { - NOT_IMPLEMENTED(); -} - TensorShape get_piece_shape(ParallelTensorShape const &s) { return get_reduced_shape(s); } +num_bytes_t get_piece_size_in_bytes(ParallelTensorShape const &s) { + return get_size_in_bytes(get_piece_shape(s)); +} + TensorShape get_reduced_shape(ParallelTensorShape const &s) { return TensorShape{ get_reduced_dims(s.dims), @@ -142,7 +141,7 @@ std::unordered_set get_parallel_tensor_dim_indices(ParallelTensorShape const &shape) { std::unordered_set indices; extend(indices, - transform(nonnegative_range(num_shard_dims(shape.dims)), + transform(nonnegative_range(num_shard_dims(shape.dims).value), [](nonnegative_int idx) { return parallel_tensor_dim_idx_t{ff_dim_t{idx}}; })); diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_space_coordinate.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_space_coordinate.cc new file mode 100644 index 0000000000..0c6e157697 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_space_coordinate.cc @@ -0,0 +1,84 @@ +#include "op-attrs/parallel_tensor_space_coordinate.h" +#include "op-attrs/ff_ordered/ff_ordered_from_map.h" +#include "op-attrs/parallel_tensor_dim_idx_t.h" +#include "utils/containers/contains_key.h" +#include "utils/containers/filtermap_keys.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/nonnegative_int/num_elements.h" + +namespace FlexFlow { + +num_ptensor_parallel_dims_t + ptensor_coord_num_dims(ParallelTensorSpaceCoordinate const &c) { + return num_ptensor_parallel_dims_t{ + 2_n + num_elements(c.shard_components), + }; +} + +num_ptensor_shard_dims_t + ptensor_coord_num_shard_dims(ParallelTensorSpaceCoordinate const &c) { + return num_ptensor_shard_dims_t{ + num_elements(c.shard_components), + }; +} + +std::unordered_set + get_dim_idxs_in_ptensor_space_coord( + ParallelTensorSpaceCoordinate const &coord) { + + std::unordered_set result = unordered_set_of( + dim_idxs_for_num_shard_dims(ptensor_coord_num_shard_dims(coord))); + result.insert(sum_dim_idx()); + result.insert(discard_copy_dim_idx()); + return result; +} + +nonnegative_int ptensor_coord_component_for_ptensor_dim_idx( + ParallelTensorSpaceCoordinate const &coord, + parallel_tensor_dim_idx_t dim_idx) { + if (dim_idx == sum_dim_idx()) { + return coord.sum_component; + } else if (dim_idx == discard_copy_dim_idx()) { + return coord.discard_copy_component; + } else { + return coord.shard_components.at(dim_idx.require_shard_dim()); + } +} + +ParallelTensorSpaceCoordinate parallel_tensor_space_coord_from_map( + std::unordered_map const &m) { + ASSERT(contains_key(m, sum_dim_idx())); + ASSERT(contains_key(m, discard_copy_dim_idx())); + + std::unordered_map shard_map = + filtermap_keys(m, [](parallel_tensor_dim_idx_t const &d) { + return d.try_require_shard_dim(); + }); + + return ParallelTensorSpaceCoordinate{ + /*sum_idx=*/m.at(parallel_tensor_dim_idx_t{ReplicaType::SUM}), + /*discard_copy_idx=*/ + m.at(parallel_tensor_dim_idx_t{ReplicaType::DISCARD_COPY}), + /*shard_idxs=*/ff_ordered_from_map(shard_map), + }; +} + +ParallelTensorSpaceCoordinate parallel_tensor_space_coord_from_dim_coord( + DimCoord const &dim_coord) { + return parallel_tensor_space_coord_from_map(dim_coord.raw); +} + +DimCoord dim_coord_from_parallel_tensor_space_coord( + ParallelTensorSpaceCoordinate const &coord) { + + return DimCoord{ + generate_map(get_dim_idxs_in_ptensor_space_coord(coord), + [&](parallel_tensor_dim_idx_t idx) { + return ptensor_coord_component_for_ptensor_dim_idx(coord, + idx); + }), + }; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_space_to_parallel_tensor_space_mapping.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_space_to_parallel_tensor_space_mapping.cc new file mode 100644 index 0000000000..2a161838cd --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_space_to_parallel_tensor_space_mapping.cc @@ -0,0 +1,53 @@ +#include "op-attrs/parallel_tensor_space_to_parallel_tensor_space_mapping.h" +#include "op-attrs/parallel_tensor_dim_degrees.h" +#include "op-attrs/parallel_tensor_dim_idx_t.h" + +namespace FlexFlow { + +ParallelTensorSpaceToParallelTensorSpaceMapping + parallel_tensor_space_mapping_from_projection( + DimProjection const &projection, + ParallelTensorDimDegrees const &l_degrees, + ParallelTensorDimDegrees const &r_degrees) { + + // TODO(@lockshaw)(#pr): + // { + // std::unordered_set + // l_dims = + // unordered_set_of(get_nontrivial_parallel_tensor_dim_indices(l_degrees)); + // std::unordered_set + // projection_input_dims = input_dims_of_projection(projection); + // + // ASSERT(l_dims == projection_input_dims); + // } + // + // { + // std::unordered_set + // r_dims = + // unordered_set_of(get_nontrivial_parallel_tensor_dim_indices(r_degrees)); + // std::unordered_set + // projection_output_dims = output_dims_of_projection(projection); + // + // ASSERT(r_dims == projection_output_dims); + // } + + return ParallelTensorSpaceToParallelTensorSpaceMapping{ + dim_domain_mapping_from_projection( + /*projection=*/projection, + /*l_domain=*/dim_domain_from_parallel_tensor_dim_degrees(l_degrees), + /*r_domain=*/dim_domain_from_parallel_tensor_dim_degrees(r_degrees), + /*l_dim_ordering=*/get_parallel_tensor_dim_ordering(), + /*r_dim_ordering=*/get_parallel_tensor_dim_ordering()), + }; +} + +ParallelTensorSpaceToParallelTensorSpaceMapping + invert_parallel_tensor_space_mapping( + ParallelTensorSpaceToParallelTensorSpaceMapping const &m) { + return ParallelTensorSpaceToParallelTensorSpaceMapping{ + invert_dim_domain_mapping(m.raw_mapping), + }; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/relative_ff_dim_t.cc b/lib/op-attrs/src/op-attrs/relative_ff_dim_t.cc index a987841b18..91caa03f36 100644 --- a/lib/op-attrs/src/op-attrs/relative_ff_dim_t.cc +++ b/lib/op-attrs/src/op-attrs/relative_ff_dim_t.cc @@ -3,10 +3,10 @@ namespace FlexFlow { ff_dim_t ff_dim_t_from_relative_ff_dim_t(relative_ff_dim_t ff_dim, - nonnegative_int input_dim) { + num_tensor_dims_t input_dim) { int raw = ff_dim.value; if (raw < 0) { - raw = input_dim.unwrap_nonnegative() + raw; + raw = input_dim.int_from_num_tensor_dims() + raw; } return ff_dim_t{nonnegative_int{raw}}; } diff --git a/lib/op-attrs/src/op-attrs/shape_inference.cc b/lib/op-attrs/src/op-attrs/shape_inference.cc index 4a0ff72fb4..a3f8066dee 100644 --- a/lib/op-attrs/src/op-attrs/shape_inference.cc +++ b/lib/op-attrs/src/op-attrs/shape_inference.cc @@ -20,333 +20,863 @@ #include "op-attrs/ops/repartition.h" #include "op-attrs/ops/replicate.h" #include "op-attrs/ops/softmax.h" +#include "op-attrs/ops/transpose.h" #include "op-attrs/ops/weight.h" +#include "op-attrs/tensor_slot_name.h" #include "utils/containers/get_only.h" +#include "utils/containers/require_only_key.h" +#include "utils/containers/require_two_keys.h" +#include "utils/containers/slice.h" #include "utils/overload.h" namespace FlexFlow { template -static std::pair require_2(std::vector const &v) { - assert(v.size() == 2); +static std::tuple + require_3(std::unordered_map const &v, + TensorSlotName k1, + TensorSlotName k2, + TensorSlotName k3) { + ASSERT(v.size() == 3); - return {v.at(0), v.at(1)}; + return {v.at(k1), v.at(k2), v.at(k3)}; } template -static std::tuple require_3(std::vector const &v) { - assert(v.size() == 3); +static std::vector + require_only_slots_sequence(std::unordered_map const &v, + std::vector const &slots) { + nonnegative_int v_num_slots = num_elements(v); + ASSERT(v_num_slots <= slots.size()); - return {v.at(0), v.at(1), v.at(2)}; -} + std::vector expected_slots = + slice(slots, 0, v_num_slots.unwrap_nonnegative()); + + ASSERT(unordered_set_of(expected_slots) == keys(v)); + + return transform(expected_slots, [&](TensorSlotName const &slot_name) { + return v.at(slot_name); + }); +}; + +std::unordered_map get_output_shapes( + ComputationGraphOpAttrs const &op_attrs, + std::unordered_map const &input_shapes) { + return op_attrs.visit>( + overload{ + [&](BatchMatmulAttrs const &attrs) + -> std::unordered_map { + auto [lhs, rhs] = require_two_keys(input_shapes, + TensorSlotName::LHS_INPUT, + TensorSlotName::RHS_INPUT); + + return { + { + TensorSlotName::OUTPUT, + TensorShape{ + throw_if_unexpected(get_output_shape(attrs, lhs, rhs)), + }, + }, + }; + }, + [&](BatchNormAttrs const &attrs) + -> std::unordered_map { + TensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + { + TensorSlotName::OUTPUT, + throw_if_unexpected(get_output_shape(attrs, input)), + }, + }; + }, + [&](CastAttrs const &attrs) + -> std::unordered_map { + TensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + { + TensorSlotName::OUTPUT, + throw_if_unexpected(get_output_shape(attrs, input)), + }, + }; + }, + [&](ConcatAttrs const &attrs) + -> std::unordered_map { + std::vector inputs = require_only_slots_sequence( + input_shapes, get_variadic_inputs_slot_name_sequence()); + + return { + { + TensorSlotName::OUTPUT, + throw_if_unexpected(get_output_shape(attrs, inputs)), + }, + }; + }, + [&](Conv2DAttrs const &attrs) + -> std::unordered_map { + TensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + { + TensorSlotName::OUTPUT, + get_output_shape(attrs, input), + }, + }; + }, + [&](DropoutAttrs const &attrs) + -> std::unordered_map { + TensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + { + TensorSlotName::OUTPUT, + get_output_shape(attrs, input), + }, + }; + }, + [&](ElementBinaryAttrs const &attrs) + -> std::unordered_map { + auto [lhs, rhs] = require_two_keys(input_shapes, + TensorSlotName::LHS_INPUT, + TensorSlotName::RHS_INPUT); + + return { + { + TensorSlotName::OUTPUT, + get_output_shape(attrs, lhs, rhs), + }, + }; + }, + [&](ElementUnaryAttrs const &attrs) + -> std::unordered_map { + TensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + { + TensorSlotName::OUTPUT, + get_output_shape(attrs, input), + }, + }; + }, + [&](EmbeddingAttrs const &attrs) + -> std::unordered_map { + TensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + { + TensorSlotName::OUTPUT, + throw_if_unexpected(get_output_shape(attrs, input)), + }, + }; + }, + [&](FlatAttrs const &attrs) + -> std::unordered_map { + TensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + { + TensorSlotName::OUTPUT, + get_output_shape(attrs, input), + }, + }; + }, + [&](GatherAttrs const &attrs) + -> std::unordered_map { + auto [input, index] = require_two_keys( + input_shapes, TensorSlotName::INPUT, TensorSlotName::INDEX); + + return { + { + TensorSlotName::OUTPUT, + get_output_shape(attrs, input, index), + }, + }; + }, + [&](InputAttrs const &attrs) + -> std::unordered_map { + ASSERT(input_shapes.size() == 0); + + return { + { + TensorSlotName::OUTPUT, + get_output_shape(attrs), + }, + }; + }, + [&](LayerNormAttrs const &attrs) + -> std::unordered_map { + TensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + { + TensorSlotName::OUTPUT, + throw_if_unexpected(get_output_shape(attrs, input)), + }, + }; + }, + [&](LinearAttrs const &attrs) + -> std::unordered_map { + TensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + { + TensorSlotName::OUTPUT, + throw_if_unexpected(get_output_shape(attrs, input)), + }, + }; + }, + [&](MultiHeadAttentionAttrs const &attrs) + -> std::unordered_map { + auto [query, key, value] = require_3(input_shapes, + TensorSlotName::QUERY, + TensorSlotName::KEY, + TensorSlotName::VALUE); + + return { + {TensorSlotName::OUTPUT, + throw_if_unexpected( + get_output_shape(attrs, query, key, value))}, + }; + }, + [&](Pool2DAttrs const &attrs) + -> std::unordered_map { + TensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + {TensorSlotName::OUTPUT, + throw_if_unexpected(get_output_shape(attrs, input))}, + }; + }, + [&](SoftmaxAttrs const &attrs) + -> std::unordered_map { + TensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); -std::vector - get_output_shapes(ComputationGraphOpAttrs const &op_attrs, - std::vector const &input_shapes) { - return op_attrs.visit>(overload{ - [&](BatchMatmulAttrs const &attrs) -> std::vector { - auto [i1, i2] = require_2(input_shapes); - - return {throw_if_unexpected(get_output_shape(attrs, i1, i2))}; - }, - [&](BatchNormAttrs const &attrs) -> std::vector { - return {throw_if_unexpected( - get_output_shape(attrs, get_only(input_shapes)))}; - }, - [&](CastAttrs const &attrs) -> std::vector { - return {throw_if_unexpected( - get_output_shape(attrs, get_only(input_shapes)))}; - }, - [&](ConcatAttrs const &attrs) -> std::vector { - return {throw_if_unexpected(get_output_shape(attrs, input_shapes))}; - }, - [&](Conv2DAttrs const &attrs) -> std::vector { - return {get_output_shape(attrs, get_only(input_shapes))}; - }, - [&](DropoutAttrs const &attrs) -> std::vector { - return {get_output_shape(attrs, get_only(input_shapes))}; - }, - [&](ElementBinaryAttrs const &attrs) -> std::vector { - auto [i1, i2] = require_2(input_shapes); - - return {throw_if_unexpected(get_output_shape(attrs, i1, i2))}; - }, - [&](ElementUnaryAttrs const &attrs) -> std::vector { - return {throw_if_unexpected( - get_output_shape(attrs, get_only(input_shapes)))}; - }, - [&](EmbeddingAttrs const &attrs) -> std::vector { - return {throw_if_unexpected( - get_output_shape(attrs, get_only(input_shapes)))}; - }, - [&](FlatAttrs const &attrs) -> std::vector { - return {get_output_shape(attrs, get_only(input_shapes))}; - }, - [&](GatherAttrs const &attrs) -> std::vector { - return { - get_output_shape(attrs, input_shapes.at(0), input_shapes.at(1))}; - }, - [&](InputAttrs const &attrs) -> std::vector { - return {get_output_shape(attrs)}; - }, - [&](LayerNormAttrs const &attrs) -> std::vector { - return {throw_if_unexpected( - get_output_shape(attrs, get_only(input_shapes)))}; - }, - [&](LinearAttrs const &attrs) -> std::vector { - return {throw_if_unexpected( - get_output_shape(attrs, get_only(input_shapes)))}; - }, - [&](MultiHeadAttentionAttrs const &attrs) -> std::vector { - auto [i1, i2, i3] = require_3(input_shapes); - - return {throw_if_unexpected(get_output_shape(attrs, i1, i2, i3))}; - }, - [&](Pool2DAttrs const &attrs) -> std::vector { - return {throw_if_unexpected( - get_output_shape(attrs, get_only(input_shapes)))}; - }, - [&](SoftmaxAttrs const &attrs) -> std::vector { - return {throw_if_unexpected( - get_output_shape(attrs, get_only(input_shapes)))}; - }, - [&](WeightAttrs const &attrs) -> std::vector { - return {get_output_shape(attrs)}; - }, - [&](auto const &attrs) -> std::vector { - NOT_IMPLEMENTED(); - }}); + return { + {TensorSlotName::OUTPUT, + throw_if_unexpected(get_output_shape(attrs, input))}, + }; + }, + [&](TransposeAttrs const &attrs) + -> std::unordered_map { + TensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + { + TensorSlotName::OUTPUT, + get_output_shape(attrs, input), + }, + }; + }, + [&](WeightAttrs const &attrs) + -> std::unordered_map { + ASSERT(input_shapes.size() == 0); + + return { + { + TensorSlotName::OUTPUT, + get_output_shape(attrs), + }, + }; + }, + [&](auto const &attrs) + -> std::unordered_map { + NOT_IMPLEMENTED(); + }}); } -std::vector - get_weight_shapes(ComputationGraphOpAttrs const &op_attrs, - std::vector const &input_shapes) { - return op_attrs.visit>(overload{ - [&](BatchMatmulAttrs const &attrs) -> std::vector { - return {}; - }, - [&](BatchNormAttrs const &attrs) -> std::vector { - return throw_if_unexpected( - get_weight_shapes(attrs, get_only(input_shapes))); - }, - [&](CastAttrs const &attrs) -> std::vector { return {}; }, - [&](ConcatAttrs const &attrs) -> std::vector { return {}; }, - [&](Conv2DAttrs const &attrs) -> std::vector { - return get_weight_shapes(attrs, get_only(input_shapes)); - }, - [&](DropoutAttrs const &attrs) -> std::vector { return {}; }, - [&](ElementBinaryAttrs const &attrs) -> std::vector { - return {}; - }, - [&](ElementUnaryAttrs const &attrs) -> std::vector { - return {}; - }, - [&](EmbeddingAttrs const &attrs) -> std::vector { - return {throw_if_unexpected( - get_weights_shape(attrs, get_only(input_shapes)))}; - }, - [&](FlatAttrs const &attrs) -> std::vector { return {}; }, - [&](GatherAttrs const &attrs) -> std::vector { return {}; }, - [&](InputAttrs const &attrs) -> std::vector { return {}; }, - [&](LayerNormAttrs const &attrs) -> std::vector { - return throw_if_unexpected( - get_weight_shapes(attrs, get_only(input_shapes))); - }, - [&](LinearAttrs const &attrs) -> std::vector { - return throw_if_unexpected( - get_weight_shapes(attrs, get_only(input_shapes))); - }, - [&](MultiHeadAttentionAttrs const &attrs) -> std::vector { - auto [i1, i2, i3] = require_3(input_shapes); - - return throw_if_unexpected(get_weight_shapes(attrs, i1, i2, i3)); - }, - [&](Pool2DAttrs const &attrs) -> std::vector { return {}; }, - [&](SoftmaxAttrs const &attrs) -> std::vector { return {}; }, - [&](WeightAttrs const &attrs) -> std::vector { return {}; }, - [&](auto const &attrs) -> std::vector { - NOT_IMPLEMENTED(); - }}); +std::unordered_map get_weight_shapes( + ComputationGraphOpAttrs const &op_attrs, + std::unordered_map const &input_shapes) { + return op_attrs.visit>( + overload{ + [&](BatchMatmulAttrs const &attrs) + -> std::unordered_map { + require_two_keys(input_shapes, + TensorSlotName::LHS_INPUT, + TensorSlotName::RHS_INPUT); + + return {}; + }, + [&](BatchNormAttrs const &attrs) + -> std::unordered_map { + TensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return throw_if_unexpected(get_weight_shapes(attrs, input)); + }, + [&](CastAttrs const &attrs) + -> std::unordered_map { + require_only_key(input_shapes, TensorSlotName::INPUT); + return {}; + }, + [&](ConcatAttrs const &attrs) + -> std::unordered_map { + require_only_slots_sequence( + input_shapes, get_variadic_inputs_slot_name_sequence()); + + return {}; + }, + [&](Conv2DAttrs const &attrs) + -> std::unordered_map { + TensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return get_weight_shapes(attrs, input); + }, + [&](DropoutAttrs const &attrs) + -> std::unordered_map { + require_only_key(input_shapes, TensorSlotName::INPUT); + return {}; + }, + [&](ElementBinaryAttrs const &attrs) + -> std::unordered_map { + require_two_keys(input_shapes, + TensorSlotName::LHS_INPUT, + TensorSlotName::RHS_INPUT); + return {}; + }, + [&](ElementUnaryAttrs const &attrs) + -> std::unordered_map { + require_only_key(input_shapes, TensorSlotName::INPUT); + return {}; + }, + [&](EmbeddingAttrs const &attrs) + -> std::unordered_map { + TensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + { + TensorSlotName::WEIGHT, + TensorShape{ + throw_if_unexpected(get_weights_shape(attrs, input)), + }, + }, + }; + }, + [&](FlatAttrs const &attrs) + -> std::unordered_map { + require_only_key(input_shapes, TensorSlotName::INPUT); + return {}; + }, + [&](GatherAttrs const &attrs) + -> std::unordered_map { + require_two_keys( + input_shapes, TensorSlotName::INPUT, TensorSlotName::INDEX); + return {}; + }, + [&](InputAttrs const &attrs) + -> std::unordered_map { + ASSERT(input_shapes.size() == 0); + return {}; + }, + [&](LayerNormAttrs const &attrs) + -> std::unordered_map { + TensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return throw_if_unexpected(get_weight_shapes(attrs, input)); + }, + [&](LinearAttrs const &attrs) + -> std::unordered_map { + TensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return throw_if_unexpected(get_weight_shapes(attrs, input)); + }, + [&](MultiHeadAttentionAttrs const &attrs) + -> std::unordered_map { + auto [query, key, value] = require_3(input_shapes, + TensorSlotName::QUERY, + TensorSlotName::KEY, + TensorSlotName::VALUE); + + return throw_if_unexpected( + get_weight_shapes(attrs, query, key, value)); + }, + [&](Pool2DAttrs const &attrs) + -> std::unordered_map { + require_only_key(input_shapes, TensorSlotName::INPUT); + + return {}; + }, + [&](SoftmaxAttrs const &attrs) + -> std::unordered_map { + require_only_key(input_shapes, TensorSlotName::INPUT); + + return {}; + }, + [&](WeightAttrs const &attrs) + -> std::unordered_map { + ASSERT(input_shapes.size() == 0); + return {}; + }, + [&](auto const &attrs) + -> std::unordered_map { + NOT_IMPLEMENTED(); + }}); } -std::vector - get_output_shapes(PCGOperatorAttrs const &pcg_op_attrs, - std::vector const &input_shapes) { - return pcg_op_attrs.visit>(overload{ - [&](BatchMatmulAttrs const &attrs) -> std::vector { - auto [i1, i2] = require_2(input_shapes); - - return {throw_if_unexpected(get_output_shape(attrs, i1, i2))}; - }, - [&](BatchNormAttrs const &attrs) -> std::vector { - return {throw_if_unexpected( - get_output_shape(attrs, get_only(input_shapes)))}; - }, - [&](CastAttrs const &attrs) -> std::vector { - return {throw_if_unexpected( - get_output_shape(attrs, get_only(input_shapes)))}; - }, - [&](CombineAttrs const &attrs) -> std::vector { - return {throw_if_unexpected( - get_output_shape(attrs, get_only(input_shapes)))}; - }, - [&](ConcatAttrs const &attrs) -> std::vector { - return {throw_if_unexpected(get_output_shape(attrs, input_shapes))}; - }, - [&](Conv2DAttrs const &attrs) -> std::vector { - return {get_output_shape(attrs, get_only(input_shapes))}; - }, - [&](DropoutAttrs const &attrs) -> std::vector { - return {throw_if_unexpected( - get_output_shape(attrs, get_only(input_shapes)))}; - }, - [&](ElementBinaryAttrs const &attrs) -> std::vector { - auto [i1, i2] = require_2(input_shapes); - - return {throw_if_unexpected(get_output_shape(attrs, i1, i2))}; - }, - [&](ElementUnaryAttrs const &attrs) -> std::vector { - return {throw_if_unexpected( - get_output_shape(attrs, get_only(input_shapes)))}; - }, - [&](EmbeddingAttrs const &attrs) -> std::vector { - return {throw_if_unexpected( - get_output_shape(attrs, get_only(input_shapes)))}; - }, - [&](FlatAttrs const &attrs) -> std::vector { - return {throw_if_unexpected( - get_output_shape(attrs, get_only(input_shapes)))}; - }, - [&](GatherAttrs const &attrs) -> std::vector { - return { - get_output_shape(attrs, input_shapes.at(0), input_shapes.at(1))}; - }, - [&](InputAttrs const &attrs) -> std::vector { - return {get_output_parallel_tensor_shape(attrs)}; - }, - [&](LayerNormAttrs const &attrs) -> std::vector { - return {throw_if_unexpected( - get_output_shape(attrs, get_only(input_shapes)))}; - }, - [&](LinearAttrs const &attrs) -> std::vector { - return {throw_if_unexpected( - get_output_shape(attrs, get_only(input_shapes)))}; - }, - [&](MultiHeadAttentionAttrs const &attrs) - -> std::vector { - auto [i1, i2, i3] = require_3(input_shapes); - - return {throw_if_unexpected(get_output_shape(attrs, i1, i2, i3))}; - }, - [&](Pool2DAttrs const &attrs) -> std::vector { - return {throw_if_unexpected( - get_output_shape(attrs, get_only(input_shapes)))}; - }, - [&](ReductionAttrs const &attrs) -> std::vector { - return {throw_if_unexpected( - get_output_shape(attrs, get_only(input_shapes)))}; - }, - [&](RepartitionAttrs const &attrs) -> std::vector { - return {throw_if_unexpected( - get_output_shape(attrs, get_only(input_shapes)))}; - }, - [&](ReplicateAttrs const &attrs) -> std::vector { - return {get_output_shape(attrs, get_only(input_shapes))}; - }, - [&](SoftmaxAttrs const &attrs) -> std::vector { - return {throw_if_unexpected( - get_output_shape(attrs, get_only(input_shapes)))}; - }, - [&](WeightAttrs const &attrs) -> std::vector { - return {get_output_parallel_tensor_shape(attrs)}; - }, - [&](auto const &attrs) -> std::vector { - NOT_IMPLEMENTED(); - }}); +std::unordered_map get_output_shapes( + PCGOperatorAttrs const &pcg_op_attrs, + std::unordered_map const + &input_shapes) { + return pcg_op_attrs + .visit>(overload{ + [&](BatchMatmulAttrs const &attrs) + -> std::unordered_map { + auto [lhs, rhs] = require_two_keys(input_shapes, + TensorSlotName::LHS_INPUT, + TensorSlotName::RHS_INPUT); + + return { + { + TensorSlotName::OUTPUT, + throw_if_unexpected(get_output_shape(attrs, lhs, rhs)), + }, + }; + }, + [&](BatchNormAttrs const &attrs) + -> std::unordered_map { + ParallelTensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + { + TensorSlotName::OUTPUT, + throw_if_unexpected(get_output_shape(attrs, input)), + }, + }; + }, + [&](CastAttrs const &attrs) + -> std::unordered_map { + ParallelTensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + {TensorSlotName::OUTPUT, + throw_if_unexpected(get_output_shape(attrs, input))}, + }; + }, + [&](CombineAttrs const &attrs) + -> std::unordered_map { + ParallelTensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + {TensorSlotName::OUTPUT, + throw_if_unexpected(get_output_shape(attrs, input))}, + }; + }, + [&](ConcatAttrs const &attrs) + -> std::unordered_map { + std::vector inputs = + require_only_slots_sequence( + input_shapes, get_variadic_inputs_slot_name_sequence()); + + return { + {TensorSlotName::OUTPUT, + throw_if_unexpected(get_output_shape(attrs, inputs))}, + }; + }, + [&](Conv2DAttrs const &attrs) + -> std::unordered_map { + ParallelTensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + {TensorSlotName::OUTPUT, get_output_shape(attrs, input)}, + }; + }, + [&](DropoutAttrs const &attrs) + -> std::unordered_map { + ParallelTensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + { + TensorSlotName::OUTPUT, + throw_if_unexpected(get_output_shape(attrs, input)), + }, + }; + }, + [&](ElementBinaryAttrs const &attrs) + -> std::unordered_map { + auto [lhs, rhs] = require_two_keys(input_shapes, + TensorSlotName::LHS_INPUT, + TensorSlotName::RHS_INPUT); + + return { + { + TensorSlotName::OUTPUT, + get_output_shape(attrs, lhs, rhs), + }, + }; + }, + [&](ElementUnaryAttrs const &attrs) + -> std::unordered_map { + ParallelTensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + { + TensorSlotName::OUTPUT, + get_output_shape(attrs, input), + }, + }; + }, + [&](EmbeddingAttrs const &attrs) + -> std::unordered_map { + ParallelTensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + { + TensorSlotName::OUTPUT, + throw_if_unexpected(get_output_shape(attrs, input)), + }, + }; + }, + [&](FlatAttrs const &attrs) + -> std::unordered_map { + ParallelTensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + { + TensorSlotName::OUTPUT, + get_output_shape(attrs, input), + }, + }; + }, + [&](GatherAttrs const &attrs) + -> std::unordered_map { + auto [input, index] = require_two_keys( + input_shapes, TensorSlotName::INPUT, TensorSlotName::INDEX); + + return { + { + TensorSlotName::OUTPUT, + get_output_shape(attrs, input, index), + }, + }; + }, + [&](InputAttrs const &attrs) + -> std::unordered_map { + ASSERT(input_shapes.size() == 0); + + return { + { + TensorSlotName::OUTPUT, + get_output_parallel_tensor_shape(attrs), + }, + }; + }, + [&](LayerNormAttrs const &attrs) + -> std::unordered_map { + ParallelTensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + { + TensorSlotName::OUTPUT, + throw_if_unexpected(get_output_shape(attrs, input)), + }, + }; + }, + [&](LinearAttrs const &attrs) + -> std::unordered_map { + ParallelTensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + { + TensorSlotName::OUTPUT, + throw_if_unexpected(get_output_shape(attrs, input)), + }, + }; + }, + [&](MultiHeadAttentionAttrs const &attrs) + -> std::unordered_map { + auto [i1, i2, i3] = require_3(input_shapes, + TensorSlotName::QUERY, + TensorSlotName::KEY, + TensorSlotName::VALUE); + + return { + {TensorSlotName::OUTPUT, + throw_if_unexpected(get_output_shape(attrs, i1, i2, i3))}, + }; + }, + [&](Pool2DAttrs const &attrs) + -> std::unordered_map { + ParallelTensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + { + TensorSlotName::OUTPUT, + throw_if_unexpected(get_output_shape(attrs, input)), + }, + }; + }, + [&](ReductionAttrs const &attrs) + -> std::unordered_map { + ParallelTensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + { + TensorSlotName::OUTPUT, + throw_if_unexpected(get_output_shape(attrs, input)), + }, + }; + }, + [&](RepartitionAttrs const &attrs) + -> std::unordered_map { + ParallelTensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + { + TensorSlotName::OUTPUT, + throw_if_unexpected(get_output_shape(attrs, input)), + }, + }; + }, + [&](ReplicateAttrs const &attrs) + -> std::unordered_map { + ParallelTensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + { + TensorSlotName::OUTPUT, + get_output_shape(attrs, input), + }, + }; + }, + [&](SoftmaxAttrs const &attrs) + -> std::unordered_map { + ParallelTensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + { + TensorSlotName::OUTPUT, + throw_if_unexpected(get_output_shape(attrs, input)), + }, + }; + }, + [&](TransposeAttrs const &attrs) + -> std::unordered_map { + ParallelTensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + { + TensorSlotName::OUTPUT, + get_output_shape(attrs, input), + }, + }; + }, + [&](WeightAttrs const &attrs) + -> std::unordered_map { + ASSERT(input_shapes.size() == 0); + + return { + { + TensorSlotName::OUTPUT, + get_output_parallel_tensor_shape(attrs), + }, + }; + }, + [&](auto const &attrs) + -> std::unordered_map { + NOT_IMPLEMENTED(); + }}); } -std::vector - get_weight_shapes(PCGOperatorAttrs const &pcg_op_attrs, - std::vector const &input_shapes) { - return pcg_op_attrs.visit>(overload{ - [&](BatchMatmulAttrs const &attrs) -> std::vector { - return {}; - }, - [&](BatchNormAttrs const &attrs) -> std::vector { - return throw_if_unexpected( - get_weight_shapes(attrs, get_only(input_shapes))); - }, - [&](CastAttrs const &attrs) -> std::vector { - return {}; - }, - [&](CombineAttrs const &attrs) -> std::vector { - return {}; - }, - [&](ConcatAttrs const &attrs) -> std::vector { - return {}; - }, - [&](Conv2DAttrs const &attrs) -> std::vector { - return get_weight_shapes(attrs, get_only(input_shapes)); - }, - [&](DropoutAttrs const &attrs) -> std::vector { - return {}; - }, - [&](ElementBinaryAttrs const &attrs) -> std::vector { - return {}; - }, - [&](ElementUnaryAttrs const &attrs) -> std::vector { - return {}; - }, - [&](EmbeddingAttrs const &attrs) -> std::vector { - return { - throw_if_unexpected( - get_weights_shape(attrs, get_only(input_shapes))), - }; - }, - [&](FlatAttrs const &attrs) -> std::vector { - return {}; - }, - [&](GatherAttrs const &attrs) -> std::vector { - return {}; - }, - [&](InputAttrs const &attrs) -> std::vector { - return {}; - }, - [&](LayerNormAttrs const &attrs) -> std::vector { - return throw_if_unexpected( - get_weight_shapes(attrs, get_only(input_shapes))); - }, - [&](LinearAttrs const &attrs) -> std::vector { - return throw_if_unexpected( - get_weight_shapes(attrs, get_only(input_shapes))); - }, - [&](MultiHeadAttentionAttrs const &attrs) - -> std::vector { - auto [i1, i2, i3] = require_3(input_shapes); - - return throw_if_unexpected(get_weight_shapes(attrs, i1, i2, i3)); - }, - [&](Pool2DAttrs const &attrs) -> std::vector { - return {}; - }, - [&](RepartitionAttrs const &attrs) -> std::vector { - return {}; - }, - [&](ReplicateAttrs const &attrs) -> std::vector { - return {}; - }, - [&](ReductionAttrs const &attrs) -> std::vector { - return {}; - }, - [&](SoftmaxAttrs const &attrs) -> std::vector { - return {}; - }, - [&](WeightAttrs const &attrs) -> std::vector { - return {}; - }, - [&](auto const &attrs) -> std::vector { - NOT_IMPLEMENTED(); - }}); +std::unordered_map get_weight_shapes( + PCGOperatorAttrs const &pcg_op_attrs, + std::unordered_map const + &input_shapes) { + return pcg_op_attrs + .visit>(overload{ + [&](BatchMatmulAttrs const &attrs) + -> std::unordered_map { + require_two_keys(input_shapes, + TensorSlotName::LHS_INPUT, + TensorSlotName::RHS_INPUT); + + return {}; + }, + [&](BatchNormAttrs const &attrs) + -> std::unordered_map { + ParallelTensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return throw_if_unexpected(get_weight_shapes(attrs, input)); + }, + [&](CastAttrs const &attrs) + -> std::unordered_map { + require_only_key(input_shapes, TensorSlotName::INPUT); + + return {}; + }, + [&](CombineAttrs const &attrs) + -> std::unordered_map { + require_only_key(input_shapes, TensorSlotName::INPUT); + + return {}; + }, + [&](ConcatAttrs const &attrs) + -> std::unordered_map { + require_only_key(input_shapes, TensorSlotName::INPUT); + + return {}; + }, + [&](Conv2DAttrs const &attrs) + -> std::unordered_map { + ParallelTensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return get_weight_shapes(attrs, input); + }, + [&](DropoutAttrs const &attrs) + -> std::unordered_map { + require_only_key(input_shapes, TensorSlotName::INPUT); + + return {}; + }, + [&](ElementBinaryAttrs const &attrs) + -> std::unordered_map { + require_two_keys(input_shapes, + TensorSlotName::LHS_INPUT, + TensorSlotName::RHS_INPUT); + + return {}; + }, + [&](ElementUnaryAttrs const &attrs) + -> std::unordered_map { + require_only_key(input_shapes, TensorSlotName::INPUT); + + return {}; + }, + [&](EmbeddingAttrs const &attrs) + -> std::unordered_map { + ParallelTensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return { + { + TensorSlotName::WEIGHT, + throw_if_unexpected(get_weights_shape(attrs, input)), + }, + }; + }, + [&](FlatAttrs const &attrs) + -> std::unordered_map { + require_only_key(input_shapes, TensorSlotName::INPUT); + + return {}; + }, + [&](GatherAttrs const &attrs) + -> std::unordered_map { + require_two_keys( + input_shapes, TensorSlotName::INPUT, TensorSlotName::INDEX); + + return {}; + }, + [&](InputAttrs const &attrs) + -> std::unordered_map { + ASSERT(input_shapes.size() == 0); + + return {}; + }, + [&](LayerNormAttrs const &attrs) + -> std::unordered_map { + ParallelTensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return throw_if_unexpected(get_weight_shapes(attrs, input)); + }, + [&](LinearAttrs const &attrs) + -> std::unordered_map { + ParallelTensorShape input = + require_only_key(input_shapes, TensorSlotName::INPUT); + + return throw_if_unexpected(get_weight_shapes(attrs, input)); + }, + [&](MultiHeadAttentionAttrs const &attrs) + -> std::unordered_map { + auto [query, key, value] = require_3(input_shapes, + TensorSlotName::QUERY, + TensorSlotName::KEY, + TensorSlotName::VALUE); + + return throw_if_unexpected( + get_weight_shapes(attrs, query, key, value)); + }, + [&](Pool2DAttrs const &attrs) + -> std::unordered_map { + require_only_key(input_shapes, TensorSlotName::INPUT); + + return {}; + }, + [&](RepartitionAttrs const &attrs) + -> std::unordered_map { + require_only_key(input_shapes, TensorSlotName::INPUT); + + return {}; + }, + [&](ReplicateAttrs const &attrs) + -> std::unordered_map { + require_only_key(input_shapes, TensorSlotName::INPUT); + + return {}; + }, + [&](ReductionAttrs const &attrs) + -> std::unordered_map { + require_only_key(input_shapes, TensorSlotName::INPUT); + + return {}; + }, + [&](SoftmaxAttrs const &attrs) + -> std::unordered_map { + require_only_key(input_shapes, TensorSlotName::INPUT); + + return {}; + }, + [&](TransposeAttrs const &attrs) + -> std::unordered_map { + require_only_key(input_shapes, TensorSlotName::INPUT); + + return {}; + }, + [&](WeightAttrs const &attrs) + -> std::unordered_map { + ASSERT(input_shapes.size() == 0); + + return {}; + }, + [&](auto const &attrs) + -> std::unordered_map { + NOT_IMPLEMENTED(); + }}); } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/task_space_coordinate.cc b/lib/op-attrs/src/op-attrs/task_space_coordinate.cc new file mode 100644 index 0000000000..302825f27e --- /dev/null +++ b/lib/op-attrs/src/op-attrs/task_space_coordinate.cc @@ -0,0 +1,55 @@ +#include "op-attrs/task_space_coordinate.h" +#include "op-attrs/operator_task_space.h" +#include "op-attrs/operator_task_space_dim_idx_t.h" +#include "utils/containers/map_keys.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/vector_from_idx_map.h" +#include "utils/nonnegative_int/nonnegative_range.h" +#include "utils/nonnegative_int/num_elements.h" +#include "utils/orthotope/dim_coord.h" +#include "utils/orthotope/orthotope_coord.h" + +namespace FlexFlow { + +nonnegative_int task_space_coord_num_dims(TaskSpaceCoordinate const &coord) { + return orthotope_coord_num_dims(coord.orthotope_coord); +} + +TaskSpaceCoordinate + make_task_space_coordinate(std::vector const &elems) { + return TaskSpaceCoordinate{OrthotopeCoord{elems}}; +} + +TaskSpaceCoordinate task_space_coordinate_from_dim_coord( + DimCoord const &dim_coord) { + std::unordered_set coord_dims = + get_coord_dims(dim_coord); + + std::set dims = + operator_task_space_dim_idx_range(num_elements(coord_dims)); + + ASSERT(coord_dims == unordered_set_of(dims)); + + std::unordered_map idx_map = + map_keys(dim_coord.raw, + [](operator_task_space_dim_idx_t idx) { return idx.raw_idx; }); + + return TaskSpaceCoordinate{ + OrthotopeCoord{ + vector_from_idx_map(idx_map).value(), + }, + }; +} + +DimCoord + dim_coord_from_task_space_coordinate(TaskSpaceCoordinate const &coord) { + + return dim_coord_from_orthotope_coord( + coord.orthotope_coord, + unordered_set_of(operator_task_space_dim_idx_range( + orthotope_coord_num_dims(coord.orthotope_coord))), + get_operator_task_space_dim_ordering()); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/tensor_dim_permutation.cc b/lib/op-attrs/src/op-attrs/tensor_dim_permutation.cc new file mode 100644 index 0000000000..1f6fa4b5d4 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/tensor_dim_permutation.cc @@ -0,0 +1,216 @@ +#include "op-attrs/tensor_dim_permutation.h" +#include "op-attrs/ff_ordered/ff_ordered_from_map.h" +#include "op-attrs/ff_ordered/map_from_ff_ordered.h" +#include "utils/bidict/algorithms/bidict_from_keys_and_values.h" +#include "utils/bidict/algorithms/exhaustive_relational_join.h" +#include "utils/bidict/algorithms/left_entries.h" +#include "utils/bidict/algorithms/right_entries.h" +#include "utils/bidict/bidict.h" +#include "utils/containers/map_keys.h" +#include "utils/containers/maximum.h" +#include "utils/containers/minimum.h" +#include "utils/containers/permute_with_key.h" +#include "utils/containers/require_same.h" +#include "utils/fmt/unordered_set.h" +#include "utils/hash/tuple.h" + +namespace FlexFlow { + +static void + check_are_contiguous_from_one(std::unordered_set const &idxs) { + if (idxs.empty()) { + return; + } + + ASSERT(minimum(idxs) == ff_dim_t{0_n}); + ASSERT(maximum(idxs) == ff_dim_t{nonnegative_int{idxs.size() - 1}}); +} + +TensorDimPermutation::TensorDimPermutation( + bidict const &raw) + : raw(raw) { + check_are_contiguous_from_one(right_entries(raw)); + check_are_contiguous_from_one(left_entries(raw)); +} + +bool TensorDimPermutation::operator==(TensorDimPermutation const &other) const { + return this->tie() == other.tie(); +} + +bool TensorDimPermutation::operator!=(TensorDimPermutation const &other) const { + return this->tie() == other.tie(); +} + +bool TensorDimPermutation::operator<(TensorDimPermutation const &other) const { + return this->tie() < other.tie(); +} + +bool TensorDimPermutation::operator>(TensorDimPermutation const &other) const { + return this->tie() > other.tie(); +} + +bool TensorDimPermutation::operator<=(TensorDimPermutation const &other) const { + return this->tie() <= other.tie(); +} + +bool TensorDimPermutation::operator>=(TensorDimPermutation const &other) const { + return this->tie() >= other.tie(); +} + +ff_dim_t TensorDimPermutation::at_l(ff_dim_t l) const { + return this->raw.at_l(l); +} + +ff_dim_t TensorDimPermutation::at_r(ff_dim_t r) const { + return this->raw.at_r(r); +} + +num_tensor_dims_t TensorDimPermutation::num_tensor_dims() const { + return num_tensor_dims_t{ + num_elements(this->raw), + }; +} + +bidict const &TensorDimPermutation::as_bidict() const { + return this->raw; +} + +std::tuple const &> + TensorDimPermutation::tie() const { + return std::tie(this->raw); +} + +bidict format_as(TensorDimPermutation const &p) { + return p.as_bidict(); +} + +std::ostream &operator<<(std::ostream &s, TensorDimPermutation const &p) { + return (s << fmt::to_string(p)); +} + +TensorDimPermutation + compose_tensor_dim_permutations(TensorDimPermutation const &lhs, + TensorDimPermutation const &rhs) { + + ASSERT(lhs.num_tensor_dims() == rhs.num_tensor_dims()); + + return TensorDimPermutation{ + exhaustive_relational_join(lhs.as_bidict(), rhs.as_bidict()), + }; +} + +TensorDimPermutation + invert_tensor_dim_permutation(TensorDimPermutation const &p) { + + return TensorDimPermutation{ + p.as_bidict().reversed(), + }; +} + +template +static FFOrdered permute_ff_ordered(TensorDimPermutation const &permutation, + FFOrdered const &ff_ordered) { + return ff_ordered_from_map( + map_keys(map_from_ff_ordered(ff_ordered), + [&](ff_dim_t k) { return permutation.at_l(k); })); +} + +TensorDims permute_tensor_dims(TensorDimPermutation const &permutation, + TensorDims const &dims) { + + return TensorDims{ + permute_ff_ordered(permutation, dims.ff_ordered), + }; +} + +TensorShape permute_tensor_shape(TensorDimPermutation const &permutation, + TensorShape const &shape) { + return TensorShape{ + /*dims=*/permute_tensor_dims(permutation, shape.dims), + /*data_type=*/shape.data_type, + }; +} + +ParallelTensorDimDegrees permute_parallel_tensor_dim_degrees( + TensorDimPermutation const &permutation, + ParallelTensorDimDegrees const ¶llel_tensor_dim_degrees) { + return ParallelTensorDimDegrees{ + /*sum_degree=*/parallel_tensor_dim_degrees.sum_degree, + /*discard_copy_degree=*/parallel_tensor_dim_degrees.discard_copy_degree, + /*shard_degrees=*/ + permute_ff_ordered(permutation, + parallel_tensor_dim_degrees.shard_degrees), + }; +} + +ParallelTensorDims permute_parallel_tensor_dims( + TensorDimPermutation const &permutation, + ParallelTensorDims const ¶llel_tensor_dims) { + return ParallelTensorDims{ + /*shard_dims=*/permute_ff_ordered(permutation, + parallel_tensor_dims.shard_dims), + /*replica_dims=*/parallel_tensor_dims.replica_dims, + }; +} + +ParallelTensorShape permute_parallel_tensor_shape( + TensorDimPermutation const &permutation, + ParallelTensorShape const ¶llel_tensor_shape) { + return ParallelTensorShape{ + /*dims=*/permute_parallel_tensor_dims(permutation, + parallel_tensor_shape.dims), + /*data_type=*/parallel_tensor_shape.data_type, + }; +} + +} // namespace FlexFlow + +namespace nlohmann { + +::FlexFlow::TensorDimPermutation + adl_serializer<::FlexFlow::TensorDimPermutation>::from_json(json const &j) { + ::FlexFlow::bidict<::FlexFlow::ff_dim_t, ::FlexFlow::ff_dim_t> b = j; + + return ::FlexFlow::TensorDimPermutation{b}; +} + +void adl_serializer<::FlexFlow::TensorDimPermutation>::to_json( + json &j, ::FlexFlow::TensorDimPermutation const &p) { + j = p.as_bidict(); +} + +} // namespace nlohmann + +namespace rc { + +Gen<::FlexFlow::TensorDimPermutation> + Arbitrary<::FlexFlow::TensorDimPermutation>::arbitrary() { + using namespace ::FlexFlow; + + Gen> key_permutation_gen = gen::withSize([=](int size) { + nonnegative_int reduced_size = std::min(nonnegative_int{size}, 5_n); + std::vector sized_keys = ff_dim_range(reduced_size); + return gen::map(gen::arbitrary(), + [=](int key) -> std::vector { + return permute_with_key(key, sized_keys); + }); + }); + + return gen::construct(gen::apply( + [](std::vector const &ks, std::vector const &vs) { + return bidict_from_keys_and_values(ks, vs); + }, + key_permutation_gen, + key_permutation_gen)); +} + +} // namespace rc + +namespace std { + +size_t hash<::FlexFlow::TensorDimPermutation>::operator()( + ::FlexFlow::TensorDimPermutation const &p) const { + return get_std_hash(p.tie()); +} + +} // namespace std diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.cc b/lib/op-attrs/src/op-attrs/tensor_dims.cc index 435f211a01..c69418c90c 100644 --- a/lib/op-attrs/src/op-attrs/tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/tensor_dims.cc @@ -31,8 +31,10 @@ bool tensor_dims_has_dim(TensorDims const &tensor_dims, ff_dim_t dim) { return contains(get_idxs(tensor_dims.ff_ordered), dim); } -nonnegative_int get_num_dims(TensorDims const &dims) { - return num_elements(dims.ff_ordered); +num_tensor_dims_t get_num_dims(TensorDims const &dims) { + return num_tensor_dims_t{ + num_elements(dims.ff_ordered), + }; } positive_int dim_at_idx(TensorDims const &dims, relative_ff_dim_t idx) { @@ -114,8 +116,8 @@ TensorDimsCoord get_broadcast_src_coord(TensorDims const &input_dims, input_dims, output_dims); - relative_ff_dim_t trailing_start_idx = - relative_ff_dim_t{-1 * get_num_dims(input_dims).unwrap_nonnegative()}; + relative_ff_dim_t trailing_start_idx = relative_ff_dim_t{ + -1 * get_num_dims(input_dims).int_from_num_tensor_dims()}; FFOrdered trailing_entries = slice(dst_coord.ff_ordered, trailing_start_idx); diff --git a/lib/op-attrs/src/op-attrs/tensor_slot_name.cc b/lib/op-attrs/src/op-attrs/tensor_slot_name.cc new file mode 100644 index 0000000000..790bd36580 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/tensor_slot_name.cc @@ -0,0 +1,29 @@ +#include "op-attrs/tensor_slot_name.h" + +namespace FlexFlow { + +std::vector get_variadic_inputs_slot_name_sequence() { + return std::vector{ + TensorSlotName::INPUT_0, + TensorSlotName::INPUT_1, + TensorSlotName::INPUT_2, + TensorSlotName::INPUT_3, + TensorSlotName::INPUT_4, + TensorSlotName::INPUT_5, + TensorSlotName::INPUT_6, + TensorSlotName::INPUT_7, + }; +}; + +std::vector get_variadic_outputs_slot_name_sequence() { + return std::vector{ + TensorSlotName::OUTPUT_0, + TensorSlotName::OUTPUT_1, + TensorSlotName::OUTPUT_2, + TensorSlotName::OUTPUT_3, + TensorSlotName::OUTPUT_4, + TensorSlotName::OUTPUT_5, + }; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/parallel_dim_mapping_record.cc b/lib/op-attrs/src/parallel_dim_mapping_record.cc deleted file mode 100644 index 5e734e88cd..0000000000 --- a/lib/op-attrs/src/parallel_dim_mapping_record.cc +++ /dev/null @@ -1,60 +0,0 @@ -#include "parallel_dim_mapping_record.h" -#include - -namespace FlexFlow { - -ParallelDimMappingRecord::ParallelDimMappingRecord(MappingRecordType type) - : type(type), output_dim(-1), input_dim(-1), weight_dim(-1), output_idx(-1), - input_idx(-1), weight_idx(-1) {} - -/*static*/ -ParallelDimMappingRecord ParallelDimMappingRecord::input_output_record( - int input_idx, - int input_dim, - int output_idx, - int output_dim, - std::optional operation) { - ParallelDimMappingRecord r(MappingRecordType::INPUT_OUTPUT); - r.operation = operation; - - assert(output_idx >= 0); - assert(output_dim >= 0); - assert(input_idx >= 0); - assert(input_dim >= 0); - - r.output_idx = output_idx; - r.output_dim = output_dim; - r.input_idx = input_idx; - r.input_dim = input_dim; - - return r; -} - -/*static*/ -ParallelDimMappingRecord ParallelDimMappingRecord::input_weight_record( - int input_idx, - int input_dim, - int weight_idx, - int weight_dim, - std::optional operation) { - ParallelDimMappingRecord r(MappingRecordType::INPUT_WEIGHT); - r.operation = operation; - - assert(input_idx >= 0); - assert(input_dim >= 0); - assert(weight_idx >= 0); - assert(weight_dim >= 0); - - r.input_idx = input_idx; - r.input_dim = input_dim; - r.weight_idx = weight_idx; - r.weight_dim = weight_dim; - - return r; -} - -MappingRecordType ParallelDimMappingRecord::get_type() const { - return this->type; -} - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/parallel_dim_mapping_record.h b/lib/op-attrs/src/parallel_dim_mapping_record.h deleted file mode 100644 index c37ac79b40..0000000000 --- a/lib/op-attrs/src/parallel_dim_mapping_record.h +++ /dev/null @@ -1,54 +0,0 @@ -#ifndef _FLEXFLOW_OP_META_SRC_PARELLEL_DIM_MAPPING_RECORD_H -#define _FLEXFLOW_OP_META_SRC_PARELLEL_DIM_MAPPING_RECORD_H - -#include "utils/visitable.h" -#include - -namespace FlexFlow { - -enum class MappingRecordType { INPUT_OUTPUT, INPUT_WEIGHT }; - -enum class MappingOperation { PARTITION, REPLICATE }; - -class ParallelDimMappingRecord { -private: - ParallelDimMappingRecord(MappingRecordType); - -public: - ParallelDimMappingRecord() = delete; - - static ParallelDimMappingRecord input_output_record( - int input_idx, - int input_dim, - int output_idx, - int output_dim, - std::optional operation = std::nullopt); - static ParallelDimMappingRecord input_weight_record( - int input_idx, - int input_dim, - int weight_idx, - int weight_dim, - std::optional operation = std::nullopt); - MappingRecordType get_type() const; - -public: - MappingRecordType type; - std::optional operation; - - int output_dim, input_dim, weight_dim; - int output_idx, input_idx, weight_idx; -}; - -} // namespace FlexFlow - -VISITABLE_STRUCT(::FlexFlow::ParallelDimMappingRecord, - type, - operation, - output_dim, - input_dim, - weight_dim, - output_idx, - input_idx, - weight_idx); - -#endif diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/dim_ordered.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/dim_ordered.cc deleted file mode 100644 index a5a261da25..0000000000 --- a/lib/op-attrs/test/src/op-attrs/dim_ordered/dim_ordered.cc +++ /dev/null @@ -1,13 +0,0 @@ -#include "op-attrs/dim_ordered/dim_ordered.h" -#include "doctest/doctest.h" -#include "test/utils/rapidcheck.h" - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - - TEST_CASE_TEMPLATE( - "Arbitrary> with T=", T, int, double, char) { - RC_SUBCASE([](DimOrdered) {}); - } -} diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc deleted file mode 100644 index b77bb8f71e..0000000000 --- a/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc +++ /dev/null @@ -1,41 +0,0 @@ -#include "op-attrs/dim_ordered/zip.h" -#include "op-attrs/ff_dim_t.dtg.h" -#include "test/utils/doctest/fmt/pair.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("zip(DimOrdered, DimOrdered)") { - DimOrdered lhs_input = {9, 9, 8, 9}; - DimOrdered rhs_input = {"m", "m", "k", "l", "m"}; - - SUBCASE("lhs is longer") { - DimOrdered> result = - zip(lhs_input, rhs_input); - - DimOrdered> correct = { - {9, "m"}, - {9, "m"}, - {8, "k"}, - {9, "l"}, - }; - - CHECK(result == correct); - } - - SUBCASE("rhs is longer") { - DimOrdered> result = - zip(rhs_input, lhs_input); - - DimOrdered> correct = { - {"m", 9}, - {"m", 9}, - {"k", 8}, - {"l", 9}, - }; - - CHECK(result == correct); - } - } -} diff --git a/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc b/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc index 4688ad4008..e03970b039 100644 --- a/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc +++ b/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc @@ -7,17 +7,28 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE( "get_incoming_tensor_roles(ComputationGraphOpAttrs, int num_incoming)") { SUBCASE("Concat") { - int num_incoming = 4; - ComputationGraphOpAttrs attrs = - ComputationGraphOpAttrs{ConcatAttrs{ff_dim_t{nonnegative_int{0}}}}; + ComputationGraphOpAttrs attrs = ComputationGraphOpAttrs{ + ConcatAttrs{ + /*axis=*/ff_dim_t{0_n}, + /*num_inputs=*/3_ge2, + }, + }; - std::vector result = - get_incoming_tensor_roles(attrs, num_incoming); - std::vector correct = { - IncomingTensorRole::INPUT, - IncomingTensorRole::INPUT, - IncomingTensorRole::INPUT, - IncomingTensorRole::INPUT, + std::unordered_map result = + get_incoming_tensor_roles(attrs); + std::unordered_map correct = { + { + TensorSlotName::INPUT_0, + IncomingTensorRole::INPUT, + }, + { + TensorSlotName::INPUT_1, + IncomingTensorRole::INPUT, + }, + { + TensorSlotName::INPUT_2, + IncomingTensorRole::INPUT, + }, }; CHECK(result == correct); diff --git a/lib/op-attrs/test/src/op-attrs/operator_space_to_parallel_tensor_space_mapping.cc b/lib/op-attrs/test/src/op-attrs/operator_space_to_parallel_tensor_space_mapping.cc new file mode 100644 index 0000000000..1528d7c365 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/operator_space_to_parallel_tensor_space_mapping.cc @@ -0,0 +1,134 @@ +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.h" +#include "op-attrs/parallel_tensor_dim_idx_t.h" +#include "utils/orthotope/up_projection.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE( + "get_identity_mapping(OperatorTaskSpace, ParallelTensorDimDegrees)") { + ParallelTensorDimDegrees dim_degrees = ParallelTensorDimDegrees{ + /*sum_degree=*/SumDegree{2_p}, + /*discard_copy_degree=*/DiscardCopyDegree{1_p}, + /*shard_degrees=*/ + FFOrdered{ + 1_p, + 3_p, + 1_p, + }, + }; + + OperatorTaskSpace operator_task_space = OperatorTaskSpace{MinimalOrthotope{{ + 3_ge2, + 2_ge2, + }}}; + + OperatorSpaceToParallelTensorSpaceMapping result = + get_identity_mapping(operator_task_space, dim_degrees); + + auto make_op_coord = [](nonnegative_int x, nonnegative_int y) { + return DimCoord{{ + {operator_task_space_dim_idx_t{0_n}, x}, + {operator_task_space_dim_idx_t{1_n}, y}, + }}; + }; + + auto make_pt_coord = [](nonnegative_int sum_coord_entry, + nonnegative_int shard_coord_entry) { + return DimCoord{{ + {sum_dim_idx(), sum_coord_entry}, + {discard_copy_dim_idx(), 0_n}, + {shard_dim_idx(ff_dim_t{0_n}), 0_n}, + {shard_dim_idx(ff_dim_t{1_n}), shard_coord_entry}, + {shard_dim_idx(ff_dim_t{2_n}), 0_n}, + }}; + }; + + OperatorSpaceToParallelTensorSpaceMapping correct = + OperatorSpaceToParallelTensorSpaceMapping{ + DimDomainMapping{ + /*coord_mapping=*/bidict< + DimCoord, + DimCoord>{ + {make_op_coord(0_n, 0_n), make_pt_coord(0_n, 0_n)}, + {make_op_coord(0_n, 1_n), make_pt_coord(1_n, 0_n)}, + {make_op_coord(1_n, 0_n), make_pt_coord(0_n, 1_n)}, + {make_op_coord(1_n, 1_n), make_pt_coord(1_n, 1_n)}, + {make_op_coord(2_n, 0_n), make_pt_coord(0_n, 2_n)}, + {make_op_coord(2_n, 1_n), make_pt_coord(1_n, 2_n)}, + }, + /*l_domain=*/ + DimDomain{{ + {operator_task_space_dim_idx_t{0_n}, 3_p}, + {operator_task_space_dim_idx_t{1_n}, 2_p}, + }}, + /*r_domain=*/ + DimDomain{{ + {sum_dim_idx(), 2_p}, + {discard_copy_dim_idx(), 1_p}, + {shard_dim_idx(ff_dim_t{0_n}), 1_p}, + {shard_dim_idx(ff_dim_t{1_n}), 3_p}, + {shard_dim_idx(ff_dim_t{2_n}), 1_p}, + }}, + }, + }; + + CHECK(result == correct); + } + + TEST_CASE("ptensor_coord_for_task_space_coord") { + SUBCASE("identity projection") { + OperatorTaskSpace op_task_space = OperatorTaskSpace{ + MinimalOrthotope{{ + 5_ge2, + 3_ge2, + 12_ge2, + 2_ge2, + }}, + }; + + ParallelTensorDimDegrees dim_degrees = ParallelTensorDimDegrees{ + /*sum_degree=*/SumDegree{5_p}, + /*discard_copy_degree=*/DiscardCopyDegree{3_p}, + /*shard_degrees=*/ + FFOrdered{ + 12_p, + 2_p, + }, + }; + + OperatorSpaceToParallelTensorSpaceMapping mapping = + get_identity_mapping(op_task_space, dim_degrees); + + TaskSpaceCoordinate task_space_coordinate = TaskSpaceCoordinate{ + OrthotopeCoord{ + std::vector{ + 3_n, + 2_n, + 10_n, + 1_n, + }, + }, + }; + + ParallelTensorSpaceCoordinate result = ptensor_coord_for_task_space_coord( + /*mapping=*/mapping, + /*task_space_coord=*/task_space_coordinate, + /*num_dims=*/num_ptensor_shard_dims_t{2_n}); + + ParallelTensorSpaceCoordinate correct = ParallelTensorSpaceCoordinate{ + /*sum_component=*/3_n, + /*discard_copy_component=*/2_n, + /*shard_components=*/ + FFOrdered{ + 10_n, + 1_n, + }, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/op-attrs/operator_task_space.cc b/lib/op-attrs/test/src/op-attrs/operator_task_space.cc new file mode 100644 index 0000000000..faa04a7bba --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/operator_task_space.cc @@ -0,0 +1,84 @@ +#include "op-attrs/operator_task_space.h" +#include "utils/fmt/unordered_set.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_task_space_coordinates") { + + SUBCASE("OperatorTaskSpace has 0 dimensions") { + OperatorTaskSpace task = OperatorTaskSpace{MinimalOrthotope{{}}}; + + std::unordered_set correct = { + TaskSpaceCoordinate{OrthotopeCoord{{}}}}; + std::unordered_set result = + get_task_space_coordinates(task); + CHECK(correct == result); + } + + SUBCASE("OperatorTaskSpace has 2 dimensions") { + + OperatorTaskSpace task = + OperatorTaskSpace{MinimalOrthotope{{2_ge2, 2_ge2}}}; + + std::unordered_set correct = {{ + TaskSpaceCoordinate{OrthotopeCoord{{0_n, 0_n}}}, + TaskSpaceCoordinate{OrthotopeCoord{{0_n, 1_n}}}, + TaskSpaceCoordinate{OrthotopeCoord{{1_n, 0_n}}}, + TaskSpaceCoordinate{OrthotopeCoord{{1_n, 1_n}}}, + }}; + std::unordered_set result = + get_task_space_coordinates(task); + CHECK(correct == result); + } + + SUBCASE("OperatorTaskSpace has 3 dimensions") { + + OperatorTaskSpace task = + OperatorTaskSpace{MinimalOrthotope{{3_ge2, 2_ge2, 2_ge2}}}; + + std::unordered_set correct = {{ + TaskSpaceCoordinate{OrthotopeCoord{{0_n, 0_n, 0_n}}}, + TaskSpaceCoordinate{OrthotopeCoord{{0_n, 0_n, 1_n}}}, + TaskSpaceCoordinate{OrthotopeCoord{{0_n, 1_n, 0_n}}}, + TaskSpaceCoordinate{OrthotopeCoord{{0_n, 1_n, 1_n}}}, + TaskSpaceCoordinate{OrthotopeCoord{{1_n, 0_n, 0_n}}}, + TaskSpaceCoordinate{OrthotopeCoord{{1_n, 0_n, 1_n}}}, + TaskSpaceCoordinate{OrthotopeCoord{{1_n, 1_n, 0_n}}}, + TaskSpaceCoordinate{OrthotopeCoord{{1_n, 1_n, 1_n}}}, + TaskSpaceCoordinate{OrthotopeCoord{{2_n, 0_n, 0_n}}}, + TaskSpaceCoordinate{OrthotopeCoord{{2_n, 0_n, 1_n}}}, + TaskSpaceCoordinate{OrthotopeCoord{{2_n, 1_n, 0_n}}}, + TaskSpaceCoordinate{OrthotopeCoord{{2_n, 1_n, 1_n}}}, + }}; + std::unordered_set result = + get_task_space_coordinates(task); + CHECK(correct == result); + } + } + + TEST_CASE("get_task_space_maximum_coordinate") { + SUBCASE("OperatorTaskSpace has 2 dimensions") { + + OperatorTaskSpace task = + OperatorTaskSpace{MinimalOrthotope{{3_ge2, 2_ge2}}}; + + TaskSpaceCoordinate correct = + TaskSpaceCoordinate{OrthotopeCoord{{2_n, 1_n}}}; + TaskSpaceCoordinate result = get_task_space_maximum_coordinate(task); + CHECK(correct == result); + } + + SUBCASE("OperatorTaskSpace has 3 dimensions") { + + OperatorTaskSpace task = + OperatorTaskSpace{MinimalOrthotope{{3_ge2, 2_ge2, 4_ge2}}}; + + TaskSpaceCoordinate correct = + TaskSpaceCoordinate{OrthotopeCoord{{2_n, 1_n, 3_n}}}; + TaskSpaceCoordinate result = get_task_space_maximum_coordinate(task); + CHECK(correct == result); + } + } +} diff --git a/lib/op-attrs/test/src/op-attrs/operator_task_space_dim_idx_t.cc b/lib/op-attrs/test/src/op-attrs/operator_task_space_dim_idx_t.cc new file mode 100644 index 0000000000..5bb0102671 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/operator_task_space_dim_idx_t.cc @@ -0,0 +1,15 @@ +#include "op-attrs/operator_task_space_dim_idx_t.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("operator_task_space_dim_idx_range(nonnegative_int)") { + SUBCASE("end is zero") { + std::set result = + operator_task_space_dim_idx_range(nonnegative_int{0}); + std::set correct = { + operator_task_space_dim_idx_t{nonnegative_int{0}}}; + } + } +} diff --git a/lib/op-attrs/test/src/op-attrs/ops/attention.cc b/lib/op-attrs/test/src/op-attrs/ops/attention.cc index a4f8cd62fd..5de69360f8 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/attention.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/attention.cc @@ -24,14 +24,26 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("without bias") { MultiHeadAttentionAttrs attrs = make_attrs(/*bias=*/false); - tl::expected, std::string> result = + std::unordered_map result = get_attention_incoming_tensor_roles(attrs); - tl::expected, std::string> correct = - std::vector{ - IncomingTensorRole::INPUT, - IncomingTensorRole::INPUT, - IncomingTensorRole::INPUT, - IncomingTensorRole::WEIGHT, + std::unordered_map correct = + std::unordered_map{ + { + TensorSlotName::KEY, + IncomingTensorRole::INPUT, + }, + { + TensorSlotName::QUERY, + IncomingTensorRole::INPUT, + }, + { + TensorSlotName::VALUE, + IncomingTensorRole::INPUT, + }, + { + TensorSlotName::WEIGHT, + IncomingTensorRole::WEIGHT, + }, }; CHECK(result == correct); @@ -40,16 +52,34 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("with bias") { MultiHeadAttentionAttrs attrs = make_attrs(/*bias=*/true); - tl::expected, std::string> result = + std::unordered_map result = get_attention_incoming_tensor_roles(attrs); - tl::expected, std::string> correct = - std::vector{ - IncomingTensorRole::INPUT, - IncomingTensorRole::INPUT, - IncomingTensorRole::INPUT, - IncomingTensorRole::WEIGHT, - IncomingTensorRole::WEIGHT, - IncomingTensorRole::WEIGHT, + std::unordered_map correct = + std::unordered_map{ + { + TensorSlotName::KEY, + IncomingTensorRole::INPUT, + }, + { + TensorSlotName::QUERY, + IncomingTensorRole::INPUT, + }, + { + TensorSlotName::VALUE, + IncomingTensorRole::INPUT, + }, + { + TensorSlotName::WEIGHT, + IncomingTensorRole::WEIGHT, + }, + { + TensorSlotName::INPUT_BIAS, + IncomingTensorRole::WEIGHT, + }, + { + TensorSlotName::OUTPUT_BIAS, + IncomingTensorRole::WEIGHT, + }, }; CHECK(result == correct); diff --git a/lib/op-attrs/test/src/op-attrs/ops/batch_matmul.cc b/lib/op-attrs/test/src/op-attrs/ops/batch_matmul.cc index d251fb731d..1044c379f0 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/batch_matmul.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/batch_matmul.cc @@ -12,9 +12,9 @@ TEST_SUITE(FF_TEST_SUITE) { positive_int p = 10_p; BatchMatmulAttrs attrs = BatchMatmulAttrs{ - /*a_seq_length_dim=*/0_n, // TODO figure out if these arguments are + /*a_seq_length_dim=*/1_p, // TODO figure out if these arguments are // still relevant - /*b_seq_length_dim=*/0_n, + /*b_seq_length_dim=*/1_p, }; TensorShape input_lhs_shape = TensorShape{ @@ -106,9 +106,9 @@ TEST_SUITE(FF_TEST_SUITE) { positive_int o_sum = 11_p; BatchMatmulAttrs attrs = BatchMatmulAttrs{ - /*a_seq_length_dim=*/0_n, // TODO figure out if these arguments are + /*a_seq_length_dim=*/0_p, // TODO figure out if these arguments are // still relevant - /*b_seq_length_dim=*/0_n, + /*b_seq_length_dim=*/0_p, }; auto make_lhs = [&](SumDegree o_sum, diff --git a/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc b/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc index b70e8fcb4e..e39649f9bd 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc @@ -21,12 +21,21 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("affine = true") { BatchNormAttrs attrs = make_attrs(/*affine=*/true); - std::vector result = + std::unordered_map result = get_batch_norm_incoming_tensor_roles(attrs); - std::vector correct = { - IncomingTensorRole::INPUT, - IncomingTensorRole::WEIGHT, - IncomingTensorRole::WEIGHT, + std::unordered_map correct = { + { + TensorSlotName::INPUT, + IncomingTensorRole::INPUT, + }, + { + TensorSlotName::GAMMA, + IncomingTensorRole::WEIGHT, + }, + { + TensorSlotName::BETA, + IncomingTensorRole::WEIGHT, + }, }; CHECK(result == correct); @@ -35,10 +44,13 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("affine = false") { BatchNormAttrs attrs = make_attrs(/*affine=*/false); - std::vector result = + std::unordered_map result = get_batch_norm_incoming_tensor_roles(attrs); - std::vector correct = { - IncomingTensorRole::INPUT, + std::unordered_map correct = { + { + TensorSlotName::INPUT, + IncomingTensorRole::INPUT, + }, }; CHECK(result == correct); diff --git a/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc b/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc deleted file mode 100644 index 3d86576279..0000000000 --- a/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc +++ /dev/null @@ -1,20 +0,0 @@ -#include "op-attrs/ops/batch_norm_attrs.dtg.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("BatchNormAttrs to/from json") { - BatchNormAttrs correct = BatchNormAttrs{ - /*relu=*/false, - /*affine=*/true, - /*eps=*/1e-5, - /*momentum=*/0.1, - }; - - nlohmann::json j = correct; - BatchNormAttrs result = j.get(); - - CHECK(result == correct); - } -} diff --git a/lib/op-attrs/test/src/op-attrs/ops/concat.cc b/lib/op-attrs/test/src/op-attrs/ops/concat.cc index 95fa7d67c7..939c799171 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/concat.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/concat.cc @@ -10,7 +10,8 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_output_shape(ConcatAttrs, std::vector)") { ConcatAttrs attrs = ConcatAttrs{ - ff_dim_t{1_n}, + /*axis=*/ff_dim_t{1_n}, + /*num_inputs=*/3_ge2, }; SUBCASE("empty input shapes list passed") { @@ -81,7 +82,8 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("concat axis is out of bounds") { attrs = ConcatAttrs{ - ff_dim_t{3_n}, + /*axis=*/ff_dim_t{3_n}, + /*num_inputs=*/3_ge2, }; std::vector input_shapes = { @@ -115,7 +117,8 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_output_shape(ConcatAttrs, std::vector)") { ConcatAttrs attrs = ConcatAttrs{ - ff_dim_t{1_n}, + /*axis=*/ff_dim_t{1_n}, + /*num_inputs=*/3_ge2, }; positive_int dim0_size = 12_p; diff --git a/lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc b/lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc index 56407c03f1..9c5cd9009b 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc @@ -1,6 +1,6 @@ #include "op-attrs/ops/conv_2d.h" -#include "doctest/doctest.h" #include "utils/integer_conversions.h" +#include using namespace ::FlexFlow; @@ -22,12 +22,21 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("with bias") { Conv2DAttrs attrs = make_attrs(/*use_bias=*/true); - std::vector result = + std::unordered_map result = get_conv2d_incoming_tensor_roles(attrs); - std::vector correct = { - IncomingTensorRole::INPUT, - IncomingTensorRole::WEIGHT, - IncomingTensorRole::WEIGHT, + std::unordered_map correct = { + { + TensorSlotName::INPUT, + IncomingTensorRole::INPUT, + }, + { + TensorSlotName::FILTER, + IncomingTensorRole::WEIGHT, + }, + { + TensorSlotName::BIAS, + IncomingTensorRole::WEIGHT, + }, }; CHECK(result == correct); @@ -36,11 +45,17 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("without bias") { Conv2DAttrs attrs = make_attrs(/*use_bias=*/false); - std::vector result = + std::unordered_map result = get_conv2d_incoming_tensor_roles(attrs); - std::vector correct = { - IncomingTensorRole::INPUT, - IncomingTensorRole::WEIGHT, + std::unordered_map correct = { + { + TensorSlotName::INPUT, + IncomingTensorRole::INPUT, + }, + { + TensorSlotName::FILTER, + IncomingTensorRole::WEIGHT, + }, }; CHECK(result == correct); diff --git a/lib/op-attrs/test/src/op-attrs/ops/element_binary.cc b/lib/op-attrs/test/src/op-attrs/ops/element_binary.cc index 72d499d20e..877284b511 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/element_binary.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/element_binary.cc @@ -44,12 +44,7 @@ TEST_SUITE(FF_TEST_SUITE) { TensorShape incorrect_rhs = input_lhs; dim_at_idx(incorrect_rhs.dims, relative_ff_dim_t{0}) += 1_p; - tl::expected result = - get_output_shape(attrs, input_lhs, incorrect_rhs); - - CHECK_MESSAGE(!result.has_value(), - "Unexpected successful result: ", - result.error()); + CHECK_THROWS(get_output_shape(attrs, input_lhs, incorrect_rhs)); } } @@ -146,12 +141,8 @@ TEST_SUITE(FF_TEST_SUITE) { make_lhs(SumDegree{1_p}, DiscardCopyDegree{degree}, 1_p, 1_p, 1_p); ParallelTensorShape input_rhs = make_rhs(SumDegree{1_p}, DiscardCopyDegree{degree}, 1_p, 1_p, 1_p); - tl::expected result = - get_output_shape(attrs, input_lhs, input_rhs); - CHECK_MESSAGE(!result.has_value(), - "Unexpected successful result: ", - result.error()); + CHECK_THROWS(get_output_shape(attrs, input_lhs, input_rhs)); } SUBCASE("invalid mismatched parallelism degrees") { @@ -161,12 +152,8 @@ TEST_SUITE(FF_TEST_SUITE) { make_lhs(SumDegree{1_p}, DiscardCopyDegree{1_p}, 1_p, degree, 1_p); ParallelTensorShape input_rhs = make_rhs(SumDegree{1_p}, DiscardCopyDegree{1_p}, 1_p, 1_p, degree); - tl::expected result = - get_output_shape(attrs, input_lhs, input_rhs); - CHECK_MESSAGE(!result.has_value(), - "Unexpected successful result: ", - result.error()); + CHECK_THROWS(get_output_shape(attrs, input_lhs, input_rhs)); } } } diff --git a/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc b/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc index 355feb4c5f..672b160cbd 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc @@ -56,25 +56,19 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("sum degree > 1") { positive_int degree = 2_p; - tl::expected result = get_output_shape( + CHECK_THROWS(get_output_shape( attrs, - make_input(SumDegree{degree}, DiscardCopyDegree{1_p}, 1_p, 1_p, 1_p)); - - CHECK_MESSAGE(!result.has_value(), - "Unexpected successful result: ", - result.error()); + make_input( + SumDegree{degree}, DiscardCopyDegree{1_p}, 1_p, 1_p, 1_p))); } SUBCASE("discard copy degree > 1") { positive_int degree = 2_p; - tl::expected result = get_output_shape( + CHECK_THROWS(get_output_shape( attrs, - make_input(SumDegree{1_p}, DiscardCopyDegree{degree}, 1_p, 1_p, 1_p)); - - CHECK_MESSAGE(!result.has_value(), - "Unexpected successful result: ", - result.error()); + make_input( + SumDegree{1_p}, DiscardCopyDegree{degree}, 1_p, 1_p, 1_p))); } } } diff --git a/lib/op-attrs/test/src/op-attrs/ops/flat.cc b/lib/op-attrs/test/src/op-attrs/ops/flat.cc index c4fe8a5250..8a2c609cd4 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/flat.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/flat.cc @@ -148,11 +148,7 @@ TEST_SUITE(FF_TEST_SUITE) { FFOrdered{1_p, 1_p, 2_p, 1_p}, }; - std::optional result = - optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); - std::optional correct = std::nullopt; - - CHECK(result == correct); + CHECK_THROWS(get_output_parallel_dim_degrees(attrs, input)); } SUBCASE("allows sum parallelism") { @@ -162,14 +158,13 @@ TEST_SUITE(FF_TEST_SUITE) { FFOrdered{1_p, 1_p, 1_p, 1_p}, }; - std::optional result = - optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); - std::optional correct = - ParallelTensorDimDegrees{ - SumDegree{2_p}, - DiscardCopyDegree{1_p}, - FFOrdered{1_p, 1_p, 1_p}, - }; + ParallelTensorDimDegrees result = + get_output_parallel_dim_degrees(attrs, input); + ParallelTensorDimDegrees correct = ParallelTensorDimDegrees{ + SumDegree{2_p}, + DiscardCopyDegree{1_p}, + FFOrdered{1_p, 1_p, 1_p}, + }; CHECK(result == correct); } @@ -181,14 +176,13 @@ TEST_SUITE(FF_TEST_SUITE) { FFOrdered{1_p, 1_p, 1_p, 1_p}, }; - std::optional result = - optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); - std::optional correct = - ParallelTensorDimDegrees{ - SumDegree{1_p}, - DiscardCopyDegree{2_p}, - FFOrdered{1_p, 1_p, 1_p}, - }; + ParallelTensorDimDegrees result = + get_output_parallel_dim_degrees(attrs, input); + ParallelTensorDimDegrees correct = ParallelTensorDimDegrees{ + SumDegree{1_p}, + DiscardCopyDegree{2_p}, + FFOrdered{1_p, 1_p, 1_p}, + }; CHECK(result == correct); } diff --git a/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc index ba311ffb1a..14591cb3d6 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc @@ -20,12 +20,21 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("elementwise_affine = true") { LayerNormAttrs attrs = make_attrs(/*elementwise_affine=*/true); - std::vector result = + std::unordered_map result = get_layer_norm_incoming_tensor_roles(attrs); - std::vector correct = { - IncomingTensorRole::INPUT, - IncomingTensorRole::WEIGHT, - IncomingTensorRole::WEIGHT, + std::unordered_map correct = { + { + TensorSlotName::INPUT, + IncomingTensorRole::INPUT, + }, + { + TensorSlotName::GAMMA, + IncomingTensorRole::WEIGHT, + }, + { + TensorSlotName::BETA, + IncomingTensorRole::WEIGHT, + }, }; CHECK(result == correct); @@ -34,10 +43,13 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("elementwise_affine = false") { LayerNormAttrs attrs = make_attrs(/*elementwise_affine=*/false); - std::vector result = + std::unordered_map result = get_layer_norm_incoming_tensor_roles(attrs); - std::vector correct = { - IncomingTensorRole::INPUT, + std::unordered_map correct = { + { + TensorSlotName::INPUT, + IncomingTensorRole::INPUT, + }, }; CHECK(result == correct); diff --git a/lib/op-attrs/test/src/op-attrs/ops/linear.cc b/lib/op-attrs/test/src/op-attrs/ops/linear.cc index 4e0dd149ab..c46e36bf7b 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/linear.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/linear.cc @@ -21,12 +21,21 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("use_bias = true") { LinearAttrs attrs = make_attrs(/*use_bias=*/true); - std::vector result = + std::unordered_map result = get_linear_incoming_tensor_roles(attrs); - std::vector correct = { - IncomingTensorRole::INPUT, - IncomingTensorRole::WEIGHT, - IncomingTensorRole::WEIGHT, + std::unordered_map correct = { + { + TensorSlotName::INPUT, + IncomingTensorRole::INPUT, + }, + { + TensorSlotName::WEIGHT, + IncomingTensorRole::WEIGHT, + }, + { + TensorSlotName::BIAS, + IncomingTensorRole::WEIGHT, + }, }; CHECK(result == correct); @@ -35,11 +44,17 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("use_bias = false") { LinearAttrs attrs = make_attrs(/*use_bias=*/false); - std::vector result = + std::unordered_map result = get_linear_incoming_tensor_roles(attrs); - std::vector correct = { - IncomingTensorRole::INPUT, - IncomingTensorRole::WEIGHT, + std::unordered_map correct = { + { + TensorSlotName::INPUT, + IncomingTensorRole::INPUT, + }, + { + TensorSlotName::WEIGHT, + IncomingTensorRole::WEIGHT, + }, }; CHECK(result == correct); @@ -288,4 +303,30 @@ TEST_SUITE(FF_TEST_SUITE) { } } } + + TEST_CASE("get_operator_to_input_mapping(LinearAttrs, nonnegative_int)") { + LinearAttrs attrs = LinearAttrs{ + /*out_channels=*/16_p, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*activation=*/Activation::RELU, + /*regularizer=*/std::nullopt, + }; + + ParallelTensorDimDegrees input_dims = ParallelTensorDimDegrees{ + /*sum_degree=*/SumDegree{2_p}, + /*discard_copy_dedgree=*/DiscardCopyDegree{1_p}, + /*shard_degrees=*/ + FFOrdered{ + 1_p, + 1_p, + }, + }; + + OperatorSpaceToParallelTensorSpaceMapping result = + get_operator_to_input_mapping(attrs, input_dims); + + // TODO(@lockshaw): implement some actual checks here + NOT_IMPLEMENTED(); + } } diff --git a/lib/op-attrs/test/src/op-attrs/parallel_tensor_dim_degrees.cc b/lib/op-attrs/test/src/op-attrs/parallel_tensor_dim_degrees.cc new file mode 100644 index 0000000000..6d0e072db5 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/parallel_tensor_dim_degrees.cc @@ -0,0 +1,161 @@ +#include "op-attrs/parallel_tensor_dim_degrees.h" +#include "op-attrs/parallel_tensor_dim_idx_t.h" +#include "test/utils/doctest/fmt/set.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include + +using namespace ::FlexFlow; + +static parallel_tensor_dim_idx_t shard_dim_idx_from_raw(int idx) { + return parallel_tensor_dim_idx_t{ff_dim_t{nonnegative_int{idx}}}; +} + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_parallel_tensor_degree_map") { + ParallelTensorDimDegrees degrees = ParallelTensorDimDegrees{ + SumDegree{3_p}, + DiscardCopyDegree{1_p}, + FFOrdered{ + 1_p, + 2_p, + 1_p, + }, + }; + + std::unordered_map result = + get_parallel_tensor_degree_map(degrees); + std::unordered_map correct = { + {parallel_tensor_dim_idx_t{ReplicaType::SUM}, 3_p}, + {parallel_tensor_dim_idx_t{ReplicaType::DISCARD_COPY}, 1_p}, + {shard_dim_idx_from_raw(0), 1_p}, + {shard_dim_idx_from_raw(1), 2_p}, + {shard_dim_idx_from_raw(2), 1_p}, + }; + + CHECK(result == correct); + } + + TEST_CASE("get_parallel_tensor_space_coordinates") { + ParallelTensorDimDegrees degrees = ParallelTensorDimDegrees{ + SumDegree{3_p}, + DiscardCopyDegree{1_p}, + FFOrdered{ + 1_p, + 2_p, + 1_p, + }, + }; + + std::unordered_set result = + get_parallel_tensor_space_coordinates(degrees); + std::unordered_set correct = { + ParallelTensorSpaceCoordinate{ + /*sum_idx=*/0_n, + /*discard_copy_idx=*/0_n, + /*shard_idxs=*/FFOrdered{0_n, 0_n, 0_n}, + }, + ParallelTensorSpaceCoordinate{ + /*sum_idx=*/1_n, + /*discard_copy_idx=*/0_n, + /*shard_idxs=*/FFOrdered{0_n, 0_n, 0_n}, + }, + ParallelTensorSpaceCoordinate{ + /*sum_idx=*/2_n, + /*discard_copy_idx=*/0_n, + /*shard_idxs=*/FFOrdered{0_n, 0_n, 0_n}, + }, + ParallelTensorSpaceCoordinate{ + /*sum_idx=*/0_n, + /*discard_copy_idx=*/0_n, + /*shard_idxs=*/FFOrdered{0_n, 1_n, 0_n}, + }, + ParallelTensorSpaceCoordinate{ + /*sum_idx=*/1_n, + /*discard_copy_idx=*/0_n, + /*shard_idxs=*/FFOrdered{0_n, 1_n, 0_n}, + }, + ParallelTensorSpaceCoordinate{ + /*sum_idx=*/2_n, + /*discard_copy_idx=*/0_n, + /*shard_idxs=*/FFOrdered{0_n, 1_n, 0_n}, + }, + }; + + CHECK(result == correct); + } + + TEST_CASE( + "get_nontrivial_parallel_tensor_dim_indices(ParallelTensorDimDegrees)") { + SUBCASE("a replica dim has degree 1") { + ParallelTensorDimDegrees degrees = ParallelTensorDimDegrees{ + SumDegree{3_p}, + DiscardCopyDegree{1_p}, + FFOrdered{4_p, 2_p, 4_p}, + }; + + std::set result = + get_nontrivial_parallel_tensor_dim_indices(degrees); + std::set correct = { + parallel_tensor_dim_idx_t{ReplicaType::SUM}, + shard_dim_idx_from_raw(0), + shard_dim_idx_from_raw(1), + shard_dim_idx_from_raw(2), + }; + + CHECK(result == correct); + } + + SUBCASE("a shard dim has degree 1") { + ParallelTensorDimDegrees degrees = ParallelTensorDimDegrees{ + SumDegree{3_p}, + DiscardCopyDegree{2_p}, + FFOrdered{1_p, 4_p, 1_p}, + }; + + std::set result = + get_nontrivial_parallel_tensor_dim_indices(degrees); + std::set correct = { + parallel_tensor_dim_idx_t{ReplicaType::SUM}, + parallel_tensor_dim_idx_t{ReplicaType::DISCARD_COPY}, + shard_dim_idx_from_raw(1), + }; + + CHECK(result == correct); + } + + SUBCASE("no dims have degree 1") { + ParallelTensorDimDegrees degrees = ParallelTensorDimDegrees{ + SumDegree{3_p}, + DiscardCopyDegree{2_p}, + FFOrdered{4_p, 2_p, 5_p}, + }; + + std::set result = + get_nontrivial_parallel_tensor_dim_indices(degrees); + std::set correct = { + parallel_tensor_dim_idx_t{ReplicaType::SUM}, + parallel_tensor_dim_idx_t{ReplicaType::DISCARD_COPY}, + shard_dim_idx_from_raw(0), + shard_dim_idx_from_raw(1), + shard_dim_idx_from_raw(2), + }; + + CHECK(result == correct); + } + + SUBCASE("all dims have degree 1") { + ParallelTensorDimDegrees degrees = ParallelTensorDimDegrees{ + SumDegree{1_p}, + DiscardCopyDegree{1_p}, + FFOrdered{1_p, 1_p, 1_p}, + }; + + std::set result = + get_nontrivial_parallel_tensor_dim_indices(degrees); + std::set correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/op-attrs/parallel_tensor_dim_idx_t.cc b/lib/op-attrs/test/src/op-attrs/parallel_tensor_dim_idx_t.cc new file mode 100644 index 0000000000..8edb5d19a9 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/parallel_tensor_dim_idx_t.cc @@ -0,0 +1,80 @@ +#include "op-attrs/parallel_tensor_dim_idx_t.h" +#include "test/utils/doctest/fmt/vector.h" +#include "test/utils/rapidcheck.h" +#include "utils/containers/sorted_by.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_parallel_tensor_dim_ordering") { + DimOrdering ordering = + get_parallel_tensor_dim_ordering(); + + RC_SUBCASE("is antireflexive", [&](parallel_tensor_dim_idx_t const &idx) { + RC_ASSERT(!(idx < idx)); + }); + + RC_SUBCASE("is antisymmetric", + [&](parallel_tensor_dim_idx_t const &a, + parallel_tensor_dim_idx_t const &b) { + RC_PRE(a < b); + + RC_ASSERT(!(b < a)); + }); + + RC_SUBCASE("is transitive", + [&](parallel_tensor_dim_idx_t const &a, + parallel_tensor_dim_idx_t const &b, + parallel_tensor_dim_idx_t const &c) { + RC_PRE(a < b); + RC_PRE(b < c); + + RC_ASSERT(a < c); + }); + + SUBCASE("sum is less than discard") { + bool result = ordering.lt(sum_dim_idx(), discard_copy_dim_idx()); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("discard is less than shard dim") { + bool result = + ordering.lt(discard_copy_dim_idx(), shard_dim_idx(ff_dim_t{0_n})); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("shard dim 0 is less than shard dim 1") { + bool result = ordering.lt(shard_dim_idx(ff_dim_t{0_n}), + shard_dim_idx(ff_dim_t{1_n})); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("properly sorts a set of dimensions") { + std::unordered_set input = { + sum_dim_idx(), + shard_dim_idx(ff_dim_t{1_n}), + shard_dim_idx(ff_dim_t{0_n}), + discard_copy_dim_idx(), + }; + + std::vector result = + sorted_by(input, get_parallel_tensor_dim_ordering().lt); + + std::vector correct = { + sum_dim_idx(), + discard_copy_dim_idx(), + shard_dim_idx(ff_dim_t{0_n}), + shard_dim_idx(ff_dim_t{1_n}), + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/op-attrs/regularizer_attrs.cc b/lib/op-attrs/test/src/op-attrs/regularizer_attrs.cc deleted file mode 100644 index 6e172d1e8e..0000000000 --- a/lib/op-attrs/test/src/op-attrs/regularizer_attrs.cc +++ /dev/null @@ -1,13 +0,0 @@ -#include "op-attrs/regularizer_attrs.dtg.h" -#include "test/utils/rapidcheck.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("Arbitrary") { - RC_SUBCASE([](RegularizerAttrs reg) { - RC_ASSERT(reg.has() || reg.has()); - }); - } -} diff --git a/lib/op-attrs/test/src/op-attrs/relative_ff_dim_t.cc b/lib/op-attrs/test/src/op-attrs/relative_ff_dim_t.cc index e3f3f4534e..944bef5bf7 100644 --- a/lib/op-attrs/test/src/op-attrs/relative_ff_dim_t.cc +++ b/lib/op-attrs/test/src/op-attrs/relative_ff_dim_t.cc @@ -5,7 +5,7 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("ff_dim_t_from_relative_ff_dim_t") { - nonnegative_int input_dim = 5_n; + num_tensor_dims_t input_dim = num_tensor_dims_t{5_n}; SUBCASE("relative index is zero") { relative_ff_dim_t relative_ff_dim = relative_ff_dim_t{0}; diff --git a/lib/op-attrs/test/src/op-attrs/tensor_dim_permutation.cc b/lib/op-attrs/test/src/op-attrs/tensor_dim_permutation.cc new file mode 100644 index 0000000000..dac99762b4 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/tensor_dim_permutation.cc @@ -0,0 +1,84 @@ +#include "op-attrs/tensor_dim_permutation.h" +#include "test/utils/rapidcheck/doctest.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("TensorDimPermutation") { + SUBCASE("fails if constructed with a non-contiguous key set") { + CHECK_THROWS(TensorDimPermutation{bidict{ + {ff_dim_t{2_n}, ff_dim_t{0_n}}, + {ff_dim_t{0_n}, ff_dim_t{1_n}}, + }}); + } + + SUBCASE("fails if constructed with a key set that doesn't start at 1") { + CHECK_THROWS(TensorDimPermutation{bidict{ + {ff_dim_t{0_n}, ff_dim_t{1_n}}, + {ff_dim_t{1_n}, ff_dim_t{2_n}}, + }}); + } + + SUBCASE("can be constructed with empty bidict") { + TensorDimPermutation p = + TensorDimPermutation{bidict{}}; + CHECK(p.num_tensor_dims() == num_tensor_dims_t{0_n}); + } + + SUBCASE("can be constructed with non-empty bidict") { + bidict b = bidict{ + {ff_dim_t{0_n}, ff_dim_t{2_n}}, + {ff_dim_t{1_n}, ff_dim_t{3_n}}, + {ff_dim_t{3_n}, ff_dim_t{0_n}}, + {ff_dim_t{2_n}, ff_dim_t{1_n}}, + }; + + TensorDimPermutation p = TensorDimPermutation{b}; + + SUBCASE("at_l") { + SUBCASE("key is present") { + ff_dim_t result = p.at_l(ff_dim_t{1_n}); + ff_dim_t correct = ff_dim_t{3_n}; + + CHECK(result == correct); + } + + SUBCASE("key is not present") { + CHECK_THROWS(p.at_l(ff_dim_t{4_n})); + } + } + + SUBCASE("at_r") { + SUBCASE("key is present") { + ff_dim_t result = p.at_r(ff_dim_t{1_n}); + ff_dim_t correct = ff_dim_t{2_n}; + + CHECK(result == correct); + } + + SUBCASE("key is not present") { + CHECK_THROWS(p.at_r(ff_dim_t{4_n})); + } + } + + SUBCASE("num_tensor_dims") { + num_tensor_dims_t result = p.num_tensor_dims(); + num_tensor_dims_t correct = num_tensor_dims_t{4_n}; + + CHECK(result == correct); + } + + SUBCASE("as_bidict") { + bidict result = p.as_bidict(); + bidict correct = b; + + CHECK(result == correct); + } + } + } + + TEST_CASE("Arbitrary") { + RC_SUBCASE([](TensorDimPermutation) {}); + } +} diff --git a/lib/pcg/include/pcg/cg_operator_plus_signature.struct.toml b/lib/pcg/include/pcg/cg_operator_plus_signature.struct.toml deleted file mode 100644 index f4714a87c8..0000000000 --- a/lib/pcg/include/pcg/cg_operator_plus_signature.struct.toml +++ /dev/null @@ -1,23 +0,0 @@ -namespace = "FlexFlow" -name = "CGOperatorPlusSignature" -features = [ - "eq", - "ord", - "hash", - "fmt", - "json", -] - -includes = [ - "op-attrs/computation_graph_op_attrs.dtg.h", - "pcg/cg_operator_tensor_shape_signature.dtg.h", - "", -] - -[[fields]] -name = "op_attrs" -type = "::FlexFlow::ComputationGraphOpAttrs" - -[[fields]] -name = "tensor_shape_signature" -type = "::FlexFlow::CGOperatorTensorShapeSignature" diff --git a/lib/pcg/include/pcg/cg_operator_tensor_shape_signature.h b/lib/pcg/include/pcg/cg_operator_tensor_shape_signature.h deleted file mode 100644 index 3629aaff43..0000000000 --- a/lib/pcg/include/pcg/cg_operator_tensor_shape_signature.h +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_CG_OPERATOR_TENSOR_SHAPE_SIGNATURE_H -#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_CG_OPERATOR_TENSOR_SHAPE_SIGNATURE_H - -#include "pcg/cg_operator_tensor_shape_signature.dtg.h" -#include "pcg/tensor_role.dtg.h" - -namespace FlexFlow { - -std::vector - tensor_shapes_for_role(CGOperatorTensorShapeSignature const &signature, - TensorRole tensor_role); - -TensorShape tensor_shape_for_role_and_index( - CGOperatorTensorShapeSignature const &signature, - TensorRole tensor_role, - nonnegative_int index); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/cg_operator_tensor_shape_signature.struct.toml b/lib/pcg/include/pcg/cg_operator_tensor_shape_signature.struct.toml deleted file mode 100644 index a2a6c047c6..0000000000 --- a/lib/pcg/include/pcg/cg_operator_tensor_shape_signature.struct.toml +++ /dev/null @@ -1,32 +0,0 @@ -namespace = "FlexFlow" -name = "CGOperatorTensorShapeSignature" -features = [ - "eq", - "ord", - "hash", - "fmt", - "json", - "rapidcheck", -] - -includes = [ - "op-attrs/tensor_shape.dtg.h", - "", -] - -src_includes = [ - "utils/hash/vector.h", - "utils/fmt/vector.h", -] - -[[fields]] -name = "input_shapes" -type = "std::vector<::FlexFlow::TensorShape>" - -[[fields]] -name = "weight_shapes" -type = "std::vector<::FlexFlow::TensorShape>" - -[[fields]] -name = "output_shapes" -type = "std::vector<::FlexFlow::TensorShape>" diff --git a/lib/pcg/include/pcg/computation_graph.dtg.toml b/lib/pcg/include/pcg/computation_graph.dtg.toml new file mode 100644 index 0000000000..2bb26771a8 --- /dev/null +++ b/lib/pcg/include/pcg/computation_graph.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "ComputationGraph" +type = "struct" +features = [ ] + +includes = [ + "pcg/layer_attrs.dtg.h", + "pcg/tensor_attrs.dtg.h", + "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::LabelledKwargDataflowGraph<::FlexFlow::LayerAttrs, ::FlexFlow::TensorAttrs, ::FlexFlow::TensorSlotName>" diff --git a/lib/pcg/include/pcg/computation_graph.h b/lib/pcg/include/pcg/computation_graph.h index 60e825c11a..cd42328fd1 100644 --- a/lib/pcg/include/pcg/computation_graph.h +++ b/lib/pcg/include/pcg/computation_graph.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_H #include "op-attrs/incoming_tensor_role.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" #include "pcg/computation_graph.dtg.h" #include "pcg/computation_graph/computation_graph_edge.dtg.h" #include "pcg/computation_graph/layer_added_result.dtg.h" @@ -18,9 +19,10 @@ std::unordered_set get_layers(ComputationGraph const &); LayerAddedResult add_layer( ComputationGraph &computation_graph, LayerAttrs const &attrs, - std::vector const &inputs, - std::vector const &weights, - std::optional> const &outputs = std::nullopt); + std::unordered_map const &inputs, + std::unordered_map const &weights, + std::optional> const + &outputs = std::nullopt); LayerAddedResult add_input_layer(ComputationGraph &computation_graph, TensorShape const &tensor_shape); @@ -34,20 +36,20 @@ bool are_tensor_guid_shapes_equivalent(ComputationGraph const &cg, std::vector topological_ordering(ComputationGraph const &cg); -std::vector get_outgoing_tensors(ComputationGraph const &cg, - layer_guid_t n); +std::unordered_map + get_outgoing_tensors(ComputationGraph const &cg, layer_guid_t n); -std::vector get_incoming_tensors(ComputationGraph const &cg, - layer_guid_t n); +std::unordered_map + get_incoming_tensors(ComputationGraph const &cg, layer_guid_t n); -std::vector get_incoming_inputs(ComputationGraph const &, - layer_guid_t const &); +std::unordered_map + get_incoming_inputs(ComputationGraph const &, layer_guid_t const &); -std::vector get_incoming_input_shapes(ComputationGraph const &, - layer_guid_t const &); +std::unordered_map + get_incoming_input_shapes(ComputationGraph const &, layer_guid_t const &); -std::vector get_incoming_weights(ComputationGraph const &, - layer_guid_t const &); +std::unordered_map + get_incoming_weights(ComputationGraph const &, layer_guid_t const &); std::unordered_set get_all_tensors(ComputationGraph const &); std::unordered_map diff --git a/lib/pcg/include/pcg/computation_graph.struct.toml b/lib/pcg/include/pcg/computation_graph.struct.toml deleted file mode 100644 index 3e7a3cb9f1..0000000000 --- a/lib/pcg/include/pcg/computation_graph.struct.toml +++ /dev/null @@ -1,13 +0,0 @@ -namespace = "FlexFlow" -name = "ComputationGraph" -features = [ ] - -includes = [ - "pcg/layer_attrs.dtg.h", - "pcg/tensor_attrs.dtg.h", - "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h", -] - -[[fields]] -name = "raw_graph" -type = "::FlexFlow::LabelledDataflowGraph<::FlexFlow::LayerAttrs, ::FlexFlow::TensorAttrs>" diff --git a/lib/pcg/include/pcg/computation_graph/computation_graph_edge.dtg.toml b/lib/pcg/include/pcg/computation_graph/computation_graph_edge.dtg.toml new file mode 100644 index 0000000000..0ba06ff16c --- /dev/null +++ b/lib/pcg/include/pcg/computation_graph/computation_graph_edge.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "ComputationGraphEdge" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge.dtg.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::KwargDataflowEdge<::FlexFlow::TensorSlotName>" diff --git a/lib/pcg/include/pcg/computation_graph/computation_graph_edge.struct.toml b/lib/pcg/include/pcg/computation_graph/computation_graph_edge.struct.toml deleted file mode 100644 index 311c47d277..0000000000 --- a/lib/pcg/include/pcg/computation_graph/computation_graph_edge.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "ComputationGraphEdge" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/dataflow_graph/dataflow_edge.dtg.h", -] - -[[fields]] -name = "raw_edge" -type = "::FlexFlow::DataflowEdge" diff --git a/lib/pcg/include/pcg/computation_graph/layer_added_result.dtg.toml b/lib/pcg/include/pcg/computation_graph/layer_added_result.dtg.toml new file mode 100644 index 0000000000..17256abe5a --- /dev/null +++ b/lib/pcg/include/pcg/computation_graph/layer_added_result.dtg.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "LayerAddedResult" +type = "struct" +features = [ + "eq", + "fmt", +] + +includes = [ + "pcg/layer_guid_t.dtg.h", + "pcg/tensor_guid_t.dtg.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", +] + +[[fields]] +name = "layer" +type = "::FlexFlow::layer_guid_t" + +[[fields]] +name = "outputs" +type = "std::unordered_map<::FlexFlow::TensorSlotName, ::FlexFlow::tensor_guid_t>" diff --git a/lib/pcg/include/pcg/computation_graph/layer_added_result.struct.toml b/lib/pcg/include/pcg/computation_graph/layer_added_result.struct.toml deleted file mode 100644 index d7b669fb3a..0000000000 --- a/lib/pcg/include/pcg/computation_graph/layer_added_result.struct.toml +++ /dev/null @@ -1,20 +0,0 @@ -namespace = "FlexFlow" -name = "LayerAddedResult" -features = [ - "eq", - "fmt", -] - -includes = [ - "pcg/layer_guid_t.dtg.h", - "pcg/tensor_guid_t.dtg.h", - "utils/fmt/vector.h" -] - -[[fields]] -name = "layer" -type = "::FlexFlow::layer_guid_t" - -[[fields]] -name = "outputs" -type = "std::vector<::FlexFlow::tensor_guid_t>" diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index d90898716f..064a4dd20d 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -255,11 +255,12 @@ struct ComputationGraphBuilder { TensorShape get_shape(tensor_guid_t const &) const; private: - std::vector add_layer( + std::unordered_map add_layer( LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - std::optional> const &outputs = std::nullopt); + std::unordered_map const &inputs, + std::unordered_map const &weights, + std::optional> const + &outputs = std::nullopt); tensor_guid_t broadcast(tensor_guid_t const &, TensorDims const &, std::string const &); diff --git a/lib/pcg/include/pcg/cost_values.h b/lib/pcg/include/pcg/cost_values.h deleted file mode 100644 index ddb942b87d..0000000000 --- a/lib/pcg/include/pcg/cost_values.h +++ /dev/null @@ -1,10 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_COST_VALUES_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_COST_VALUES_H - -namespace FlexFlow { - -struct CostValues; - -} - -#endif diff --git a/lib/pcg/include/pcg/cpu_id_t.dtg.toml b/lib/pcg/include/pcg/cpu_id_t.dtg.toml new file mode 100644 index 0000000000..a5182827b3 --- /dev/null +++ b/lib/pcg/include/pcg/cpu_id_t.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "cpu_id_t" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "utils/nonnegative_int/nonnegative_int.h", +] + +[[fields]] +name = "cpu_index" +type = "::FlexFlow::nonnegative_int" diff --git a/lib/pcg/include/pcg/cpu_id_t.struct.toml b/lib/pcg/include/pcg/cpu_id_t.struct.toml deleted file mode 100644 index 152debbded..0000000000 --- a/lib/pcg/include/pcg/cpu_id_t.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "cpu_id_t" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "utils/nonnegative_int/nonnegative_int.h", -] - -[[fields]] -name = "cpu_index" -type = "::FlexFlow::nonnegative_int" diff --git a/lib/pcg/include/pcg/create_grad.dtg.toml b/lib/pcg/include/pcg/create_grad.dtg.toml new file mode 100644 index 0000000000..56aec96f6b --- /dev/null +++ b/lib/pcg/include/pcg/create_grad.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "CreateGrad" +type = "enum" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "YES" + +[[values]] +name = "NO" diff --git a/lib/pcg/include/pcg/create_grad.enum.toml b/lib/pcg/include/pcg/create_grad.enum.toml deleted file mode 100644 index 20febe49fb..0000000000 --- a/lib/pcg/include/pcg/create_grad.enum.toml +++ /dev/null @@ -1,14 +0,0 @@ -namespace = "FlexFlow" -name = "CreateGrad" -features = [ - "hash", - "json", - "rapidcheck", - "fmt", -] - -[[values]] -name = "YES" - -[[values]] -name = "NO" diff --git a/lib/pcg/include/pcg/device_id_t.dtg.toml b/lib/pcg/include/pcg/device_id_t.dtg.toml new file mode 100644 index 0000000000..4efcb07975 --- /dev/null +++ b/lib/pcg/include/pcg/device_id_t.dtg.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "device_id_t" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "pcg/cpu_id_t.dtg.h", + "pcg/gpu_id_t.dtg.h", +] + +[[values]] +type = "::FlexFlow::gpu_id_t" +key = "gpu" + +[[values]] +type = "::FlexFlow::cpu_id_t" +key = "cpu" diff --git a/lib/pcg/include/pcg/device_id_t.h b/lib/pcg/include/pcg/device_id_t.h new file mode 100644 index 0000000000..e8e605b068 --- /dev/null +++ b/lib/pcg/include/pcg/device_id_t.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DEVICE_ID_T_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DEVICE_ID_T_H + +#include "pcg/device_id_t.dtg.h" +#include "pcg/device_type.dtg.h" + +namespace FlexFlow { + +device_id_t make_device_id_t_from_idx(nonnegative_int idx, + DeviceType device_type); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/device_id_t.variant.toml b/lib/pcg/include/pcg/device_id_t.variant.toml deleted file mode 100644 index 71af18919f..0000000000 --- a/lib/pcg/include/pcg/device_id_t.variant.toml +++ /dev/null @@ -1,22 +0,0 @@ -namespace = "FlexFlow" -name = "device_id_t" -features = [ - "eq", - "ord", - "hash", - "json", - "fmt", -] - -includes = [ - "pcg/cpu_id_t.dtg.h", - "pcg/gpu_id_t.dtg.h", -] - -[[values]] -type = "::FlexFlow::gpu_id_t" -key = "gpu" - -[[values]] -type = "::FlexFlow::cpu_id_t" -key = "cpu" diff --git a/lib/pcg/include/pcg/device_type.dtg.toml b/lib/pcg/include/pcg/device_type.dtg.toml new file mode 100644 index 0000000000..8649183fe4 --- /dev/null +++ b/lib/pcg/include/pcg/device_type.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "DeviceType" +type = "enum" +features = [ + "hash", + "json", + "fmt", + "rapidcheck", +] + +[[values]] +name = "GPU" + +[[values]] +name = "CPU" diff --git a/lib/pcg/include/pcg/device_type.enum.toml b/lib/pcg/include/pcg/device_type.enum.toml deleted file mode 100644 index 67f89fbc6f..0000000000 --- a/lib/pcg/include/pcg/device_type.enum.toml +++ /dev/null @@ -1,14 +0,0 @@ -namespace = "FlexFlow" -name = "DeviceType" -features = [ - "hash", - "json", - "fmt", - "rapidcheck", -] - -[[values]] -name = "GPU" - -[[values]] -name = "CPU" diff --git a/lib/pcg/include/pcg/file_format/keyed_variant.h b/lib/pcg/include/pcg/file_format/keyed_variant.h deleted file mode 100644 index 5e29d8c252..0000000000 --- a/lib/pcg/include/pcg/file_format/keyed_variant.h +++ /dev/null @@ -1,132 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_KEYED_VARIANT_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_KEYED_VARIANT_H - -#include "utils/json/is_jsonable.h" -#include "utils/sequence.h" -#include "utils/strong_typedef.h" -#include "utils/variant.h" -#include - -namespace FlexFlow { - -template -struct KeyedVariant { - KeyedVariant() = delete; - KeyedVariant(Variant const &v) : v(v) {} - - Variant v; - - friend bool operator==(KeyedVariant const &lhs, KeyedVariant const &rhs) { - return lhs.v == rhs.v; - } - - friend bool operator!=(KeyedVariant const &lhs, KeyedVariant const &rhs) { - return lhs.v != rhs.v; - } - - friend bool operator<(KeyedVariant const &lhs, KeyedVariant const &rhs) { - return lhs.v < rhs.v; - } -}; - -struct ToJsonFunctor { - ToJsonFunctor(nlohmann::json &j) : j(j) {} - - nlohmann::json &j; - - template - void operator()(T const &t) { - static_assert(is_jsonable::value, ""); - - j = t; - } -}; - -template -void to_json(nlohmann::json &j, KeyedVariant const &v) { - static_assert(is_jsonable::value, ""); - - K key = static_cast(v.value.index()); - j["type"] = key; - nlohmann::json &jj = j["value"]; - visit(ToJsonFunctor{j["value"]}, v.value); -} - -template -struct FromJsonFunctor { - FromJsonFunctor(nlohmann::json const &j, int idx) : j(j), idx(idx) {} - - nlohmann::json const &j; - int idx; - - template - void operator()(T &t) { - if (idx == index_of_type::value) { - t = j.get(); - } - } -}; - -template -std::string get_json_name(T const &t) { - return nlohmann::json{t}.get(); -} - -template -struct FromJsonMoveOnlyFunctor { - FromJsonMoveOnlyFunctor(nlohmann::json const &j, Key const &key) : j(j) {} - - nlohmann::json const &j; - Key const &key; - - template - Variant operator()(std::integral_constant const &) const { - return j.get::type>(); - } -}; - -template -Variant from_json_moveonly(nlohmann::json const &j, K const &key) { - FromJsonMoveOnlyFunctor func(j); - return seq_get(func, idx, seq_count_t::value>{}); -} - -template -typename std::enable_if::value>::type - from_json(nlohmann::json const &j, KeyedVariant &v) { - K key = j.at("type").get(); - std::string key_string = j.at("type").get(); - - visit(FromJsonFunctor{j.at("value"), key_string}, v.value); -} - -template -KeyedVariant keyed_variant_from_json(nlohmann::json const &j) { - K key = j.at("type").get(); - - return KeyedVariant{ - from_json_moveonly(j, static_cast(key))}; -} - -} // namespace FlexFlow - -namespace nlohmann { - -template -struct adl_serializer<::FlexFlow::KeyedVariant> { - static void to_json(json &j, ::FlexFlow::KeyedVariant const &v) { - return ::FlexFlow::to_json(v); - } - - static ::FlexFlow::KeyedVariant from_json(json const &j) { - return ::FlexFlow::keyed_variant_from_json(j); - } -}; - -} // namespace nlohmann - -namespace FlexFlow { -static_assert(is_jsonable>>::value, ""); -} - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/data_type_value.h b/lib/pcg/include/pcg/file_format/v1/data_type_value.h deleted file mode 100644 index dae0ccb368..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/data_type_value.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_DATA_TYPE_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_DATA_TYPE_H - -#include "utils/half.h" -#include - -namespace FlexFlow { - -using V1DataTypeValue = - std::variant; - -} // namespace FlexFlow - -namespace nlohmann { - -template <> -struct adl_serializer { - static void to_json(json &j, half h) { - j = json{static_cast(h)}; - } - - static void from_json(json const &j, half &h) { - h = static_cast(j.get()); - } -}; - -} // namespace nlohmann - -namespace FlexFlow { -static_assert(is_jsonable::value, ""); -static_assert(is_json_serializable::value, ""); -static_assert(is_json_deserializable::value, ""); -static_assert(is_jsonable::value, ""); -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.h deleted file mode 100644 index 9554995fa0..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_MULTIDIGRAPH_H -#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_MULTIDIGRAPH_H - -#include "pcg/file_format/v1/graphs/v1_dataflow_graph.dtg.h" -#include "utils/graph/dataflow_graph/dataflow_graph_view.h" - -namespace FlexFlow { - -V1DataflowGraph to_v1(DataflowGraphView const &); -V1DataflowGraph to_v1(DataflowGraphView const &, - std::unordered_map const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml deleted file mode 100644 index 57b559a18e..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml +++ /dev/null @@ -1,32 +0,0 @@ -namespace = "FlexFlow" -name = "V1DataflowGraph" -features = [ - "eq", - # "ord", - "hash", - "json", - # "rapidcheck", - "fmt", -] - -includes = [ - "", - "", - "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h", - "utils/nonnegative_int/nonnegative_int.h", -] - -src_includes = [ - "utils/fmt/vector.h", - "utils/hash/vector.h", - "utils/fmt/unordered_set.h", - "utils/hash/unordered_set.h", -] - -[[fields]] -name = "nodes" -type = "std::vector<::FlexFlow::nonnegative_int>" - -[[fields]] -name = "edges" -type = "std::unordered_set<::FlexFlow::V1GraphEdge>" diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.dtg.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.dtg.toml new file mode 100644 index 0000000000..082e1a8a22 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.dtg.toml @@ -0,0 +1,35 @@ +namespace = "FlexFlow" +name = "V1GraphEdge" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +template_params = [ + "SlotName", +] + +includes = [ + "utils/nonnegative_int/nonnegative_int.h", +] + +[[fields]] +name = "srcNode" +type = "::FlexFlow::nonnegative_int" + +[[fields]] +name = "srcSlot" +type = "SlotName" + +[[fields]] +name = "dstNode" +type = "::FlexFlow::nonnegative_int" + +[[fields]] +name = "dstSlot" +type = "SlotName" diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml deleted file mode 100644 index 9150c20056..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml +++ /dev/null @@ -1,30 +0,0 @@ -namespace = "FlexFlow" -name = "V1GraphEdge" -features = [ - "eq", - "ord", - "hash", - "json", - # "rapidcheck", - "fmt", -] - -includes = [ - "utils/nonnegative_int/nonnegative_int.h", -] - -[[fields]] -name = "srcNode" -type = "::FlexFlow::nonnegative_int" - -[[fields]] -name = "srcIdx" -type = "::FlexFlow::nonnegative_int" - -[[fields]] -name = "dstNode" -type = "::FlexFlow::nonnegative_int" - -[[fields]] -name = "dstIdx" -type = "::FlexFlow::nonnegative_int" diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.dtg.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.dtg.toml new file mode 100644 index 0000000000..aa198da7ad --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "V1GraphOutput" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +[[fields]] +name = "srcNode" +type = "size_t" + +[[fields]] +name = "srcIdx" +type = "size_t" diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.struct.toml deleted file mode 100644 index ba41f7e43f..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "V1GraphOutput" -features = [ - "eq", - "ord", - "hash", - "json", - # "rapidcheck", - "fmt", -] - -[[fields]] -name = "srcNode" -type = "size_t" - -[[fields]] -name = "srcIdx" -type = "size_t" diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.dtg.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.dtg.toml new file mode 100644 index 0000000000..7f07e3b194 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.dtg.toml @@ -0,0 +1,37 @@ +namespace = "FlexFlow" +name = "V1KwargDataflowGraph" +type = "struct" +features = [ + "eq", + # "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +template_params = [ + "SlotName", +] + +includes = [ + "", + "", + "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h", + "utils/nonnegative_int/nonnegative_int.h", +] + +src_includes = [ + "utils/fmt/vector.h", + "utils/hash/vector.h", + "utils/fmt/unordered_set.h", + "utils/hash/unordered_set.h", +] + +[[fields]] +name = "nodes" +type = "std::vector<::FlexFlow::nonnegative_int>" + +[[fields]] +name = "edges" +type = "std::unordered_set<::FlexFlow::V1GraphEdge>" diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.h new file mode 100644 index 0000000000..a923d18ce6 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.h @@ -0,0 +1,46 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_KWARG_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_KWARG_DATAFLOW_GRAPH_H + +#include "pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.dtg.h" +#include "utils/bidict/algorithms/bidict_from_enumerating.h" +#include "utils/containers/enumerate.h" +#include "utils/containers/sorted.h" +#include "utils/containers/values.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_edges.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" +#include "utils/graph/node/algorithms.h" +#include "utils/integer_conversions.h" + +namespace FlexFlow { + +template +V1KwargDataflowGraph + to_v1(KwargDataflowGraphView const &g) { + bidict node_enumeration_bidict = + bidict_from_enumerating(get_nodes(g)); + std::unordered_map node_enumeration = + node_enumeration_bidict.reversed().as_unordered_map(); + return to_v1(g, node_enumeration); +} + +template +V1KwargDataflowGraph + to_v1(KwargDataflowGraphView const &g, + std::unordered_map const &nodes) { + std::unordered_set> edges; + for (KwargDataflowEdge const &e : get_all_kwarg_dataflow_edges(g)) { + edges.insert(V1GraphEdge{nodes.at(e.src.node), + e.src.slot_name, + nodes.at(e.dst.node), + e.dst.slot_name}); + } + + return V1KwargDataflowGraph{ + sorted(values(nodes)), + edges, + }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h deleted file mode 100644 index 426bad5a82..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h +++ /dev/null @@ -1,49 +0,0 @@ -#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_LABELLED_DATAFLOW_GRAPH_H -#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_LABELLED_DATAFLOW_GRAPH_H - -#include "pcg/file_format/v1/graphs/v1_dataflow_graph.h" -#include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h" -#include "utils/bidict/algorithms/bidict_from_enumerating.h" -#include "utils/containers/map_values.h" -#include "utils/containers/transform.h" -#include "utils/graph/dataflow_graph/algorithms.h" -#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h" -#include "utils/graph/node/algorithms.h" - -namespace FlexFlow { - -template -std::pair, - bidict> - to_v1_including_node_numbering( - LabelledDataflowGraphView const &g) { - - bidict nodes = bidict_from_enumerating(get_nodes(g)); - - V1DataflowGraph unlabelled = to_v1(g, nodes.reversed()); - - std::unordered_map node_labels = map_values( - nodes.as_unordered_map(), [&](Node const &n) { return g.at(n); }); - - std::unordered_map> output_labels = - map_values(nodes.as_unordered_map(), [&](Node const &n) { - return transform(get_outputs(g, n), - [&](DataflowOutput const &o) { return g.at(o); }); - }); - - return { - V1LabelledDataflowGraph{ - node_labels, output_labels, unlabelled}, - nodes, - }; -} - -template -V1LabelledDataflowGraph - to_v1(LabelledDataflowGraphView const &g) { - return to_v1_including_node_numbering(g).first; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml deleted file mode 100644 index 1f69f5cd93..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml +++ /dev/null @@ -1,41 +0,0 @@ -namespace = "FlexFlow" -name = "V1LabelledDataflowGraph" -features = [ - "eq", - # "ord", - "hash", - "json", - # "rapidcheck", - "fmt", -] - -template_params = [ - "NodeLabel", - "OutputLabel", -] - -includes = [ - "", - "pcg/file_format/v1/graphs/v1_dataflow_graph.dtg.h", - "pcg/file_format/v1/graphs/v1_graph_output.dtg.h", - "utils/nonnegative_int/nonnegative_int.h", -] - -src_includes = [ - "utils/fmt/unordered_map.h", - "utils/hash/unordered_map.h", - "utils/fmt/vector.h", - "utils/hash/vector.h", -] - -[[fields]] -name = "node_labels" -type = "std::unordered_map<::FlexFlow::nonnegative_int, NodeLabel>" - -[[fields]] -name = "output_labels" -type = "std::unordered_map<::FlexFlow::nonnegative_int, std::vector>" - -[[fields]] -name = "graph" -type = "::FlexFlow::V1DataflowGraph" diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.dtg.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.dtg.toml new file mode 100644 index 0000000000..96b79928e6 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.dtg.toml @@ -0,0 +1,43 @@ +namespace = "FlexFlow" +name = "V1LabelledKwargDataflowGraph" +type = "struct" +features = [ + "eq", + # "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +template_params = [ + "NodeLabel", + "OutputLabel", + "SlotName", +] + +includes = [ + "", + "pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.dtg.h", + "pcg/file_format/v1/graphs/v1_graph_output.dtg.h", + "utils/nonnegative_int/nonnegative_int.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "node_labels" +type = "std::unordered_map<::FlexFlow::nonnegative_int, NodeLabel>" + +[[fields]] +name = "output_labels" +type = "std::unordered_map<::FlexFlow::nonnegative_int, std::unordered_map>" + +[[fields]] +name = "graph" +type = "::FlexFlow::V1KwargDataflowGraph" diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.h new file mode 100644 index 0000000000..dbe660c3a6 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.h @@ -0,0 +1,55 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_LABELLED_KWARG_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_LABELLED_KWARG_DATAFLOW_GRAPH_H + +#include "pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.h" +#include "pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.dtg.h" +#include "utils/bidict/algorithms/bidict_from_enumerating.h" +#include "utils/containers/map_values.h" +#include "utils/containers/transform.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph_view.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +template +std::pair, + bidict> + to_v1_including_node_numbering( + LabelledKwargDataflowGraphView const + &g) { + + bidict nodes = bidict_from_enumerating(get_nodes(g)); + + V1KwargDataflowGraph unlabelled = to_v1(g, nodes.reversed()); + + std::unordered_map node_labels = map_values( + nodes.as_unordered_map(), [&](Node const &n) { return g.at(n); }); + + std::unordered_map> + output_labels = map_values( + nodes.as_unordered_map(), + [&](Node const &n) -> std::unordered_map { + return map_values( + get_outgoing_kwarg_dataflow_outputs_for_node(g, n), + [&](KwargDataflowOutput const &o) { + return g.at(o); + }); + }); + + return { + V1LabelledKwargDataflowGraph{ + node_labels, output_labels, unlabelled}, + nodes, + }; +} + +template +V1LabelledKwargDataflowGraph to_v1( + LabelledKwargDataflowGraphView const &g) { + return to_v1_including_node_numbering(g).first; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.dtg.toml b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.dtg.toml new file mode 100644 index 0000000000..49a397ac04 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "V1BinaryParallelSplit" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct V1BinarySPDecomposition" +] + +post_includes = [ + "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h", +] + +[[fields]] +name = "left_child" +type = "::FlexFlow::V1BinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::V1BinarySPDecomposition" +indirect = true diff --git a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.struct.toml b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.struct.toml deleted file mode 100644 index d2d0c3bc77..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.struct.toml +++ /dev/null @@ -1,25 +0,0 @@ -namespace = "FlexFlow" -name = "V1BinaryParallelSplit" -features = [ - "eq", - "hash", - "fmt", -] - -fwd_decls = [ - "struct V1BinarySPDecomposition" -] - -post_includes = [ - "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h", -] - -[[fields]] -name = "left_child" -type = "::FlexFlow::V1BinarySPDecomposition" -indirect = true - -[[fields]] -name = "right_child" -type = "::FlexFlow::V1BinarySPDecomposition" -indirect = true diff --git a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.dtg.toml b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.dtg.toml new file mode 100644 index 0000000000..eddf3aff25 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "V1BinarySeriesSplit" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct V1BinarySPDecomposition" +] + +post_includes = [ + "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h", +] + +[[fields]] +name = "left_child" +type = "::FlexFlow::V1BinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::V1BinarySPDecomposition" +indirect = true diff --git a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.struct.toml b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.struct.toml deleted file mode 100644 index 317fa8b6ce..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.struct.toml +++ /dev/null @@ -1,25 +0,0 @@ -namespace = "FlexFlow" -name = "V1BinarySeriesSplit" -features = [ - "eq", - "hash", - "fmt", -] - -fwd_decls = [ - "struct V1BinarySPDecomposition" -] - -post_includes = [ - "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h", -] - -[[fields]] -name = "left_child" -type = "::FlexFlow::V1BinarySPDecomposition" -indirect = true - -[[fields]] -name = "right_child" -type = "::FlexFlow::V1BinarySPDecomposition" -indirect = true diff --git a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.toml b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.toml new file mode 100644 index 0000000000..703a1d0a32 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "V1BinarySPDecomposition" +type = "variant" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.dtg.h", + "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.dtg.h", + "utils/nonnegative_int/nonnegative_int.h", +] + +[[values]] +type = "::FlexFlow::V1BinarySeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::V1BinaryParallelSplit" +key = "parallel" + +[[values]] +type = "::FlexFlow::nonnegative_int" +key = "leaf" diff --git a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.variant.toml b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.variant.toml deleted file mode 100644 index bd60564465..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.variant.toml +++ /dev/null @@ -1,25 +0,0 @@ -namespace = "FlexFlow" -name = "V1BinarySPDecomposition" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.dtg.h", - "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.dtg.h", - "utils/nonnegative_int/nonnegative_int.h", -] - -[[values]] -type = "::FlexFlow::V1BinarySeriesSplit" -key = "series" - -[[values]] -type = "::FlexFlow::V1BinaryParallelSplit" -key = "parallel" - -[[values]] -type = "::FlexFlow::nonnegative_int" -key = "leaf" diff --git a/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.dtg.toml b/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.dtg.toml new file mode 100644 index 0000000000..89d81b4630 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.dtg.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "V1ComputationGraph" +type = "struct" +features = [ + "eq", + "hash", + "fmt", + "json", +] + +includes = [ + "pcg/layer_attrs.dtg.h", + "pcg/tensor_attrs.dtg.h", + "pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.dtg.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::V1LabelledKwargDataflowGraph<::FlexFlow::LayerAttrs, ::FlexFlow::TensorAttrs, ::FlexFlow::TensorSlotName>" diff --git a/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.struct.toml deleted file mode 100644 index 0d7135ec74..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "V1ComputationGraph" -features = [ - "eq", - "hash", - "fmt", - "json", -] - -includes = [ - "pcg/layer_attrs.dtg.h", - "pcg/tensor_attrs.dtg.h", - "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h", -] - -[[fields]] -name = "raw_graph" -type = "::FlexFlow::V1LabelledDataflowGraph<::FlexFlow::LayerAttrs, ::FlexFlow::TensorAttrs>" diff --git a/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.dtg.toml b/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.dtg.toml new file mode 100644 index 0000000000..8f0d80877e --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.dtg.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "V1ParallelComputationGraph" +type = "struct" +features = [ + "eq", + "hash", + "fmt", + "json", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h", + "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h", + "pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.dtg.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::V1LabelledKwargDataflowGraph<::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs, ::FlexFlow::TensorSlotName>" diff --git a/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.struct.toml deleted file mode 100644 index 16be4a9561..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "V1ParallelComputationGraph" -features = [ - "eq", - "hash", - "fmt", - "json", -] - -includes = [ - "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h", - "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h", - "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h", -] - -[[fields]] -name = "raw_graph" -type = "::FlexFlow::V1LabelledDataflowGraph<::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs>" diff --git a/lib/pcg/include/pcg/gpu_id_t.dtg.toml b/lib/pcg/include/pcg/gpu_id_t.dtg.toml new file mode 100644 index 0000000000..26867f2673 --- /dev/null +++ b/lib/pcg/include/pcg/gpu_id_t.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "gpu_id_t" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "utils/nonnegative_int/nonnegative_int.h", +] + +[[fields]] +name = "gpu_index" +type = "::FlexFlow::nonnegative_int" diff --git a/lib/pcg/include/pcg/gpu_id_t.struct.toml b/lib/pcg/include/pcg/gpu_id_t.struct.toml deleted file mode 100644 index 7a85b4c0a7..0000000000 --- a/lib/pcg/include/pcg/gpu_id_t.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "gpu_id_t" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "utils/nonnegative_int/nonnegative_int.h", -] - -[[fields]] -name = "gpu_index" -type = "::FlexFlow::nonnegative_int" diff --git a/lib/pcg/include/pcg/layer_attrs.dtg.toml b/lib/pcg/include/pcg/layer_attrs.dtg.toml new file mode 100644 index 0000000000..4fd14fc77f --- /dev/null +++ b/lib/pcg/include/pcg/layer_attrs.dtg.toml @@ -0,0 +1,31 @@ +namespace = "FlexFlow" +name = "LayerAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/computation_graph_op_attrs.dtg.h", + "utils/stack_string.h", + "", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", +] + +[[fields]] +name = "op_attrs" +type = "::FlexFlow::ComputationGraphOpAttrs" + +[[fields]] +name = "name" +type = "std::optional<::FlexFlow::stack_string>" + diff --git a/lib/pcg/include/pcg/layer_attrs.struct.toml b/lib/pcg/include/pcg/layer_attrs.struct.toml deleted file mode 100644 index 0e22bf0ccf..0000000000 --- a/lib/pcg/include/pcg/layer_attrs.struct.toml +++ /dev/null @@ -1,30 +0,0 @@ -namespace = "FlexFlow" -name = "LayerAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - # "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/computation_graph_op_attrs.dtg.h", - "utils/stack_string.h", - "", -] - -src_includes = [ - "utils/fmt/optional.h", - "utils/json/optional.h", -] - -[[fields]] -name = "op_attrs" -type = "::FlexFlow::ComputationGraphOpAttrs" - -[[fields]] -name = "name" -type = "std::optional<::FlexFlow::stack_string>" - diff --git a/lib/pcg/include/pcg/layer_guid_t.dtg.toml b/lib/pcg/include/pcg/layer_guid_t.dtg.toml new file mode 100644 index 0000000000..d73cf547da --- /dev/null +++ b/lib/pcg/include/pcg/layer_guid_t.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "layer_guid_t" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "raw_node" +type = "::FlexFlow::Node" diff --git a/lib/pcg/include/pcg/layer_guid_t.struct.toml b/lib/pcg/include/pcg/layer_guid_t.struct.toml deleted file mode 100644 index 7f820cbd6d..0000000000 --- a/lib/pcg/include/pcg/layer_guid_t.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "layer_guid_t" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/node/node.dtg.h", -] - -[[fields]] -name = "raw_node" -type = "::FlexFlow::Node" diff --git a/lib/pcg/include/pcg/machine_compute_specification.dtg.toml b/lib/pcg/include/pcg/machine_compute_specification.dtg.toml new file mode 100644 index 0000000000..00bf9a5f22 --- /dev/null +++ b/lib/pcg/include/pcg/machine_compute_specification.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "MachineComputeSpecification" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "utils/positive_int/positive_int.h", +] + +[[fields]] +name = "num_nodes" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "num_cpus_per_node" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "num_gpus_per_node" +type = "::FlexFlow::positive_int" diff --git a/lib/pcg/include/pcg/machine_compute_specification.h b/lib/pcg/include/pcg/machine_compute_specification.h new file mode 100644 index 0000000000..835e9040e0 --- /dev/null +++ b/lib/pcg/include/pcg/machine_compute_specification.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MACHINE_COMPUTE_SPECIFICATION_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MACHINE_COMPUTE_SPECIFICATION_H + +#include "pcg/device_id_t.dtg.h" +#include "pcg/device_type.dtg.h" +#include "pcg/machine_compute_specification.dtg.h" +#include "pcg/machine_space_coordinate.dtg.h" + +namespace FlexFlow { + +positive_int get_num_gpus(MachineComputeSpecification const &ms); +positive_int get_num_cpus(MachineComputeSpecification const &ms); +positive_int get_num_devices(MachineComputeSpecification const &ms, + DeviceType const &device_type); +positive_int get_num_devices_per_node(MachineComputeSpecification const &ms, + DeviceType const &device_type); + +bool is_valid_machine_space_coordinate(MachineComputeSpecification const &ms, + MachineSpaceCoordinate const &coord); + +device_id_t get_device_id(MachineComputeSpecification const &ms, + MachineSpaceCoordinate const &coord); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/machine_interconnect_specification.dtg.toml b/lib/pcg/include/pcg/machine_interconnect_specification.dtg.toml new file mode 100644 index 0000000000..88f2b98eb8 --- /dev/null +++ b/lib/pcg/include/pcg/machine_interconnect_specification.dtg.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "MachineInterconnectSpecification" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "utils/units/bytes_per_second_t.h", +] + +[[fields]] +name = "inter_node_bandwidth" +type = "::FlexFlow::bytes_per_second_t" + +[[fields]] +name = "intra_node_bandwidth" +type = "::FlexFlow::bytes_per_second_t" diff --git a/lib/pcg/include/pcg/machine_space_coordinate.dtg.toml b/lib/pcg/include/pcg/machine_space_coordinate.dtg.toml new file mode 100644 index 0000000000..41f4d563f3 --- /dev/null +++ b/lib/pcg/include/pcg/machine_space_coordinate.dtg.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "MachineSpaceCoordinate" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "pcg/device_type.dtg.h", + "utils/nonnegative_int/nonnegative_int.h", +] + +[[fields]] +name = "node_idx" +type = "::FlexFlow::nonnegative_int" + +[[fields]] +name = "device_idx" +type = "::FlexFlow::nonnegative_int" + +[[fields]] +name = "device_type" +type = "::FlexFlow::DeviceType" diff --git a/lib/pcg/include/pcg/machine_space_coordinate.struct.toml b/lib/pcg/include/pcg/machine_space_coordinate.struct.toml deleted file mode 100644 index 2528eab849..0000000000 --- a/lib/pcg/include/pcg/machine_space_coordinate.struct.toml +++ /dev/null @@ -1,27 +0,0 @@ -namespace = "FlexFlow" -name = "MachineSpaceCoordinate" -features = [ - "eq", - "ord", - "hash", - "json", - # "rapidcheck", - "fmt", -] - -includes = [ - "pcg/device_type.dtg.h", - "utils/nonnegative_int/nonnegative_int.h", -] - -[[fields]] -name = "node_idx" -type = "::FlexFlow::nonnegative_int" - -[[fields]] -name = "device_idx" -type = "::FlexFlow::nonnegative_int" - -[[fields]] -name = "device_type" -type = "::FlexFlow::DeviceType" diff --git a/lib/pcg/include/pcg/machine_space_offset.dtg.toml b/lib/pcg/include/pcg/machine_space_offset.dtg.toml new file mode 100644 index 0000000000..57f884906b --- /dev/null +++ b/lib/pcg/include/pcg/machine_space_offset.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "MachineSpaceOffset" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "pcg/device_type.dtg.h", +] + +[[fields]] +name = "node_offset" +type = "int" + +[[fields]] +name = "device_offset" +type = "int" + +[[fields]] +name = "device_type" +type = "::FlexFlow::DeviceType" diff --git a/lib/pcg/include/pcg/machine_space_offset.struct.toml b/lib/pcg/include/pcg/machine_space_offset.struct.toml deleted file mode 100644 index 3f6eab38fd..0000000000 --- a/lib/pcg/include/pcg/machine_space_offset.struct.toml +++ /dev/null @@ -1,26 +0,0 @@ -namespace = "FlexFlow" -name = "MachineSpaceOffset" -features = [ - "eq", - "ord", - "hash", - "json", - # "rapidcheck", - "fmt", -] - -includes = [ - "pcg/device_type.dtg.h", -] - -[[fields]] -name = "node_offset" -type = "int" - -[[fields]] -name = "device_offset" -type = "int" - -[[fields]] -name = "device_type" -type = "::FlexFlow::DeviceType" diff --git a/lib/pcg/include/pcg/machine_specification.dtg.toml b/lib/pcg/include/pcg/machine_specification.dtg.toml new file mode 100644 index 0000000000..49e2011c3a --- /dev/null +++ b/lib/pcg/include/pcg/machine_specification.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "MachineSpecification" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "pcg/machine_compute_specification.dtg.h", + "pcg/machine_interconnect_specification.dtg.h", +] + +[[fields]] +name = "compute_specification" +type = "::FlexFlow::MachineComputeSpecification" + +[[fields]] +name = "interconnect_specification" +type = "::FlexFlow::MachineInterconnectSpecification" diff --git a/lib/pcg/include/pcg/machine_specification.h b/lib/pcg/include/pcg/machine_specification.h deleted file mode 100644 index 48c6e9a7a6..0000000000 --- a/lib/pcg/include/pcg/machine_specification.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_SPECIFICATION_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_SPECIFICATION_H - -#include "pcg/device_id_t.dtg.h" -#include "pcg/device_type.dtg.h" -#include "pcg/machine_space_coordinate.dtg.h" -#include "pcg/machine_specification.dtg.h" - -namespace FlexFlow { - -positive_int get_num_gpus(MachineSpecification const &ms); -positive_int get_num_cpus(MachineSpecification const &ms); -positive_int get_num_devices(MachineSpecification const &ms, - DeviceType const &device_type); -positive_int get_num_devices_per_node(MachineSpecification const &ms, - DeviceType const &device_type); - -bool is_valid_machine_space_coordinate(MachineSpecification const &ms, - MachineSpaceCoordinate const &coord); - -device_id_t get_device_id(MachineSpecification const &ms, - MachineSpaceCoordinate const &coord); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/machine_specification.struct.toml b/lib/pcg/include/pcg/machine_specification.struct.toml deleted file mode 100644 index 49e9bd9d78..0000000000 --- a/lib/pcg/include/pcg/machine_specification.struct.toml +++ /dev/null @@ -1,34 +0,0 @@ -namespace = "FlexFlow" -name = "MachineSpecification" -features = [ - "eq", - "ord", - "hash", - "json", - # "rapidcheck", - "fmt", -] - -includes = [ - "utils/positive_int/positive_int.h", -] - -[[fields]] -name = "num_nodes" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "num_cpus_per_node" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "num_gpus_per_node" -type = "::FlexFlow::positive_int" - -[[fields]] -name = "inter_node_bandwidth" -type = "float" - -[[fields]] -name = "intra_node_bandwidth" -type = "float" diff --git a/lib/pcg/include/pcg/machine_specification_dimension.dtg.toml b/lib/pcg/include/pcg/machine_specification_dimension.dtg.toml new file mode 100644 index 0000000000..a0cf568690 --- /dev/null +++ b/lib/pcg/include/pcg/machine_specification_dimension.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "MachineSpecificationDimension" +type = "enum" +features = [ + "hash", + "json", + "fmt", + "rapidcheck", +] + +[[values]] +name = "INTER_NODE" + +[[values]] +name = "INTRA_NODE" diff --git a/lib/pcg/include/pcg/machine_specification_dimension.enum.toml b/lib/pcg/include/pcg/machine_specification_dimension.enum.toml deleted file mode 100644 index 837b4306da..0000000000 --- a/lib/pcg/include/pcg/machine_specification_dimension.enum.toml +++ /dev/null @@ -1,14 +0,0 @@ -namespace = "FlexFlow" -name = "MachineSpecificationDimension" -features = [ - "hash", - "json", - "fmt", - "rapidcheck", -] - -[[values]] -name = "INTER_NODE" - -[[values]] -name = "INTRA_NODE" diff --git a/lib/pcg/include/pcg/machine_view.h b/lib/pcg/include/pcg/machine_view.h deleted file mode 100644 index 6ed9e7dd9c..0000000000 --- a/lib/pcg/include/pcg/machine_view.h +++ /dev/null @@ -1,50 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_VIEW_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_VIEW_H - -#include "machine_specification.dtg.h" -#include "machine_view.dtg.h" -#include "pcg/device_id_t.dtg.h" -#include "pcg/operator_task_space.dtg.h" -#include "pcg/task_space_coordinate.dtg.h" -#include -#include -#include - -namespace FlexFlow { - -size_t num_dims(MachineView const &mv); - -DeviceType get_device_type(MachineView const &mv); - -std::vector get_strides(MachineView const &mv); - -std::vector - get_dimensions(MachineView const &mv); - -MachineView machine_view_from_strides_and_machine_spec_dimensions( - MachineSpaceCoordinate const &start, - std::vector const &strides, - std::vector const &dims); - -std::optional - get_machine_space_coordinate(OperatorTaskSpace const &task, - MachineView const &mv, - TaskSpaceCoordinate const &coordinates, - MachineSpecification const &ms); - -std::unordered_set - get_machine_space_coordinates(OperatorTaskSpace const &task, - MachineView const &mv, - MachineSpecification const &ms); - -std::unordered_set get_device_ids(OperatorTaskSpace const &task, - MachineView const &mv, - MachineSpecification const &ms); - -MachineView make_1d_machine_view(MachineSpaceCoordinate const &start, - MachineSpecificationDimension const &dim, - stride_t stride); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/machine_view.struct.toml b/lib/pcg/include/pcg/machine_view.struct.toml deleted file mode 100644 index e4de69eafc..0000000000 --- a/lib/pcg/include/pcg/machine_view.struct.toml +++ /dev/null @@ -1,29 +0,0 @@ -namespace = "FlexFlow" -name = "MachineView" -features = [ - "eq", - "ord", - "hash", - "json", - # "rapidcheck", - "fmt", -] - -includes = [ - "pcg/machine_view_dimension.dtg.h", - "pcg/machine_space_coordinate.dtg.h" -] - -src_includes = [ - "utils/fmt/vector.h", - "utils/hash/vector.h" -] - - -[[fields]] -name = "start" -type = "::FlexFlow::MachineSpaceCoordinate" - -[[fields]] -name = "dimensions" -type = "std::vector<::FlexFlow::MachineViewDimension>" diff --git a/lib/pcg/include/pcg/machine_view_dimension.struct.toml b/lib/pcg/include/pcg/machine_view_dimension.struct.toml deleted file mode 100644 index 03b0ac51e4..0000000000 --- a/lib/pcg/include/pcg/machine_view_dimension.struct.toml +++ /dev/null @@ -1,24 +0,0 @@ -namespace = "FlexFlow" -name = "MachineViewDimension" -features = [ - "eq", - "ord", - "hash", - "json", - # "rapidcheck", - "fmt", -] - -includes = [ - "pcg/machine_specification_dimension.dtg.h", - "pcg/stride_t.dtg.h", -] - - -[[fields]] -name = "stride" -type = "::FlexFlow::stride_t" - -[[fields]] -name = "projection" -type = "::FlexFlow::MachineSpecificationDimension" diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h new file mode 100644 index 0000000000..5b1cad5e99 --- /dev/null +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h @@ -0,0 +1,48 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MAPPED_PARALLEL_COMPUTATION_GRAPH_MAPPED_OPERATOR_TASK_GROUP_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MAPPED_PARALLEL_COMPUTATION_GRAPH_MAPPED_OPERATOR_TASK_GROUP_H + +#include "op-attrs/computation_graph_op_attrs.dtg.h" +#include "pcg/machine_space_coordinate.dtg.h" +#include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h" +#include "utils/bidict/bidict.h" + +namespace FlexFlow { + +struct MappedOperatorTaskGroup { + MappedOperatorTaskGroup() = delete; + + explicit MappedOperatorTaskGroup( + bidict const + &shard_bindings); + + [[nodiscard]] bool operator==(MappedOperatorTaskGroup const &) const; + [[nodiscard]] bool operator!=(MappedOperatorTaskGroup const &) const; + + [[nodiscard]] bidict const & + get_shard_bindings() const; + +private: + bidict shard_bindings; + +private: + [[nodiscard]] std::tuple tie() const; + + friend struct ::std::hash; +}; + +std::string format_as(::FlexFlow::MappedOperatorTaskGroup const &); +std::ostream &operator<<(std::ostream &, + ::FlexFlow::MappedOperatorTaskGroup const &); + +} // namespace FlexFlow + +namespace std { + +template <> +struct hash<::FlexFlow::MappedOperatorTaskGroup> { + size_t operator()(::FlexFlow::MappedOperatorTaskGroup const &) const; +}; + +} // namespace std +#endif diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.toml b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.toml new file mode 100644 index 0000000000..8786cfe889 --- /dev/null +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "MappedParallelComputationGraph" +type = "struct" +features = [] + +includes = [ + "pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h", + "pcg/parallel_computation_graph/parallel_computation_graph.h", + "", +] + +[[fields]] +name = "pcg" +type = "::FlexFlow::ParallelComputationGraph" + +[[fields]] +name = "mapped_tasks" +type = "std::unordered_map<::FlexFlow::parallel_layer_guid_t, ::FlexFlow::MappedOperatorTaskGroup>" diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h new file mode 100644 index 0000000000..0e3db03a91 --- /dev/null +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MAPPED_PARALLEL_COMPUTATION_GRAPH_MAPPED_PARALLEL_COMPUTATION_GRAPH_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MAPPED_PARALLEL_COMPUTATION_GRAPH_MAPPED_PARALLEL_COMPUTATION_GRAPH_H + +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.h" + +namespace FlexFlow { + +std::string format_as(MappedParallelComputationGraph const &); +std::ostream &operator<<(std::ostream &, + MappedParallelComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.toml b/lib/pcg/include/pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.toml new file mode 100644 index 0000000000..c06eff0375 --- /dev/null +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "OperatorAtomicTaskShardBinding" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_space_coordinate.dtg.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/fmt/unordered_map.h", + "utils/ord/unordered_map.h", +] + +[[fields]] +name = "tensor_coords" +type = "std::unordered_map<::FlexFlow::TensorSlotName, ::FlexFlow::ParallelTensorSpaceCoordinate>" diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.h b/lib/pcg/include/pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.h new file mode 100644 index 0000000000..f856670188 --- /dev/null +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MAPPED_PARALLEL_COMPUTATION_GRAPH_OPERATOR_ATOMIC_TASK_SHARD_BINDING_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MAPPED_PARALLEL_COMPUTATION_GRAPH_OPERATOR_ATOMIC_TASK_SHARD_BINDING_H + +#include "op-attrs/tensor_role.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" +#include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h" + +namespace FlexFlow { + +ParallelTensorSpaceCoordinate + ptensor_space_coord_for_slot_name(OperatorAtomicTaskShardBinding const &, + TensorSlotName const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/metric.dtg.toml b/lib/pcg/include/pcg/metric.dtg.toml new file mode 100644 index 0000000000..1f7a367736 --- /dev/null +++ b/lib/pcg/include/pcg/metric.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "Metric" +type = "enum" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "ACCURACY" + +[[values]] +name = "CATEGORICAL_CROSSENTROPY" + +[[values]] +name = "SPARSE_CATEGORICAL_CROSSENTROPY" + +[[values]] +name = "MEAN_SQUARED_ERROR" + +[[values]] +name = "ROOT_MEAN_SQUARED_ERROR" + +[[values]] +name = "MEAN_ABSOLUTE_ERROR" diff --git a/lib/pcg/include/pcg/metric.enum.toml b/lib/pcg/include/pcg/metric.enum.toml deleted file mode 100644 index ebb2323203..0000000000 --- a/lib/pcg/include/pcg/metric.enum.toml +++ /dev/null @@ -1,26 +0,0 @@ -namespace = "FlexFlow" -name = "Metric" -features = [ - "hash", - "json", - "rapidcheck", - "fmt", -] - -[[values]] -name = "ACCURACY" - -[[values]] -name = "CATEGORICAL_CROSSENTROPY" - -[[values]] -name = "SPARSE_CATEGORICAL_CROSSENTROPY" - -[[values]] -name = "MEAN_SQUARED_ERROR" - -[[values]] -name = "ROOT_MEAN_SQUARED_ERROR" - -[[values]] -name = "MEAN_ABSOLUTE_ERROR" diff --git a/lib/pcg/include/pcg/model_compilation.h b/lib/pcg/include/pcg/model_compilation.h deleted file mode 100644 index 1ab66161ec..0000000000 --- a/lib/pcg/include/pcg/model_compilation.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_MODEL_COMPILATION_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_MODEL_COMPILATION_H - -#include "pcg/computation_graph.h" -#include "pcg/optimizer.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "pcg/tensor_mapping.h" - -namespace FlexFlow { - -struct ModelCompilationInput { - ComputationGraph computation_graph; - Optimizer optimizer; -}; -FF_VISITABLE_STRUCT(ModelCompilationInput, computation_graph, optimizer); - -struct ModelCompilationResult { - ModelCompilationInput input; - ParallelComputationGraph pcg; - req tensor_mapping; -}; -FF_VISITABLE_STRUCT(ModelCompilationResult, input, pcg, tensor_mapping); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/multi_dimensional_stride.struct.toml b/lib/pcg/include/pcg/multi_dimensional_stride.struct.toml deleted file mode 100644 index 9fa5a77f77..0000000000 --- a/lib/pcg/include/pcg/multi_dimensional_stride.struct.toml +++ /dev/null @@ -1,25 +0,0 @@ -namespace = "FlexFlow" -name = "MultiDimensionalStride" -features = [ - "eq", - "ord", - "hash", - "json", - # "rapidcheck", - "fmt", -] - -includes = [ - "", - "pcg/stride_t.dtg.h", -] - -src_includes = [ - "utils/hash/vector.h", - "utils/fmt/vector.h" - -] - -[[fields]] -name = "raw_strides" -type = "std::vector<::FlexFlow::stride_t>" diff --git a/lib/pcg/include/pcg/num_points_t.dtg.toml b/lib/pcg/include/pcg/num_points_t.dtg.toml new file mode 100644 index 0000000000..cd065178b4 --- /dev/null +++ b/lib/pcg/include/pcg/num_points_t.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "num_points_t" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "utils/positive_int/positive_int.h" +] + +[[fields]] +name = "unwrapped" +type = "::FlexFlow::positive_int" diff --git a/lib/pcg/include/pcg/num_points_t.struct.toml b/lib/pcg/include/pcg/num_points_t.struct.toml deleted file mode 100644 index b389245c63..0000000000 --- a/lib/pcg/include/pcg/num_points_t.struct.toml +++ /dev/null @@ -1,14 +0,0 @@ -namespace = "FlexFlow" -name = "num_points_t" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -[[fields]] -name = "unwrapped" -type = "int" diff --git a/lib/pcg/include/pcg/open_dataflow_graph.h b/lib/pcg/include/pcg/open_dataflow_graph.h deleted file mode 100644 index b3367686b3..0000000000 --- a/lib/pcg/include/pcg/open_dataflow_graph.h +++ /dev/null @@ -1,81 +0,0 @@ -// #ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPEN_DATAFLOW_GRAPH_H -// #define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPEN_DATAFLOW_GRAPH_H -// -// #include "utils/containers/enumerate_vector.h" -// #include "utils/graph.h" -// #include "pcg/dataflow_input.dtg.h" -// -// namespace FlexFlow { -// -// template -// struct OpenDataflowGraph { -// public: -// OpenDataflowGraph() -// : g(OutputLabelledOpenMultiDiGraph::template -// create< -// UnorderedOutputLabelledOpenMultiDiGraph>()) -// { } -// -// DataflowInput add_external_input(OutputLabel const &label) { -// /* size_t src_node_idx = edge_uid_ctr; */ -// /* edge_uid_ctr++; */ -// /* size_t src_port_idx = 0; */ -// /* edge_uid_t edge_uid = { src_node_idx, src_port_idx }; */ -// /* return MultiDiOutput{edge_uid}; */ -// } -// -// std::vector add_operator(NodeLabel const &func, -// std::vector const &inputs, std::vector const -// &outputs) { -// Node n = this->g.add_node(func); -// for (auto const &[idx, input] : enumerate_vector(inputs)) { -// this->g.add_edge(MultiDiEdge{input.src, input.src_idx, n, -// this->make_port_for_idx(idx)}); -// } -// -// std::vector result; -// for (auto const &[idx, label] : enumerate_vector(outputs)) { -// MultiDiOutput output = MultiDiOutput{n, this->make_port_for_idx(idx)}; -// this->g.add_output(output, label); -// result.push_back(output); -// } -// -// return result; -// } -// -// NodePort make_port_for_idx(int idx) { -// if (!this->port_mapping.contains_l(idx)) { -// this->port_mapping.equate(idx, this->g.add_node_port()); -// } -// return this->port_mapping.at_l(idx); -// } -// -// NodePort port_for_idx(int idx) const { -// return this->port_mapping.at_l(idx); -// } -// -// int idx_for_port(NodePort const &p) const { -// return this->port_mapping.at_r(p); -// } -// -// OutputLabelledMultiDiGraphView const -// &get_raw_graph() const { -// return this->g; -// } -// -// NodeLabel const &at(Node const &n) const { -// return this->g.at(n); -// } -// -// OutputLabel const &at(MultiDiOutput const &o) const { -// return this->g.at(o); -// } -// private: -// OutputLabelledOpenMultiDiGraph g; -// bidict port_mapping; -// size_t edge_uid_ctr = 0; -// }; -// -// } // namespace FlexFlow -// -// #endif diff --git a/lib/pcg/include/pcg/operator_space_to_machine_space_mapping.dtg.toml b/lib/pcg/include/pcg/operator_space_to_machine_space_mapping.dtg.toml new file mode 100644 index 0000000000..a97d84da12 --- /dev/null +++ b/lib/pcg/include/pcg/operator_space_to_machine_space_mapping.dtg.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "OperatorSpaceToMachineSpaceMapping" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "op-attrs/operator_task_space.dtg.h", + "op-attrs/task_space_coordinate.dtg.h", + "pcg/machine_space_coordinate.dtg.h", + "utils/bidict/bidict.h", +] + +[[fields]] +name = "raw_mapping" +type = "::FlexFlow::bidict<::FlexFlow::TaskSpaceCoordinate, ::FlexFlow::MachineSpaceCoordinate>" + +[[fields]] +name = "operator_task_space" +type = "::FlexFlow::OperatorTaskSpace" diff --git a/lib/pcg/include/pcg/operator_task_space.h b/lib/pcg/include/pcg/operator_task_space.h deleted file mode 100644 index ceb0146f15..0000000000 --- a/lib/pcg/include/pcg/operator_task_space.h +++ /dev/null @@ -1,27 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_OPERATOR_TASK_SPACE_H -#define _FLEXFLOW_PCG_INCLUDE_OPERATOR_TASK_SPACE_H - -#include "pcg/operator_task_space.dtg.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" -#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" -#include "pcg/task_space_coordinate.dtg.h" -#include -#include - -namespace FlexFlow { - -std::unordered_set - get_task_space_coordinates(OperatorTaskSpace const &task); - -TaskSpaceCoordinate - get_task_space_maximum_coordinate(OperatorTaskSpace const &task); - -nonnegative_int num_dims(OperatorTaskSpace const &task); -positive_int num_tasks(OperatorTaskSpace const &task); - -OperatorTaskSpace get_operator_task_space(ParallelComputationGraph const &pcg, - parallel_layer_guid_t const &layer); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/operator_task_space.struct.toml b/lib/pcg/include/pcg/operator_task_space.struct.toml deleted file mode 100644 index 389e12e8f2..0000000000 --- a/lib/pcg/include/pcg/operator_task_space.struct.toml +++ /dev/null @@ -1,24 +0,0 @@ -namespace = "FlexFlow" -name = "OperatorTaskSpace" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "", - "utils/positive_int/positive_int.h", -] - -src_includes = [ - "utils/fmt/vector.h", - "utils/hash/vector.h" -] - -[[fields]] -name = "degrees" -type = "std::vector<::FlexFlow::positive_int>" diff --git a/lib/pcg/include/pcg/optimizer_attrs.dtg.toml b/lib/pcg/include/pcg/optimizer_attrs.dtg.toml new file mode 100644 index 0000000000..a3b4313ae5 --- /dev/null +++ b/lib/pcg/include/pcg/optimizer_attrs.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "OptimizerAttrs" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", + "rapidcheck", +] + +includes = [ + "pcg/optimizers/sgd_optimizer_attrs.dtg.h", + "pcg/optimizers/adam_optimizer_attrs.dtg.h", +] + +[[values]] +type = "::FlexFlow::SGDOptimizerAttrs" +key = "sgd_optimizer" + +[[values]] +type = "::FlexFlow::AdamOptimizerAttrs" +key = "adam_optimizer" diff --git a/lib/pcg/include/pcg/optimizer_attrs.h b/lib/pcg/include/pcg/optimizer_attrs.h index 5d9ea8d112..b554b68284 100644 --- a/lib/pcg/include/pcg/optimizer_attrs.h +++ b/lib/pcg/include/pcg/optimizer_attrs.h @@ -2,12 +2,15 @@ #define _FLEXFLOW_PCG_OPTIMIZER_ATTRS_H #include "pcg/optimizer_attrs.dtg.h" +#include "pcg/optimizer_slot_name.dtg.h" #include "utils/nonnegative_int/nonnegative_int.h" namespace FlexFlow { OptimizerAttrs get_optimizer_attrs_for_next_iter(OptimizerAttrs const &old); -nonnegative_int get_num_optimizer_tensors(OptimizerAttrs const &); + +std::unordered_set + get_slot_names_for_optimizer(OptimizerAttrs const &); } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/optimizer_attrs.variant.toml b/lib/pcg/include/pcg/optimizer_attrs.variant.toml deleted file mode 100644 index 585c150700..0000000000 --- a/lib/pcg/include/pcg/optimizer_attrs.variant.toml +++ /dev/null @@ -1,23 +0,0 @@ -namespace = "FlexFlow" -name = "OptimizerAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "fmt", - "rapidcheck", -] - -includes = [ - "pcg/optimizers/sgd_optimizer_attrs.dtg.h", - "pcg/optimizers/adam_optimizer_attrs.dtg.h", -] - -[[values]] -type = "::FlexFlow::SGDOptimizerAttrs" -key = "sgd_optimizer" - -[[values]] -type = "::FlexFlow::AdamOptimizerAttrs" -key = "adam_optimizer" diff --git a/lib/pcg/include/pcg/optimizer_slot_name.dtg.toml b/lib/pcg/include/pcg/optimizer_slot_name.dtg.toml new file mode 100644 index 0000000000..c5f7a28548 --- /dev/null +++ b/lib/pcg/include/pcg/optimizer_slot_name.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "OptimizerSlotName" +type = "enum" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "SGD_V" + +[[values]] +name = "ADAM_M" + +[[values]] +name = "ADAM_V" diff --git a/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.dtg.toml b/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.dtg.toml new file mode 100644 index 0000000000..d3e2b9460f --- /dev/null +++ b/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.dtg.toml @@ -0,0 +1,43 @@ +namespace = "FlexFlow" +name = "AdamOptimizerAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "alpha" +type = "double" + +[[fields]] +name = "beta1" +type = "double" + +[[fields]] +name = "beta2" +type = "double" + +[[fields]] +name = "weight_decay" +type = "double" + +[[fields]] +name = "alpha_t" +type = "double" + +[[fields]] +name = "beta_t" +type = "double" + +[[fields]] +name = "beta2_t" +type = "double" + +[[fields]] +name = "epsilon" +type = "double" diff --git a/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.struct.toml b/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.struct.toml deleted file mode 100644 index c25baa6c89..0000000000 --- a/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.struct.toml +++ /dev/null @@ -1,42 +0,0 @@ -namespace = "FlexFlow" -name = "AdamOptimizerAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -[[fields]] -name = "alpha" -type = "double" - -[[fields]] -name = "beta1" -type = "double" - -[[fields]] -name = "beta2" -type = "double" - -[[fields]] -name = "weight_decay" -type = "double" - -[[fields]] -name = "alpha_t" -type = "double" - -[[fields]] -name = "beta_t" -type = "double" - -[[fields]] -name = "beta2_t" -type = "double" - -[[fields]] -name = "epsilon" -type = "double" diff --git a/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.dtg.toml b/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.dtg.toml new file mode 100644 index 0000000000..d9a1084e4f --- /dev/null +++ b/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "SGDOptimizerAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "lr" +type = "double" + +[[fields]] +name = "momentum" +type = "double" + +[[fields]] +name = "nesterov" +type = "bool" + +[[fields]] +name = "weight_decay" +type = "double" diff --git a/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.struct.toml b/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.struct.toml deleted file mode 100644 index 37affb0e1f..0000000000 --- a/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.struct.toml +++ /dev/null @@ -1,26 +0,0 @@ -namespace = "FlexFlow" -name = "SGDOptimizerAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -[[fields]] -name = "lr" -type = "double" - -[[fields]] -name = "momentum" -type = "double" - -[[fields]] -name = "nesterov" -type = "bool" - -[[fields]] -name = "weight_decay" -type = "double" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.dtg.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.dtg.toml new file mode 100644 index 0000000000..3e46248850 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "ParallelComputationGraph" +type = "struct" +features = [ ] + +includes = [ + "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph.h", + "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h", + "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::LabelledKwargDataflowGraph<::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs, ::FlexFlow::TensorSlotName>" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index 3542e73dea..25dc0721cd 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -1,6 +1,10 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_H #define _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_H +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.dtg.h" +#include "op-attrs/operator_task_space.dtg.h" +#include "op-attrs/operator_task_space_to_operator_task_space_mapping.dtg.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_added_result.dtg.h" @@ -20,17 +24,21 @@ std::unordered_set ParallelLayerAddedResult add_parallel_layer( ParallelComputationGraph &pcg, ParallelLayerAttrs const &layer_attrs, - std::vector const &inputs, - std::vector const &weights, - std::optional> const &outputs = std::nullopt); + std::unordered_map const &inputs, + std::unordered_map const &weights, + std::optional> const + &outputs = std::nullopt); ParallelLayerAddedResult pcg_add_input_layer(ParallelComputationGraph &pcg, TensorShape const &tensor_shape); +OperatorTaskSpace get_operator_task_space(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &layer); + std::unordered_set - get_pcg_edges_from_layer_to_layer(ParallelComputationGraph const &, - parallel_layer_guid_t const &, - parallel_layer_guid_t const &); + get_pcg_edges_from_layer_to_layer(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &src, + parallel_layer_guid_t const &dst); std::unordered_set get_edges(ParallelComputationGraph const &); @@ -39,27 +47,43 @@ std::unordered_set get_outgoing_edges(ParallelComputationGraph const &, parallel_layer_guid_t const &); -std::unordered_set +std::unordered_map get_incoming_edges(ParallelComputationGraph const &, parallel_layer_guid_t const &); std::unordered_set get_initial_layers(ParallelComputationGraph const &); -std::vector +std::unordered_map get_incoming_tensors(ParallelComputationGraph const &, parallel_layer_guid_t const &); -std::vector +std::unordered_map get_layer_outputs(ParallelComputationGraph const &, parallel_layer_guid_t const &); -std::vector +std::unordered_map + pcg_get_operator_to_incoming_mappings(ParallelComputationGraph const &, + parallel_layer_guid_t const &); + +std::unordered_map + pcg_get_operator_to_output_mappings(ParallelComputationGraph const &, + parallel_layer_guid_t const &); + +OperatorTaskSpaceToOperatorTaskSpaceMapping + pcg_get_mapping_along_edge(ParallelComputationGraph const &, + ParallelComputationGraphEdge const &); + +std::unordered_map get_incoming_inputs(ParallelComputationGraph const &, parallel_layer_guid_t const &); -std::vector +std::unordered_map get_incoming_weights(ParallelComputationGraph const &, parallel_layer_guid_t const &); +std::unordered_map + get_incoming_input_degrees(ParallelComputationGraph const &, + parallel_layer_guid_t const &); + std::unordered_set get_successors(ParallelComputationGraph const &, parallel_layer_guid_t const &); diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.struct.toml deleted file mode 100644 index c97333701c..0000000000 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.struct.toml +++ /dev/null @@ -1,13 +0,0 @@ -namespace = "FlexFlow" -name = "ParallelComputationGraph" -features = [ ] - -includes = [ - "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h", - "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h", - "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h", -] - -[[fields]] -name = "raw_graph" -type = "::FlexFlow::LabelledDataflowGraph<::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs>" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h index aad2770101..b0adec3ab1 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h @@ -143,10 +143,11 @@ struct ParallelComputationGraphBuilder { std::string const &name); private: - std::vector - add_layer(ParallelLayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weight_initializers); + std::unordered_map add_layer( + ParallelLayerAttrs const &layer, + std::unordered_map const &inputs, + std::unordered_map const + &weight_initializers); parallel_tensor_guid_t add_weight(ParallelTensorShape const &weight_tensor_shape, diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.toml new file mode 100644 index 0000000000..c8d8ca48bf --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "ParallelComputationGraphEdge" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge.dtg.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::KwargDataflowEdge<::FlexFlow::TensorSlotName>" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.h index 5bce560020..e40573a2e5 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.h @@ -11,7 +11,10 @@ parallel_tensor_guid_t get_parallel_tensor(ParallelComputationGraphEdge const &); parallel_layer_guid_t get_src_layer(ParallelComputationGraphEdge const &); parallel_layer_guid_t get_dst_layer(ParallelComputationGraphEdge const &); -nonnegative_int get_dst_layer_input_idx(ParallelComputationGraphEdge const &); +TensorSlotName + get_src_layer_output_slot_name(ParallelComputationGraphEdge const &); +TensorSlotName + get_dst_layer_input_slot_name(ParallelComputationGraphEdge const &); } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.struct.toml deleted file mode 100644 index 25ef3f5d27..0000000000 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "ParallelComputationGraphEdge" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/dataflow_graph/dataflow_edge.dtg.h", -] - -[[fields]] -name = "raw_edge" -type = "::FlexFlow::DataflowEdge" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_added_result.dtg.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_added_result.dtg.toml new file mode 100644 index 0000000000..455e61c783 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_added_result.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "ParallelLayerAddedResult" +type = "struct" + +features = [ + "eq", + "fmt", +] + +includes = [ + "", + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", +] + +[[fields]] +name = "parallel_layer" +type = "::FlexFlow::parallel_layer_guid_t" + +[[fields]] +name = "outputs" +type = "std::unordered_map<::FlexFlow::TensorSlotName, ::FlexFlow::parallel_tensor_guid_t>" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_added_result.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_added_result.struct.toml deleted file mode 100644 index f3113255ef..0000000000 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_added_result.struct.toml +++ /dev/null @@ -1,23 +0,0 @@ -namespace = "FlexFlow" -name = "ParallelLayerAddedResult" - -features = [ - "eq", - "ord", - "fmt", -] - -includes = [ - "", - "utils/fmt/vector.h", - "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", - "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h", -] - -[[fields]] -name = "parallel_layer" -type = "::FlexFlow::parallel_layer_guid_t" - -[[fields]] -name = "outputs" -type = "std::vector<::FlexFlow::parallel_tensor_guid_t>" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.dtg.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.dtg.toml new file mode 100644 index 0000000000..a292adf7d4 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.dtg.toml @@ -0,0 +1,31 @@ +namespace = "FlexFlow" +name = "ParallelLayerAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/pcg_operator_attrs.dtg.h", + "utils/stack_string.h", + "", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", +] + +[[fields]] +name = "op_attrs" +type = "::FlexFlow::PCGOperatorAttrs" + +[[fields]] +name = "name" +type = "std::optional<::FlexFlow::stack_string>" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml deleted file mode 100644 index 027b9f6c80..0000000000 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml +++ /dev/null @@ -1,30 +0,0 @@ -namespace = "FlexFlow" -name = "ParallelLayerAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/pcg_operator_attrs.dtg.h", - "utils/stack_string.h", - "", -] - -src_includes = [ - "utils/fmt/optional.h", - "utils/json/optional.h", - "utils/rapidcheck/optional.h", -] - -[[fields]] -name = "op_attrs" -type = "::FlexFlow::PCGOperatorAttrs" - -[[fields]] -name = "name" -type = "std::optional<::FlexFlow::stack_string>" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.toml new file mode 100644 index 0000000000..618bcb0dc4 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "parallel_layer_guid_t" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h" +] + +[[fields]] +name = "raw_graph_node" +type = "::FlexFlow::Node" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.struct.toml deleted file mode 100644 index 85436460aa..0000000000 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "parallel_layer_guid_t" -features = [ - "eq", - "ord", - "hash", - # "json", - # "rapidcheck", - "fmt", -] - -includes = [ - "utils/graph/node/node.dtg.h" -] - -[[fields]] -name = "raw_graph_node" -type = "::FlexFlow::Node" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.toml new file mode 100644 index 0000000000..bf6a1d70f1 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "ParallelTensorAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_shape.dtg.h", + "pcg/create_grad.dtg.h", +] + +[[fields]] +name = "shape" +type = "::FlexFlow::ParallelTensorShape" + +[[fields]] +name = "create_grad" +type = "::FlexFlow::CreateGrad" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml deleted file mode 100644 index 877a576f3a..0000000000 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml +++ /dev/null @@ -1,23 +0,0 @@ -namespace = "FlexFlow" -name = "ParallelTensorAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/parallel_tensor_shape.dtg.h", - "pcg/create_grad.dtg.h", -] - -[[fields]] -name = "shape" -type = "::FlexFlow::ParallelTensorShape" - -[[fields]] -name = "create_grad" -type = "::FlexFlow::CreateGrad" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.toml new file mode 100644 index 0000000000..4494a31ac2 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "parallel_tensor_guid_t" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output.dtg.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +[[fields]] +name = "raw_graph_output" +type = "::FlexFlow::KwargDataflowOutput<::FlexFlow::TensorSlotName>" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.struct.toml deleted file mode 100644 index a9e8bbc917..0000000000 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "parallel_tensor_guid_t" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/dataflow_graph/dataflow_output.dtg.h" -] - -[[fields]] -name = "raw_graph_output" -type = "::FlexFlow::DataflowOutput" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.dtg.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.dtg.toml new file mode 100644 index 0000000000..4d6a607e1d --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "parallel_tensor_use_t" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_input.dtg.h", + "op-attrs/tensor_slot_name.dtg.h" +] + +[[fields]] +name = "raw_dataflow_input" +type = "::FlexFlow::KwargDataflowInput<::FlexFlow::TensorSlotName>" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.struct.toml deleted file mode 100644 index 6d5e007650..0000000000 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "parallel_tensor_use_t" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/dataflow_graph/dataflow_input.dtg.h", -] - -[[fields]] -name = "raw_dataflow_input" -type = "::FlexFlow::DataflowInput" diff --git a/lib/pcg/include/pcg/pcg_operator_plus_signature.dtg.toml b/lib/pcg/include/pcg/pcg_operator_plus_signature.dtg.toml new file mode 100644 index 0000000000..8faaad7518 --- /dev/null +++ b/lib/pcg/include/pcg/pcg_operator_plus_signature.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "PCGOperatorPlusSignature" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", +] + +includes = [ + "op-attrs/pcg_operator_attrs.dtg.h", + "pcg/pcg_operator_tensor_shape_signature.dtg.h", + "", +] + +[[fields]] +name = "op_attrs" +type = "::FlexFlow::PCGOperatorAttrs" + +[[fields]] +name = "tensor_shape_signature" +type = "::FlexFlow::PCGOperatorTensorShapeSignature" diff --git a/lib/pcg/include/pcg/pcg_operator_plus_signature.struct.toml b/lib/pcg/include/pcg/pcg_operator_plus_signature.struct.toml deleted file mode 100644 index e827dae891..0000000000 --- a/lib/pcg/include/pcg/pcg_operator_plus_signature.struct.toml +++ /dev/null @@ -1,23 +0,0 @@ -namespace = "FlexFlow" -name = "PCGOperatorPlusSignature" -features = [ - "eq", - "ord", - "hash", - "fmt", - "json", -] - -includes = [ - "op-attrs/pcg_operator_attrs.dtg.h", - "pcg/pcg_operator_tensor_shape_signature.dtg.h", - "", -] - -[[fields]] -name = "op_attrs" -type = "::FlexFlow::PCGOperatorAttrs" - -[[fields]] -name = "tensor_shape_signature" -type = "::FlexFlow::PCGOperatorTensorShapeSignature" diff --git a/lib/pcg/include/pcg/pcg_operator_tensor_shape_signature.dtg.toml b/lib/pcg/include/pcg/pcg_operator_tensor_shape_signature.dtg.toml new file mode 100644 index 0000000000..537f1ae480 --- /dev/null +++ b/lib/pcg/include/pcg/pcg_operator_tensor_shape_signature.dtg.toml @@ -0,0 +1,32 @@ +namespace = "FlexFlow" +name = "PCGOperatorTensorShapeSignature" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", +] + +includes = [ + "op-attrs/parallel_tensor_shape.dtg.h", + "", +] + +src_includes = [ + "utils/hash/vector.h", + "utils/fmt/vector.h", +] + +[[fields]] +name = "input_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "weight_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "output_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" diff --git a/lib/pcg/include/pcg/pcg_operator_tensor_shape_signature.struct.toml b/lib/pcg/include/pcg/pcg_operator_tensor_shape_signature.struct.toml deleted file mode 100644 index 3e99bdde64..0000000000 --- a/lib/pcg/include/pcg/pcg_operator_tensor_shape_signature.struct.toml +++ /dev/null @@ -1,31 +0,0 @@ -namespace = "FlexFlow" -name = "PCGOperatorTensorShapeSignature" -features = [ - "eq", - "ord", - "hash", - "fmt", - "json", -] - -includes = [ - "op-attrs/parallel_tensor_shape.dtg.h", - "", -] - -src_includes = [ - "utils/hash/vector.h", - "utils/fmt/vector.h", -] - -[[fields]] -name = "input_shapes" -type = "std::vector<::FlexFlow::ParallelTensorShape>" - -[[fields]] -name = "weight_shapes" -type = "std::vector<::FlexFlow::ParallelTensorShape>" - -[[fields]] -name = "output_shapes" -type = "std::vector<::FlexFlow::ParallelTensorShape>" diff --git a/lib/pcg/include/pcg/start_invariant_machine_view.h b/lib/pcg/include/pcg/start_invariant_machine_view.h deleted file mode 100644 index cdf17213f9..0000000000 --- a/lib/pcg/include/pcg/start_invariant_machine_view.h +++ /dev/null @@ -1,47 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_START_INVARIANT_MACHINE_VIEW_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_START_INVARIANT_MACHINE_VIEW_H - -#include "pcg/machine_space_offset.h" -#include "pcg/machine_specification.dtg.h" -#include "pcg/machine_view.dtg.h" -#include "pcg/operator_task_space.dtg.h" -#include "pcg/start_invariant_machine_view.dtg.h" -#include "pcg/task_space_coordinate.dtg.h" -#include - -namespace FlexFlow { - -MachineView - machine_view_from_start_invariant(StartInvariantMachineView const &mv, - MachineSpaceCoordinate const &start); -StartInvariantMachineView - start_invariant_from_machine_view(MachineView const &mv); - -nonnegative_int num_dims(StartInvariantMachineView const &mv); - -DeviceType get_device_type(StartInvariantMachineView const &mv); - -std::vector get_strides(StartInvariantMachineView const &mv); - -std::vector - get_dimensions(StartInvariantMachineView const &mv); - -StartInvariantMachineView - start_invariant_machine_view_from_strides_and_machine_spec_dimensions( - std::vector const &strides, - std::vector const &dims); - -std::optional - get_machine_space_offset(OperatorTaskSpace const &task, - StartInvariantMachineView const &mv, - TaskSpaceCoordinate const &coordinates, - MachineSpecification const &ms); - -std::unordered_set - get_machine_space_offsets(OperatorTaskSpace const &task, - StartInvariantMachineView const &mv, - MachineSpecification const &ms); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/start_invariant_machine_view.struct.toml b/lib/pcg/include/pcg/start_invariant_machine_view.struct.toml deleted file mode 100644 index a1b2b40524..0000000000 --- a/lib/pcg/include/pcg/start_invariant_machine_view.struct.toml +++ /dev/null @@ -1,29 +0,0 @@ -namespace = "FlexFlow" -name = "StartInvariantMachineView" -features = [ - "eq", - "ord", - "hash", - "json", - # "rapidcheck", - "fmt", -] - -includes = [ - "pcg/machine_view_dimension.dtg.h", - "pcg/device_type.dtg.h" -] - -src_includes = [ - "utils/fmt/vector.h", - "utils/hash/vector.h", -] - -[[fields]] -name = "dimensions" -type = "std::vector<::FlexFlow::MachineViewDimension>" - - -[[fields]] -name = "device_type" -type = "::FlexFlow::DeviceType" diff --git a/lib/pcg/include/pcg/stride_t.struct.toml b/lib/pcg/include/pcg/stride_t.struct.toml deleted file mode 100644 index 3f07ec6b01..0000000000 --- a/lib/pcg/include/pcg/stride_t.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "stride_t" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "utils/positive_int/positive_int.h", -] - -[[fields]] -name = "unwrapped" -type = "::FlexFlow::positive_int" diff --git a/lib/pcg/include/pcg/task_space_coordinate.struct.toml b/lib/pcg/include/pcg/task_space_coordinate.struct.toml deleted file mode 100644 index 1057676b8e..0000000000 --- a/lib/pcg/include/pcg/task_space_coordinate.struct.toml +++ /dev/null @@ -1,24 +0,0 @@ -namespace = "FlexFlow" -name = "TaskSpaceCoordinate" -features = [ - "eq", - "ord", - "hash", - "json", - # "rapidcheck", - "fmt", -] - -includes = [ - "", - "utils/nonnegative_int/nonnegative_int.h", -] - -src_includes = [ - "utils/hash/vector.h", - "utils/fmt/vector.h", -] - -[[fields]] -name = "raw_coord" -type = "std::vector<::FlexFlow::nonnegative_int>" diff --git a/lib/pcg/include/pcg/tensor_attrs.dtg.toml b/lib/pcg/include/pcg/tensor_attrs.dtg.toml new file mode 100644 index 0000000000..72631bf77a --- /dev/null +++ b/lib/pcg/include/pcg/tensor_attrs.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "TensorAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/tensor_shape.dtg.h", + "pcg/create_grad.dtg.h", +] + +[[fields]] +name = "shape" +type = "::FlexFlow::TensorShape" + +[[fields]] +name = "create_grad" +type = "::FlexFlow::CreateGrad" diff --git a/lib/pcg/include/pcg/tensor_attrs.struct.toml b/lib/pcg/include/pcg/tensor_attrs.struct.toml deleted file mode 100644 index dbafa69c75..0000000000 --- a/lib/pcg/include/pcg/tensor_attrs.struct.toml +++ /dev/null @@ -1,23 +0,0 @@ -namespace = "FlexFlow" -name = "TensorAttrs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/tensor_shape.dtg.h", - "pcg/create_grad.dtg.h", -] - -[[fields]] -name = "shape" -type = "::FlexFlow::TensorShape" - -[[fields]] -name = "create_grad" -type = "::FlexFlow::CreateGrad" diff --git a/lib/pcg/include/pcg/tensor_direction.dtg.toml b/lib/pcg/include/pcg/tensor_direction.dtg.toml new file mode 100644 index 0000000000..be8cc45eb2 --- /dev/null +++ b/lib/pcg/include/pcg/tensor_direction.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "TensorDirection" +type = "enum" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "INCOMING" + +[[values]] +name = "OUTPUT" diff --git a/lib/pcg/include/pcg/tensor_guid_t.dtg.toml b/lib/pcg/include/pcg/tensor_guid_t.dtg.toml new file mode 100644 index 0000000000..151f7b1f0f --- /dev/null +++ b/lib/pcg/include/pcg/tensor_guid_t.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "tensor_guid_t" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output.dtg.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +[[fields]] +name = "raw_graph_output" +type = "::FlexFlow::KwargDataflowOutput<::FlexFlow::TensorSlotName>" diff --git a/lib/pcg/include/pcg/tensor_guid_t.struct.toml b/lib/pcg/include/pcg/tensor_guid_t.struct.toml deleted file mode 100644 index 0f710c81e6..0000000000 --- a/lib/pcg/include/pcg/tensor_guid_t.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "tensor_guid_t" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/dataflow_graph/dataflow_output.dtg.h" -] - -[[fields]] -name = "raw_graph_output" -type = "::FlexFlow::DataflowOutput" diff --git a/lib/pcg/include/pcg/tensor_mapping.h b/lib/pcg/include/pcg/tensor_mapping.h deleted file mode 100644 index eff48e5e06..0000000000 --- a/lib/pcg/include/pcg/tensor_mapping.h +++ /dev/null @@ -1,32 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_TENSOR_MAPPING_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_TENSOR_MAPPING_H - -#include "parallel_tensor_guid_t.h" -#include "tensor_guid_t.h" - -namespace FlexFlow { - -struct TensorMapping - : public strong_typedef< - TensorMapping, - std::unordered_map> { -public: - TensorMapping(); - - parallel_tensor_guid_t at(tensor_guid_t) const; - void add_dependence(tensor_guid_t, parallel_tensor_guid_t); - -private: - std::unordered_map contents; -}; - -} // namespace FlexFlow - -MAKE_TYPEDEF_PRINTABLE(::FlexFlow::TensorMapping, "TensorMapping"); -MAKE_TYPEDEF_HASHABLE(::FlexFlow::TensorMapping); - -namespace FlexFlow { -static_assert(is_well_behaved_value_type::value, ""); -} - -#endif diff --git a/lib/pcg/include/pcg/tensor_role.enum.toml b/lib/pcg/include/pcg/tensor_role.enum.toml deleted file mode 100644 index 98d18b3ce4..0000000000 --- a/lib/pcg/include/pcg/tensor_role.enum.toml +++ /dev/null @@ -1,17 +0,0 @@ -namespace = "FlexFlow" -name = "TensorRole" -features = [ - "hash", - "fmt", - "rapidcheck", - "json", -] - -[[values]] -name = "INPUT" - -[[values]] -name = "WEIGHT" - -[[values]] -name = "OUTPUT" diff --git a/lib/pcg/src/pcg/cg_operator_tensor_shape_signature.cc b/lib/pcg/src/pcg/cg_operator_tensor_shape_signature.cc deleted file mode 100644 index 90ffb85c9b..0000000000 --- a/lib/pcg/src/pcg/cg_operator_tensor_shape_signature.cc +++ /dev/null @@ -1,28 +0,0 @@ -#include "pcg/cg_operator_tensor_shape_signature.h" - -namespace FlexFlow { - -std::vector - tensor_shapes_for_role(CGOperatorTensorShapeSignature const &signature, - TensorRole tensor_role) { - switch (tensor_role) { - case TensorRole::INPUT: - return signature.input_shapes; - case TensorRole::WEIGHT: - return signature.weight_shapes; - case TensorRole::OUTPUT: - return signature.output_shapes; - default: - PANIC("Unhandled tensor role", tensor_role); - }; -} - -TensorShape tensor_shape_for_role_and_index( - CGOperatorTensorShapeSignature const &signature, - TensorRole tensor_role, - nonnegative_int index) { - return tensor_shapes_for_role(signature, tensor_role) - .at(index.unwrap_nonnegative()); -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc index b8917eed35..a78f179e66 100644 --- a/lib/pcg/src/pcg/computation_graph.cc +++ b/lib/pcg/src/pcg/computation_graph.cc @@ -2,12 +2,17 @@ #include "op-attrs/computation_graph_op_attrs.h" #include "op-attrs/get_incoming_tensor_roles.h" #include "op-attrs/shape_inference.h" +#include "utils/containers/binary_merge_disjoint_maps.h" #include "utils/containers/concat_vectors.h" +#include "utils/containers/filter_values.h" #include "utils/containers/filtrans.h" #include "utils/containers/get_only.h" +#include "utils/containers/map_values.h" #include "utils/containers/repeat_element.h" #include "utils/containers/reversed.h" #include "utils/containers/transform.h" +#include "utils/containers/zip_values_strict.h" +#include "utils/containers/zip_values_strict_with.h" #include "utils/containers/zip_with_strict.h" #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h" @@ -15,9 +20,18 @@ #include "utils/graph/digraph/algorithms/get_subgraph_successors.h" #include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/find_isomorphism_between_kwarg_dataflow_graphs.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_kwarg_dataflow_outputs_for_node.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_subgraph_incoming_edges.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_subgraph_outgoing_edges.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h" #include "utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.h" #include "utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h" #include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/rewrite_labelled_kwarg_dataflow_graph_node_labels.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/view_as_labelled_open_kwarg_dataflow_graph.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h" #include "utils/graph/node/algorithms.h" #include "utils/record_formatter.h" @@ -26,8 +40,11 @@ namespace FlexFlow { ComputationGraph make_empty_computation_graph() { return ComputationGraph{ - LabelledDataflowGraph::create< - UnorderedSetLabelledOpenDataflowGraph>()}; + LabelledKwargDataflowGraph:: + create>()}; } std::unordered_set get_layers(ComputationGraph const &cg) { @@ -38,51 +55,61 @@ std::unordered_set get_layers(ComputationGraph const &cg) { LayerAddedResult add_layer( ComputationGraph &computation_graph, LayerAttrs const &layer_attrs, - std::vector const &inputs, - std::vector const &weights, - std::optional> const &maybe_output_flags) { - std::vector input_shapes = - transform(inputs, [&](tensor_guid_t const &i) { + std::unordered_map const &inputs, + std::unordered_map const &weights, + std::optional> const + &maybe_output_flags) { + + std::unordered_map input_shapes = + map_values(inputs, [&](tensor_guid_t const &i) { return get_tensor_attrs(computation_graph, i).shape; }); - std::vector provided_weight_shapes = - transform(weights, [&](tensor_guid_t const &w) { + std::unordered_map provided_weight_shapes = + map_values(weights, [&](tensor_guid_t const &w) { return get_tensor_attrs(computation_graph, w).shape; }); - std::vector expected_weight_shapes = + std::unordered_map expected_weight_shapes = get_weight_shapes(layer_attrs.op_attrs, input_shapes); - std::vector raw_inputs = transform( - inputs, [](tensor_guid_t const &t) { return t.raw_graph_output; }); - - std::vector raw_weights = transform( - weights, [](tensor_guid_t const &t) { return t.raw_graph_output; }); + std::unordered_map> + raw_inputs = map_values( + inputs, [&](tensor_guid_t const &t) { return t.raw_graph_output; }); - std::vector output_shapes = + std::unordered_map> + raw_weights = map_values( + weights, [&](tensor_guid_t const &t) { return t.raw_graph_output; }); + std::unordered_map output_shapes = get_output_shapes(layer_attrs.op_attrs, input_shapes); - std::vector output_flags = maybe_output_flags.value_or( - repeat_element(num_elements(output_shapes), CreateGrad::YES)); - - std::vector output_attrs = zip_with_strict( - output_shapes, - output_flags, - [](TensorShape const &shape, CreateGrad const &create_grad) { - return TensorAttrs{ - /*shape=*/shape, - /*create_grad=*/create_grad, - }; - }); - - NodeAddedResult added = computation_graph.raw_graph.add_node( - layer_attrs, concat_vectors(raw_inputs, raw_weights), output_attrs); + std::unordered_map output_flags = + maybe_output_flags.value_or(map_values( + output_shapes, [&](TensorShape const &) { return CreateGrad::YES; })); + + std::unordered_map output_attrs = + zip_values_strict_with(output_shapes, + output_flags, + [](TensorShape const &shape, + CreateGrad const &create_grad) -> TensorAttrs { + return TensorAttrs{ + /*shape=*/shape, + /*create_grad=*/create_grad, + }; + }); + + KwargNodeAddedResult added = + computation_graph.raw_graph.add_node( + layer_attrs, + binary_merge_disjoint_maps(raw_inputs, raw_weights), + output_attrs); return LayerAddedResult{ layer_guid_t{added.node}, - transform(added.outputs, - [](DataflowOutput const &o) { return tensor_guid_t{o}; }), + map_values(added.outputs, + [](KwargDataflowOutput const &o) { + return tensor_guid_t{o}; + }), }; } @@ -97,7 +124,10 @@ LayerAddedResult add_input_layer(ComputationGraph &cg, layer_attrs, /*inputs=*/{}, /*weights=*/{}, - /*outputs=*/std::vector{CreateGrad::NO}); + /*outputs=*/ + std::unordered_map{ + {TensorSlotName::OUTPUT, CreateGrad::NO}, + }); } LayerAddedResult add_input_layer_with_grad(ComputationGraph &cg, @@ -111,7 +141,10 @@ LayerAddedResult add_input_layer_with_grad(ComputationGraph &cg, layer_attrs, /*inputs=*/{}, /*weights=*/{}, - /*outputs=*/std::vector{CreateGrad::YES}); + /*outputs=*/ + std::unordered_map{ + {TensorSlotName::OUTPUT, CreateGrad::YES}, + }); } TensorAttrs get_tensor_attrs(ComputationGraph const &cg, @@ -138,67 +171,69 @@ std::vector layers, [&](Node const &e) -> layer_guid_t { return layer_guid_t{e}; }); } -std::vector get_outgoing_tensors(ComputationGraph const &cg, - layer_guid_t n) { - return transform(get_outputs(cg.raw_graph, n.raw_node), - [](DataflowOutput const &o) { return tensor_guid_t{o}; }); +std::unordered_map + get_outgoing_tensors(ComputationGraph const &cg, layer_guid_t n) { + return map_values( + get_outgoing_kwarg_dataflow_outputs_for_node(cg.raw_graph, n.raw_node), + [](KwargDataflowOutput const &o) { + return tensor_guid_t{o}; + }); } -std::vector get_incoming_tensors(ComputationGraph const &cg, - layer_guid_t n) { - return transform(get_input_values(cg.raw_graph, n.raw_node), - [](DataflowOutput const &o) { return tensor_guid_t{o}; }); +std::unordered_map + get_incoming_tensors(ComputationGraph const &cg, layer_guid_t n) { + return map_values( + get_incoming_kwarg_dataflow_outputs_for_node(cg.raw_graph, n.raw_node), + [](KwargDataflowOutput const &o) { + return tensor_guid_t{o}; + }); } -std::vector get_incoming_input_shapes(ComputationGraph const &cg, - layer_guid_t const &n) { - return transform(get_incoming_inputs(cg, n), [&](tensor_guid_t const &t) { +std::unordered_map + get_incoming_input_shapes(ComputationGraph const &cg, + layer_guid_t const &n) { + return map_values(get_incoming_inputs(cg, n), [&](tensor_guid_t const &t) { return get_tensor_attrs(cg, t).shape; }); } -static std::vector +static std::unordered_map get_incoming_tensors_with_role(ComputationGraph const &cg, layer_guid_t const &l, IncomingTensorRole desired_role) { ComputationGraphOpAttrs attrs = get_layer_attrs(cg, l).op_attrs; - std::vector incoming_tensors = get_incoming_tensors(cg, l); + std::unordered_map incoming_tensors = + get_incoming_tensors(cg, l); - std::vector incoming_tensor_roles = - get_incoming_tensor_roles(attrs, incoming_tensors.size()); + std::unordered_map incoming_slot_roles = + get_incoming_tensor_roles(attrs); - assert(incoming_tensors.size() == incoming_tensor_roles.size()); + ASSERT(incoming_tensors.size() == incoming_slot_roles.size()); - std::vector result = - filtrans(zip(incoming_tensors, incoming_tensor_roles), - [&](std::pair const &p) - -> std::optional { - tensor_guid_t tensor = p.first; - IncomingTensorRole role = p.second; + std::unordered_set slots_with_desired_role = + keys(filter_values(incoming_slot_roles, [&](IncomingTensorRole role) { + return role == desired_role; + })); - if (role == desired_role) { - return tensor; - } else { - return std::nullopt; - } - }); - return result; + return restrict_keys(incoming_tensors, slots_with_desired_role); } -std::vector get_incoming_inputs(ComputationGraph const &cg, - layer_guid_t const &l) { +std::unordered_map + get_incoming_inputs(ComputationGraph const &cg, layer_guid_t const &l) { return get_incoming_tensors_with_role(cg, l, IncomingTensorRole::INPUT); } -std::vector get_incoming_weights(ComputationGraph const &cg, - layer_guid_t const &l) { +std::unordered_map + get_incoming_weights(ComputationGraph const &cg, layer_guid_t const &l) { return get_incoming_tensors_with_role(cg, l, IncomingTensorRole::WEIGHT); } std::unordered_set get_all_tensors(ComputationGraph const &cg) { - return transform(get_all_dataflow_outputs(cg.raw_graph), - [](DataflowOutput const &t) { return tensor_guid_t(t); }); + return transform(get_all_kwarg_dataflow_outputs(cg.raw_graph), + [](KwargDataflowOutput const &t) { + return tensor_guid_t(t); + }); } std::unordered_map @@ -217,12 +252,14 @@ std::unordered_set get_subgraph_incoming_edges( std::unordered_set raw_subgraph_nodes = transform( subgraph_nodes, [](layer_guid_t const &l) { return l.raw_node; }); - std::unordered_set raw_incoming_edges = - get_subgraph_incoming_edges(cg.raw_graph, raw_subgraph_nodes); - - return transform(raw_incoming_edges, [](DataflowEdge const &e) { - return ComputationGraphEdge{e}; - }); + std::unordered_set> raw_incoming_edges = + get_kwarg_dataflow_subgraph_incoming_edges(cg.raw_graph, + raw_subgraph_nodes); + + return transform(raw_incoming_edges, + [](KwargDataflowEdge const &e) { + return ComputationGraphEdge{e}; + }); } std::unordered_set get_subgraph_outgoing_edges( @@ -231,12 +268,14 @@ std::unordered_set get_subgraph_outgoing_edges( std::unordered_set raw_subgraph_nodes = transform( subgraph_nodes, [](layer_guid_t const &l) { return l.raw_node; }); - std::unordered_set raw_outgoing_edges = - get_subgraph_outgoing_edges(cg.raw_graph, raw_subgraph_nodes); - - return transform(raw_outgoing_edges, [](DataflowEdge const &e) { - return ComputationGraphEdge{e}; - }); + std::unordered_set> raw_outgoing_edges = + get_kwarg_dataflow_subgraph_outgoing_edges(cg.raw_graph, + raw_subgraph_nodes); + + return transform(raw_outgoing_edges, + [](KwargDataflowEdge const &e) { + return ComputationGraphEdge{e}; + }); } std::unordered_set get_subgraph_successors( @@ -275,22 +314,30 @@ layer_guid_t get_layer_by_name(ComputationGraph const &cg, } ComputationGraph without_layer_names(ComputationGraph const &cg) { + LabelledKwargDataflowGraphView + relabelled = rewrite_labelled_kwarg_dataflow_graph_node_labels( + cg.raw_graph, + [](Node const &n, LayerAttrs const &old_attrs) -> LayerAttrs { + LayerAttrs new_attrs = old_attrs; + new_attrs.name = std::nullopt; + return new_attrs; + }); return ComputationGraph{ - LabelledDataflowGraph::create_copy_of< - UnorderedSetLabelledOpenDataflowGraph>( - rewrite_node_labels(cg.raw_graph, - [](Node const &n, LayerAttrs const &old_attrs) { - LayerAttrs new_attrs = old_attrs; - new_attrs.name = std::nullopt; - return new_attrs; - })), + LabelledKwargDataflowGraph:: + create_copy_of< + UnorderedSetLabelledOpenKwargDataflowGraph>( + relabelled), }; } bool computation_graphs_are_isomorphic(ComputationGraph const &lhs, ComputationGraph const &rhs) { - return find_isomorphism(without_layer_names(lhs).raw_graph, - without_layer_names(rhs).raw_graph) + return find_isomorphism_between_kwarg_dataflow_graphs( + without_layer_names(lhs).raw_graph, + without_layer_names(rhs).raw_graph) .has_value(); } @@ -321,9 +368,13 @@ std::string as_dot(ComputationGraph const &cg) { return oss.str(); }; - return as_dot(view_as_labelled_open_dataflow_graph(cg.raw_graph), - get_node_label, - get_input_label); + return labelled_open_kwarg_dataflow_graph_view_as_dot( + view_as_labelled_open_kwarg_dataflow_graph(cg.raw_graph), + get_node_label, + get_input_label); } void debug_print_dot(ComputationGraph const &cg) { diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index 4feefa713e..b687aa11b6 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -30,15 +30,19 @@ #include "op-attrs/shape_inference.h" #include "op-attrs/tensor_dims.h" #include "op-attrs/tensor_shape.h" +#include "op-attrs/tensor_slot_name.h" #include "pcg/computation_graph.h" #include "utils/containers/any_of.h" #include "utils/containers/concat_vectors.h" #include "utils/containers/enumerate_vector.h" #include "utils/containers/get_only.h" +#include "utils/containers/repeat_element.h" +#include "utils/containers/require_only_key.h" #include "utils/containers/transform.h" #include "utils/containers/transform_until.h" #include "utils/containers/vector_of.h" #include "utils/containers/without_nullopts.h" +#include "utils/containers/zip_values_strict_with.h" #include "utils/containers/zip_with_strict.h" #include "utils/expected.h" #include "utils/fmt/set.h" @@ -72,8 +76,18 @@ tensor_guid_t ComputationGraphBuilder::create_input( maybe_name, }; - return get_only( - this->add_layer(layer_attrs, {}, {}, std::vector{create_grad})); + return require_only_key( + this->add_layer(/*layer=*/layer_attrs, + /*inputs=*/{}, + /*weights=*/{}, + /*outputs=*/ + std::unordered_map{ + { + TensorSlotName::OUTPUT, + create_grad, + }, + }), + TensorSlotName::OUTPUT); } tensor_guid_t ComputationGraphBuilder::create_weight( @@ -88,46 +102,54 @@ tensor_guid_t ComputationGraphBuilder::create_weight( maybe_name, }; - return get_only(this->add_layer(layer_attrs, {}, {})); + return require_only_key(this->add_layer(layer_attrs, {}, {}), + TensorSlotName::OUTPUT); } -static void check_incoming_tensor_roles(LayerAttrs const &layer, - int num_inputs, - int num_weights) { - std::vector correct = - get_incoming_tensor_roles(layer.op_attrs, num_inputs + num_weights); - std::vector current = concat_vectors( - std::vector(num_inputs, IncomingTensorRole::INPUT), - std::vector(num_weights, IncomingTensorRole::WEIGHT)); - - if (correct != current) { - throw mk_runtime_error( - fmt::format("check_incoming_tensor_roles found deviation in incoming " - "tensors: expected {}, received {}", - correct, - current)); - } -} - -std::vector ComputationGraphBuilder::add_layer( +static void check_incoming_tensor_roles( LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weight_initializers, - std::optional> const &outputs) { - check_incoming_tensor_roles(layer, inputs.size(), weight_initializers.size()); - - std::vector input_shapes = transform( + std::unordered_set const &input_slots, + std::unordered_set const &weight_slots) { + std::unordered_map correct = + restrict_keys(get_incoming_tensor_roles(layer.op_attrs), + set_union(input_slots, weight_slots)); + std::unordered_map current = + binary_merge_disjoint_maps( + generate_map( + input_slots, + [](TensorSlotName) { return IncomingTensorRole::INPUT; }), + generate_map(weight_slots, [](TensorSlotName) { + return IncomingTensorRole::WEIGHT; + })); + + ASSERT(correct == current, + "check_incoming_tensor_roles found deviation in incoming tensors"); +} + +std::unordered_map + ComputationGraphBuilder::add_layer( + LayerAttrs const &layer, + std::unordered_map const &inputs, + std::unordered_map const + &weight_initializers, + std::optional> const + &outputs) { + ASSERT(are_disjoint(keys(inputs), keys(weight_initializers))); + check_incoming_tensor_roles(layer, keys(inputs), keys(weight_initializers)); + + std::unordered_map input_shapes = map_values( inputs, [&](tensor_guid_t const &t) { return this->get_shape(t); }); - std::vector weight_shapes = + std::unordered_map weight_shapes = get_weight_shapes(layer.op_attrs, input_shapes); - std::vector weights = zip_with_strict( - weight_shapes, - weight_initializers, - [&](TensorShape const &shape, InitializerAttrs const &initializer) { - return this->create_weight(shape, initializer); - }); + std::unordered_map weights = + zip_values_strict_with( + weight_shapes, + weight_initializers, + [&](TensorShape const &shape, InitializerAttrs const &initializer) { + return this->create_weight(shape, initializer); + }); LayerAddedResult added = ::FlexFlow::add_layer( this->computation_graph, layer, inputs, weights, outputs); @@ -159,18 +181,24 @@ tensor_guid_t ComputationGraphBuilder::broadcast(tensor_guid_t const &input, return input; } - if (!tensor_dims_is_broadcastable_to(input_shape.dims, target_dims)) { - throw mk_runtime_error(fmt::format( - "Cannot broadcast input tensor of dims {} to target dims {}", - input_shape.dims, - target_dims)); - } + ASSERT(tensor_dims_is_broadcastable_to(input_shape.dims, target_dims), + "Cannot broadcast input tensor to target dims", + input_shape.dims, + target_dims); BroadcastAttrs attrs = BroadcastAttrs{target_dims}; LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - return get_only(this->add_layer(layer, {input}, {})); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::INPUT, + input, + }, + }, + {}), + TensorSlotName::OUTPUT); } tensor_guid_t ComputationGraphBuilder::cast( @@ -185,7 +213,15 @@ tensor_guid_t ComputationGraphBuilder::cast( LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - return get_only(this->add_layer(layer, {input}, {})); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::INPUT, + input, + }, + }, + {}), + TensorSlotName::OUTPUT); } tensor_guid_t ComputationGraphBuilder::element_unary( @@ -204,7 +240,15 @@ tensor_guid_t ComputationGraphBuilder::element_unary( LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - return get_only(this->add_layer(layer, {input}, {})); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::INPUT, + input, + }, + }, + {}), + TensorSlotName::OUTPUT); } tensor_guid_t ComputationGraphBuilder::element_binary( @@ -235,7 +279,19 @@ tensor_guid_t ComputationGraphBuilder::element_binary( LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - return get_only(this->add_layer(layer, {lhs_input, rhs_input}, {})); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::LHS_INPUT, + lhs_input, + }, + { + TensorSlotName::RHS_INPUT, + rhs_input, + }, + }, + {}), + TensorSlotName::OUTPUT); } tensor_guid_t @@ -414,13 +470,21 @@ tensor_guid_t ComputationGraphBuilder::conv2d( LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - std::vector initializers = + std::unordered_map initializers = get_initializers(attrs, this->get_shape(input), maybe_kernel_initializer, maybe_bias_initializer); - return get_only(this->add_layer(layer, {input}, initializers)); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::INPUT, + input, + }, + }, + initializers), + TensorSlotName::OUTPUT); } tensor_guid_t ComputationGraphBuilder::dropout( @@ -436,7 +500,15 @@ tensor_guid_t ComputationGraphBuilder::dropout( tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); - return get_only(this->add_layer(layer, {input}, {})); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::INPUT, + input, + }, + }, + {}), + TensorSlotName::OUTPUT); } tensor_guid_t ComputationGraphBuilder::embedding( @@ -460,10 +532,18 @@ tensor_guid_t ComputationGraphBuilder::embedding( TensorShape input_shape = this->get_shape(input); - std::vector initializers = + std::unordered_map initializers = get_initializers(attrs, initializer); - return get_only(this->add_layer(layer, {input}, initializers)); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::INPUT, + input, + }, + }, + initializers), + TensorSlotName::OUTPUT); } tensor_guid_t ComputationGraphBuilder::gather( @@ -473,12 +553,11 @@ tensor_guid_t ComputationGraphBuilder::gather( std::optional const &maybe_name) { if (this->get_shape(index).data_type != DataType::INT32 && this->get_shape(index).data_type != DataType::INT64) { - throw mk_runtime_error( - fmt::format("Invalid data type for input tensor 2 for Gather: " - "{} (should be {} or {})", - this->get_shape(input).data_type, - DataType::INT32, - DataType::INT64)); + PANIC(fmt::format("Invalid data type for input tensor 2 for Gather: " + "{} (should be {} or {})", + this->get_shape(input).data_type, + DataType::INT32, + DataType::INT64)); } GatherAttrs attrs = GatherAttrs{ff_dim_t_from_relative_ff_dim_t( @@ -488,8 +567,17 @@ tensor_guid_t ComputationGraphBuilder::gather( LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - return get_only(this->add_layer(layer, {input}, {})); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::INPUT, + input, + }, + }, + {}), + TensorSlotName::OUTPUT); } + tensor_guid_t ComputationGraphBuilder::pool2d( tensor_guid_t const &x, positive_int kernelH, @@ -521,7 +609,15 @@ tensor_guid_t ComputationGraphBuilder::pool2d( LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - return get_only(this->add_layer(layer, {input}, {})); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::INPUT, + input, + }, + }, + {}), + TensorSlotName::OUTPUT); } tensor_guid_t ComputationGraphBuilder::adaptive_pool2d( @@ -548,7 +644,15 @@ tensor_guid_t ComputationGraphBuilder::adaptive_pool2d( TensorShape output_shape = throw_if_unexpected( get_output_shape(attrs, this->get_shape(casted_input))); - return get_only(this->add_layer(layer, {casted_input}, {})); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::INPUT, + casted_input, + }, + }, + {}), + TensorSlotName::OUTPUT); } tensor_guid_t ComputationGraphBuilder::batch_norm( @@ -560,7 +664,7 @@ tensor_guid_t ComputationGraphBuilder::batch_norm( std::optional const &maybe_name) { if (activation.has_value() && activation.value() != Activation::RELU) { - throw mk_runtime_error(fmt::format( + PANIC(fmt::format( "batch_norm currently only supports (1) no activation function, or (2) " "relu activation function, but received {}. " "If you need support for additional activation functions, please " @@ -582,10 +686,18 @@ tensor_guid_t ComputationGraphBuilder::batch_norm( TensorShape input_shape = this->get_shape(input); - std::vector initializers = + std::unordered_map initializers = throw_if_unexpected(get_initializers(attrs)); - return get_only(this->add_layer(layer, {input}, initializers)); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::INPUT, + input, + }, + }, + initializers), + TensorSlotName::OUTPUT); } tensor_guid_t ComputationGraphBuilder::multihead_attention( @@ -603,19 +715,15 @@ tensor_guid_t ComputationGraphBuilder::multihead_attention( std::optional initializer, std::optional const &maybe_name) { - if (add_bias_kv) { - throw mk_runtime_error( - "ComputationGraphBuilder::multihead_attention received currently " - "unsupported argument add_bias_kv=true. " - "If you need this functionality, please create an issue."); - } + ASSERT(!add_bias_kv, + "ComputationGraphBuilder::multihead_attention received currently " + "unsupported argument add_bias_kv=true. " + "If you need this functionality, please create an issue."); - if (add_zero_attn) { - throw mk_runtime_error( - "ComputationGraphBuilder::multihead_attention received currently " - "unsupported argument add_zero_attn=true. " - "If you need this functionality, please create an issue."); - } + ASSERT(!add_zero_attn, + "ComputationGraphBuilder::multihead_attention received currently " + "unsupported argument add_zero_attn=true. " + "If you need this functionality, please create an issue."); MultiHeadAttentionAttrs attrs = MultiHeadAttentionAttrs{ /*embed_dim=*/embed_dim, @@ -633,14 +741,30 @@ tensor_guid_t ComputationGraphBuilder::multihead_attention( LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - std::vector initializers = + std::unordered_map initializers = throw_if_unexpected(get_initializers(attrs, this->get_shape(query), this->get_shape(key), this->get_shape(value), initializer)); - return get_only(this->add_layer(layer, {query, key, value}, initializers)); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::KEY, + query, + }, + { + TensorSlotName::QUERY, + key, + }, + { + TensorSlotName::VALUE, + value, + }, + }, + initializers), + TensorSlotName::OUTPUT); } TensorDims ComputationGraphBuilder::get_broadcast_target_dims( @@ -688,13 +812,21 @@ tensor_guid_t ComputationGraphBuilder::dense( LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - std::vector initializers = + std::unordered_map initializers = throw_if_unexpected(get_initializers(attrs, this->get_shape(input), maybe_projection_initializer, maybe_bias_initializer)); - return get_only(this->add_layer(layer, {input}, initializers)); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::INPUT, + input, + }, + }, + initializers), + TensorSlotName::OUTPUT); } tensor_guid_t ComputationGraphBuilder::concat( @@ -702,17 +834,27 @@ tensor_guid_t ComputationGraphBuilder::concat( relative_ff_dim_t axis, std::optional const &maybe_name) { + std::vector input_slot_names = + get_variadic_inputs_slot_name_sequence(); + ASSERT(inputs.size() <= input_slot_names.size()); + ff_dim_t abs_axis = ff_dim_t_from_relative_ff_dim_t( axis, get_num_dims(this->get_shape(inputs.at(0)).dims)); - ConcatAttrs attrs = ConcatAttrs{abs_axis}; + ConcatAttrs attrs = ConcatAttrs{ + /*axis=*/abs_axis, + /*num_inputs=*/int_ge_two{inputs.size()}, + }; std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - return get_only(this->add_layer(layer, inputs, {})); + return require_only_key( + this->add_layer( + layer, unordered_map_from_pairs(zip(input_slot_names, inputs)), {}), + TensorSlotName::OUTPUT); } tensor_guid_t ComputationGraphBuilder::flat( @@ -720,13 +862,14 @@ tensor_guid_t ComputationGraphBuilder::flat( relative_ff_dim_t start_dim, std::optional const &end_dim, std::optional const &maybe_name) { - nonnegative_int input_num_dims = get_num_dims(this->get_shape(input).dims); + num_tensor_dims_t input_num_dims = get_num_dims(this->get_shape(input).dims); ff_dim_t abs_start_dim = ff_dim_t_from_relative_ff_dim_t(start_dim, input_num_dims); ff_dim_t abs_end_dim = ff_dim_t_from_relative_ff_dim_t( - end_dim.value_or(relative_ff_dim_t{input_num_dims.unwrap_nonnegative()}), + end_dim.value_or( + relative_ff_dim_t{input_num_dims.int_from_num_tensor_dims()}), input_num_dims); FlatAttrs attrs = FlatAttrs{ @@ -739,7 +882,15 @@ tensor_guid_t ComputationGraphBuilder::flat( LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - return get_only(this->add_layer(layer, {input}, {})); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::INPUT, + input, + }, + }, + {}), + TensorSlotName::OUTPUT); } tensor_guid_t ComputationGraphBuilder::layer_norm( @@ -779,9 +930,18 @@ tensor_guid_t ComputationGraphBuilder::layer_norm( LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - std::vector initializers = get_initializers(attrs); + std::unordered_map initializers = + get_initializers(attrs); - return get_only(this->add_layer(layer, {input}, initializers)); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::INPUT, + input, + }, + }, + initializers), + TensorSlotName::OUTPUT); } tensor_guid_t ComputationGraphBuilder::softmax( @@ -792,7 +952,7 @@ tensor_guid_t ComputationGraphBuilder::softmax( TensorShape input_shape = this->get_shape(input); relative_ff_dim_t dim = maybe_dim.value_or(relative_ff_dim_t{ - get_num_dims(input_shape.dims).unwrap_nonnegative() - 1}); + get_num_dims(input_shape.dims).int_from_num_tensor_dims() - 1}); SoftmaxAttrs attrs = SoftmaxAttrs{ ff_dim_t_from_relative_ff_dim_t(dim, get_num_dims(input_shape.dims))}; @@ -807,7 +967,15 @@ tensor_guid_t ComputationGraphBuilder::softmax( LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - return get_only(this->add_layer(layer, {input}, {})); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::INPUT, + input, + }, + }, + {}), + TensorSlotName::OUTPUT); } } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/device_id_t.cc b/lib/pcg/src/pcg/device_id_t.cc new file mode 100644 index 0000000000..eecaf1c81d --- /dev/null +++ b/lib/pcg/src/pcg/device_id_t.cc @@ -0,0 +1,17 @@ +#include "pcg/device_id_t.h" + +namespace FlexFlow { + +device_id_t make_device_id_t_from_idx(nonnegative_int idx, + DeviceType device_type) { + switch (device_type) { + case DeviceType::GPU: + return device_id_t{gpu_id_t{idx}}; + case DeviceType::CPU: + return device_id_t{cpu_id_t{idx}}; + default: + PANIC("Unhandled device_type", device_type); + } +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.cc deleted file mode 100644 index 064e2d81d3..0000000000 --- a/lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.cc +++ /dev/null @@ -1,34 +0,0 @@ -#include "pcg/file_format/v1/graphs/v1_dataflow_graph.h" -#include "utils/bidict/algorithms/bidict_from_enumerating.h" -#include "utils/containers/enumerate.h" -#include "utils/containers/sorted.h" -#include "utils/containers/values.h" -#include "utils/graph/dataflow_graph/algorithms.h" -#include "utils/graph/node/algorithms.h" -#include "utils/integer_conversions.h" - -namespace FlexFlow { - -V1DataflowGraph to_v1(DataflowGraphView const &g) { - bidict node_enumeration_bidict = - bidict_from_enumerating(get_nodes(g)); - std::unordered_map node_enumeration = - node_enumeration_bidict.reversed().as_unordered_map(); - return to_v1(g, node_enumeration); -} - -V1DataflowGraph to_v1(DataflowGraphView const &g, - std::unordered_map const &nodes) { - std::unordered_set edges; - for (DataflowEdge const &e : get_edges(g)) { - edges.insert(V1GraphEdge{ - nodes.at(e.src.node), e.src.idx, nodes.at(e.dst.node), e.dst.idx}); - } - - return V1DataflowGraph{ - sorted(values(nodes)), - edges, - }; -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.cc new file mode 100644 index 0000000000..9e4a46b87a --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.cc @@ -0,0 +1,15 @@ +#include "pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using SlotName = ordered_value_type<0>; + +template V1KwargDataflowGraph + to_v1(KwargDataflowGraphView const &); + +template V1KwargDataflowGraph + to_v1(KwargDataflowGraphView const &, + std::unordered_map const &); + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.cc deleted file mode 100644 index ac819db342..0000000000 --- a/lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.cc +++ /dev/null @@ -1,17 +0,0 @@ -#include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h" -#include "utils/archetypes/value_type.h" - -namespace FlexFlow { - -using NodeLabel = value_type<0>; -using OutputLabel = value_type<1>; - -template std::pair, - bidict> - to_v1_including_node_numbering( - LabelledDataflowGraphView const &); - -template V1LabelledDataflowGraph - to_v1(LabelledDataflowGraphView const &); - -} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.cc new file mode 100644 index 0000000000..4e7b9b651f --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.cc @@ -0,0 +1,21 @@ +#include "pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using NodeLabel = value_type<0>; +using OutputLabel = value_type<1>; +using SlotName = ordered_value_type<2>; + +template std::pair< + V1LabelledKwargDataflowGraph, + bidict> + to_v1_including_node_numbering( + LabelledKwargDataflowGraphView const + &); + +template V1LabelledKwargDataflowGraph to_v1( + LabelledKwargDataflowGraphView const &); + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc b/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc index 3511ccc269..852ca73a36 100644 --- a/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc +++ b/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc @@ -1,5 +1,6 @@ #include "pcg/file_format/v1/v1_computation_graph.h" -#include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h" +#include "pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.h" +#include "utils/bidict/algorithms/transform_values.h" namespace FlexFlow { @@ -11,13 +12,15 @@ V1ComputationGraph to_v1(ComputationGraph const &g) { std::pair> to_v1_including_node_numbering(ComputationGraph const &cg) { - std::pair, - bidict> - raw = - to_v1_including_node_numbering(cg.raw_graph); + std::pair< + V1LabelledKwargDataflowGraph, + bidict> + raw = to_v1_including_node_numbering(cg.raw_graph); V1ComputationGraph v1_cg = V1ComputationGraph{raw.first}; - bidict v1_node_ids = - map_values(raw.second, [](Node const &n) { return layer_guid_t{n}; }); + bidict v1_node_ids = transform_values( + raw.second, [](Node const &n) { return layer_guid_t{n}; }); return {v1_cg, v1_node_ids}; } diff --git a/lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc b/lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc index 9da58fcf6e..e14d15d66a 100644 --- a/lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc @@ -1,11 +1,12 @@ #include "pcg/file_format/v1/v1_parallel_computation_graph.h" -#include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h" +#include "pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.h" namespace FlexFlow { V1ParallelComputationGraph to_v1(ParallelComputationGraph const &g) { return V1ParallelComputationGraph{ - to_v1(g.raw_graph), + to_v1( + g.raw_graph), }; } diff --git a/lib/pcg/src/pcg/machine_compute_specification.cc b/lib/pcg/src/pcg/machine_compute_specification.cc new file mode 100644 index 0000000000..a8dfb27524 --- /dev/null +++ b/lib/pcg/src/pcg/machine_compute_specification.cc @@ -0,0 +1,56 @@ +#include "pcg/machine_compute_specification.h" +#include "pcg/device_id.h" +#include "utils/containers/transform.h" +#include + +namespace FlexFlow { + +positive_int get_num_gpus(MachineComputeSpecification const &ms) { + return ms.num_nodes * ms.num_gpus_per_node; +} + +positive_int get_num_cpus(MachineComputeSpecification const &ms) { + return ms.num_nodes * ms.num_cpus_per_node; +} + +positive_int get_num_devices(MachineComputeSpecification const &ms, + DeviceType const &device_type) { + switch (device_type) { + case DeviceType::GPU: + return get_num_gpus(ms); + case DeviceType::CPU: + return get_num_cpus(ms); + default: + PANIC("Unknown DeviceType", device_type); + } +} + +positive_int get_num_devices_per_node(MachineComputeSpecification const &ms, + DeviceType const &device_type) { + switch (device_type) { + case DeviceType::GPU: + return ms.num_gpus_per_node; + case DeviceType::CPU: + return ms.num_cpus_per_node; + default: + PANIC("Unknown DeviceType", device_type); + } +} + +bool is_valid_machine_space_coordinate(MachineComputeSpecification const &ms, + MachineSpaceCoordinate const &coord) { + return (coord.node_idx < ms.num_nodes) && + (coord.device_idx < get_num_devices_per_node(ms, coord.device_type)); +} + +device_id_t get_device_id(MachineComputeSpecification const &ms, + MachineSpaceCoordinate const &coord) { + ASSERT(is_valid_machine_space_coordinate(ms, coord)); + + nonnegative_int raw_idx = + coord.node_idx * get_num_devices_per_node(ms, coord.device_type) + + coord.device_idx; + return device_id_from_index(raw_idx, coord.device_type); +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/machine_space_offset.cc b/lib/pcg/src/pcg/machine_space_offset.cc index 4aa79b3d1b..953dc38bc6 100644 --- a/lib/pcg/src/pcg/machine_space_offset.cc +++ b/lib/pcg/src/pcg/machine_space_offset.cc @@ -2,26 +2,26 @@ #include "utils/exception.h" namespace FlexFlow { + MachineSpaceOffset get_machine_space_offset_from_coordinate( MachineSpaceCoordinate const &start, MachineSpaceCoordinate const &coord) { - if ((coord.device_idx < start.device_idx) || - (coord.node_idx < start.node_idx)) { - throw mk_runtime_error(fmt::format( - "One of the coordinates of start {} is greater than one of the " - "coordinates of coord {}, are you sure you didn't swap them?", - start, - coord)); - } - if (start.device_type != coord.device_type) { - throw mk_runtime_error( - fmt::format("{} has different DeviceType from {}", start, coord)); - } + ASSERT(start.device_idx <= coord.device_idx, + "The start device_idx is greater than one of the coord device_idx." + "Are you sure you didn't swap them?"); + + ASSERT(start.node_idx <= coord.device_idx, + "The start node_idx is greater than one of the coord node_idx." + "Are you sure you didn't swap them?"); + + ASSERT(start.device_type == coord.device_type); - return MachineSpaceOffset{coord.node_idx.unwrap_nonnegative() - - start.node_idx.unwrap_nonnegative(), - coord.device_idx.unwrap_nonnegative() - - start.device_idx.unwrap_nonnegative(), - coord.device_type}; + return MachineSpaceOffset{ + /*node_offset=*/coord.node_idx.unwrap_nonnegative() - + start.node_idx.unwrap_nonnegative(), + /*device_offset=*/coord.device_idx.unwrap_nonnegative() - + start.device_idx.unwrap_nonnegative(), + /*device_type=*/coord.device_type, + }; } } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/machine_specification.cc b/lib/pcg/src/pcg/machine_specification.cc deleted file mode 100644 index 3db949b99d..0000000000 --- a/lib/pcg/src/pcg/machine_specification.cc +++ /dev/null @@ -1,58 +0,0 @@ -#include "pcg/machine_specification.h" -#include "pcg/device_id.h" -#include "utils/containers/transform.h" -#include "utils/exception.h" - -namespace FlexFlow { - -positive_int get_num_gpus(MachineSpecification const &ms) { - return ms.num_nodes * ms.num_gpus_per_node; -} - -positive_int get_num_cpus(MachineSpecification const &ms) { - return ms.num_nodes * ms.num_cpus_per_node; -} - -positive_int get_num_devices(MachineSpecification const &ms, - DeviceType const &device_type) { - switch (device_type) { - case DeviceType::GPU: - return get_num_gpus(ms); - case DeviceType::CPU: - return get_num_cpus(ms); - default: - throw mk_runtime_error(fmt::format("Unknown DeviceType {}", device_type)); - } -} - -positive_int get_num_devices_per_node(MachineSpecification const &ms, - DeviceType const &device_type) { - switch (device_type) { - case DeviceType::GPU: - return ms.num_gpus_per_node; - case DeviceType::CPU: - return ms.num_cpus_per_node; - default: - throw mk_runtime_error(fmt::format("Unknown DeviceType {}", device_type)); - } -} - -bool is_valid_machine_space_coordinate(MachineSpecification const &ms, - MachineSpaceCoordinate const &coord) { - return (coord.node_idx < ms.num_nodes) && - (coord.device_idx < get_num_devices_per_node(ms, coord.device_type)); -} - -device_id_t get_device_id(MachineSpecification const &ms, - MachineSpaceCoordinate const &coord) { - if (!is_valid_machine_space_coordinate(ms, coord)) { - throw mk_runtime_error(fmt::format( - "Invalid coordinate {} for machine specification {}", ms, coord)); - } - nonnegative_int raw_idx = - coord.node_idx * get_num_devices_per_node(ms, coord.device_type) + - coord.device_idx; - return device_id_from_index(raw_idx, coord.device_type); -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/machine_view.cc b/lib/pcg/src/pcg/machine_view.cc deleted file mode 100644 index 0fbb021a55..0000000000 --- a/lib/pcg/src/pcg/machine_view.cc +++ /dev/null @@ -1,178 +0,0 @@ -#include "pcg/machine_view.h" -#include "pcg/machine_space_coordinate.dtg.h" -#include "pcg/machine_specification.dtg.h" -#include "pcg/machine_specification.h" -#include "pcg/machine_specification_dimension.dtg.h" -#include "pcg/machine_view_dimension.dtg.h" -#include "pcg/operator_task_space.dtg.h" -#include "pcg/operator_task_space.h" -#include "pcg/stride_t.dtg.h" -#include "utils/containers/contains.h" -#include "utils/containers/count.h" -#include "utils/containers/filter.h" -#include "utils/containers/get_only.h" -#include "utils/containers/scanl.h" -#include "utils/containers/sum.h" -#include "utils/containers/transform.h" -#include "utils/containers/zip3_strict.h" -#include "utils/containers/zip_with_strict.h" -#include "utils/exception.h" -#include "utils/nonnegative_int/nonnegative_range.h" -#include "utils/nonnegative_int/num_elements.h" - -namespace FlexFlow { - -size_t num_dims(MachineView const &mv) { - return get_strides(mv).size(); -} - -DeviceType get_device_type(MachineView const &mv) { - return mv.start.device_type; -} - -std::vector get_strides(MachineView const &mv) { - return transform(mv.dimensions, - [](MachineViewDimension const &dim) { return dim.stride; }); -} - -std::vector - get_dimensions(MachineView const &mv) { - return transform(mv.dimensions, [](MachineViewDimension const &dim) { - return dim.projection; - }); -} - -MachineView machine_view_from_strides_and_machine_spec_dimensions( - MachineSpaceCoordinate const &start, - std::vector const &strides, - std::vector const &dims) { - if (strides.size() != dims.size()) { - throw mk_runtime_error(fmt::format( - "Length of strides ({}) and dims ({}) must match when calling " - "machine_view_from_strides_and_machine_spec_dimensions", - start, - strides)); - } - std::vector dimensions = zip_with_strict( - strides, dims, [](stride_t s, MachineSpecificationDimension d) { - return MachineViewDimension{s, d}; - }); - return MachineView{start, dimensions}; -} - -std::optional get_machine_space_coordinate( - OperatorTaskSpace const &task, - MachineView const &machine_view, - TaskSpaceCoordinate const &coord, - MachineSpecification const &machine_specification) { - - if (num_dims(machine_view) != task.degrees.size()) { - throw mk_runtime_error( - fmt::format("Dimension of machine_view ({}) must match dimension of " - "task ({}) when computing machine space coordinate", - machine_view, - task.degrees)); - } - - auto get_dimension_indices_for_dimension = - [&](MachineSpecificationDimension dimension) - -> std::vector { - std::vector mv_dimensions = - get_dimensions(machine_view); - return filter(nonnegative_range(num_elements(mv_dimensions)), - [&](nonnegative_int idx) { - return mv_dimensions.at(idx.unwrap_nonnegative()) == - dimension; - }); - }; - - auto compute_index = - [&](nonnegative_int start_idx, - std::vector const &dimension_indices) { - std::vector mv_strides = get_strides(machine_view); - - std::vector sizes = - transform(dimension_indices, [&](nonnegative_int i) { - return task.degrees.at(i.unwrap_nonnegative()) * - mv_strides.at(i.unwrap_nonnegative()).unwrapped; - }); - std::vector coord_points = - transform(dimension_indices, [&](nonnegative_int i) { - return coord.raw_coord.at(i.unwrap_nonnegative()); - }); - std::vector strides = - transform(dimension_indices, [&](nonnegative_int i) { - return mv_strides.at(i.unwrap_nonnegative()).unwrapped; - }); - - std::vector coeffs = - scanl(sizes, 1_p, std::multiplies()); - - nonnegative_int index = start_idx; - for (auto [coeff, coord_point, stride] : - zip3(coeffs, coord_points, strides)) { - index += coeff * coord_point * stride; - } - return index; - }; - - std::vector inter_dimension_indices = - get_dimension_indices_for_dimension( - MachineSpecificationDimension::INTER_NODE); - std::vector intra_dimension_indices = - get_dimension_indices_for_dimension( - MachineSpecificationDimension::INTRA_NODE); - - nonnegative_int node_idx = - compute_index(machine_view.start.node_idx, inter_dimension_indices); - nonnegative_int device_idx = - compute_index(machine_view.start.device_idx, intra_dimension_indices); - MachineSpaceCoordinate ms_coord = MachineSpaceCoordinate{ - node_idx, device_idx, get_device_type(machine_view)}; - - if (!is_valid_machine_space_coordinate(machine_specification, ms_coord)) { - return std::nullopt; - } - return ms_coord; -} - -std::unordered_set get_machine_space_coordinates( - OperatorTaskSpace const &task, - MachineView const &machine_view, - MachineSpecification const &machine_specification) { - return transform( - get_task_space_coordinates(task), [&](TaskSpaceCoordinate const &coord) { - std::optional maybe_coordinate = - get_machine_space_coordinate( - task, machine_view, coord, machine_specification); - if (!maybe_coordinate.has_value()) { - throw mk_runtime_error( - fmt::format("In get_machine_space_coordinates, the given " - "OperatorTaskSpace {} and MachineView {} are not " - "compatible with the given MachineSpecification {}", - task, - machine_view, - machine_specification)); - } - return maybe_coordinate.value(); - }); -} - -std::unordered_set get_device_ids(OperatorTaskSpace const &task, - MachineView const &mv, - MachineSpecification const &ms) { - return transform(get_machine_space_coordinates(task, mv, ms), - [&](MachineSpaceCoordinate const &coord) { - return get_device_id(ms, coord); - }); -} - -MachineView make_1d_machine_view(MachineSpaceCoordinate const &start, - MachineSpecificationDimension const &dim, - stride_t stride) { - - return machine_view_from_strides_and_machine_spec_dimensions( - start, {stride}, {dim}); -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc new file mode 100644 index 0000000000..b96a447383 --- /dev/null +++ b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc @@ -0,0 +1,92 @@ +#include "pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h" +#include "op-attrs/get_operator_task_space.h" +#include "op-attrs/operator_task_space.h" +#include "op-attrs/parallel_tensor_space_coordinate.h" +#include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.h" +#include "utils/bidict/generate_bidict.h" +#include "utils/containers/are_all_distinct.h" +#include "utils/containers/require_all_same.h" +#include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" +#include "utils/hash/tuple.h" +#include "utils/nonnegative_int/num_elements.h" + +namespace FlexFlow { + +MappedOperatorTaskGroup::MappedOperatorTaskGroup( + bidict const + &shard_bindings) + : shard_bindings(shard_bindings) { + std::vector> binding_slot_sets = + transform(vector_of(shard_bindings.right_values()), + [&](OperatorAtomicTaskShardBinding const &s) + -> std::unordered_set { + return keys(s.tensor_coords); + }); + + std::unordered_set slot_names = + require_all_same(binding_slot_sets).value(); + + for (TensorSlotName const &slot_name : slot_names) { + std::vector signatures_for_key = + vector_of(shard_bindings.right_values()); + + std::vector coords_for_key = transform( + signatures_for_key, + [&](OperatorAtomicTaskShardBinding const &signature) { + return ptensor_space_coord_for_slot_name(signature, slot_name); + }); + + ASSERT(are_all_distinct(coords_for_key)); + + std::vector coord_dims_for_key = + transform(coords_for_key, [](ParallelTensorSpaceCoordinate const &c) { + return ptensor_coord_num_dims(c); + }); + + require_all_same(coord_dims_for_key); + } +} + +bool MappedOperatorTaskGroup::operator==( + MappedOperatorTaskGroup const &other) const { + return this->tie() == other.tie(); +} + +bool MappedOperatorTaskGroup::operator!=( + MappedOperatorTaskGroup const &other) const { + return this->tie() == other.tie(); +} + +std::tuple< + bidict const &> + MappedOperatorTaskGroup::tie() const { + + return std::tie(this->shard_bindings); +} + +bidict const & + MappedOperatorTaskGroup::get_shard_bindings() const { + return this->shard_bindings; +} + +std::string format_as(::FlexFlow::MappedOperatorTaskGroup const &m) { + return fmt::format("", + m.get_shard_bindings()); +} + +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::MappedOperatorTaskGroup const &x) { + return (s << fmt::to_string(x)); +} + +} // namespace FlexFlow + +namespace std { + +size_t hash<::FlexFlow::MappedOperatorTaskGroup>::operator()( + ::FlexFlow::MappedOperatorTaskGroup const &x) const { + return ::FlexFlow::get_std_hash(x.tie()); +} + +} // namespace std diff --git a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc new file mode 100644 index 0000000000..17ac533162 --- /dev/null +++ b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc @@ -0,0 +1,17 @@ +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h" + +namespace FlexFlow { + +std::string format_as(MappedParallelComputationGraph const &mapped_pcg) { + return fmt::format( + "", + as_dot(mapped_pcg.pcg), + mapped_pcg.mapped_tasks); +} + +std::ostream &operator<<(std::ostream &s, + MappedParallelComputationGraph const &mapped_pcg) { + return (s << fmt::to_string(mapped_pcg)); +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.cc b/lib/pcg/src/pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.cc new file mode 100644 index 0000000000..e97fabc22b --- /dev/null +++ b/lib/pcg/src/pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.cc @@ -0,0 +1,15 @@ +#include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.h" +#include "op-attrs/get_operator_space_to_parallel_tensor_space_mappings.h" +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.h" +#include "utils/containers/at_idx.h" +#include + +namespace FlexFlow { + +ParallelTensorSpaceCoordinate ptensor_space_coord_for_slot_name( + OperatorAtomicTaskShardBinding const &op_task_signature, + TensorSlotName const &slot_name) { + return op_task_signature.tensor_coords.at(slot_name); +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_task_space.cc b/lib/pcg/src/pcg/operator_task_space.cc deleted file mode 100644 index d612680de6..0000000000 --- a/lib/pcg/src/pcg/operator_task_space.cc +++ /dev/null @@ -1,65 +0,0 @@ -#include "pcg/operator_task_space.h" -#include "op-attrs/parallel_tensor_shape.dtg.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "pcg/operator_task_space.dtg.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" -#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" -#include "utils/containers/cartesian_product.h" -#include "utils/containers/extend.h" -#include "utils/containers/maximum.h" -#include "utils/containers/product.h" -#include "utils/containers/range.h" -#include "utils/containers/transform.h" -#include "utils/containers/unordered_set_of.h" -#include "utils/containers/vector_of.h" -#include "utils/fmt/unordered_set.h" -#include "utils/nonnegative_int/nonnegative_range.h" -#include "utils/nonnegative_int/num_elements.h" - -namespace FlexFlow { - -std::unordered_set - get_task_space_coordinates(OperatorTaskSpace const &task) { - - std::vector> coordinate_ranges = - transform(task.degrees, [&](positive_int num_points) { - return nonnegative_range( - num_points.nonnegative_int_from_positive_int()); - }); - - std::unordered_set> raw_coordinates = - unordered_set_of(cartesian_product(coordinate_ranges)); - std::unordered_set task_space_coordinates = - transform(raw_coordinates, [](std::vector const &point) { - return TaskSpaceCoordinate{point}; - }); - return task_space_coordinates; -} - -TaskSpaceCoordinate - get_task_space_maximum_coordinate(OperatorTaskSpace const &task) { - return maximum(get_task_space_coordinates(task)); -} - -nonnegative_int num_dims(OperatorTaskSpace const &task) { - return num_elements(task.degrees); -} - -positive_int num_tasks(OperatorTaskSpace const &task) { - return product(task.degrees); -} - -OperatorTaskSpace get_operator_task_space(ParallelComputationGraph const &pcg, - parallel_layer_guid_t const &layer) { - parallel_tensor_guid_t out_tensor = get_layer_outputs(pcg, layer).at(0); - ParallelTensorShape shape = get_parallel_tensor_shape(pcg, out_tensor); - - std::vector degrees; - extend(degrees, vector_of(ff_ordered_shard_degrees(shape))); - degrees.push_back(get_sum_degree(shape)); - degrees.push_back(get_discard_copy_degree(shape)); - return OperatorTaskSpace{degrees}; -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/optimizer_attrs.cc b/lib/pcg/src/pcg/optimizer_attrs.cc index b99fcd600b..4651292f6e 100644 --- a/lib/pcg/src/pcg/optimizer_attrs.cc +++ b/lib/pcg/src/pcg/optimizer_attrs.cc @@ -23,16 +23,25 @@ OptimizerAttrs } } -nonnegative_int get_num_optimizer_tensors(OptimizerAttrs const &attrs) { - return attrs.visit( - overload{[&](SGDOptimizerAttrs const &o) { - if (o.momentum > 0.0f) { - return 1_n; - } else { - return 0_n; - } - }, - [&](AdamOptimizerAttrs const &) { return 2_n; }}); +std::unordered_set + get_slot_names_for_optimizer(OptimizerAttrs const &attrs) { + return attrs.visit>(overload{ + [](SGDOptimizerAttrs const &sgd_attrs) + -> std::unordered_set { + if (sgd_attrs.momentum > 0.0f) { + return {OptimizerSlotName::SGD_V}; + ; + } else { + return {}; + } + }, + [](AdamOptimizerAttrs const &) -> std::unordered_set { + return { + OptimizerSlotName::ADAM_M, + OptimizerSlotName::ADAM_V, + }; + }, + }); } } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph/generate_weight_transform.cc b/lib/pcg/src/pcg/parallel_computation_graph/generate_weight_transform.cc index e3caffe260..50cbea9ca0 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/generate_weight_transform.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/generate_weight_transform.cc @@ -1,6 +1,7 @@ #include "pcg/parallel_computation_graph/generate_weight_transform.h" #include "op-attrs/ff_ordered/enumerate.h" #include "op-attrs/parallel_tensor_shape.h" +#include namespace FlexFlow { @@ -10,12 +11,8 @@ std::unordered_set std::unordered_set result; positive_int sum_degree = get_sum_degree(goal); - if (sum_degree != 1) { - throw mk_runtime_error( - fmt::format("generate_weight_transform currently only supports " - "sum_degree = 1, but received {}", - sum_degree)); - } + ASSERT(sum_degree == 1, + "generate_weight_transform currently only supports sum_degree = 1"); positive_int discard_copy_degree = get_discard_copy_degree(goal); if (discard_copy_degree != 1) { diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 052d30df0f..f83628b8e1 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -1,30 +1,37 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "op-attrs/get_incoming_tensor_roles.h" +#include "op-attrs/get_operator_space_to_parallel_tensor_space_mappings.h" +#include "op-attrs/get_operator_task_space.h" +#include "op-attrs/operator_task_space_to_operator_task_space_mapping.h" +#include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/pcg_operator_attrs.h" #include "op-attrs/shape_inference.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" #include "utils/containers/concat_vectors.h" +#include "utils/containers/extend.h" +#include "utils/containers/filter_values.h" #include "utils/containers/filtrans.h" #include "utils/containers/get_only.h" #include "utils/containers/repeat_element.h" #include "utils/containers/transform.h" #include "utils/containers/unordered_set_of.h" +#include "utils/containers/zip_values_strict_with.h" #include "utils/containers/zip_with_strict.h" -#include "utils/graph/dataflow_graph/algorithms.h" -#include "utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h" -#include "utils/graph/dataflow_graph/algorithms/get_incoming_edges.h" -#include "utils/graph/dataflow_graph/algorithms/get_outgoing_edges.h" -#include "utils/graph/dataflow_graph/dataflow_edge.dtg.h" -#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_initial_nodes.h" #include "utils/graph/digraph/algorithms/get_subgraph_successors.h" #include "utils/graph/digraph/algorithms/get_successors.h" #include "utils/graph/digraph/algorithms/get_topological_ordering.h" -#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" -#include "utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.h" -#include "utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h" -#include "utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h" +#include "utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/find_isomorphism_between_kwarg_dataflow_graphs.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_kwarg_dataflow_outputs_for_node.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_edges_from_node_to_node.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_edges_for_node.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/rewrite_labelled_kwarg_dataflow_graph_node_labels.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/node/node.dtg.h" #include "utils/record_formatter.h" @@ -34,9 +41,13 @@ namespace FlexFlow { ParallelComputationGraph empty_parallel_computation_graph() { return ParallelComputationGraph{ - LabelledDataflowGraph::create< - UnorderedSetLabelledOpenDataflowGraph>()}; + LabelledKwargDataflowGraph:: + create>()}; } std::unordered_set @@ -48,58 +59,66 @@ std::unordered_set ParallelLayerAddedResult add_parallel_layer( ParallelComputationGraph &pcg, ParallelLayerAttrs const &layer_attrs, - std::vector const &inputs, - std::vector const &weights, - std::optional> const &maybe_output_flags) { - std::vector input_shapes = - transform(inputs, [&](parallel_tensor_guid_t const &i) { + std::unordered_map const &inputs, + std::unordered_map const &weights, + std::optional> const + &maybe_output_flags) { + std::unordered_map input_shapes = + map_values(inputs, [&](parallel_tensor_guid_t const &i) { return get_parallel_tensor_shape(pcg, i); }); - std::vector weight_shapes = - transform(weights, [&](parallel_tensor_guid_t const &i) { + std::unordered_map weight_shapes = + map_values(weights, [&](parallel_tensor_guid_t const &i) { return get_parallel_tensor_shape(pcg, i); }); - std::vector correct_weight_shapes = - get_weight_shapes(layer_attrs.op_attrs, input_shapes); + std::unordered_map + correct_weight_shapes = + get_weight_shapes(layer_attrs.op_attrs, input_shapes); ASSERT(weight_shapes == correct_weight_shapes, "add_parallel_layer received incorrect weight shapes"); - std::vector output_shapes = + std::unordered_map output_shapes = get_output_shapes(layer_attrs.op_attrs, input_shapes); - std::vector unwrapped_inputs = - transform(inputs, [](parallel_tensor_guid_t const &t) { - return t.raw_graph_output; - }); - - std::vector unwrapped_weights = - transform(weights, [](parallel_tensor_guid_t const &t) { - return t.raw_graph_output; - }); - - std::vector output_flags = maybe_output_flags.value_or( - repeat_element(num_elements(output_shapes), CreateGrad::YES)); - - std::vector output_attrs = zip_with_strict( - output_shapes, - output_flags, - [](ParallelTensorShape const &shape, CreateGrad const &create_grad) { - return ParallelTensorAttrs{shape, create_grad}; - }); - - NodeAddedResult op_added = pcg.raw_graph.add_node( + std::unordered_map> + unwrapped_inputs = + map_values(inputs, [](parallel_tensor_guid_t const &t) { + return t.raw_graph_output; + }); + + std::unordered_map> + unwrapped_weights = + map_values(weights, [](parallel_tensor_guid_t const &t) { + return t.raw_graph_output; + }); + + std::unordered_map output_flags = + maybe_output_flags.value_or( + generate_map(keys(output_shapes), + [](TensorSlotName const &) { return CreateGrad::YES; })); + + std::unordered_map output_attrs = + zip_values_strict_with( + output_shapes, + output_flags, + [](ParallelTensorShape const &shape, CreateGrad const &create_grad) { + return ParallelTensorAttrs{shape, create_grad}; + }); + + KwargNodeAddedResult op_added = pcg.raw_graph.add_node( layer_attrs, - concat_vectors(unwrapped_inputs, unwrapped_weights), + binary_merge_disjoint_maps(unwrapped_inputs, unwrapped_weights), output_attrs); return ParallelLayerAddedResult{ parallel_layer_guid_t{op_added.node}, - transform( - op_added.outputs, - [](DataflowOutput const &o) { return parallel_tensor_guid_t{o}; }), + map_values(op_added.outputs, + [](KwargDataflowOutput const &o) { + return parallel_tensor_guid_t{o}; + }), }; } @@ -114,24 +133,51 @@ ParallelLayerAddedResult pcg_add_input_layer(ParallelComputationGraph &pcg, /*layer_attrs=*/layer_attrs, /*inputs=*/{}, /*weights=*/{}, - /*output_flags=*/std::vector{CreateGrad::NO}); + /*output_flags=*/ + std::unordered_map{ + { + TensorSlotName::OUTPUT, + CreateGrad::NO, + }, + }); +} + +OperatorTaskSpace get_operator_task_space(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &layer) { + PCGOperatorAttrs op_attrs = pcg_get_op_attrs(pcg, layer); + + ASSERT(!is_parallel_op(op_attrs)); + + std::unordered_map inputs = + get_incoming_inputs(pcg, layer); + + std::unordered_map input_degrees = + map_values(get_incoming_inputs(pcg, layer), + [&](parallel_tensor_guid_t input_guid) { + return get_parallel_degrees( + get_parallel_tensor_shape(pcg, input_guid)); + }); + + return get_operator_task_space( + compgraph_op_attrs_from_pcg_op_attrs(op_attrs).value(), input_degrees); } std::unordered_set get_edges(ParallelComputationGraph const &pcg) { - return transform(get_edges(pcg.raw_graph), [](DataflowEdge const &e) { - return ParallelComputationGraphEdge{e}; - }); + return transform(get_all_kwarg_dataflow_edges(pcg.raw_graph), + [](KwargDataflowEdge const &e) { + return ParallelComputationGraphEdge{e}; + }); } std::unordered_set get_pcg_edges_from_layer_to_layer(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &src, parallel_layer_guid_t const &dst) { - std::unordered_set raw_edges = - get_dataflow_edges_from_node_to_node( + std::unordered_set> raw_edges = + get_kwarg_dataflow_edges_from_node_to_node( pcg.raw_graph, src.raw_graph_node, dst.raw_graph_node); - return transform(raw_edges, [](DataflowEdge const &e) { + return transform(raw_edges, [](KwargDataflowEdge const &e) { return ParallelComputationGraphEdge{e}; }); } @@ -139,19 +185,22 @@ std::unordered_set std::unordered_set get_outgoing_edges(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &l) { - std::unordered_set raw_edges = - get_outgoing_edges(pcg.raw_graph, l.raw_graph_node); - return transform(raw_edges, [](DataflowEdge const &e) { + std::unordered_set> raw_edges = + get_outgoing_kwarg_dataflow_edges_for_node(pcg.raw_graph, + l.raw_graph_node) + .right_values(); + return transform(raw_edges, [](KwargDataflowEdge const &e) { return ParallelComputationGraphEdge{e}; }); } -std::unordered_set +std::unordered_map get_incoming_edges(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &l) { - std::unordered_set raw_edges = - unordered_set_of(get_incoming_edges(pcg.raw_graph, l.raw_graph_node)); - return transform(raw_edges, [](DataflowEdge const &e) { + std::unordered_map> + raw_edges = get_incoming_kwarg_dataflow_edges_for_node(pcg.raw_graph, + l.raw_graph_node); + return map_values(raw_edges, [](KwargDataflowEdge const &e) { return ParallelComputationGraphEdge{e}; }); } @@ -163,64 +212,117 @@ std::unordered_set [](Node const &n) { return parallel_layer_guid_t{n}; }); } -std::vector +std::unordered_map get_incoming_tensors(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &l) { - return transform( - get_input_values(pcg.raw_graph, l.raw_graph_node), - [](DataflowOutput const &o) { return parallel_tensor_guid_t{o}; }); + return map_values(get_incoming_kwarg_dataflow_outputs_for_node( + pcg.raw_graph, l.raw_graph_node), + [](KwargDataflowOutput const &o) { + return parallel_tensor_guid_t{o}; + }); } -std::vector +std::unordered_map get_layer_outputs(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &l) { - return transform( - get_outputs(pcg.raw_graph, l.raw_graph_node), - [](DataflowOutput const &o) { return parallel_tensor_guid_t{o}; }); + return map_values(get_outgoing_kwarg_dataflow_outputs_for_node( + pcg.raw_graph, l.raw_graph_node), + [](KwargDataflowOutput const &o) { + return parallel_tensor_guid_t{o}; + }); +} + +std::unordered_map + pcg_get_operator_to_incoming_mappings(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l) { + ComputationGraphOpAttrs op_attrs = + compgraph_op_attrs_from_pcg_op_attrs(pcg_get_op_attrs(pcg, l)).value(); + + return get_operator_to_incoming_mappings( + /*attrs=*/op_attrs, + /*input_degrees=*/get_incoming_input_degrees(pcg, l)); +} + +std::unordered_map + pcg_get_operator_to_output_mappings(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l) { + ComputationGraphOpAttrs op_attrs = + compgraph_op_attrs_from_pcg_op_attrs(pcg_get_op_attrs(pcg, l)).value(); + + return get_operator_to_output_mappings( + /*attrs=*/op_attrs, + /*input_degrees=*/get_incoming_input_degrees(pcg, l)); +} + +OperatorTaskSpaceToOperatorTaskSpaceMapping + pcg_get_mapping_along_edge(ParallelComputationGraph const &pcg, + ParallelComputationGraphEdge const &edge) { + + parallel_layer_guid_t src_layer = get_src_layer(edge); + TensorSlotName src_slot_name = get_src_layer_output_slot_name(edge); + parallel_tensor_guid_t tensor = parallel_tensor_guid_t{edge.raw_edge.src}; + parallel_layer_guid_t dst_layer = get_dst_layer(edge); + TensorSlotName dst_slot_name = get_dst_layer_input_slot_name(edge); + + ParallelTensorShape tensor_shape = get_parallel_tensor_shape(pcg, tensor); + + OperatorTaskSpace src_task_space = get_operator_task_space(pcg, src_layer); + + OperatorTaskSpace dst_task_space = get_operator_task_space(pcg, dst_layer); + + OperatorSpaceToParallelTensorSpaceMapping src_to_tensor_mapping = + pcg_get_operator_to_output_mappings(pcg, src_layer).at(src_slot_name); + + OperatorSpaceToParallelTensorSpaceMapping dst_to_tensor_mapping = + pcg_get_operator_to_incoming_mappings(pcg, dst_layer).at(dst_slot_name); + + return op_to_op_mapping_from_composition_through_tensor( + src_to_tensor_mapping, dst_to_tensor_mapping); } -static std::vector +static std::unordered_map get_incoming_tensors_with_role(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &l, IncomingTensorRole desired_role) { PCGOperatorAttrs attrs = get_parallel_layer_attrs(pcg, l).op_attrs; - std::vector incoming_tensors = + std::unordered_map incoming_tensors = get_incoming_tensors(pcg, l); - std::vector incoming_tensor_roles = - get_incoming_tensor_roles(attrs, incoming_tensors.size()); + std::unordered_map incoming_slot_roles = + get_incoming_tensor_roles(attrs); - assert(incoming_tensors.size() == incoming_tensor_roles.size()); + ASSERT(incoming_tensors.size() == incoming_slot_roles.size()); - std::vector result = filtrans( - zip(incoming_tensors, incoming_tensor_roles), - [&](std::pair const &p) - -> std::optional { - parallel_tensor_guid_t tensor = p.first; - IncomingTensorRole role = p.second; + std::unordered_set slots_with_desired_role = + keys(filter_values(incoming_slot_roles, [&](IncomingTensorRole role) { + return role == desired_role; + })); - if (role == desired_role) { - return tensor; - } else { - return std::nullopt; - } - }); - return result; + return restrict_keys(incoming_tensors, slots_with_desired_role); } -std::vector +std::unordered_map get_incoming_inputs(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &l) { return get_incoming_tensors_with_role(pcg, l, IncomingTensorRole::INPUT); } -std::vector +std::unordered_map get_incoming_weights(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &l) { return get_incoming_tensors_with_role(pcg, l, IncomingTensorRole::WEIGHT); } +std::unordered_map + get_incoming_input_degrees(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l) { + + return map_values(get_incoming_inputs(pcg, l), [&](parallel_tensor_guid_t t) { + return get_parallel_degrees(get_parallel_tensor_shape(pcg, t)); + }); +} + std::unordered_set get_successors(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &l) { @@ -289,11 +391,15 @@ parallel_layer_guid_t ParallelComputationGraph without_layer_names(ParallelComputationGraph const &pcg) { return ParallelComputationGraph{ - LabelledDataflowGraph:: + LabelledKwargDataflowGraph:: create_copy_of< - UnorderedSetLabelledOpenDataflowGraph>( - rewrite_node_labels( + UnorderedSetLabelledOpenKwargDataflowGraph>( + rewrite_labelled_kwarg_dataflow_graph_node_labels( pcg.raw_graph, [](Node const &n, ParallelLayerAttrs const &old_attrs) { ParallelLayerAttrs new_attrs = old_attrs; @@ -305,8 +411,9 @@ ParallelComputationGraph bool pcgs_are_isomorphic(ParallelComputationGraph const &lhs, ParallelComputationGraph const &rhs) { - return find_isomorphism(without_layer_names(lhs).raw_graph, - without_layer_names(rhs).raw_graph) + return find_isomorphism_between_kwarg_dataflow_graphs( + without_layer_names(lhs).raw_graph, + without_layer_names(rhs).raw_graph) .has_value(); } @@ -337,9 +444,13 @@ std::string as_dot(ParallelComputationGraph const &cg) { return oss.str(); }; - return as_dot(view_as_labelled_open_dataflow_graph(cg.raw_graph), - get_node_label, - get_input_label); + return labelled_open_kwarg_dataflow_graph_view_as_dot( + view_as_labelled_open_kwarg_dataflow_graph(cg.raw_graph), + get_node_label, + get_input_label); } void debug_print_dot(ParallelComputationGraph const &cg) { diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index f7f3cfdcfd..d18fc17621 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -29,7 +29,10 @@ #include "utils/containers/count.h" #include "utils/containers/enumerate_vector.h" #include "utils/containers/get_only.h" +#include "utils/containers/repeat_element.h" +#include "utils/containers/require_only_key.h" #include "utils/containers/transform.h" +#include "utils/containers/zip_values_strict_with.h" #include "utils/containers/zip_with.h" namespace FlexFlow { @@ -53,10 +56,19 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::create_input_tensor( name, }; - return get_only( - add_parallel_layer( - this->pcg, layer_attrs, {}, {}, std::vector{CreateGrad::NO}) - .outputs); + return require_only_key( + add_parallel_layer(this->pcg, + layer_attrs, + {}, + {}, + std::unordered_map{ + { + TensorSlotName::OUTPUT, + CreateGrad::NO, + }, + }) + .outputs, + TensorSlotName::OUTPUT); } parallel_tensor_guid_t ParallelComputationGraphBuilder::add( @@ -90,7 +102,19 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::add( ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; - return get_only(this->add_layer(layer, {lhs, rhs}, {})); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::LHS_INPUT, + lhs, + }, + { + TensorSlotName::RHS_INPUT, + rhs, + }, + }, + {}), + TensorSlotName::OUTPUT); } parallel_tensor_guid_t ParallelComputationGraphBuilder::batch_matmul( @@ -108,7 +132,17 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::batch_matmul( ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; - return get_only(this->add_layer(layer, {a, b}, {})); + return require_only_key(this->add_layer(layer, + {{ + TensorSlotName::LHS_INPUT, + a, + }, + { + TensorSlotName::RHS_INPUT, + b, + }}, + {}), + TensorSlotName::OUTPUT); } parallel_tensor_guid_t ParallelComputationGraphBuilder::cast( @@ -123,7 +157,15 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::cast( ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; - return get_only(this->add_layer(layer, {input}, {})); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::INPUT, + input, + }, + }, + {}), + TensorSlotName::OUTPUT); } parallel_tensor_guid_t ParallelComputationGraphBuilder::conv2d( @@ -165,13 +207,21 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::conv2d( ParallelTensorShape input_shape = this->get_shape(input); - std::vector initializers = + std::unordered_map initializers = get_initializers(attrs, get_reduced_shape(input_shape), maybe_kernel_initializer, maybe_bias_initializer); - return get_only(this->add_layer(layer, {input}, initializers)); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::INPUT, + input, + }, + }, + initializers), + TensorSlotName::OUTPUT); } parallel_tensor_guid_t ParallelComputationGraphBuilder::dense( @@ -198,13 +248,21 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::dense( ParallelTensorShape input_shape = this->get_shape(input); - std::vector initializers = + std::unordered_map initializers = throw_if_unexpected(get_initializers(attrs, get_reduced_shape(input_shape), maybe_projection_initializer, maybe_bias_initializer)); - return get_only(this->add_layer(layer, {input}, initializers)); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::INPUT, + input, + }, + }, + initializers), + TensorSlotName::OUTPUT); } parallel_tensor_guid_t ParallelComputationGraphBuilder::embedding( @@ -228,10 +286,18 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::embedding( ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; - std::vector initializers = + std::unordered_map initializers = get_initializers(attrs, maybe_kernel_initializer); - return get_only(this->add_layer(layer, {input}, initializers)); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::INPUT, + input, + }, + }, + initializers), + TensorSlotName::OUTPUT); } parallel_tensor_guid_t ParallelComputationGraphBuilder::multihead_attention( @@ -270,16 +336,33 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::multihead_attention( ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; - std::vector initializers = throw_if_unexpected( - get_initializers(attrs, - get_reduced_shape(this->get_shape(query)), - get_reduced_shape(this->get_shape(key)), - get_reduced_shape(this->get_shape(value)), - maybe_weights_initializer, - maybe_input_bias_initializer, - maybe_output_bias_initializer)); - - return get_only(this->add_layer(layer, {query, key, value}, initializers)); + std::unordered_map initializers = + throw_if_unexpected( + get_initializers(attrs, + get_reduced_shape(this->get_shape(query)), + get_reduced_shape(this->get_shape(key)), + get_reduced_shape(this->get_shape(value)), + maybe_weights_initializer, + maybe_input_bias_initializer, + maybe_output_bias_initializer)); + + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::QUERY, + query, + }, + { + TensorSlotName::KEY, + key, + }, + { + TensorSlotName::VALUE, + value, + }, + }, + initializers), + TensorSlotName::OUTPUT); } parallel_tensor_guid_t ParallelComputationGraphBuilder::batch_norm( @@ -315,10 +398,18 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::batch_norm( std::vector weights; - std::vector initializers = + std::unordered_map initializers = throw_if_unexpected(get_initializers(attrs)); - return get_only(this->add_layer(layer, {input}, initializers)); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::INPUT, + input, + }, + }, + initializers), + TensorSlotName::OUTPUT); } parallel_tensor_guid_t ParallelComputationGraphBuilder::element_unary( @@ -331,7 +422,15 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::element_unary( ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; - return get_only(this->add_layer(layer, {input}, {})); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::INPUT, + input, + }, + }, + {}), + TensorSlotName::OUTPUT); } parallel_tensor_guid_t ParallelComputationGraphBuilder::relu( @@ -422,7 +521,15 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::parallel_partition( ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; - return get_only(this->add_layer(layer, {input}, {})); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::INPUT, + input, + }, + }, + {}), + TensorSlotName::OUTPUT); } parallel_tensor_guid_t ParallelComputationGraphBuilder::parallel_combine( @@ -441,7 +548,15 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::parallel_combine( ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; - return get_only(this->add_layer(layer, {input}, {})); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::INPUT, + input, + }, + }, + {}), + TensorSlotName::OUTPUT); } parallel_tensor_guid_t ParallelComputationGraphBuilder::parallel_replicate( @@ -456,7 +571,15 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::parallel_replicate( ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; - return get_only(this->add_layer(layer, {input}, {})); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::INPUT, + input, + }, + }, + {}), + TensorSlotName::OUTPUT); } parallel_tensor_guid_t ParallelComputationGraphBuilder::parallel_reduce( @@ -471,7 +594,15 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::parallel_reduce( ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; - return get_only(this->add_layer(layer, {input}, {})); + return require_only_key(this->add_layer(layer, + { + { + TensorSlotName::INPUT, + input, + }, + }, + {}), + TensorSlotName::OUTPUT); } parallel_tensor_guid_t ParallelComputationGraphBuilder::as_type( @@ -512,8 +643,9 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::add_weight( weight_name, }; - parallel_tensor_guid_t current_weight_tensor = get_only( - add_parallel_layer(this->pcg, weight_layer_attrs, {}, {}).outputs); + parallel_tensor_guid_t current_weight_tensor = require_only_key( + add_parallel_layer(this->pcg, weight_layer_attrs, {}, {}).outputs, + TensorSlotName::OUTPUT); for (ParallelOpAttrs const ¶llel_op_attr : generate_weight_transform(unpar_weight_shape, par_weight_shape)) { @@ -522,58 +654,68 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::add_weight( pcg_op_attrs_from_parallel_op_attrs(parallel_op_attr), std::nullopt, }; - current_weight_tensor = get_only( - add_parallel_layer(this->pcg, layer_attrs, {current_weight_tensor}, {}) - .outputs); + current_weight_tensor = + require_only_key(add_parallel_layer(this->pcg, + layer_attrs, + { + { + TensorSlotName::INPUT, + current_weight_tensor, + }, + }, + {}) + .outputs, + TensorSlotName::OUTPUT); } return current_weight_tensor; } -static void check_incoming_tensor_roles(ParallelLayerAttrs const &layer, - int num_inputs, - int num_weights) { - std::vector correct = - get_incoming_tensor_roles(layer.op_attrs, num_inputs + num_weights); - std::vector current = concat_vectors( - std::vector(num_inputs, IncomingTensorRole::INPUT), - std::vector(num_weights, IncomingTensorRole::WEIGHT)); - - if (correct != current) { - throw mk_runtime_error( - fmt::format("check_incoming_tensor_roles found deviation in incoming " - "tensors: expected {}, received {}", - correct, - current)); - } -} - -std::vector ParallelComputationGraphBuilder::add_layer( +static void check_incoming_tensor_roles( ParallelLayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weight_initializers) { + std::unordered_set const &input_slots, + std::unordered_set const &weight_slots) { + std::unordered_map correct = + get_incoming_tensor_roles(layer.op_attrs); + std::unordered_map current = + binary_merge_disjoint_maps( + generate_map( + input_slots, + [](TensorSlotName) { return IncomingTensorRole::INPUT; }), + generate_map(weight_slots, [](TensorSlotName) { + return IncomingTensorRole::WEIGHT; + })); + + ASSERT(correct == current, + "check_incoming_tensor_roles found deviation in incoming tensors"); +} - int num_weights_provided = - count(weight_initializers, [](std::optional const &i) { - return i.has_value(); - }); +std::unordered_map + ParallelComputationGraphBuilder::add_layer( + ParallelLayerAttrs const &layer, + std::unordered_map const + &inputs, + std::unordered_map const + &weight_initializers) { - check_incoming_tensor_roles(layer, inputs.size(), num_weights_provided); + ASSERT(are_disjoint(keys(inputs), keys(weight_initializers))); + check_incoming_tensor_roles(layer, keys(inputs), keys(weight_initializers)); - std::vector input_shapes = - transform(inputs, [&](parallel_tensor_guid_t const &i) { + std::unordered_map input_shapes = + map_values(inputs, [&](parallel_tensor_guid_t const &i) { return this->get_shape(i); }); - std::vector weight_shapes = + std::unordered_map weight_shapes = get_weight_shapes(layer.op_attrs, input_shapes); - std::vector weight_tensors = - zip_with(weight_shapes, - weight_initializers, - [&](ParallelTensorShape const &weight_shape, - InitializerAttrs const &initializer) { - return this->add_weight(weight_shape, initializer); - }); + std::unordered_map weight_tensors = + zip_values_strict_with(weight_shapes, + weight_initializers, + [&](ParallelTensorShape const &weight_shape, + InitializerAttrs const &initializer) { + return this->add_weight(weight_shape, + initializer); + }); return add_parallel_layer(this->pcg, layer, inputs, weight_tensors, {}) .outputs; diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_edge.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_edge.cc index f37d08dc8a..212f4d83ec 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_edge.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_edge.cc @@ -16,8 +16,14 @@ parallel_layer_guid_t get_dst_layer(ParallelComputationGraphEdge const &e) { return parallel_layer_guid_t{e.raw_edge.dst.node}; } -nonnegative_int get_dst_layer_input_idx(ParallelComputationGraphEdge const &e) { - return e.raw_edge.dst.idx; +TensorSlotName + get_src_layer_output_slot_name(ParallelComputationGraphEdge const &e) { + return e.raw_edge.src.slot_name; +} + +TensorSlotName + get_dst_layer_input_slot_name(ParallelComputationGraphEdge const &e) { + return e.raw_edge.dst.slot_name; } } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/pcg_from_computation_graph.cc b/lib/pcg/src/pcg/pcg_from_computation_graph.cc index 6e2b8668d1..ec93e880bb 100644 --- a/lib/pcg/src/pcg/pcg_from_computation_graph.cc +++ b/lib/pcg/src/pcg/pcg_from_computation_graph.cc @@ -7,29 +7,41 @@ #include "pcg/parallel_tensor_attrs.h" #include "pcg/tensor_attrs.dtg.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" -#include "utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h" -#include "utils/graph/labelled_dataflow_graph/algorithms/rewrite_value_labels.h" -#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h" -#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h" +#include "utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/rewrite_labelled_kwarg_dataflow_graph_node_labels.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/rewrite_labelled_kwarg_dataflow_graph_value_labels.h" #include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" namespace FlexFlow { ParallelComputationGraph pcg_from_computation_graph(ComputationGraph const &cg) { + auto layer_map = [&](Node const &_, LayerAttrs const &layer) { return parallel_layer_attrs_from_layer_attrs(layer); }; - auto tensor_map = [&](OpenDataflowValue const &_, TensorAttrs const &tensor) { + + auto tensor_map = [&](KwargDataflowOutput const &, + TensorAttrs const &tensor) { return parallel_tensor_attrs_from_tensor_attrs(tensor); }; - auto graph_view = rewrite_value_labels( - rewrite_node_labels(cg.raw_graph, layer_map), tensor_map); + + LabelledKwargDataflowGraphView + graph_view = rewrite_labelled_kwarg_dataflow_graph_value_labels( + rewrite_labelled_kwarg_dataflow_graph_node_labels(cg.raw_graph, + layer_map), + tensor_map); return ParallelComputationGraph{ - LabelledDataflowGraph:: + LabelledKwargDataflowGraph:: create_copy_of< - UnorderedSetLabelledOpenDataflowGraph>( + UnorderedSetLabelledOpenKwargDataflowGraph>( graph_view)}; } diff --git a/lib/pcg/src/pcg/start_invariant_machine_view.cc b/lib/pcg/src/pcg/start_invariant_machine_view.cc deleted file mode 100644 index e9f864d416..0000000000 --- a/lib/pcg/src/pcg/start_invariant_machine_view.cc +++ /dev/null @@ -1,87 +0,0 @@ -#include "pcg/start_invariant_machine_view.h" -#include "pcg/machine_space_offset.h" -#include "pcg/machine_view.h" -#include "pcg/operator_task_space.h" -#include "utils/containers/count.h" -#include "utils/containers/filter.h" -#include "utils/containers/scanl.h" -#include "utils/containers/transform.h" -#include "utils/containers/zip.h" -#include "utils/nonnegative_int/num_elements.h" -namespace FlexFlow { - -MachineView machine_view_from_start_invariant( - StartInvariantMachineView const &start_inv_mv, - MachineSpaceCoordinate const &start) { - return MachineView{start, start_inv_mv.dimensions}; -} - -StartInvariantMachineView - start_invariant_from_machine_view(MachineView const &mv) { - return StartInvariantMachineView{mv.dimensions, get_device_type(mv)}; -} - -nonnegative_int num_dims(StartInvariantMachineView const &start_inv_mv) { - return num_elements(start_inv_mv.dimensions); -} - -DeviceType get_device_type(StartInvariantMachineView const &start_inv_mv) { - return start_inv_mv.device_type; -} - -std::vector - get_strides(StartInvariantMachineView const &start_inv_mv) { - return transform(start_inv_mv.dimensions, - [](MachineViewDimension const &dim) { return dim.stride; }); -} - -std::vector - get_dimensions(StartInvariantMachineView const &start_inv_mv) { - return transform( - start_inv_mv.dimensions, - [](MachineViewDimension const &dim) { return dim.projection; }); -} - -StartInvariantMachineView - start_invariant_machine_view_from_strides_and_machine_spec_dimensions( - std::vector const &strides, - std::vector const &dims, - DeviceType device_type) { - std::vector dimensions = - transform(zip(strides, dims), [&](auto const &p) { - return MachineViewDimension{p.first, p.second}; - }); - return StartInvariantMachineView{dimensions, device_type}; -} - -std::optional get_machine_space_offset( - OperatorTaskSpace const &task, - StartInvariantMachineView const &start_inv_machine_view, - TaskSpaceCoordinate const &coord, - MachineSpecification const &machine_specification) { - MachineSpaceCoordinate dummy_start = - MachineSpaceCoordinate{0_n, 0_n, get_device_type(start_inv_machine_view)}; - MachineView mv = - machine_view_from_start_invariant(start_inv_machine_view, dummy_start); - std::optional ms_coord = - get_machine_space_coordinate(task, mv, coord, machine_specification); - if (ms_coord == std::nullopt) { - return std::nullopt; - } - return get_machine_space_offset_from_coordinate(dummy_start, - ms_coord.value()); -} - -std::unordered_set get_machine_space_offsets( - OperatorTaskSpace const &task, - StartInvariantMachineView const &start_inv_machine_view, - MachineSpecification const &machine_specification) { - return transform( - get_task_space_coordinates(task), [&](TaskSpaceCoordinate const &coord) { - return get_machine_space_offset( - task, start_inv_machine_view, coord, machine_specification) - .value(); - }); -} - -} // namespace FlexFlow diff --git a/lib/pcg/test/src/pcg/computation_graph.cc b/lib/pcg/test/src/pcg/computation_graph.cc index 8451545e32..721179b647 100644 --- a/lib/pcg/test/src/pcg/computation_graph.cc +++ b/lib/pcg/test/src/pcg/computation_graph.cc @@ -1,7 +1,7 @@ #include "pcg/computation_graph.h" #include "op-attrs/ops/linear.h" #include "pcg/computation_graph_builder.h" -#include "utils/containers/get_only.h" +#include "utils/containers/require_only_key.h" #include using namespace ::FlexFlow; @@ -29,8 +29,9 @@ TEST_SUITE(FF_TEST_SUITE) { layer_guid_t input_layer = get_layer_by_name(cg, input_name); - std::vector result = get_incoming_inputs(cg, input_layer); - std::vector correct = {}; + std::unordered_map result = + get_incoming_inputs(cg, input_layer); + std::unordered_map correct = {}; CHECK(result == correct); } @@ -55,8 +56,14 @@ TEST_SUITE(FF_TEST_SUITE) { layer_guid_t layer = get_layer_by_name(cg, layer_name); - std::vector result = get_incoming_inputs(cg, layer); - std::vector correct = {input}; + std::unordered_map result = + get_incoming_inputs(cg, layer); + std::unordered_map correct = { + { + TensorSlotName::INPUT, + input, + }, + }; CHECK(result == correct); } @@ -88,9 +95,13 @@ TEST_SUITE(FF_TEST_SUITE) { layer_guid_t dense_layer = get_layer_by_name(cg, layer_name); - std::vector result = get_incoming_inputs(cg, dense_layer); - std::vector correct = { - input, + std::unordered_map result = + get_incoming_inputs(cg, dense_layer); + std::unordered_map correct = { + { + TensorSlotName::INPUT, + input, + }, }; CHECK(result == correct); @@ -119,8 +130,9 @@ TEST_SUITE(FF_TEST_SUITE) { layer_guid_t input_layer = get_layer_by_name(cg, input_name); - std::vector result = get_incoming_weights(cg, input_layer); - std::vector correct = {}; + std::unordered_map result = + get_incoming_weights(cg, input_layer); + std::unordered_map correct = {}; CHECK(result == correct); } @@ -147,8 +159,9 @@ TEST_SUITE(FF_TEST_SUITE) { layer_guid_t layer = get_layer_by_name(cg, layer_name); - std::vector result = get_incoming_weights(cg, layer); - std::vector correct = {}; + std::unordered_map result = + get_incoming_weights(cg, layer); + std::unordered_map correct = {}; CHECK(result == correct); } @@ -194,28 +207,49 @@ TEST_SUITE(FF_TEST_SUITE) { }; LayerAddedResult input_added = add_input_layer(cg, input_shape); - tensor_guid_t t_input = get_only(input_added.outputs); + tensor_guid_t t_input = + require_only_key(input_added.outputs, TensorSlotName::OUTPUT); LayerAddedResult projection_weight_added = add_layer(cg, make_layer_attrs(projection_weight_attrs), {}, {}); - tensor_guid_t t_projection_weight = - get_only(projection_weight_added.outputs); + tensor_guid_t t_projection_weight = require_only_key( + projection_weight_added.outputs, TensorSlotName::OUTPUT); LayerAddedResult bias_weight_added = add_layer(cg, make_layer_attrs(bias_weight_attrs), {}, {}); - tensor_guid_t t_bias_weight = get_only(bias_weight_added.outputs); - - LayerAddedResult linear_added = - add_layer(cg, - make_layer_attrs(linear_attrs), - {t_input}, - {t_projection_weight, t_bias_weight}); - - std::vector result = + tensor_guid_t t_bias_weight = + require_only_key(bias_weight_added.outputs, TensorSlotName::OUTPUT); + + LayerAddedResult linear_added = add_layer(cg, + make_layer_attrs(linear_attrs), + { + { + TensorSlotName::INPUT, + t_input, + }, + }, + { + { + TensorSlotName::WEIGHT, + t_projection_weight, + }, + { + TensorSlotName::BIAS, + t_bias_weight, + }, + }); + + std::unordered_map result = get_incoming_weights(cg, linear_added.layer); - std::vector correct = { - t_projection_weight, - t_bias_weight, + std::unordered_map correct = { + { + TensorSlotName::WEIGHT, + t_projection_weight, + }, + { + TensorSlotName::BIAS, + t_bias_weight, + }, }; CHECK(result == correct); diff --git a/lib/pcg/test/src/pcg/computation_graph_builder.cc b/lib/pcg/test/src/pcg/computation_graph_builder.cc index f7430b3403..513cfbfe18 100644 --- a/lib/pcg/test/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/test/src/pcg/computation_graph_builder.cc @@ -1,6 +1,6 @@ #include "pcg/computation_graph_builder.h" -#include "doctest/doctest.h" #include "pcg/computation_graph.h" +#include using namespace ::FlexFlow; diff --git a/lib/pcg/test/src/pcg/machine_compute_specification.cc b/lib/pcg/test/src/pcg/machine_compute_specification.cc new file mode 100644 index 0000000000..c725da80ed --- /dev/null +++ b/lib/pcg/test/src/pcg/machine_compute_specification.cc @@ -0,0 +1,51 @@ +#include "pcg/machine_compute_specification.h" +#include "pcg/device_id.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("MachineComputeSpecification") { + MachineComputeSpecification ms = MachineComputeSpecification{ + /*num_nodes=*/4_p, + /*num_cpus_per_node=*/16_p, + /*num_gpus_per_node=*/8_p, + }; + + SUBCASE("get_num_gpus") { + CHECK(get_num_gpus(ms) == 4 * 8); + } + + SUBCASE("get_num_cpus") { + CHECK(get_num_cpus(ms) == 4 * 16); + } + + SUBCASE("get_num_devices") { + CHECK(get_num_devices(ms, DeviceType::GPU) == 4 * 8); + CHECK(get_num_devices(ms, DeviceType::CPU) == 16 * 4); + } + + SUBCASE("get_device_id") { + SUBCASE("valid MachineSpaceCoordinate") { + MachineSpaceCoordinate coord = MachineSpaceCoordinate{ + /*node_idx=*/2_n, + /*device_idx=*/12_n, + DeviceType::CPU, + }; + device_id_t correct = + device_id_from_index(nonnegative_int{2 * 16 + 12}, DeviceType::CPU); + device_id_t result = get_device_id(ms, coord); + CHECK(correct == result); + } + SUBCASE("MachineSpaceCoordinate out of bounds for given machine spec") { + MachineSpaceCoordinate coord = MachineSpaceCoordinate{ + /*node_idx=*/2_n, + /*device_idx=*/18_n, + DeviceType::CPU, + }; + CHECK_THROWS(get_device_id(ms, coord)); + } + } + } +} diff --git a/lib/pcg/test/src/pcg/machine_specification.cc b/lib/pcg/test/src/pcg/machine_specification.cc deleted file mode 100644 index 4064f36679..0000000000 --- a/lib/pcg/test/src/pcg/machine_specification.cc +++ /dev/null @@ -1,53 +0,0 @@ -#include "pcg/machine_specification.h" -#include "pcg/device_id.h" -#include - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - - TEST_CASE("MachineSpecification") { - MachineSpecification ms = MachineSpecification{ - /*num_nodes=*/4_p, - /*num_cpus_per_node=*/16_p, - /*num_gpus_per_node=*/8_p, - /*inter_node_bandwidth=*/0, - /*intra_node_bandwidth=*/0, - }; - - SUBCASE("get_num_gpus") { - CHECK(get_num_gpus(ms) == 4 * 8); - } - - SUBCASE("get_num_cpus") { - CHECK(get_num_cpus(ms) == 4 * 16); - } - - SUBCASE("get_num_devices") { - CHECK(get_num_devices(ms, DeviceType::GPU) == 4 * 8); - CHECK(get_num_devices(ms, DeviceType::CPU) == 16 * 4); - } - - SUBCASE("get_device_id") { - SUBCASE("valid MachineSpaceCoordinate") { - MachineSpaceCoordinate coord = MachineSpaceCoordinate{ - /*node_idx=*/2_n, - /*device_idx=*/12_n, - DeviceType::CPU, - }; - device_id_t correct = - device_id_from_index(nonnegative_int{2 * 16 + 12}, DeviceType::CPU); - device_id_t result = get_device_id(ms, coord); - CHECK(correct == result); - } - SUBCASE("MachineSpaceCoordinate out of bounds for given machine spec") { - MachineSpaceCoordinate coord = MachineSpaceCoordinate{ - /*node_idx=*/2_n, - /*device_idx=*/18_n, - DeviceType::CPU, - }; - CHECK_THROWS(get_device_id(ms, coord)); - } - } - } -} diff --git a/lib/pcg/test/src/pcg/machine_view.cc b/lib/pcg/test/src/pcg/machine_view.cc deleted file mode 100644 index ecc196a118..0000000000 --- a/lib/pcg/test/src/pcg/machine_view.cc +++ /dev/null @@ -1,392 +0,0 @@ -#include "pcg/machine_view.h" -#include "pcg/gpu_id_t.dtg.h" -#include "test/utils/doctest/fmt/optional.h" -#include "utils/containers/transform.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" -#include - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("MachineView - utility functions") { - MachineView mv = MachineView{ - MachineSpaceCoordinate{ - /*node_idx=*/0_n, /*device_idx=*/0_n, DeviceType::GPU}, - {MachineViewDimension{stride_t{2_p}, - MachineSpecificationDimension::INTER_NODE}, - MachineViewDimension{stride_t{2_p}, - MachineSpecificationDimension::INTER_NODE}}}; - - SUBCASE("num_dims") { - CHECK(num_dims(mv) == 2); - } - SUBCASE("get_device_type") { - CHECK(get_device_type(mv) == DeviceType::GPU); - } - } - - TEST_CASE("get_machine_space_coordinate") { - SUBCASE("1D case") { - - // This operator has shape (3,), and thus 3 tasks. - // The (only) dimension is projected on the INTER (device) dimension with - // a stride of 2. The start of the projection defined by MachineView - // starts at MachineSpaceCoordinate (0,1), and the machine space has 1 - // node and 6 devices per node. - - /** - * The tasks will thus be distributed like this: - * +-------+-------+-------+-------+-------+-------+ - * | | (0,) | | (1,) | | (2,) | - * +-------+-------+-------+-------+-------+-------+ - * Where the (x,) are the `TaskSpaceCoordinate`s, and the underlying grid - * is the machine space. - */ - OperatorTaskSpace task = OperatorTaskSpace{{3_p}}; - MachineView mv = MachineView{ - MachineSpaceCoordinate{ - /*node_idx=*/0_n, /*device_idx=*/1_n, DeviceType::GPU}, - {MachineViewDimension{stride_t{2_p}, - MachineSpecificationDimension::INTRA_NODE}}}; - MachineSpecification ms = - MachineSpecification{/*num_nodes=*/1_p, - /*num_cpus_per_node=*/6_p, - /*num_gpus_per_node=*/6_p, - /*inter_node_bandwidth=*/0, - /*intra_node_bandwidth=*/0}; - - SUBCASE("Task with TaskSpaceCoordinate = (0,)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0_n}}; - MachineSpaceCoordinate correct = MachineSpaceCoordinate{ - /*node_idx=*/0_n, /*device_idx=*/1_n, DeviceType::GPU}; - MachineSpaceCoordinate result = - get_machine_space_coordinate(task, mv, coord, ms).value(); - CHECK(correct == result); - } - - SUBCASE("Task with TaskSpaceCoordinate = (1,)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1_n}}; - MachineSpaceCoordinate correct = MachineSpaceCoordinate{ - /*node_idx=*/0_n, /*device_idx=*/3_n, DeviceType::GPU}; - MachineSpaceCoordinate result = - get_machine_space_coordinate(task, mv, coord, ms).value(); - CHECK(correct == result); - } - - SUBCASE("Task with TaskSpaceCoordinate = (2,)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{2_n}}; - MachineSpaceCoordinate correct = MachineSpaceCoordinate{ - /*node_idx=*/0_n, /*device_idx=*/5_n, DeviceType::GPU}; - MachineSpaceCoordinate result = - get_machine_space_coordinate(task, mv, coord, ms).value(); - CHECK(correct == result); - } - - SUBCASE("TaskSpaceCoordinate is out of bounds") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{4_n}}; - std::optional result = - get_machine_space_coordinate(task, mv, coord, ms); - std::optional correct = std::nullopt; - CHECK(result == correct); - } - - SUBCASE("2D case - projection on different dimensions") { - // This operator has shape (2, 2), and thus 2 * 2 = 4 tasks. - // The first dimension is projected onto the INTER (node) dimension with - // stride 1, while the second dimension is projected onto the INTRA - // (device) dimension with stride 2. The start of the projection defined - // by MachineView is at MachineSpaceCoordinates (1, 2), and the machine - // space has 3 nodes and 5 devices per node. - - /** - * The tasks will thus be distributed like this: - * +-------+-------+-------+-------+-------+ - * | | | | | | - * +-------+-------+-------+-------+-------+ - * | | | (0,0) | | (0,1) | - * +-------+-------+-------+-------+-------+ - * | | | (1,0) | | (1,1) | - * +-------+-------+-------+-------+-------+ - * Where the (x,y) are the `TaskSpaceCoordinate`s, and the underlying - * grid is the machine space. - */ - - OperatorTaskSpace task = OperatorTaskSpace{{2_p, 2_p}}; - MachineView mv = MachineView{ - MachineSpaceCoordinate{ - /*node_idx=*/1_n, /*device_idx=*/2_n, DeviceType::GPU}, - {MachineViewDimension{stride_t{1_p}, - MachineSpecificationDimension::INTER_NODE}, - MachineViewDimension{stride_t{2_p}, - MachineSpecificationDimension::INTRA_NODE}}}; - MachineSpecification ms = - MachineSpecification{/*num_nodes=*/3_p, - /*num_cpus_per_node=*/5_p, - /*num_gpus_per_node=*/5_p, - /*inter_node_bandwidth=*/0, - /*intra_node_bandwidth=*/0}; - - SUBCASE("Task with TaskSpaceCoordinate = (0,0)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0_n, 0_n}}; - MachineSpaceCoordinate correct = MachineSpaceCoordinate{ - /*node_idx=*/1_n, /*device_idx=*/2_n, DeviceType::GPU}; - MachineSpaceCoordinate result = - get_machine_space_coordinate(task, mv, coord, ms).value(); - CHECK(correct == result); - } - - SUBCASE("Task with TaskSpaceCoordinate = (0,1)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0_n, 1_n}}; - MachineSpaceCoordinate correct = MachineSpaceCoordinate{ - /*node_idx=*/1_n, /*device_idx=*/4_n, DeviceType::GPU}; - MachineSpaceCoordinate result = - get_machine_space_coordinate(task, mv, coord, ms).value(); - CHECK(correct == result); - } - - SUBCASE("Task with TaskSpaceCoordinate = (1,0)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1_n, 0_n}}; - MachineSpaceCoordinate correct = MachineSpaceCoordinate{ - /*node_idx=*/2_n, /*device_idx=*/2_n, DeviceType::GPU}; - MachineSpaceCoordinate result = - get_machine_space_coordinate(task, mv, coord, ms).value(); - CHECK(correct == result); - } - - SUBCASE("Task with TaskSpaceCoordinate = (1,1)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1_n, 1_n}}; - MachineSpaceCoordinate correct = MachineSpaceCoordinate{ - /*node_idx=*/2_n, /*device_idx=*/4_n, DeviceType::GPU}; - MachineSpaceCoordinate result = - get_machine_space_coordinate(task, mv, coord, ms).value(); - CHECK(correct == result); - } - } - - SUBCASE("2D case - projection on same dimension") { - // This operator has shape (2, 2), and thus 2 * 2 = 4 tasks. - // Both dimensions are projected on the INTRA (device) dimension, with - // strides 1 and 2 respectively. The start of the projection defined by - // MachineView is at MachineSpaceCoordinates (1, 0), and the machine - // space has 2 nodes and 6 devices per node. - - /** - * +-------+-------+-------+-------+-------+-------+ - * | (0,0) | (1,0) | | | (0,1) | (1,1) | - * +-------+-------+-------+-------+-------+-------+ - * Where the (x,y) are the `TaskSpaceCoordinate`s, and the underlying - * grid is the machine space. - */ - - OperatorTaskSpace task = OperatorTaskSpace{{2_p, 2_p}}; - MachineView mv = MachineView{ - MachineSpaceCoordinate{ - /*node_idx=*/1_n, /*device_idx=*/0_n, DeviceType::GPU}, - {MachineViewDimension{stride_t{1_p}, - MachineSpecificationDimension::INTRA_NODE}, - MachineViewDimension{stride_t{2_p}, - MachineSpecificationDimension::INTRA_NODE}}}; - MachineSpecification ms = - MachineSpecification{/*num_nodes=*/2_p, - /*num_cpus_per_node=*/6_p, - /*num_gpus_per_node=*/6_p, - /*inter_node_bandwidth=*/0, - /*intra_node_bandwidth=*/0}; - - SUBCASE("Task with TaskSpaceCoordinate = (0,0)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0_n, 0_n}}; - MachineSpaceCoordinate correct = MachineSpaceCoordinate{ - /*node_idx=*/1_n, /*device_idx=*/0_n, DeviceType::GPU}; - MachineSpaceCoordinate result = - get_machine_space_coordinate(task, mv, coord, ms).value(); - CHECK(correct == result); - } - - SUBCASE("Task with TaskSpaceCoordinate = (0,1)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0_n, 1_n}}; - MachineSpaceCoordinate correct = MachineSpaceCoordinate{ - /*node_idx=*/1_n, /*device_idx=*/4_n, DeviceType::GPU}; - MachineSpaceCoordinate result = - get_machine_space_coordinate(task, mv, coord, ms).value(); - CHECK(correct == result); - } - - SUBCASE("Task with TaskSpaceCoordinate = (1,0)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1_n, 0_n}}; - MachineSpaceCoordinate correct = MachineSpaceCoordinate{ - /*node_idx=*/1_n, /*device_idx=*/1_n, DeviceType::GPU}; - MachineSpaceCoordinate result = - get_machine_space_coordinate(task, mv, coord, ms).value(); - CHECK(correct == result); - } - - SUBCASE("Task with TaskSpaceCoordinate = (1,1)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1_n, 1_n}}; - MachineSpaceCoordinate correct = MachineSpaceCoordinate{ - /*node_idx=*/1_n, /*device_idx=*/5_n, DeviceType::GPU}; - MachineSpaceCoordinate result = - get_machine_space_coordinate(task, mv, coord, ms).value(); - CHECK(correct == result); - } - } - - SUBCASE("3D case") { - // This operator has shape (2, 2, 2), and thus 2 * 2 * 2 = 8 tasks. - // - The first dimension is projected onto the INTER (node) dimension - // with stride 1, - // - The second dimension is projected onto the INTRA (device) dimension - // with stride 2, - // - The third dimension is projected onto the INTRA (device) dimension - // with stride 1. The start of the projection defined by MachineView is - // at MachineSpaceCoordinates (0, 1), and the machine space has 2 nodes - // and 8 devices per node. - - /** - * The tasks will thus be distributed like this: - * +-------+-------+-------+-------+-------+-------+-------+-------+ - * | |(0,0,0)| |(0,0,1)| |(0,1,0)| |(0,1,1)| - * +-------+-------+-------+-------+-------+-------+-------+-------+ - * | |(1,0,0)| |(1,0,1)| |(1,1,0)| |(1,1,1)| - * +-------+-------+-------+-------+-------+-------+-------+-------+ - * Where the (x,y,z) are the `TaskSpaceCoordinate`s, and the underlying - * grid is the machine space. - */ - - OperatorTaskSpace task = OperatorTaskSpace{{2_p, 2_p, 2_p}}; - MachineView mv = MachineView{ - MachineSpaceCoordinate{ - /*node_idx=*/0_n, /*device_idx=*/1_n, DeviceType::GPU}, - {MachineViewDimension{stride_t{1_p}, - MachineSpecificationDimension::INTER_NODE}, - MachineViewDimension{stride_t{2_p}, - MachineSpecificationDimension::INTRA_NODE}, - MachineViewDimension{stride_t{1_p}, - MachineSpecificationDimension::INTRA_NODE}}}; - MachineSpecification ms = - MachineSpecification{/*num_nodes=*/2_p, - /*num_cpus_per_node=*/8_p, - /*num_gpus_per_node=*/8_p, - /*inter_node_bandwidth=*/0, - /*intra_node_bandwidth=*/0}; - - SUBCASE("Task with TaskSpaceCoordinate = (0,0,1)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0_n, 1_n, 0_n}}; - MachineSpaceCoordinate correct = MachineSpaceCoordinate{ - /*node_idx=*/0_n, /*device_idx=*/3_n, DeviceType::GPU}; - MachineSpaceCoordinate result = - get_machine_space_coordinate(task, mv, coord, ms).value(); - CHECK(correct == result); - } - - SUBCASE("Task with TaskSpaceCoordinate = (1,1,0)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1_n, 0_n, 1_n}}; - MachineSpaceCoordinate correct = MachineSpaceCoordinate{ - /*node_idx=*/1_n, /*device_idx=*/5_n, DeviceType::GPU}; - MachineSpaceCoordinate result = - get_machine_space_coordinate(task, mv, coord, ms).value(); - CHECK(correct == result); - } - - SUBCASE("Task with TaskSpaceCoordinate = (1,1,1)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1_n, 1_n, 1_n}}; - MachineSpaceCoordinate correct = MachineSpaceCoordinate{ - /*node_idx=*/1_n, /*device_idx=*/7_n, DeviceType::GPU}; - MachineSpaceCoordinate result = - get_machine_space_coordinate(task, mv, coord, ms).value(); - CHECK(correct == result); - } - } - } - } - - TEST_CASE("get_device_ids") { - - SUBCASE("1D machine view") { - - // This operator has shape (3,), and thus 3 tasks. - // The (only) dimension is projected onto the INTRA (device) dimension - // with a stride of 2. The start of the projection defined by MachineView - // is at MachineSpaceCoordinate (0, 1), and the machine space has 1 node - // and 6 devices per node. - - /** - * The tasks will thus be distributed like this: - * +-------+-------+-------+-------+-------+-------+ - * | 0 | ((1)) | 2 | ((3)) | 4 | ((5)) | - * +-------+-------+-------+-------+-------+-------+ - * Where the integers are the device ids and ((x)) are the devices we - * select - */ - MachineSpecification ms = - MachineSpecification{/*num_nodes=*/1_p, - /*num_cpus_per_node=*/6_p, - /*num_gpus_per_node=*/6_p, - /*inter_node_bandwidth=*/0, - /*intra_node_bandwidth=*/0}; - - OperatorTaskSpace task = OperatorTaskSpace{{3_p}}; - MachineView mv = MachineView{ - MachineSpaceCoordinate{ - /*node_idx=*/0_n, /*device_idx=*/1_n, DeviceType::GPU}, - {MachineViewDimension{stride_t{2_p}, - MachineSpecificationDimension::INTRA_NODE}}}; - - std::unordered_set correct = { - device_id_t{gpu_id_t{1_n}}, - device_id_t{gpu_id_t{3_n}}, - device_id_t{gpu_id_t{5_n}}, - }; - std::unordered_set result = get_device_ids(task, mv, ms); - CHECK(result == correct); - } - - SUBCASE("2D machine view") { - // This operator has shape (2, 2), and thus 2 * 2 = 4 tasks. - // - The first dimension is projected onto the INTER (node) dimension with - // stride 1, - // - The second dimension is projected onto the INTRA (device) dimension - // with stride 2. The start of the projection defined by MachineView is at - // MachineSpaceCoordinate (1, 2), and the machine space has 3 nodes and 5 - // devices per node. - - /** - * The tasks will thus be distributed like this: - * +-------+-------+-------+-------+-------+ - * | 0 | 1 | 2 | 3 | 4 | - * +-------+-------+-------+-------+-------+ - * | 5 | 6 | ((7)) | 8 | ((9)) | - * +-------+-------+-------+-------+-------+ - * | 10 | 11 | ((12))| 13 | ((14))| - * +-------+-------+-------+-------+-------+ - * Where the integers are the device ids and ((x)) are the devices we - * select - */ - - MachineSpecification ms = - MachineSpecification{/*num_nodes=*/3_p, - /*num_cpus_per_node=*/5_p, - /*num_gpus_per_node=*/5_p, - /*inter_node_bandwidth=*/0, - /*intra_node_bandwidth=*/0}; - - OperatorTaskSpace task = OperatorTaskSpace{{2_p, 2_p}}; - MachineView mv = MachineView{ - MachineSpaceCoordinate{ - /*node_idx=*/1_n, /*device_idx=*/2_n, DeviceType::GPU}, - {MachineViewDimension{stride_t{1_p}, - MachineSpecificationDimension::INTER_NODE}, - MachineViewDimension{stride_t{2_p}, - MachineSpecificationDimension::INTRA_NODE}}}; - - std::unordered_set correct = { - device_id_t{gpu_id_t{7_n}}, - device_id_t{gpu_id_t{9_n}}, - device_id_t{gpu_id_t{12_n}}, - device_id_t{gpu_id_t{14_n}}, - }; - std::unordered_set result = get_device_ids(task, mv, ms); - CHECK(result == correct); - } - } -} diff --git a/lib/pcg/test/src/pcg/operator_task_space.cc b/lib/pcg/test/src/pcg/operator_task_space.cc deleted file mode 100644 index 4b01ed02fb..0000000000 --- a/lib/pcg/test/src/pcg/operator_task_space.cc +++ /dev/null @@ -1,66 +0,0 @@ -#include "pcg/operator_task_space.h" -#include "utils/fmt/unordered_set.h" -#include - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_task_space_coordinates") { - - SUBCASE("OperatorTaskSpace has 0 dimensions") { - OperatorTaskSpace task = OperatorTaskSpace{{}}; - - std::unordered_set correct = { - TaskSpaceCoordinate{{}}}; - std::unordered_set result = - get_task_space_coordinates(task); - CHECK(correct == result); - } - SUBCASE("OperatorTaskSpace has 2 dimensions") { - - OperatorTaskSpace task = OperatorTaskSpace{{2_p, 2_p}}; - - std::unordered_set correct = {{ - TaskSpaceCoordinate{{0_n, 0_n}}, - TaskSpaceCoordinate{{0_n, 1_n}}, - TaskSpaceCoordinate{{1_n, 0_n}}, - TaskSpaceCoordinate{{1_n, 1_n}}, - }}; - std::unordered_set result = - get_task_space_coordinates(task); - CHECK(correct == result); - } - SUBCASE("OperatorTaskSpace has 3 dimensions") { - - OperatorTaskSpace task = OperatorTaskSpace{{1_p, 2_p, 2_p}}; - - std::unordered_set correct = {{ - TaskSpaceCoordinate{{0_n, 0_n, 0_n}}, - TaskSpaceCoordinate{{0_n, 0_n, 1_n}}, - TaskSpaceCoordinate{{0_n, 1_n, 0_n}}, - TaskSpaceCoordinate{{0_n, 1_n, 1_n}}, - }}; - std::unordered_set result = - get_task_space_coordinates(task); - CHECK(correct == result); - } - } - TEST_CASE("get_task_space_maximum_coordinate") { - SUBCASE("OperatorTaskSpace has 2 dimensions") { - - OperatorTaskSpace task = OperatorTaskSpace{{3_p, 2_p}}; - - TaskSpaceCoordinate correct = TaskSpaceCoordinate{{2_n, 1_n}}; - TaskSpaceCoordinate result = get_task_space_maximum_coordinate(task); - CHECK(correct == result); - } - SUBCASE("OperatorTaskSpace has 3 dimensions") { - - OperatorTaskSpace task = OperatorTaskSpace{{3_p, 2_p, 4_p}}; - - TaskSpaceCoordinate correct = TaskSpaceCoordinate{{2_n, 1_n, 3_n}}; - TaskSpaceCoordinate result = get_task_space_maximum_coordinate(task); - CHECK(correct == result); - } - } -} diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index f223558868..97ef5fa46e 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -1,10 +1,12 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "op-attrs/operator_task_space_to_operator_task_space_mapping.h" #include "op-attrs/ops/element_unary.h" #include "op-attrs/ops/linear.h" #include "op-attrs/ops/replicate.h" #include "op-attrs/parallel_tensor_shape.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" #include "utils/containers/get_only.h" +#include "utils/containers/require_only_key.h" #include using namespace ::FlexFlow; @@ -18,7 +20,7 @@ static ParallelLayerAttrs make_layer_attrs(T const &op_attrs) { }; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("topological_ordering") { + TEST_CASE("topological_ordering(ParallelComputationGraph)") { // TODO(@lockshaw) should probably be replaced with a rapidcheck test that // compares ParallelComputationGraph to DataflowGraph, but since we // currently don't have rapidcheck generation for DataflowGraph this will @@ -41,17 +43,36 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelLayerAddedResult layer1_added = pcg_add_input_layer(pcg, input_shape); parallel_layer_guid_t layer1 = layer1_added.parallel_layer; - parallel_tensor_guid_t tensor1 = get_only(layer1_added.outputs); + parallel_tensor_guid_t tensor1 = + require_only_key(layer1_added.outputs, TensorSlotName::OUTPUT); ParallelLayerAddedResult layer2_added = - add_parallel_layer(pcg, make_layer_attrs(relu_attrs), {tensor1}, {}); + add_parallel_layer(pcg, + make_layer_attrs(relu_attrs), + { + { + TensorSlotName::INPUT, + tensor1, + }, + }, + {}); parallel_layer_guid_t layer2 = layer2_added.parallel_layer; - parallel_tensor_guid_t tensor2 = get_only(layer2_added.outputs); + parallel_tensor_guid_t tensor2 = + require_only_key(layer2_added.outputs, TensorSlotName::OUTPUT); ParallelLayerAddedResult layer3_added = - add_parallel_layer(pcg, make_layer_attrs(relu_attrs), {tensor2}, {}); + add_parallel_layer(pcg, + make_layer_attrs(relu_attrs), + { + { + TensorSlotName::INPUT, + tensor2, + }, + }, + {}); parallel_layer_guid_t layer3 = layer3_added.parallel_layer; - parallel_tensor_guid_t tensor3 = get_only(layer3_added.outputs); + parallel_tensor_guid_t tensor3 = + require_only_key(layer3_added.outputs, TensorSlotName::OUTPUT); std::vector result = topological_ordering(pcg); std::vector correct = {layer1, layer2, layer3}; @@ -73,9 +94,9 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelLayerAddedResult input_added = pcg_add_input_layer(pcg, input_shape); - std::vector result = + std::unordered_map result = get_incoming_inputs(pcg, input_added.parallel_layer); - std::vector correct = {}; + std::unordered_map correct = {}; CHECK(result == correct); } @@ -105,26 +126,49 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelLayerAddedResult input_added = pcg_add_input_layer(pcg, input_shape); - parallel_tensor_guid_t t_input = get_only(input_added.outputs); + parallel_tensor_guid_t t_input = + require_only_key(input_added.outputs, TensorSlotName::OUTPUT); ParallelLayerAddedResult projection_weight_added = add_parallel_layer( pcg, make_layer_attrs(projection_weight_attrs), {}, {}); - parallel_tensor_guid_t t_projection = - get_only(projection_weight_added.outputs); + parallel_tensor_guid_t t_projection = require_only_key( + projection_weight_added.outputs, TensorSlotName::OUTPUT); ParallelLayerAddedResult bias_weight_added = add_parallel_layer(pcg, make_layer_attrs(bias_weight_attrs), {}, {}); - parallel_tensor_guid_t t_bias = get_only(bias_weight_added.outputs); - - ParallelLayerAddedResult linear_added = - add_parallel_layer(pcg, - make_layer_attrs(linear_attrs), - {t_input}, - {t_projection, t_bias}); - - std::vector result = + parallel_tensor_guid_t t_bias = + require_only_key(bias_weight_added.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult linear_added = add_parallel_layer( + /*pcg=*/pcg, + /*layer_attrs=*/make_layer_attrs(linear_attrs), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_input, + }, + }, + /*weights=*/ + { + { + TensorSlotName::WEIGHT, + t_projection, + }, + { + TensorSlotName::BIAS, + t_bias, + }, + }); + + std::unordered_map result = get_incoming_inputs(pcg, linear_added.parallel_layer); - std::vector correct = {t_input}; + std::unordered_map correct = { + { + TensorSlotName::INPUT, + t_input, + }, + }; CHECK(result == correct); } @@ -150,7 +194,8 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelLayerAddedResult layer1_added = pcg_add_input_layer(pcg, input_shape); parallel_layer_guid_t layer1 = layer1_added.parallel_layer; - parallel_tensor_guid_t tensor1 = get_only(layer1_added.outputs); + parallel_tensor_guid_t tensor1 = + require_only_key(layer1_added.outputs, TensorSlotName::OUTPUT); parallel_layer_guid_t result = get_source_layer(pcg, tensor1); parallel_layer_guid_t correct = layer1; @@ -161,10 +206,20 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelLayerAddedResult layer1_added = pcg_add_input_layer(pcg, input_shape); parallel_layer_guid_t layer1 = layer1_added.parallel_layer; - parallel_tensor_guid_t tensor1 = get_only(layer1_added.outputs); - - ParallelLayerAddedResult layer2_added = - add_parallel_layer(pcg, make_layer_attrs(relu_attrs), {tensor1}, {}); + parallel_tensor_guid_t tensor1 = + require_only_key(layer1_added.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult layer2_added = add_parallel_layer( + /*pcg=*/pcg, + /*layer_attrs=*/make_layer_attrs(relu_attrs), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + tensor1, + }, + }, + /*weights=*/{}); parallel_layer_guid_t layer2 = layer2_added.parallel_layer; parallel_layer_guid_t result = get_source_layer(pcg, tensor1); @@ -176,15 +231,35 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelLayerAddedResult layer1_added = pcg_add_input_layer(pcg, input_shape); parallel_layer_guid_t layer1 = layer1_added.parallel_layer; - parallel_tensor_guid_t tensor1 = get_only(layer1_added.outputs); - - ParallelLayerAddedResult layer2_added = - add_parallel_layer(pcg, make_layer_attrs(relu_attrs), {tensor1}, {}); + parallel_tensor_guid_t tensor1 = + require_only_key(layer1_added.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult layer2_added = add_parallel_layer( + /*pcg=*/pcg, + /*layer_attrs=*/make_layer_attrs(relu_attrs), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + tensor1, + }, + }, + /*weights=*/{}); parallel_layer_guid_t layer2 = layer2_added.parallel_layer; - parallel_tensor_guid_t tensor2 = get_only(layer2_added.outputs); - - ParallelLayerAddedResult layer3_added = - add_parallel_layer(pcg, make_layer_attrs(relu_attrs), {tensor1}, {}); + parallel_tensor_guid_t tensor2 = + require_only_key(layer2_added.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult layer3_added = add_parallel_layer( + /*pcg=*/pcg, + /*layer_attrs=*/make_layer_attrs(relu_attrs), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + tensor1, + }, + }, + /*weights=*/{}); parallel_layer_guid_t layer3 = layer3_added.parallel_layer; SUBCASE("tensor 1") { @@ -219,9 +294,9 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelLayerAddedResult input_added = pcg_add_input_layer(pcg, input_shape); - std::vector result = + std::unordered_map result = get_incoming_weights(pcg, input_added.parallel_layer); - std::vector correct = {}; + std::unordered_map correct = {}; CHECK(result == correct); } @@ -229,14 +304,24 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("layer has inputs but no weights") { ParallelLayerAddedResult input_added = pcg_add_input_layer(pcg, input_shape); - parallel_tensor_guid_t t_input = get_only(input_added.outputs); + parallel_tensor_guid_t t_input = + require_only_key(input_added.outputs, TensorSlotName::OUTPUT); ParallelLayerAddedResult relu_added = add_parallel_layer( - pcg, make_layer_attrs(make_relu_attrs()), {t_input}, {}); - - std::vector result = + /*pcg=*/pcg, + /*layer_attrs=*/make_layer_attrs(make_relu_attrs()), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_input, + }, + }, + /*weights=*/{}); + + std::unordered_map result = get_incoming_weights(pcg, relu_added.parallel_layer); - std::vector correct = {}; + std::unordered_map correct = {}; CHECK(result == correct); } @@ -257,7 +342,8 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelLayerAddedResult input_added = pcg_add_input_layer(pcg, input_shape); - parallel_tensor_guid_t t_input = get_only(input_added.outputs); + parallel_tensor_guid_t t_input = + require_only_key(input_added.outputs, TensorSlotName::OUTPUT); RepartitionAttrs partition_input_attrs = RepartitionAttrs{ /*repartition_dim=*/ff_dim_t{0_n}, @@ -265,9 +351,18 @@ TEST_SUITE(FF_TEST_SUITE) { }; ParallelLayerAddedResult partition_input_added = add_parallel_layer( - pcg, make_layer_attrs(partition_input_attrs), {t_input}, {}); - parallel_tensor_guid_t t_partitioned_input = - get_only(partition_input_added.outputs); + /*pcg=*/pcg, + /*layer_attrs=*/make_layer_attrs(partition_input_attrs), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_input, + }, + }, + /*weights=*/{}); + parallel_tensor_guid_t t_partitioned_input = require_only_key( + partition_input_added.outputs, TensorSlotName::OUTPUT); WeightAttrs projection_weight_attrs = WeightAttrs{ /*tensor_shape=*/throw_if_unexpected( @@ -276,31 +371,53 @@ TEST_SUITE(FF_TEST_SUITE) { }; ParallelLayerAddedResult projection_weight_added = add_parallel_layer( - pcg, make_layer_attrs(projection_weight_attrs), {}, {}); - parallel_tensor_guid_t t_projection_weight = - get_only(projection_weight_added.outputs); + /*pcg=*/pcg, + /*layer_attrs=*/make_layer_attrs(projection_weight_attrs), + /*inputs=*/{}, + /*weights=*/{}); + parallel_tensor_guid_t t_projection_weight = require_only_key( + projection_weight_added.outputs, TensorSlotName::OUTPUT); ReplicateAttrs replicate_projection_attrs = ReplicateAttrs{ /*replicate_degree=*/2_p, }; - ParallelLayerAddedResult replicate_projection_added = - add_parallel_layer(pcg, - make_layer_attrs(replicate_projection_attrs), - {t_projection_weight}, - {}); - parallel_tensor_guid_t t_replicated_projection_weight = - get_only(replicate_projection_added.outputs); - - ParallelLayerAddedResult linear_added = - add_parallel_layer(pcg, - make_layer_attrs(linear_attrs), - {t_partitioned_input}, - {t_replicated_projection_weight}); - - std::vector result = + ParallelLayerAddedResult replicate_projection_added = add_parallel_layer( + /*pcg=*/pcg, + /*layer_attrs=*/make_layer_attrs(replicate_projection_attrs), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_projection_weight, + }, + }, + /*weights=*/{}); + parallel_tensor_guid_t t_replicated_projection_weight = require_only_key( + replicate_projection_added.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult linear_added = add_parallel_layer( + /*pcg=*/pcg, + /*layer_attrs=*/make_layer_attrs(linear_attrs), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_partitioned_input, + }, + }, + /*weights=*/ + { + { + TensorSlotName::WEIGHT, + t_replicated_projection_weight, + }, + }); + + std::unordered_map result = get_incoming_weights(pcg, linear_added.parallel_layer); - std::vector correct = { - t_replicated_projection_weight}; + std::unordered_map correct = { + {TensorSlotName::WEIGHT, t_replicated_projection_weight}, + }; CHECK(result == correct); } @@ -335,11 +452,236 @@ TEST_SUITE(FF_TEST_SUITE) { /*layer_attrs=*/layer_attrs, /*inputs=*/{}, /*weights=*/{}, - /*output_labels=*/std::vector{CreateGrad::NO}); + /*output_labels=*/ + std::unordered_map{ + { + TensorSlotName::OUTPUT, + CreateGrad::NO, + }, + }); return pcg; }(); CHECK(pcgs_are_isomorphic(result, correct)); } + + TEST_CASE("pcg_get_mapping_along_edge") { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 10_p, + 12_p, + }, + }, + DataType::FLOAT, + }; + + ParallelTensorShape par_input_shape = lift_to_parallel(input_shape); + + ParallelLayerAttrs partition_attrs = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{ + RepartitionAttrs{ + /*repartition_dim=*/ff_dim_t{0_n}, + /*repartition_degree=*/2_p, + }, + }, + /*name=*/std::nullopt, + }; + + ParallelLayerAttrs relu_attrs = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{ + ElementUnaryAttrs{ + /*op_type=*/OperatorType::RELU, + /*scalar=*/std::nullopt, + }, + }, + /*name=*/std::nullopt, + }; + + SUBCASE("trivial mapping (relu into relu)") { + ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); + parallel_tensor_guid_t t_input = + require_only_key(input.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult partition_input = add_parallel_layer( + /*pcg=*/pcg, + /*layer_attrs=*/partition_attrs, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_input, + }, + }, + /*weights=*/{}); + parallel_tensor_guid_t t_partition_input = + require_only_key(partition_input.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult layer_1 = add_parallel_layer( + /*pcg=*/pcg, + /*layer_attrs=*/relu_attrs, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_partition_input, + }, + }, + /*weights=*/{}); + parallel_tensor_guid_t t_layer_1 = + require_only_key(layer_1.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult layer_2 = add_parallel_layer( + /*pcg=*/pcg, + /*layer_attrs=*/relu_attrs, + { + { + TensorSlotName::INPUT, + t_layer_1, + }, + }, + {}); + + ParallelComputationGraphEdge edge = + get_only(get_pcg_edges_from_layer_to_layer( + /*pcg=*/pcg, + /*src=*/layer_1.parallel_layer, + /*dst=*/layer_2.parallel_layer)); + + OperatorTaskSpaceToOperatorTaskSpaceMapping result = + pcg_get_mapping_along_edge(pcg, edge); + + DimDomain layer_1_task_space = + DimDomain{{ + {operator_task_space_dim_idx_t{0_n}, 2_p}, + }}; + + DimDomain layer_2_task_space = + layer_1_task_space; + + auto make_coord = [](nonnegative_int x) { + return DimCoord{ + std::unordered_map{ + {operator_task_space_dim_idx_t{0_n}, x}, + }, + }; + }; + + OperatorTaskSpaceToOperatorTaskSpaceMapping correct = + OperatorTaskSpaceToOperatorTaskSpaceMapping{ + DimDomainMapping{ + bidict, + DimCoord>{ + {make_coord(0_n), make_coord(0_n)}, + {make_coord(1_n), make_coord(1_n)}, + }, + layer_1_task_space, + layer_2_task_space, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("nontrivial mapping (linear into linear)") { + ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); + parallel_tensor_guid_t t_input = + require_only_key(input.outputs, TensorSlotName::OUTPUT); + ParallelLayerAddedResult partition_input = add_parallel_layer( + /*pcg=*/pcg, + /*layer_attrs=*/partition_attrs, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_input, + }, + }, + /*weights=*/{}); + parallel_tensor_guid_t t_partition_input = + require_only_key(partition_input.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAttrs transpose_attrs = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{ + TransposeAttrs{ + TensorDimPermutation{ + bidict{ + {ff_dim_t{0_n}, ff_dim_t{1_n}}, + {ff_dim_t{1_n}, ff_dim_t{0_n}}, + }, + }, + }, + }, + /*name=*/std::nullopt, + }; + + ParallelLayerAddedResult layer_1 = add_parallel_layer( + /*pcg=*/pcg, + /*layer_attrs=*/relu_attrs, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_partition_input, + }, + }, + /*weights=*/{}); + parallel_tensor_guid_t t_layer_1 = + require_only_key(layer_1.outputs, TensorSlotName::OUTPUT); + ParallelLayerAddedResult layer_2 = add_parallel_layer( + /*pcg=*/pcg, + /*layer_attrs=*/transpose_attrs, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_layer_1, + }, + }, + /*weights=*/{}); + + ParallelComputationGraphEdge edge = + get_only(get_pcg_edges_from_layer_to_layer( + /*pcg=*/pcg, + /*src=*/layer_1.parallel_layer, + /*dst=*/layer_2.parallel_layer)); + + OperatorTaskSpaceToOperatorTaskSpaceMapping result = + pcg_get_mapping_along_edge(pcg, edge); + + DimDomain layer_1_task_space = + DimDomain{{ + {operator_task_space_dim_idx_t{0_n}, 2_p}, + }}; + + DimDomain layer_2_task_space = + layer_1_task_space; + + auto make_coord = [](nonnegative_int x) { + return DimCoord{ + std::unordered_map{ + {operator_task_space_dim_idx_t{0_n}, x}, + }, + }; + }; + + OperatorTaskSpaceToOperatorTaskSpaceMapping correct = + OperatorTaskSpaceToOperatorTaskSpaceMapping{ + DimDomainMapping{ + bidict, + DimCoord>{ + {make_coord(0_n), make_coord(1_n)}, + {make_coord(1_n), make_coord(0_n)}, + }, + layer_1_task_space, + layer_2_task_space, + }, + }; + } + } } diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index 1682ac6254..aae34c8080 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -8,6 +8,7 @@ #include "utils/containers/generate_map.h" #include "utils/containers/get_only.h" #include "utils/containers/items.h" +#include "utils/containers/require_only_key.h" #include "utils/containers/transform.h" #include "utils/containers/values.h" #include "utils/containers/without_nullopts.h" @@ -16,12 +17,14 @@ using namespace ::FlexFlow; -// Stylistically these tests are not great (they're rather complicated -// and hard to read) and should not be used as a model for other FlexFlow -// tests. -// -// Improving them is being tracked in -// https://github.com/flexflow/FlexFlow/issues/1474 +/** + * Stylistically these tests are not great (they're rather complicated + * and hard to read) and should not be used as a model for other FlexFlow + * tests. + * + * Improving them is being tracked in + * https://github.com/flexflow/FlexFlow/issues/1474 + */ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("ParallelComputationGraphBuilder::add") { ParallelComputationGraphBuilder b; @@ -39,20 +42,6 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT, }; - // ParallelTensorShape lhs_shape = ParallelTensorShape{ - // ParallelTensorDims{ - // FFOrdered{ - // ShardParallelDim{10_p, 2_p}, - // ShardParallelDim{15_p, 3_p}, - // }, - // ReplicaParallelDimSet{ - // SumDegree{2_p}, - // DiscardCopyDegree{1_p}, - // }, - // }, - // DataType::FLOAT, - // }; - TensorShape rhs_shape = lhs_shape; parallel_tensor_guid_t lhs = b.create_input_tensor(lhs_shape); @@ -62,23 +51,46 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t layer = get_source_layer(out); SUBCASE("incoming") { - std::vector result = + std::unordered_map result = get_incoming_tensors(b.pcg, layer); - std::vector correct = {lhs, rhs}; + std::unordered_map correct = { + { + TensorSlotName::LHS_INPUT, + lhs, + }, + { + TensorSlotName::RHS_INPUT, + rhs, + }, + }; + CHECK(result == correct); } SUBCASE("outputs") { - std::vector result = + std::unordered_map result = get_layer_outputs(b.pcg, layer); - std::vector correct = {out}; + std::unordered_map correct = { + { + TensorSlotName::OUTPUT, + out, + }, + }; + CHECK(result == correct); } SUBCASE("op attrs") { PCGOperatorAttrs result = get_parallel_layer_attrs(b.pcg, layer).op_attrs; - PCGOperatorAttrs correct = PCGOperatorAttrs{ElementBinaryAttrs{ - OperatorType::EW_ADD, DataType::FLOAT, false, false}}; + PCGOperatorAttrs correct = PCGOperatorAttrs{ + ElementBinaryAttrs{ + /*type=*/OperatorType::EW_ADD, + /*compute_type=*/DataType::FLOAT, + /*should_broadcast_lhs=*/false, + /*should_broadcast_rhs=*/false, + }, + }; + CHECK(result == correct); } } @@ -115,23 +127,42 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t layer = get_source_layer(out); SUBCASE("incoming") { - std::vector result = + std::unordered_map result = get_incoming_tensors(b.pcg, layer); - std::vector correct = {a_tensor, b_tensor}; + std::unordered_map correct = { + { + TensorSlotName::LHS_INPUT, + a_tensor, + }, + { + TensorSlotName::RHS_INPUT, + b_tensor, + }, + }; + CHECK(result == correct); } SUBCASE("outputs") { - std::vector result = + std::unordered_map result = get_layer_outputs(b.pcg, layer); - std::vector correct = {out}; + std::unordered_map correct = { + { + TensorSlotName::OUTPUT, + out, + }, + }; + CHECK(result == correct); } SUBCASE("op attrs") { PCGOperatorAttrs result = get_parallel_layer_attrs(b.pcg, layer).op_attrs; - PCGOperatorAttrs correct = - PCGOperatorAttrs{BatchMatmulAttrs{std::nullopt, std::nullopt}}; + PCGOperatorAttrs correct = PCGOperatorAttrs{ + BatchMatmulAttrs{/*a_seq_length_dim=*/std::nullopt, + /*b_seq_length_dim=*/std::nullopt}, + }; + CHECK(result == correct); } } @@ -155,16 +186,27 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("incoming") { - std::vector result = + std::unordered_map result = get_incoming_tensors(b.pcg, layer); - std::vector correct = {input}; + std::unordered_map correct = { + { + TensorSlotName::INPUT, + input, + }, + }; + CHECK(result == correct); } SUBCASE("outputs") { - std::vector result = + std::unordered_map result = get_layer_outputs(b.pcg, layer); - std::vector correct = {output}; + std::unordered_map correct = { + { + TensorSlotName::OUTPUT, + output, + }, + }; CHECK(result == correct); ParallelTensorShape output_shape = @@ -212,25 +254,27 @@ TEST_SUITE(FF_TEST_SUITE) { }); CHECK_MESSAGE(layers.size() == 7, "Incorrect layers ", layers); - auto num_attrs_of_type = [&](OperatorType op_type) -> int { + auto num_attrs_of_type = [&](OperatorType op_type) -> nonnegative_int { return count(values(layers), [&](ParallelLayerAttrs const &l) { return get_op_type(l) == op_type; }); }; - int num_weight_attrs = num_attrs_of_type(OperatorType::WEIGHT); + nonnegative_int num_weight_attrs = num_attrs_of_type(OperatorType::WEIGHT); CHECK(num_weight_attrs == 2); - int num_input_attrs = num_attrs_of_type(OperatorType::INPUT); + nonnegative_int num_input_attrs = num_attrs_of_type(OperatorType::INPUT); CHECK(num_input_attrs == 1); - int num_conv_attrs = num_attrs_of_type(OperatorType::CONV2D); + nonnegative_int num_conv_attrs = num_attrs_of_type(OperatorType::CONV2D); CHECK(num_conv_attrs == 1); - int num_replicate_attrs = num_attrs_of_type(OperatorType::REPLICATE); + nonnegative_int num_replicate_attrs = + num_attrs_of_type(OperatorType::REPLICATE); CHECK(num_replicate_attrs == 2); - int num_partition_attrs = num_attrs_of_type(OperatorType::REPARTITION); + nonnegative_int num_partition_attrs = + num_attrs_of_type(OperatorType::REPARTITION); CHECK(num_partition_attrs == 1); parallel_layer_guid_t conv_guid = get_only(without_nullopts(transform( @@ -265,29 +309,31 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelTensorShape correct_bias_shape = get_bias_shape(correct_attrs, par_input_shape); - std::vector conv_incoming = + std::unordered_map conv_incoming = get_incoming_tensors(b.pcg, conv_guid); - parallel_tensor_guid_t conv_input = conv_incoming.at(0); + parallel_tensor_guid_t conv_input = conv_incoming.at(TensorSlotName::INPUT); ParallelTensorShape conv_input_shape = get_parallel_tensor_attrs(b.pcg, conv_input).shape; CHECK(conv_input_shape == par_input_shape); - parallel_tensor_guid_t conv_kernel = conv_incoming.at(1); + parallel_tensor_guid_t conv_kernel = + conv_incoming.at(TensorSlotName::FILTER); ParallelTensorShape conv_kernel_shape = get_parallel_tensor_attrs(b.pcg, conv_kernel).shape; CHECK(conv_kernel_shape == correct_kernel_shape); - parallel_tensor_guid_t conv_bias = conv_incoming.at(2); + parallel_tensor_guid_t conv_bias = conv_incoming.at(TensorSlotName::BIAS); ParallelTensorShape conv_bias_shape = get_parallel_tensor_attrs(b.pcg, conv_bias).shape; CHECK(conv_bias_shape == correct_bias_shape); - std::vector conv_outputs = + std::unordered_map conv_outputs = get_layer_outputs(b.pcg, conv_guid); CHECK(conv_outputs.size() == 1); - parallel_tensor_guid_t conv_output = get_only(conv_outputs); + parallel_tensor_guid_t conv_output = + require_only_key(conv_outputs, TensorSlotName::OUTPUT); ParallelTensorShape conv_output_shape = get_parallel_tensor_attrs(b.pcg, conv_output).shape; CHECK(conv_output_shape == correct_output_shape); @@ -316,17 +362,22 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("incoming") { - std::vector result = + std::unordered_map result = get_incoming_tensors(b.pcg, layer); - CHECK(result.at(0) == input); + CHECK(result.at(TensorSlotName::INPUT) == input); CHECK(result.size() == 3); } SUBCASE("outputs") { - std::vector result = + std::unordered_map result = get_layer_outputs(b.pcg, layer); - std::vector correct = {output}; + std::unordered_map correct = { + { + TensorSlotName::OUTPUT, + output, + }, + }; CHECK(result == correct); } } @@ -353,17 +404,23 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("incoming") { - std::vector result = + std::unordered_map result = get_incoming_tensors(b.pcg, layer); - CHECK(result.at(0) == input); CHECK(result.size() == 2); + CHECK(result.at(TensorSlotName::INPUT) == input); } SUBCASE("outputs") { - std::vector result = + std::unordered_map result = get_layer_outputs(b.pcg, layer); - std::vector correct = {output}; + std::unordered_map correct = { + { + TensorSlotName::OUTPUT, + output, + }, + }; + CHECK(result == correct); } } @@ -396,18 +453,25 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("incoming") { - std::vector result = + std::unordered_map result = get_incoming_tensors(b.pcg, layer); - CHECK(result.at(0) == query); - CHECK(result.at(1) == key); - CHECK(result.at(2) == value); + CHECK(result.size() == 6); + CHECK(result.at(TensorSlotName::QUERY) == query); + CHECK(result.at(TensorSlotName::KEY) == key); + CHECK(result.at(TensorSlotName::VALUE) == value); } SUBCASE("outputs") { - std::vector result = + std::unordered_map result = get_layer_outputs(b.pcg, layer); - std::vector correct = {output}; + std::unordered_map correct = { + { + TensorSlotName::OUTPUT, + output, + }, + }; + CHECK(result == correct); } } @@ -430,16 +494,28 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("incoming") { - std::vector result = + std::unordered_map result = get_incoming_tensors(b.pcg, layer); - std::vector correct = {input}; + std::unordered_map correct = { + { + TensorSlotName::INPUT, + input, + }, + }; + CHECK(result == correct); } SUBCASE("outputs") { - std::vector result = + std::unordered_map result = get_layer_outputs(b.pcg, layer); - std::vector correct = {output}; + std::unordered_map correct = { + { + TensorSlotName::OUTPUT, + output, + }, + }; + CHECK(result == correct); } } @@ -466,16 +542,28 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("incoming") { - std::vector result = + std::unordered_map result = get_incoming_tensors(b.pcg, layer); - std::vector correct = {input}; + std::unordered_map correct = { + { + TensorSlotName::INPUT, + input, + }, + }; + CHECK(result == correct); } SUBCASE("outputs") { - std::vector result = + std::unordered_map result = get_layer_outputs(b.pcg, layer); - std::vector correct = {output}; + std::unordered_map correct = { + { + TensorSlotName::OUTPUT, + output, + }, + }; + CHECK(result == correct); } } @@ -500,16 +588,28 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("incoming") { - std::vector result = + std::unordered_map result = get_incoming_tensors(b.pcg, layer); - std::vector correct = {input}; + std::unordered_map correct = { + { + TensorSlotName::INPUT, + input, + }, + }; + CHECK(result == correct); } SUBCASE("outputs") { - std::vector result = + std::unordered_map result = get_layer_outputs(b.pcg, layer); - std::vector correct = {output}; + std::unordered_map correct = { + { + TensorSlotName::OUTPUT, + output, + }, + }; + CHECK(result == correct); } } @@ -532,16 +632,28 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("incoming") { - std::vector result = + std::unordered_map result = get_incoming_tensors(b.pcg, layer); - std::vector correct = {input}; + std::unordered_map correct = { + { + TensorSlotName::INPUT, + input, + }, + }; + CHECK(result == correct); } SUBCASE("outputs") { - std::vector result = + std::unordered_map result = get_layer_outputs(b.pcg, layer); - std::vector correct = {output}; + std::unordered_map correct = { + { + TensorSlotName::OUTPUT, + output, + }, + }; + CHECK(result == correct); } } @@ -569,16 +681,28 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("incoming") { - std::vector result = + std::unordered_map result = get_incoming_tensors(b.pcg, layer); - std::vector correct = {input}; + std::unordered_map correct = { + { + TensorSlotName::INPUT, + input, + }, + }; + CHECK(result == correct); } SUBCASE("outputs") { - std::vector result = + std::unordered_map result = get_layer_outputs(b.pcg, layer); - std::vector correct = {output}; + std::unordered_map correct = { + { + TensorSlotName::OUTPUT, + output, + }, + }; + CHECK(result == correct); } } diff --git a/lib/pcg/test/src/pcg/pcg_from_computation_graph.cc b/lib/pcg/test/src/pcg/pcg_from_computation_graph.cc index d037d64672..87c79fd341 100644 --- a/lib/pcg/test/src/pcg/pcg_from_computation_graph.cc +++ b/lib/pcg/test/src/pcg/pcg_from_computation_graph.cc @@ -3,7 +3,7 @@ #include "op-attrs/ops/linear.h" #include "pcg/computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "utils/containers/get_only.h" +#include "utils/containers/require_only_key.h" #include using namespace ::FlexFlow; @@ -53,32 +53,57 @@ TEST_SUITE(FF_TEST_SUITE) { ComputationGraph cg = make_empty_computation_graph(); LayerAddedResult input_added = add_input_layer(cg, input_shape); - tensor_guid_t t_input = get_only(input_added.outputs); + tensor_guid_t t_input = + require_only_key(input_added.outputs, TensorSlotName::OUTPUT); LayerAddedResult projection_weight_added = add_layer(cg, make_layer_attrs(projection_weight_attrs), /*inputs=*/{}, /*weights=*/{}); - tensor_guid_t t_projection = get_only(projection_weight_added.outputs); + tensor_guid_t t_projection = require_only_key( + projection_weight_added.outputs, TensorSlotName::OUTPUT); LayerAddedResult bias_weight_added = add_layer(cg, make_layer_attrs(bias_weight_attrs), /*inputs=*/{}, /*weights=*/{}); - tensor_guid_t t_bias = get_only(bias_weight_added.outputs); - - LayerAddedResult linear_added = - add_layer(cg, - make_layer_attrs(linear_attrs), - /*inputs=*/{t_input}, - /*weights=*/{t_projection, t_bias}); - tensor_guid_t t_linear = get_only(linear_added.outputs); + tensor_guid_t t_bias = + require_only_key(bias_weight_added.outputs, TensorSlotName::OUTPUT); + + LayerAddedResult linear_added = add_layer(cg, + make_layer_attrs(linear_attrs), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_input, + }, + }, + /*weights=*/ + { + { + TensorSlotName::WEIGHT, + t_projection, + }, + { + TensorSlotName::BIAS, + t_bias, + }, + }); + tensor_guid_t t_linear = + require_only_key(linear_added.outputs, TensorSlotName::OUTPUT); add_layer(cg, make_layer_attrs(make_relu_attrs()), - /*inputs=*/{t_linear}, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_linear, + }, + }, /*weights=*/{}); return cg; @@ -96,33 +121,58 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelLayerAddedResult input_added = pcg_add_input_layer(pcg, input_shape); - parallel_tensor_guid_t t_input = get_only(input_added.outputs); + parallel_tensor_guid_t t_input = + require_only_key(input_added.outputs, TensorSlotName::OUTPUT); ParallelLayerAddedResult projection_weight_added = add_parallel_layer(pcg, make_layer_attrs(projection_weight_attrs), /*inputs=*/{}, /*weights=*/{}); - parallel_tensor_guid_t t_projection = - get_only(projection_weight_added.outputs); + parallel_tensor_guid_t t_projection = require_only_key( + projection_weight_added.outputs, TensorSlotName::OUTPUT); ParallelLayerAddedResult bias_weight_added = add_parallel_layer(pcg, make_layer_attrs(bias_weight_attrs), /*inputs=*/{}, /*weights=*/{}); - parallel_tensor_guid_t t_bias = get_only(bias_weight_added.outputs); + parallel_tensor_guid_t t_bias = + require_only_key(bias_weight_added.outputs, TensorSlotName::OUTPUT); ParallelLayerAddedResult linear_added = add_parallel_layer(pcg, make_layer_attrs(linear_attrs), - /*inputs=*/{t_input}, - /*weights=*/{t_projection, t_bias}); - parallel_tensor_guid_t t_linear = get_only(linear_added.outputs); + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_input, + }, + }, + /*weights=*/ + { + { + TensorSlotName::WEIGHT, + t_projection, + }, + { + TensorSlotName::BIAS, + t_bias, + }, + }); + parallel_tensor_guid_t t_linear = + require_only_key(linear_added.outputs, TensorSlotName::OUTPUT); add_parallel_layer(pcg, make_layer_attrs(make_relu_attrs()), - /*inputs=*/{t_linear}, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_linear, + }, + }, /*weights=*/{}); return pcg; }(); diff --git a/lib/pcg/test/src/pcg/start_invariant_machine_view.cc b/lib/pcg/test/src/pcg/start_invariant_machine_view.cc deleted file mode 100644 index afd6ad6b33..0000000000 --- a/lib/pcg/test/src/pcg/start_invariant_machine_view.cc +++ /dev/null @@ -1,229 +0,0 @@ -#include "pcg/start_invariant_machine_view.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" -#include - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("StartInvariantMachineView - utility functions") { - StartInvariantMachineView simv = StartInvariantMachineView{ - {MachineViewDimension{stride_t{2_p}, - MachineSpecificationDimension::INTER_NODE}, - MachineViewDimension{stride_t{2_p}, - MachineSpecificationDimension::INTER_NODE}}, - DeviceType::GPU}; - - SUBCASE("num_dims") { - nonnegative_int result = num_dims(simv); - nonnegative_int correct = 2_n; - CHECK(result == correct); - } - - SUBCASE("get_device_type") { - DeviceType result = get_device_type(simv); - DeviceType correct = DeviceType::GPU; - CHECK(result == correct); - } - - SUBCASE("get_strides") { - std::vector result = get_strides(simv); - std::vector correct = {stride_t{2_p}, stride_t{2_p}}; - CHECK(result == correct); - } - - SUBCASE("get_dimensions") { - std::vector result = get_dimensions(simv); - std::vector correct = { - MachineSpecificationDimension::INTER_NODE, - MachineSpecificationDimension::INTER_NODE}; - CHECK(result == correct); - } - } - - TEST_CASE("StartInvariantMachineView - conversions") { - MachineSpaceCoordinate start = - MachineSpaceCoordinate{1_n, 2_n, DeviceType::GPU}; - std::vector dimensions = { - MachineViewDimension{stride_t{2_p}, - MachineSpecificationDimension::INTER_NODE}, - MachineViewDimension{stride_t{3_p}, - MachineSpecificationDimension::INTRA_NODE}}; - - MachineView mv = MachineView{start, dimensions}; - StartInvariantMachineView simv = - StartInvariantMachineView{dimensions, DeviceType::GPU}; - - SUBCASE("start_invariant_from_machine_view") { - StartInvariantMachineView result = start_invariant_from_machine_view(mv); - StartInvariantMachineView correct = simv; - CHECK(result == correct); - } - - SUBCASE("machine_view_from_start_invariant") { - MachineView result = machine_view_from_start_invariant(simv, start); - MachineView correct = mv; - CHECK(result == correct); - } - - SUBCASE("conversion is invertible") { - SUBCASE("MachineView -> StartInvariant -> MachineView") { - MachineView result = machine_view_from_start_invariant( - start_invariant_from_machine_view(mv), start); - MachineView correct = mv; - CHECK(result == correct); - } - - SUBCASE("StartInvariant -> MachineView -> StartInvariant") { - StartInvariantMachineView result = start_invariant_from_machine_view( - machine_view_from_start_invariant(simv, start)); - StartInvariantMachineView correct = simv; - CHECK(result == correct); - } - } - } - - TEST_CASE("StartInvariantMachineView - get_machine_space_offset") { - SUBCASE("1D case") { - // This operator has shape (3,), and thus 3 tasks. - // The (only) dimension is projected on the INTRA (device) dimension with - // a stride of 2. The machine space has 1 node and 6 devices per node. - /** - * The tasks will thus be distributed like this: - * +-------+-------+-------+-------+-------+-------+ - * | (0,) | | (1,) | | (2,) | | - * +-------+-------+-------+-------+-------+-------+ - */ - OperatorTaskSpace task = OperatorTaskSpace{{3_p}}; - StartInvariantMachineView simv = StartInvariantMachineView{ - {MachineViewDimension{stride_t{2_p}, - MachineSpecificationDimension::INTRA_NODE}}, - DeviceType::GPU}; - MachineSpecification ms = - MachineSpecification{/*num_nodes=*/1_p, - /*num_cpus_per_node=*/6_p, - /*num_gpus_per_node=*/6_p, - /*inter_node_bandwidth=*/0.0, - /*intra_node_bandwidth=*/0.0}; - - SUBCASE("get_machine_space_offset") { - SUBCASE("Task with TaskSpaceCoordinate = (0,)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0_n}}; - MachineSpaceOffset correct = - MachineSpaceOffset{0, 0, DeviceType::GPU}; - MachineSpaceOffset result = - get_machine_space_offset(task, simv, coord, ms).value(); - CHECK(correct == result); - } - - SUBCASE("Task with TaskSpaceCoordinate = (1,)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1_n}}; - MachineSpaceOffset correct = - MachineSpaceOffset{0, 2, DeviceType::GPU}; - MachineSpaceOffset result = - get_machine_space_offset(task, simv, coord, ms).value(); - CHECK(correct == result); - } - - SUBCASE("Task with TaskSpaceCoordinate = (2,)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{2_n}}; - MachineSpaceOffset correct = - MachineSpaceOffset{0, 4, DeviceType::GPU}; - MachineSpaceOffset result = - get_machine_space_offset(task, simv, coord, ms).value(); - CHECK(correct == result); - } - } - - SUBCASE("get_machine_space_offsets") { - std::unordered_set correct = { - MachineSpaceOffset{0, 0, DeviceType::GPU}, - MachineSpaceOffset{0, 2, DeviceType::GPU}, - MachineSpaceOffset{0, 4, DeviceType::GPU}}; - std::unordered_set result = - get_machine_space_offsets(task, simv, ms); - CHECK(correct == result); - } - } - - SUBCASE("2D case") { - // This operator has shape (2, 2), and thus 2 * 2 = 4 tasks. - // The first dimension is projected onto the INTER (node) dimension with - // stride 1, while the second dimension is projected onto the INTRA - // (device) dimension with stride 2. The machine space has 2 nodes and 4 - // devices per node. - - /** - * The tasks will thus be distributed like this: - * +-------+-------+-------+-------+ - * | (0,0) | | (0,1) | | - * +-------+-------+-------+-------+ - * | (1,0) | | (1,1) | | - * +-------+-------+-------+-------+ - */ - - OperatorTaskSpace task = OperatorTaskSpace{{2_p, 2_p}}; - StartInvariantMachineView simv = StartInvariantMachineView{ - {MachineViewDimension{stride_t{1_p}, - MachineSpecificationDimension::INTER_NODE}, - MachineViewDimension{stride_t{2_p}, - MachineSpecificationDimension::INTRA_NODE}}, - DeviceType::GPU}; - MachineSpecification ms = - MachineSpecification{/*num_nodes=*/2_p, - /*num_cpus_per_node=*/4_p, - /*num_gpus_per_node=*/4_p, - /*inter_node_bandwidth=*/0, - /*intra_node_bandwidth=*/0}; - - SUBCASE("get_machine_space_offset") { - SUBCASE("Task with TaskSpaceCoordinate = (0,0)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0_n, 0_n}}; - MachineSpaceOffset correct = - MachineSpaceOffset{0, 0, DeviceType::GPU}; - MachineSpaceOffset result = - get_machine_space_offset(task, simv, coord, ms).value(); - CHECK(correct == result); - } - - SUBCASE("Task with TaskSpaceCoordinate = (0,1)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0_n, 1_n}}; - MachineSpaceOffset correct = - MachineSpaceOffset{0, 2, DeviceType::GPU}; - MachineSpaceOffset result = - get_machine_space_offset(task, simv, coord, ms).value(); - CHECK(correct == result); - } - - SUBCASE("Task with TaskSpaceCoordinate = (1,0)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1_n, 0_n}}; - MachineSpaceOffset correct = - MachineSpaceOffset{1, 0, DeviceType::GPU}; - MachineSpaceOffset result = - get_machine_space_offset(task, simv, coord, ms).value(); - CHECK(correct == result); - } - - SUBCASE("Task with TaskSpaceCoordinate = (1,1)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1_n, 1_n}}; - MachineSpaceOffset correct = - MachineSpaceOffset{1, 2, DeviceType::GPU}; - MachineSpaceOffset result = - get_machine_space_offset(task, simv, coord, ms).value(); - CHECK(correct == result); - } - } - - SUBCASE("get_machine_space_offsets") { - std::unordered_set correct = { - MachineSpaceOffset{0, 0, DeviceType::GPU}, - MachineSpaceOffset{0, 2, DeviceType::GPU}, - MachineSpaceOffset{1, 0, DeviceType::GPU}, - MachineSpaceOffset{1, 2, DeviceType::GPU}}; - std::unordered_set result = - get_machine_space_offsets(task, simv, ms); - CHECK(correct == result); - } - } - } -} diff --git a/lib/runtime/include/runtime/legion_backing.h b/lib/runtime/include/runtime/legion_backing.h index ae930b014c..e6d26c6bb4 100644 --- a/lib/runtime/include/runtime/legion_backing.h +++ b/lib/runtime/include/runtime/legion_backing.h @@ -1,11 +1,11 @@ #ifndef _FLEXFLOW_RUNTIME_INCLUDE_RUNTIME_RUNTIME_BACKING_H #define _FLEXFLOW_RUNTIME_INCLUDE_RUNTIME_RUNTIME_BACKING_H +#include "compiler/machine_mapping/machine_view.h" #include "legion.h" #include "mapping_id_t.h" #include "op-attrs/parallel_tensor_shape.h" #include "parallel_computation_graph.h" -#include "pcg/machine_view.h" #include "task_spec/task_return_accessor.h" #include "task_spec/tensorless_task_invocation.h" #include "utils/visitable.h" diff --git a/lib/runtime/include/runtime/task_spec/tensorless_task_invocation.h b/lib/runtime/include/runtime/task_spec/tensorless_task_invocation.h index 996f4e9946..5bb874a6d3 100644 --- a/lib/runtime/include/runtime/task_spec/tensorless_task_invocation.h +++ b/lib/runtime/include/runtime/task_spec/tensorless_task_invocation.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_RUNTIME_SRC_TENSORLESS_TASK_INVOCATION_H #define _FLEXFLOW_RUNTIME_SRC_TENSORLESS_TASK_INVOCATION_H +#include "compiler/machine_mapping/machine_view.h" #include "concrete_arg.h" #include "index_arg.h" -#include "pcg/machine_view.h" #include "slot_id.h" #include "typed_future.h" #include "typed_future_map.h" diff --git a/lib/runtime/src/fused_op_attrs.h b/lib/runtime/src/fused_op_attrs.h index a8ea524165..a1ab876167 100644 --- a/lib/runtime/src/fused_op_attrs.h +++ b/lib/runtime/src/fused_op_attrs.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_RUNTIME_SRC_FUSED_OP_ATTRS_H #include "op-attrs/get_op_type.h" -#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.h" #include "operator.h" #include "utils/visitable.h" diff --git a/lib/runtime/src/initializer.cc b/lib/runtime/src/initializer.cc index da9b28edaf..90356aa7a3 100644 --- a/lib/runtime/src/initializer.cc +++ b/lib/runtime/src/initializer.cc @@ -370,7 +370,7 @@ static void template <> void register_task() { TaskSignature sig; - sig.add_slot(TENSOR, {SlotType::TENSOR, Permissions::WO}); + sig.add_slot(TENSOR, {TensorSlotArity::TENSOR, Permissions::WO}); sig.add_arg_slot(INITIALIZER); sig.add_arg_slot(TENSOR_DIMS); @@ -380,7 +380,7 @@ void register_task() { template <> void register_task() { TaskSignature sig; - sig.add_slot(TENSOR, {SlotType::TENSOR, Permissions::WO}); + sig.add_slot(TENSOR, {TensorSlotArity::TENSOR, Permissions::WO}); register_task( ZERO_INIT_TASK_ID, "Zero Init", sig, zero_init_task, zero_init_task_cpu); @@ -389,7 +389,7 @@ void register_task() { template <> void register_task() { TaskSignature sig; - sig.add_slot(TENSOR, {SlotType::TENSOR, Permissions::WO}); + sig.add_slot(TENSOR, {TensorSlotArity::TENSOR, Permissions::WO}); sig.add_arg_slot(INITIALIZER); register_task(UNIFORM_INIT_TASK_ID, @@ -401,7 +401,7 @@ void register_task() { template <> void register_task() { TaskSignature sig; - sig.add_slot(TENSOR, {SlotType::TENSOR, Permissions::WO}); + sig.add_slot(TENSOR, {TensorSlotArity::TENSOR, Permissions::WO}); sig.add_arg_slot(INITIALIZER); register_task( @@ -411,7 +411,7 @@ void register_task() { template <> void register_task() { TaskSignature sig; - sig.add_slot(TENSOR, {SlotType::TENSOR, Permissions::WO}); + sig.add_slot(TENSOR, {TensorSlotArity::TENSOR, Permissions::WO}); sig.add_arg_slot(INITIALIZER); register_task(CONSTANT_INIT_TASK_ID, diff --git a/lib/runtime/src/mapper.cc b/lib/runtime/src/mapper.cc index fe68a224e1..ca6b0ed319 100644 --- a/lib/runtime/src/mapper.cc +++ b/lib/runtime/src/mapper.cc @@ -14,9 +14,9 @@ */ #include "mapper.h" +#include "compiler/machine_mapping/machine_view.h" #include "default_mapper.h" #include "loggers.h" -#include "pcg/machine_view.h" #include "tasks.h" #include "utils/exception.h" diff --git a/lib/runtime/src/metrics_functions.cc b/lib/runtime/src/metrics_functions.cc index 33e15baed2..23e66ba17e 100644 --- a/lib/runtime/src/metrics_functions.cc +++ b/lib/runtime/src/metrics_functions.cc @@ -216,8 +216,8 @@ static PerfMetrics template <> void register_task() { TaskSignature sig; - sig.add_slot(LOGIT, {SlotType::TENSOR, Permissions::RO}); - sig.add_slot(LABEL, {SlotType::TENSOR, Permissions::RO}); + sig.add_slot(LOGIT, {TensorSlotArity::TENSOR, Permissions::RO}); + sig.add_slot(LABEL, {TensorSlotArity::TENSOR, Permissions::RO}); sig.add_arg_slot(PROFILING_SETTINGS); sig.add_arg_slot(METRICS_STRUCT); sig.add_return_value(); diff --git a/lib/runtime/src/operator.h b/lib/runtime/src/operator.h index 2db40de78b..ff6a10ce49 100644 --- a/lib/runtime/src/operator.h +++ b/lib/runtime/src/operator.h @@ -1,10 +1,10 @@ #ifndef _FLEXFLOW_RUNTIME_SRC_OPERATOR_H #define _FLEXFLOW_RUNTIME_SRC_OPERATOR_H +#include "compiler/machine_mapping/machine_view.h" #include "kernels/profiling.h" #include "layer_id.h" #include "op-attrs/operator_attrs.h" -#include "pcg/machine_view.h" #include "profiling.h" #include "runtime/config.h" #include "tasks.h" diff --git a/lib/runtime/src/ops/fused_parallel_op.h b/lib/runtime/src/ops/fused_parallel_op.h index 439c6d55e4..844fcc7b0b 100644 --- a/lib/runtime/src/ops/fused_parallel_op.h +++ b/lib/runtime/src/ops/fused_parallel_op.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_FUSED_PARALLEL_OP_H #include "fused_parallel_op_attrs.h" -#include "op_task_invocation.h" #include "sim_environment.h" namespace FlexFlow { diff --git a/lib/runtime/src/ops/fused_parallel_op_attrs.h b/lib/runtime/src/ops/fused_parallel_op_attrs.h index 454b7caead..4dd64199d3 100644 --- a/lib/runtime/src/ops/fused_parallel_op_attrs.h +++ b/lib/runtime/src/ops/fused_parallel_op_attrs.h @@ -1,7 +1,6 @@ -#ifndef _FLEXFLOW_FUSED_PARALLEL_OP_ATTRS_H -#define _FLEXFLOW_FUSED_PARALLEL_OP_ATTRS_H +#ifndef _FLEXFLOW_LIB_RUNTIME_INCLUDE_OPS_FUSED_PARALLEL_OP_ATTRS_H +#define _FLEXFLOW_LIB_RUNTIME_INCLUDE_OPS_FUSED_PARALLEL_OP_ATTRS_H -#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.h" #include "parallel_op_info.h" #include "utils/visitable.h" diff --git a/lib/runtime/src/optimizer.cc b/lib/runtime/src/optimizer.cc index edad16021a..7845a573e5 100644 --- a/lib/runtime/src/optimizer.cc +++ b/lib/runtime/src/optimizer.cc @@ -661,7 +661,7 @@ static void adam_nccl_update_task(Task const *task, template <> void register_task() { TaskSignature sig; - sig.add_slot(TENSOR, {SlotType::TENSOR, READ_ONLY}); + sig.add_slot(TENSOR, {TensorSlotArity::TENSOR, READ_ONLY}); register_task( PS_PREFETCH_TASK_ID, "Weights Prefetch", sig, UtilityTasks::dummy_task); @@ -670,9 +670,9 @@ void register_task() { template <> void register_task() { TaskSignature sig; - sig.add_slot(TENSOR, {SlotType::TENSOR, READ_WRITE}); - sig.add_slot(GRADIENT, {SlotType::TENSOR, READ_ONLY}); - sig.add_slot(MOMENTUM_V, {SlotType::TENSOR, READ_WRITE}); + sig.add_slot(TENSOR, {TensorSlotArity::TENSOR, READ_WRITE}); + sig.add_slot(GRADIENT, {TensorSlotArity::TENSOR, READ_ONLY}); + sig.add_slot(MOMENTUM_V, {TensorSlotArity::TENSOR, READ_WRITE}); sig.add_arg_slot(OPTIMIZER); register_task(SGD_UPD_PS_TASK_ID, @@ -684,9 +684,9 @@ void register_task() { template <> void register_task() { TaskSignature sig; - sig.add_slot(TENSOR, {SlotType::TENSOR, READ_WRITE}); - sig.add_slot(GRADIENT, {SlotType::TENSOR < READ_ONLY}); - sig.add_slot(MOMENTUM_V, {SlotType::TENSOR, READ_WRITE}); + sig.add_slot(TENSOR, {TensorSlotArity::TENSOR, READ_WRITE}); + sig.add_slot(GRADIENT, {TensorSlotArity::TENSOR < READ_ONLY}); + sig.add_slot(MOMENTUM_V, {TensorSlotArity::TENSOR, READ_WRITE}); sig.add_arg_slot(OPTIMIZER); sig.add_arg_slot(HANDLE); @@ -697,10 +697,10 @@ void register_task() { template <> void register_task() { TaskSignature sig; - sig.add_slot(TENSOR, {SlotType::TENSOR, READ_WRITE}); - sig.add_slot(GRADIENT, {SlotType::TENSOR, READ_ONLY}); - sig.add_slot(ADAM_W, {SlotType::TENSOR, READ_WRITE}); - sig.add_slot(ADAM_M, {SlotType::TENSOR, READ_WRITE}); + sig.add_slot(TENSOR, {TensorSlotArity::TENSOR, READ_WRITE}); + sig.add_slot(GRADIENT, {TensorSlotArity::TENSOR, READ_ONLY}); + sig.add_slot(ADAM_W, {TensorSlotArity::TENSOR, READ_WRITE}); + sig.add_slot(ADAM_M, {TensorSlotArity::TENSOR, READ_WRITE}); sig.add_slot(OPTIMIZER); register_task(ADAM_UPD_PS_TASK_ID, @@ -712,10 +712,10 @@ void register_task() { template <> void register_task() { TaskSignature sig; - sig.add_slot(TENSOR, {SlotType::TENSOR, READ_WRITE}); - sig.add_slot(GRADIENT, {SlotType::TENSOR, READ_ONLY}); - sig.add_slot(ADAM_W, {SlotType::TENSOR, READ_WRITE}); - sig.add_slot(ADAM_M, {SlotType::TENSOR, READ_WRITE}); + sig.add_slot(TENSOR, {TensorSlotArity::TENSOR, READ_WRITE}); + sig.add_slot(GRADIENT, {TensorSlotArity::TENSOR, READ_ONLY}); + sig.add_slot(ADAM_W, {TensorSlotArity::TENSOR, READ_WRITE}); + sig.add_slot(ADAM_M, {TensorSlotArity::TENSOR, READ_WRITE}); sig.add_slot(OPTIMIZER); sig.add_slot(HANDLE); diff --git a/lib/runtime/src/parallel_computation_graph.h b/lib/runtime/src/parallel_computation_graph.h index 5ffd6f7cad..0a06e32dfb 100644 --- a/lib/runtime/src/parallel_computation_graph.h +++ b/lib/runtime/src/parallel_computation_graph.h @@ -7,7 +7,6 @@ #include "pcg/optimizer.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_tensor.h" -#include "task_spec/op_task_invocation.h" #include "utils/graph.h" #include "utils/strong_typedef.h" #include diff --git a/lib/runtime/src/simulator.h b/lib/runtime/src/simulator.h index cc5905f218..c4dcb49e4c 100644 --- a/lib/runtime/src/simulator.h +++ b/lib/runtime/src/simulator.h @@ -15,10 +15,10 @@ #ifndef _FLEXFLOW_SIMULATOR_H_ #define _FLEXFLOW_SIMULATOR_H_ +#include "compiler/machine_mapping/machine_view.h" #include "cost_metrics.h" #include "kernels/ff_handle.h" #include "op-attrs/operator_attrs.h" -#include "pcg/machine_view.h" #include "pcg/operator_guid_t.h" #include "pcg/parallel_tensor.h" #include "runtime/config.h" diff --git a/lib/runtime/src/task_spec/index_task_invocation.h b/lib/runtime/src/task_spec/index_task_invocation.h index 795de0ae94..42eaadf769 100644 --- a/lib/runtime/src/task_spec/index_task_invocation.h +++ b/lib/runtime/src/task_spec/index_task_invocation.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_RUNTIME_INCLUDE_RUNTIME_TASK_SPEC_INDEX_TASK_INVOCATION_H #include "arg_ref.h" +#include "compiler/machine_mapping/machine_view.h" #include "parallel_tensor_spec.h" -#include "pcg/machine_view.h" #include "pcg/parallel_tensor_guid_t.h" #include "runtime/task_spec/concrete_arg.h" #include "runtime/task_spec/index_arg.h" diff --git a/lib/runtime/src/task_spec/task_signature.h b/lib/runtime/src/task_spec/task_signature.h index 2910d4f652..f0efd34ce0 100644 --- a/lib/runtime/src/task_spec/task_signature.h +++ b/lib/runtime/src/task_spec/task_signature.h @@ -16,10 +16,10 @@ namespace FlexFlow { struct ParallelTensorSlotSpec { public: ParallelTensorSlotSpec() = delete; - ParallelTensorSlotSpec(SlotType, Permissions perm); + ParallelTensorSlotSpec(TensorSlotArity, Permissions perm); public: - SlotType slot_type; + TensorSlotArity slot_type; Permissions perm; }; diff --git a/lib/runtime/test/src/main.cc b/lib/runtime/test/src/main.cc index 9522fa7fdb..0a3f254ea8 100644 --- a/lib/runtime/test/src/main.cc +++ b/lib/runtime/test/src/main.cc @@ -1,2 +1,2 @@ #define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN -#include "doctest/doctest.h" +#include diff --git a/lib/runtime/test/src/test_op_task_spec.cc b/lib/runtime/test/src/test_op_task_spec.cc deleted file mode 100644 index bb0bee567c..0000000000 --- a/lib/runtime/test/src/test_op_task_spec.cc +++ /dev/null @@ -1,47 +0,0 @@ -#include "doctest/doctest.h" -#include "op_task_invocation.h" -#include "op_task_signature.h" - -using namespace FlexFlow; - -TEST_CASE("OpTaskSignature") { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_input_slot(0); - fwd.add_input_slot(1); - fwd.add_output_slot(2); - - OpTaskSignature bwd = infer_bwd_signature(fwd); - - OpTaskSignature correct_bwd = {OpTaskType::BWD}; - - correct_bwd.add_input_slot(0); - correct_bwd.add_input_grad_slot(0); - correct_bwd.add_input_slot(1); - correct_bwd.add_input_grad_slot(1); - correct_bwd.add_output_slot(2); - correct_bwd.add_output_grad_slot(2); - - CHECK(bwd == correct_bwd); -} - -TEST_CASE("OpTaskBinding") { - OpTaskBinding fwd; - - binding.bind(0, input_tensor(0)); - binding.bind(1, input_tensor(1)); - binding.bind(2, input_tensor(2)); - - OpTaskBinding bwd = infer_bwd_binding(fwd); - - OpTaskBinding correct_bwd; - - correct_bwd.bind(0, input_tensor(0)); - correct_bwd.bind_grad(0, input_tensor(0).grad()); - correct_bwd.bind(1, input_tensor(1)); - correct_bwd.bind_grad(1, input_tensor(1).grad()); - correct_bwd.bind(2, input_tensor(2)); - correct_bwd.bind_grad(2, input_tensor(2).grad()); - - CHECK(correct_bwd == bwd); -} diff --git a/lib/runtime/test/src/test_serialization.cc b/lib/runtime/test/src/test_serialization.cc index 471f2a2709..2c5d680071 100644 --- a/lib/runtime/test/src/test_serialization.cc +++ b/lib/runtime/test/src/test_serialization.cc @@ -1,7 +1,7 @@ -#include "doctest/doctest.h" #include "legion/legion_utilities.h" #include "op-attrs/ffconst.h" #include "serialization.h" +#include #include using namespace FlexFlow; diff --git a/lib/substitution-generator/include/substitution-generator/legacy_operator_type.dtg.toml b/lib/substitution-generator/include/substitution-generator/legacy_operator_type.dtg.toml new file mode 100644 index 0000000000..e2e4e80e87 --- /dev/null +++ b/lib/substitution-generator/include/substitution-generator/legacy_operator_type.dtg.toml @@ -0,0 +1,96 @@ +namespace = "FlexFlow" +name = "LegacyOperatorType" +type = "enum" + +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +values = [ + { name = "NOOP", json_key = "OP_NOOP" }, + { name = "INPUT", json_key = "OP_INPUT" }, + { name = "WEIGHT", json_key = "OP_WEIGHT" }, + { name = "CONV2D", json_key = "OP_CONV2D" }, + { name = "DROPOUT", json_key = "OP_DROPOUT" }, + { name = "LINEAR", json_key = "OP_LINEAR" }, + { name = "BATCHMATMUL", json_key = "OP_BATCHMATMUL" }, + { name = "POOL2D", json_key = "OP_POOL2D" }, + { name = "SCALAR_MULTIPLY", json_key = "OP_SCALAR_MULTIPLY" }, + { name = "SCALAR_ADD", json_key = "OP_SCALAR_ADD" }, + { name = "SCALAR_FLOOR_DIV", json_key = "OP_SCALAR_FLOOR_DIV" }, + { name = "SCALAR_TRUE_DIV", json_key = "OP_SCALAR_TRUE_DIV" }, + { name = "SCALAR_SUB", json_key = "OP_SCALAR_SUB" }, + { name = "RELU", json_key = "OP_RELU" }, + { name = "IDENTITY", json_key = "OP_IDENTITY" }, + { name = "SIGMOID", json_key = "OP_SIGMOID" }, + { name = "TANH", json_key = "OP_TANH" }, + { name = "ELU", json_key = "OP_ELU" }, + { name = "FLAT", json_key = "OP_FLAT" }, + { name = "SOFTMAX", json_key = "OP_SOFTMAX" }, + { name = "BATCHNORM", json_key = "OP_BATCHNORM" }, + { name = "CONCAT", json_key = "OP_CONCAT" }, + { name = "SPLIT", json_key = "OP_SPLIT" }, + { name = "EMBEDDING", json_key = "OP_EMBEDDING" }, + { name = "CACHE", json_key = "OP_CACHE" }, + { name = "RESHAPE", json_key = "OP_RESHAPE" }, + { name = "REVERSE", json_key = "OP_REVERSE" }, + { name = "TRANSPOSE", json_key = "OP_TRANSPOSE" }, + { name = "EW_ADD", json_key = "OP_EW_ADD" }, + { name = "EW_MUL", json_key = "OP_EW_MUL" }, + { name = "MATMUL", json_key = "OP_MATMUL" }, + { name = "MUL", json_key = "OP_MUL" }, + { name = "ENLARGE", json_key = "OP_ENLARGE" }, + { name = "SQUEEZE", json_key = "OP_SQUEEZE" }, + { name = "UNSQUEEZE", json_key = "OP_UNSQUEEZE" }, + { name = "EW_SUB", json_key = "OP_EW_SUB" }, + { name = "EW_DIV", json_key = "OP_EW_DIV" }, + { name = "EW_EQUAL", json_key = "OP_EW_EQUAL" }, + { name = "EW_GREATER", json_key = "OP_EW_GREATER" }, + { name = "EW_LESS", json_key = "OP_EW_LESS" }, + { name = "EW_MAX", json_key = "OP_EW_MAX" }, + { name = "EW_MIN", json_key = "OP_EW_MIN" }, + { name = "REDUCE_ARGMAX", json_key = "OP_REDUCE_ARGMAX" }, + { name = "REDUCE_ARGMIN", json_key = "OP_REDUCE_ARGMIN" }, + { name = "REDUCE_MAX", json_key = "OP_REDUCE_MAX" }, + { name = "REDUCE_MEAN", json_key = "OP_REDUCE_MEAN" }, + { name = "REDUCE_MIN", json_key = "OP_REDUCE_MIN" }, + { name = "REDUCE_PROD", json_key = "OP_REDUCE_PROD" }, + { name = "REDUCE_SUM", json_key = "OP_REDUCE_SUM" }, + { name = "PAD", json_key = "OP_PAD" }, + { name = "SHAPE", json_key = "OP_SHAPE" }, + { name = "SIZE", json_key = "OP_SIZE" }, + { name = "TOPK", json_key = "OP_TOPK" }, + { name = "WHERE", json_key = "OP_WHERE" }, + { name = "CEIL", json_key = "OP_CEIL" }, + { name = "CAST", json_key = "OP_CAST" }, + { name = "EXP", json_key = "OP_EXP" }, + { name = "ROUND", json_key = "OP_ROUND" }, + { name = "LOG", json_key = "OP_LOG" }, + { name = "LOGICAL_NOT", json_key = "OP_LOGICAL_NOT" }, + { name = "SQRT", json_key = "OP_SQRT" }, + { name = "SIN", json_key = "OP_SIN" }, + { name = "COS", json_key = "OP_COS" }, + { name = "LEAKYRELU", json_key = "OP_LEAKYRELU" }, + { name = "SLICE", json_key = "OP_SLICE" }, + { name = "RESIZE", json_key = "OP_RESIZE" }, + { name = "PRELU", json_key = "OP_PRELU" }, + { name = "GELU", json_key = "OP_GELU" }, + { name = "MULTIHEAD_ATTENTION", json_key = "OP_MULTIHEAD_ATTENTION" }, + { name = "FUSED", json_key = "OP_FUSED" }, + { name = "RSQRT", json_key = "OP_RSQRT" }, + { name = "POW", json_key = "OP_POW" }, + { name = "MEAN", json_key = "OP_MEAN" }, + { name = "LAYERNORM", json_key = "OP_LAYERNORM" }, + { name = "GATHER", json_key = "OP_GATHER" }, + { name = "BROADCAST", json_key = "OP_BROADCAST" }, + { name = "REPARTITION", json_key = "OP_PARTITION" }, + { name = "COMBINE", json_key = "OP_COMBINE" }, + { name = "REPLICATE", json_key = "OP_REPLICATE" }, + { name = "REDUCTION", json_key = "OP_REDUCE" }, + { name = "BATCH", json_key = "OP_BATCH" }, + { name = "PIPELINE", json_key = "OP_PIPELINE" }, + { name = "FUSED_PARALLEL", json_key = "OP_FUSED_PARALLEL" }, +] diff --git a/lib/substitution-generator/include/substitution-generator/legacy_operator_type.enum.toml b/lib/substitution-generator/include/substitution-generator/legacy_operator_type.enum.toml deleted file mode 100644 index 3f0bcccf6f..0000000000 --- a/lib/substitution-generator/include/substitution-generator/legacy_operator_type.enum.toml +++ /dev/null @@ -1,95 +0,0 @@ -namespace = "FlexFlow" -name = "LegacyOperatorType" - -features = [ - "hash", - "json", - "rapidcheck", - "fmt", -] - -values = [ - { name = "NOOP", json_key = "OP_NOOP" }, - { name = "INPUT", json_key = "OP_INPUT" }, - { name = "WEIGHT", json_key = "OP_WEIGHT" }, - { name = "CONV2D", json_key = "OP_CONV2D" }, - { name = "DROPOUT", json_key = "OP_DROPOUT" }, - { name = "LINEAR", json_key = "OP_LINEAR" }, - { name = "BATCHMATMUL", json_key = "OP_BATCHMATMUL" }, - { name = "POOL2D", json_key = "OP_POOL2D" }, - { name = "SCALAR_MULTIPLY", json_key = "OP_SCALAR_MULTIPLY" }, - { name = "SCALAR_ADD", json_key = "OP_SCALAR_ADD" }, - { name = "SCALAR_FLOOR_DIV", json_key = "OP_SCALAR_FLOOR_DIV" }, - { name = "SCALAR_TRUE_DIV", json_key = "OP_SCALAR_TRUE_DIV" }, - { name = "SCALAR_SUB", json_key = "OP_SCALAR_SUB" }, - { name = "RELU", json_key = "OP_RELU" }, - { name = "IDENTITY", json_key = "OP_IDENTITY" }, - { name = "SIGMOID", json_key = "OP_SIGMOID" }, - { name = "TANH", json_key = "OP_TANH" }, - { name = "ELU", json_key = "OP_ELU" }, - { name = "FLAT", json_key = "OP_FLAT" }, - { name = "SOFTMAX", json_key = "OP_SOFTMAX" }, - { name = "BATCHNORM", json_key = "OP_BATCHNORM" }, - { name = "CONCAT", json_key = "OP_CONCAT" }, - { name = "SPLIT", json_key = "OP_SPLIT" }, - { name = "EMBEDDING", json_key = "OP_EMBEDDING" }, - { name = "CACHE", json_key = "OP_CACHE" }, - { name = "RESHAPE", json_key = "OP_RESHAPE" }, - { name = "REVERSE", json_key = "OP_REVERSE" }, - { name = "TRANSPOSE", json_key = "OP_TRANSPOSE" }, - { name = "EW_ADD", json_key = "OP_EW_ADD" }, - { name = "EW_MUL", json_key = "OP_EW_MUL" }, - { name = "MATMUL", json_key = "OP_MATMUL" }, - { name = "MUL", json_key = "OP_MUL" }, - { name = "ENLARGE", json_key = "OP_ENLARGE" }, - { name = "SQUEEZE", json_key = "OP_SQUEEZE" }, - { name = "UNSQUEEZE", json_key = "OP_UNSQUEEZE" }, - { name = "EW_SUB", json_key = "OP_EW_SUB" }, - { name = "EW_DIV", json_key = "OP_EW_DIV" }, - { name = "EW_EQUAL", json_key = "OP_EW_EQUAL" }, - { name = "EW_GREATER", json_key = "OP_EW_GREATER" }, - { name = "EW_LESS", json_key = "OP_EW_LESS" }, - { name = "EW_MAX", json_key = "OP_EW_MAX" }, - { name = "EW_MIN", json_key = "OP_EW_MIN" }, - { name = "REDUCE_ARGMAX", json_key = "OP_REDUCE_ARGMAX" }, - { name = "REDUCE_ARGMIN", json_key = "OP_REDUCE_ARGMIN" }, - { name = "REDUCE_MAX", json_key = "OP_REDUCE_MAX" }, - { name = "REDUCE_MEAN", json_key = "OP_REDUCE_MEAN" }, - { name = "REDUCE_MIN", json_key = "OP_REDUCE_MIN" }, - { name = "REDUCE_PROD", json_key = "OP_REDUCE_PROD" }, - { name = "REDUCE_SUM", json_key = "OP_REDUCE_SUM" }, - { name = "PAD", json_key = "OP_PAD" }, - { name = "SHAPE", json_key = "OP_SHAPE" }, - { name = "SIZE", json_key = "OP_SIZE" }, - { name = "TOPK", json_key = "OP_TOPK" }, - { name = "WHERE", json_key = "OP_WHERE" }, - { name = "CEIL", json_key = "OP_CEIL" }, - { name = "CAST", json_key = "OP_CAST" }, - { name = "EXP", json_key = "OP_EXP" }, - { name = "ROUND", json_key = "OP_ROUND" }, - { name = "LOG", json_key = "OP_LOG" }, - { name = "LOGICAL_NOT", json_key = "OP_LOGICAL_NOT" }, - { name = "SQRT", json_key = "OP_SQRT" }, - { name = "SIN", json_key = "OP_SIN" }, - { name = "COS", json_key = "OP_COS" }, - { name = "LEAKYRELU", json_key = "OP_LEAKYRELU" }, - { name = "SLICE", json_key = "OP_SLICE" }, - { name = "RESIZE", json_key = "OP_RESIZE" }, - { name = "PRELU", json_key = "OP_PRELU" }, - { name = "GELU", json_key = "OP_GELU" }, - { name = "MULTIHEAD_ATTENTION", json_key = "OP_MULTIHEAD_ATTENTION" }, - { name = "FUSED", json_key = "OP_FUSED" }, - { name = "RSQRT", json_key = "OP_RSQRT" }, - { name = "POW", json_key = "OP_POW" }, - { name = "MEAN", json_key = "OP_MEAN" }, - { name = "LAYERNORM", json_key = "OP_LAYERNORM" }, - { name = "GATHER", json_key = "OP_GATHER" }, - { name = "BROADCAST", json_key = "OP_BROADCAST" }, - { name = "REPARTITION", json_key = "OP_PARTITION" }, - { name = "COMBINE", json_key = "OP_COMBINE" }, - { name = "REPLICATE", json_key = "OP_REPLICATE" }, - { name = "REDUCTION", json_key = "OP_REDUCE" }, - { name = "BATCH", json_key = "OP_BATCH" }, - { name = "PIPELINE", json_key = "OP_PIPELINE" }, - { name = "FUSED_PARALLEL", json_key = "OP_FUSED_PARALLEL" }, -] diff --git a/lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.dtg.toml b/lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.dtg.toml new file mode 100644 index 0000000000..e51216233f --- /dev/null +++ b/lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.dtg.toml @@ -0,0 +1,45 @@ +namespace = "FlexFlow" +name = "LegacyPMParameter" +type = "enum" + +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +values = [ + { name = "OP_TYPE", json_key = "PM_OP_TYPE" }, + { name = "NUM_INPUTS", json_key = "PM_NUM_INPUTS" }, + { name = "NUM_OUTPUTS", json_key = "PM_NUM_OUTPUTS" }, + { name = "GROUP", json_key = "PM_GROUP" }, + { name = "KERNEL_H", json_key = "PM_KERNEL_H" }, + { name = "KERNEL_W", json_key = "PM_KERNEL_W" }, + { name = "STRIDE_H", json_key = "PM_STRIDE_H" }, + { name = "STRIDE_W", json_key = "PM_STRIDE_W" }, + { name = "PADDING_H", json_key = "PM_PADDING_H" }, + { name = "PADDING_W", json_key = "PM_PADDING_W" }, + { name = "ACTI", json_key = "PM_ACTI" }, + { name = "NUMDIM", json_key = "PM_NUMDIM" }, + { name = "AXIS", json_key = "PM_AXIS" }, + { name = "PERM", json_key = "PM_PERM" }, + { name = "OUTSHUFFLE", json_key = "PM_OUTSHUFFLE" }, + { name = "MERGE_GCONV_COUNT", json_key = "PM_MERGE_GCONV_COUNT" }, + { name = "AXES", json_key = "PM_AXES" }, + { name = "KEEP_DIMS", json_key = "PM_KEEP_DIMS" }, + { name = "EPSILON", json_key = "PM_EPSILON" }, + { name = "REPARTITION_DIM", json_key = "PM_REPARTITION_DIM" }, + { name = "REPARTITION_DEGREE", json_key = "PM_REPARTITION_DEGREE" }, + { name = "REPLICATE_DIM", json_key = "PM_REPLICATE_DIM" }, + { name = "REPLICATE_DEGREE", json_key = "PM_REPLICATE_DEGREE" }, + { name = "COMBINE_DIM", json_key = "PM_COMBINE_DIM" }, + { name = "COMBINE_DEGREE", json_key = "PM_COMBINE_DEGREE" }, + { name = "REDUCTION_DIM", json_key = "PM_REDUCTION_DIM" }, + { name = "REDUCTION_DEGREE", json_key = "PM_REDUCTION_DEGREE" }, + { name = "SOFTMAX_DIM", json_key = "PM_SOFTMAX_DIM" }, + { name = "NUM_HEADS", json_key = "PM_NUM_HEADS" }, + { name = "PARALLEL_DIM", json_key = "PM_PARALLEL_DIM" }, + { name = "PARALLEL_DEGREE", json_key = "PM_PARALLEL_DEGREE" }, + { name = "PAD", json_key = "PM_PAD" }, +] diff --git a/lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.enum.toml b/lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.enum.toml deleted file mode 100644 index e71a71a5a8..0000000000 --- a/lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.enum.toml +++ /dev/null @@ -1,44 +0,0 @@ -namespace = "FlexFlow" -name = "LegacyPMParameter" - -features = [ - "hash", - "json", - "rapidcheck", - "fmt", -] - -values = [ - { name = "OP_TYPE", json_key = "PM_OP_TYPE" }, - { name = "NUM_INPUTS", json_key = "PM_NUM_INPUTS" }, - { name = "NUM_OUTPUTS", json_key = "PM_NUM_OUTPUTS" }, - { name = "GROUP", json_key = "PM_GROUP" }, - { name = "KERNEL_H", json_key = "PM_KERNEL_H" }, - { name = "KERNEL_W", json_key = "PM_KERNEL_W" }, - { name = "STRIDE_H", json_key = "PM_STRIDE_H" }, - { name = "STRIDE_W", json_key = "PM_STRIDE_W" }, - { name = "PADDING_H", json_key = "PM_PADDING_H" }, - { name = "PADDING_W", json_key = "PM_PADDING_W" }, - { name = "ACTI", json_key = "PM_ACTI" }, - { name = "NUMDIM", json_key = "PM_NUMDIM" }, - { name = "AXIS", json_key = "PM_AXIS" }, - { name = "PERM", json_key = "PM_PERM" }, - { name = "OUTSHUFFLE", json_key = "PM_OUTSHUFFLE" }, - { name = "MERGE_GCONV_COUNT", json_key = "PM_MERGE_GCONV_COUNT" }, - { name = "AXES", json_key = "PM_AXES" }, - { name = "KEEP_DIMS", json_key = "PM_KEEP_DIMS" }, - { name = "EPSILON", json_key = "PM_EPSILON" }, - { name = "REPARTITION_DIM", json_key = "PM_REPARTITION_DIM" }, - { name = "REPARTITION_DEGREE", json_key = "PM_REPARTITION_DEGREE" }, - { name = "REPLICATE_DIM", json_key = "PM_REPLICATE_DIM" }, - { name = "REPLICATE_DEGREE", json_key = "PM_REPLICATE_DEGREE" }, - { name = "COMBINE_DIM", json_key = "PM_COMBINE_DIM" }, - { name = "COMBINE_DEGREE", json_key = "PM_COMBINE_DEGREE" }, - { name = "REDUCTION_DIM", json_key = "PM_REDUCTION_DIM" }, - { name = "REDUCTION_DEGREE", json_key = "PM_REDUCTION_DEGREE" }, - { name = "SOFTMAX_DIM", json_key = "PM_SOFTMAX_DIM" }, - { name = "NUM_HEADS", json_key = "PM_NUM_HEADS" }, - { name = "PARALLEL_DIM", json_key = "PM_PARALLEL_DIM" }, - { name = "PARALLEL_DEGREE", json_key = "PM_PARALLEL_DEGREE" }, - { name = "PAD", json_key = "PM_PAD" }, -] diff --git a/lib/substitution-generator/test/CMakeLists.txt b/lib/substitution-generator/test/CMakeLists.txt index a7374cdf78..166c7ab51f 100644 --- a/lib/substitution-generator/test/CMakeLists.txt +++ b/lib/substitution-generator/test/CMakeLists.txt @@ -8,7 +8,7 @@ ff_add_test_executable( NAME substitution-generator-tests SRC_PATTERNS - substitution-generator/*.cc + src/*.cc PRIVATE_INCLUDE src/ DEPS diff --git a/lib/substitution-generator/test/substitution-generator/legacy_rules.cc b/lib/substitution-generator/test/src/substitution-generator/legacy_rules.cc similarity index 97% rename from lib/substitution-generator/test/substitution-generator/legacy_rules.cc rename to lib/substitution-generator/test/src/substitution-generator/legacy_rules.cc index 4dd9bb8cc4..19102d9670 100644 --- a/lib/substitution-generator/test/substitution-generator/legacy_rules.cc +++ b/lib/substitution-generator/test/src/substitution-generator/legacy_rules.cc @@ -1,5 +1,5 @@ #include "substitution-generator/legacy_rules.h" -#include "doctest/doctest.h" +#include using namespace FlexFlow; using nlohmann::json; diff --git a/lib/substitutions/include/substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.dtg.toml b/lib/substitutions/include/substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.dtg.toml new file mode 100644 index 0000000000..f961ba30c7 --- /dev/null +++ b/lib/substitutions/include/substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "OutputExprToResultSubPCGMapping" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/bidict/bidict.h", + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "substitutions/input_parallel_tensor_guid_t.dtg.h", + "substitutions/output_graph/output_graph_expr_node.dtg.h", + "substitutions/output_graph/output_graph_expr_input.dtg.h", +] + +[[fields]] +name = "node_mapping" +type = "::FlexFlow::bidict<::FlexFlow::parallel_layer_guid_t, ::FlexFlow::OutputGraphExprNode>" + +[[fields]] +name = "input_mapping" +type = "::FlexFlow::bidict<::FlexFlow::input_parallel_tensor_guid_t, ::FlexFlow::OutputGraphExprInput>" diff --git a/lib/substitutions/include/substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.struct.toml b/lib/substitutions/include/substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.struct.toml deleted file mode 100644 index 1fac79a91d..0000000000 --- a/lib/substitutions/include/substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.struct.toml +++ /dev/null @@ -1,23 +0,0 @@ -namespace = "FlexFlow" -name = "OutputExprToResultSubPCGMapping" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "utils/bidict/bidict.h", - "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", - "substitutions/input_parallel_tensor_guid_t.dtg.h", - "substitutions/output_graph/output_graph_expr_node.dtg.h", - "substitutions/output_graph/output_graph_expr_input.dtg.h", -] - -[[fields]] -name = "node_mapping" -type = "::FlexFlow::bidict<::FlexFlow::parallel_layer_guid_t, ::FlexFlow::OutputGraphExprNode>" - -[[fields]] -name = "input_mapping" -type = "::FlexFlow::bidict<::FlexFlow::input_parallel_tensor_guid_t, ::FlexFlow::OutputGraphExprInput>" diff --git a/lib/substitutions/include/substitutions/apply_substitution/perform_shape_inference.h b/lib/substitutions/include/substitutions/apply_substitution/perform_shape_inference.h index c3f9eff349..c3ebc0f77f 100644 --- a/lib/substitutions/include/substitutions/apply_substitution/perform_shape_inference.h +++ b/lib/substitutions/include/substitutions/apply_substitution/perform_shape_inference.h @@ -2,8 +2,10 @@ #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_APPLY_SUBSTITUTION_PERFORM_SHAPE_INFERENCE_H #include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" -#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph_view.h" +#include "utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_graph_input.dtg.h" namespace FlexFlow { @@ -22,12 +24,17 @@ namespace FlexFlow { * Exists only to enable apply_substitution(SubParallelComputationGraph const &, * Substitution const &, PCGPatternMatch const &) */ -LabelledOpenDataflowGraphView +LabelledOpenKwargDataflowGraphView perform_shape_inference( - LabelledOpenDataflowGraphView const - &g, - std::unordered_map const - &input_shapes); + LabelledOpenKwargDataflowGraphView const &g, + std::unordered_map, + ParallelTensorShape> const &input_shapes); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/constraint_type.dtg.toml b/lib/substitutions/include/substitutions/constraint_type.dtg.toml new file mode 100644 index 0000000000..06da6136b0 --- /dev/null +++ b/lib/substitutions/include/substitutions/constraint_type.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "ConstraintType" +type = "enum" +features = [ + "json", + "hash", + "rapidcheck", + "fmt", +] + +[[values]] +name = "EQUAL" + +[[values]] +name = "DIVISIBLE_BY" diff --git a/lib/substitutions/include/substitutions/constraint_type.enum.toml b/lib/substitutions/include/substitutions/constraint_type.enum.toml deleted file mode 100644 index f366a17725..0000000000 --- a/lib/substitutions/include/substitutions/constraint_type.enum.toml +++ /dev/null @@ -1,14 +0,0 @@ -namespace = "FlexFlow" -name = "ConstraintType" -features = [ - "json", - "hash", - "rapidcheck", - "fmt", -] - -[[values]] -name = "EQUAL" - -[[values]] -name = "DIVISIBLE_BY" diff --git a/lib/substitutions/include/substitutions/input_parallel_tensor_guid_t.dtg.toml b/lib/substitutions/include/substitutions/input_parallel_tensor_guid_t.dtg.toml new file mode 100644 index 0000000000..0d79274d12 --- /dev/null +++ b/lib/substitutions/include/substitutions/input_parallel_tensor_guid_t.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "input_parallel_tensor_guid_t" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_graph_input.dtg.h", +] + +[[fields]] +name = "raw_dataflow_graph_input" +type = "::FlexFlow::KwargDataflowGraphInput" diff --git a/lib/substitutions/include/substitutions/input_parallel_tensor_guid_t.struct.toml b/lib/substitutions/include/substitutions/input_parallel_tensor_guid_t.struct.toml deleted file mode 100644 index dd2e850aed..0000000000 --- a/lib/substitutions/include/substitutions/input_parallel_tensor_guid_t.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "input_parallel_tensor_guid_t" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", -] - -[[fields]] -name = "raw_dataflow_graph_input" -type = "::FlexFlow::DataflowGraphInput" diff --git a/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.dtg.toml b/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.dtg.toml new file mode 100644 index 0000000000..75bfd04e86 --- /dev/null +++ b/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "open_parallel_tensor_guid_t" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_value.dtg.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +[[fields]] +name = "raw_open_dataflow_value" +type = "::FlexFlow::OpenKwargDataflowValue" diff --git a/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.h b/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.h index ad60d50db1..3fca3050df 100644 --- a/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.h +++ b/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.h @@ -17,8 +17,10 @@ template > Ret visit_open_parallel_tensor_guid(open_parallel_tensor_guid_t t, F f) { return t.raw_open_dataflow_value.visit(overload{ - [&](DataflowOutput const &o) { return f(parallel_tensor_guid_t{o}); }, - [&](DataflowGraphInput const &i) { + [&](KwargDataflowOutput const &o) { + return f(parallel_tensor_guid_t{o}); + }, + [&](KwargDataflowGraphInput const &i) { return f(input_parallel_tensor_guid_t{i}); }, }); diff --git a/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.struct.toml b/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.struct.toml deleted file mode 100644 index f07dc12d62..0000000000 --- a/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "open_parallel_tensor_guid_t" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" -] - -[[fields]] -name = "raw_open_dataflow_value" -type = "::FlexFlow::OpenDataflowValue" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.dtg.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.dtg.toml new file mode 100644 index 0000000000..8c6223050c --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.dtg.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "OperatorAttributeConstraint" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "substitutions/constraint_type.dtg.h", + "substitutions/operator_pattern/operator_attribute_expr.dtg.h", + "substitutions/operator_pattern/operator_attribute_value.dtg.h", +] + +[[fields]] +name = "constraint_type" +type = "::FlexFlow::ConstraintType" + +[[fields]] +name = "attribute_expr" +type = "::FlexFlow::OperatorAttributeExpr" + +[[fields]] +name = "attribute_value" +type = "::FlexFlow::OperatorAttributeValue" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.struct.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.struct.toml deleted file mode 100644 index 646faf878e..0000000000 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.struct.toml +++ /dev/null @@ -1,28 +0,0 @@ -namespace = "FlexFlow" -name = "OperatorAttributeConstraint" -features = [ - "eq", - "ord", - "hash", - "json", - # "rapidcheck", - "fmt", -] - -includes = [ - "substitutions/constraint_type.dtg.h", - "substitutions/operator_pattern/operator_attribute_expr.dtg.h", - "substitutions/operator_pattern/operator_attribute_value.dtg.h", -] - -[[fields]] -name = "constraint_type" -type = "::FlexFlow::ConstraintType" - -[[fields]] -name = "attribute_expr" -type = "::FlexFlow::OperatorAttributeExpr" - -[[fields]] -name = "attribute_value" -type = "::FlexFlow::OperatorAttributeValue" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.dtg.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.dtg.toml new file mode 100644 index 0000000000..34f212cdae --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.dtg.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "OperatorAttributeExpr" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "substitutions/operator_pattern/operator_attribute_key.dtg.h", + "substitutions/operator_pattern/operator_attribute_list_access.dtg.h", + "substitutions/operator_pattern/operator_attribute_list_size.dtg.h", +] + +[[values]] +type = "::FlexFlow::OperatorAttributeKey" +key = "key" + +[[values]] +type = "::FlexFlow::OperatorAttributeListSize" +key = "list_size" + +[[values]] +type = "::FlexFlow::OperatorAttributeListIndexAccess" +key = "list_idx" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.variant.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.variant.toml deleted file mode 100644 index ff79ecaaa5..0000000000 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.variant.toml +++ /dev/null @@ -1,27 +0,0 @@ -namespace = "FlexFlow" -name = "OperatorAttributeExpr" -features = [ - "eq", - "ord", - "hash", - "json", - "fmt", -] - -includes = [ - "substitutions/operator_pattern/operator_attribute_key.dtg.h", - "substitutions/operator_pattern/operator_attribute_list_access.dtg.h", - "substitutions/operator_pattern/operator_attribute_list_size.dtg.h", -] - -[[values]] -type = "::FlexFlow::OperatorAttributeKey" -key = "key" - -[[values]] -type = "::FlexFlow::OperatorAttributeListSize" -key = "list_size" - -[[values]] -type = "::FlexFlow::OperatorAttributeListIndexAccess" -key = "list_idx" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.dtg.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.dtg.toml new file mode 100644 index 0000000000..14cf9b6c34 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.dtg.toml @@ -0,0 +1,70 @@ +namespace = "FlexFlow" +name = "OperatorAttributeKey" +type = "enum" +features = [ + "json", + "hash", + "fmt", + "rapidcheck", +] + +values = [ + { name = "OP_TYPE" }, + { name = "USE_BIAS" }, + { name = "GROUPS" }, + { name = "POOL_TYPE" }, + { name = "KERNEL_H" }, + { name = "KERNEL_W" }, + { name = "DATA_TYPE" }, + { name = "SCALAR" }, + { name = "STRIDE_H" }, + { name = "STRIDE_W" }, + { name = "PADDING_H" }, + { name = "PADDING_W" }, + { name = "AGGR" }, + { name = "NUM_ENTRIES" }, + { name = "OUT_CHANNELS" }, + { name = "ACTIVATION" }, + { name = "NUMDIM" }, + { name = "AXIS" }, + { name = "PERMUTATION" }, + { name = "OUTSHUFFLE" }, + { name = "MERGE_GCONV_COUNT" }, + { name = "AXES" }, + { name = "KEEP_DIMS" }, + { name = "EPSILON" }, + { name = "PARALLEL_OP_DIM" }, + { name = "PARALLEL_OP_DEGREE" }, + { name = "SOFTMAX_DIM" }, + { name = "NUM_HEADS" }, + { name = "PARALLEL_DIM" }, + { name = "PARALLEL_DEGREE" }, + { name = "PAD" }, + { name = "EMBED_DIM" }, + { name = "KDIM" }, + { name = "VDIM" }, + { name = "DROPOUT" }, + { name = "BIAS" }, + { name = "ADD_BIAS_KV" }, + { name = "ADD_ZERO_ATTN" }, + { name = "A_SEQ_LENGTH_DIM" }, + { name = "B_SEQ_LENGTH_DIM" }, + { name = "RELU" }, + { name = "TARGET_DIMS" }, + { name = "RATE" }, + { name = "SEED" }, + { name = "SHOULD_BROADCAST_LHS" }, + { name = "SHOULD_BROADCAST_RHS" }, + { name = "DIM" }, + { name = "AFFINE" }, + { name = "ELEMENTWISE_AFFINE" }, + { name = "MOMENTUM" }, + { name = "REGULARIZER" }, + { name = "SHAPE" }, + { name = "SPLITS" }, + { name = "K" }, + { name = "SORTED" }, + { name = "COMBINE_DIM" }, + { name = "COMBINE_DEGREE" }, + { name = "NUM_INPUTS" }, +] diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml deleted file mode 100644 index af3666d46f..0000000000 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml +++ /dev/null @@ -1,69 +0,0 @@ -namespace = "FlexFlow" -name = "OperatorAttributeKey" -features = [ - "json", - "hash", - "fmt", - "rapidcheck", -] - -values = [ - { name = "OP_TYPE" }, - { name = "USE_BIAS" }, - { name = "GROUPS" }, - { name = "POOL_TYPE" }, - { name = "KERNEL_H" }, - { name = "KERNEL_W" }, - { name = "DATA_TYPE" }, - { name = "SCALAR" }, - { name = "STRIDE_H" }, - { name = "STRIDE_W" }, - { name = "PADDING_H" }, - { name = "PADDING_W" }, - { name = "AGGR" }, - { name = "NUM_ENTRIES" }, - { name = "OUT_CHANNELS" }, - { name = "ACTIVATION" }, - { name = "NUMDIM" }, - { name = "AXIS" }, - { name = "PERMUTATION" }, - { name = "OUTSHUFFLE" }, - { name = "MERGE_GCONV_COUNT" }, - { name = "AXES" }, - { name = "KEEP_DIMS" }, - { name = "EPSILON" }, - { name = "PARALLEL_OP_DIM" }, - { name = "PARALLEL_OP_DEGREE" }, - { name = "SOFTMAX_DIM" }, - { name = "NUM_HEADS" }, - { name = "PARALLEL_DIM" }, - { name = "PARALLEL_DEGREE" }, - { name = "PAD" }, - { name = "EMBED_DIM" }, - { name = "KDIM" }, - { name = "VDIM" }, - { name = "DROPOUT" }, - { name = "BIAS" }, - { name = "ADD_BIAS_KV" }, - { name = "ADD_ZERO_ATTN" }, - { name = "A_SEQ_LENGTH_DIM" }, - { name = "B_SEQ_LENGTH_DIM" }, - { name = "RELU" }, - { name = "TARGET_DIMS" }, - { name = "RATE" }, - { name = "SEED" }, - { name = "SHOULD_BROADCAST_LHS" }, - { name = "SHOULD_BROADCAST_RHS" }, - { name = "DIM" }, - { name = "AFFINE" }, - { name = "ELEMENTWISE_AFFINE" }, - { name = "MOMENTUM" }, - { name = "REGULARIZER" }, - { name = "SHAPE" }, - { name = "SPLITS" }, - { name = "K" }, - { name = "SORTED" }, - { name = "COMBINE_DIM" }, - { name = "COMBINE_DEGREE" }, - { name = "NUM_INPUTS" }, -] diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.dtg.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.dtg.toml new file mode 100644 index 0000000000..a0bff04d2f --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "OperatorAttributeListIndexAccess" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "rapidcheck", + "json", + "fmt", +] + +includes = [ + "substitutions/operator_pattern/operator_attribute_key.dtg.h", + "utils/nonnegative_int/nonnegative_int.h", +] + +[[fields]] +name = "attribute_key" +type = "::FlexFlow::OperatorAttributeKey" + +[[fields]] +name = "index" +type = "::FlexFlow::nonnegative_int" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.struct.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.struct.toml deleted file mode 100644 index 4ed226907e..0000000000 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.struct.toml +++ /dev/null @@ -1,23 +0,0 @@ -namespace = "FlexFlow" -name = "OperatorAttributeListIndexAccess" -features = [ - "eq", - "ord", - "hash", - "rapidcheck", - "json", - "fmt", -] - -includes = [ - "substitutions/operator_pattern/operator_attribute_key.dtg.h", - "utils/nonnegative_int/nonnegative_int.h", -] - -[[fields]] -name = "attribute_key" -type = "::FlexFlow::OperatorAttributeKey" - -[[fields]] -name = "index" -type = "::FlexFlow::nonnegative_int" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.dtg.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.dtg.toml new file mode 100644 index 0000000000..e61e545839 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.dtg.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "OperatorAttributeListSize" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "rapidcheck", + "json", + "fmt", +] + +includes = [ + "substitutions/operator_pattern/operator_attribute_key.dtg.h", +] + + +[[fields]] +name = "attribute_key" +type = "::FlexFlow::OperatorAttributeKey" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.struct.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.struct.toml deleted file mode 100644 index 271b545fda..0000000000 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.struct.toml +++ /dev/null @@ -1,19 +0,0 @@ -namespace = "FlexFlow" -name = "OperatorAttributeListSize" -features = [ - "eq", - "ord", - "hash", - "rapidcheck", - "json", - "fmt", -] - -includes = [ - "substitutions/operator_pattern/operator_attribute_key.dtg.h", -] - - -[[fields]] -name = "attribute_key" -type = "::FlexFlow::OperatorAttributeKey" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.dtg.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.dtg.toml new file mode 100644 index 0000000000..44dfd9e61d --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.dtg.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "OperatorAttributePattern" +type = "struct" +features = [ + "eq", + # "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "", + "utils/fmt/unordered_set.h", + "substitutions/operator_pattern/operator_attribute_constraint.dtg.h", + "utils/hash/unordered_set.h", +] + +[[fields]] +name = "attribute_constraints" +type = "std::unordered_set<::FlexFlow::OperatorAttributeConstraint>" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.struct.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.struct.toml deleted file mode 100644 index 8b7797af99..0000000000 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.struct.toml +++ /dev/null @@ -1,21 +0,0 @@ -namespace = "FlexFlow" -name = "OperatorAttributePattern" -features = [ - "eq", - # "ord", - "hash", - "json", - # "rapidcheck", - "fmt", -] - -includes = [ - "", - "utils/fmt/unordered_set.h", - "substitutions/operator_pattern/operator_attribute_constraint.dtg.h", - "utils/hash/unordered_set.h", -] - -[[fields]] -name = "attribute_constraints" -type = "std::unordered_set<::FlexFlow::OperatorAttributeConstraint>" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.dtg.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.dtg.toml new file mode 100644 index 0000000000..f440a9f90b --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.dtg.toml @@ -0,0 +1,85 @@ +namespace = "FlexFlow" +name = "OperatorAttributeValue" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", +] + +includes = [ + "", + "", + "op-attrs/operator_type.dtg.h", + "op-attrs/ff_dim_t.dtg.h", + "op-attrs/activation.dtg.h", + "op-attrs/aggregate_op.dtg.h", + "op-attrs/regularizer_attrs.dtg.h", + "op-attrs/pool_op.dtg.h", + "op-attrs/tensor_shape.dtg.h", + "op-attrs/datatype.dtg.h", + "", + "utils/nonnegative_int/nonnegative_int.h", + "utils/positive_int/positive_int.h", + "op-attrs/tensor_dim_permutation.h", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[values]] +type = "::FlexFlow::nonnegative_int" + +[[values]] +type = "::FlexFlow::positive_int" + +[[values]] +type = "bool" + +[[values]] +type = "float" + +[[values]] +type = "std::optional" + +[[values]] +type = "std::vector<::FlexFlow::positive_int>" + +[[values]] +type = "std::vector<::FlexFlow::ff_dim_t>" + +[[values]] +type = "::FlexFlow::OperatorType" + +[[values]] +type = "std::optional<::FlexFlow::Activation>" + +[[values]] +type = "::FlexFlow::ff_dim_t" + +[[values]] +type = "std::optional<::FlexFlow::AggregateOp>" + +[[values]] +type = "std::optional<::FlexFlow::RegularizerAttrs>" + +[[values]] +type = "::FlexFlow::PoolOp" + +[[values]] +type = "::FlexFlow::TensorShape" + +[[values]] +type = "::FlexFlow::TensorDims" + +[[values]] +type = "::FlexFlow::DataType" + +[[values]] +type = "::FlexFlow::TensorDimPermutation" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml deleted file mode 100644 index 1994d54f38..0000000000 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml +++ /dev/null @@ -1,80 +0,0 @@ -namespace = "FlexFlow" -name = "OperatorAttributeValue" -features = [ - "eq", - "ord", - "hash", - "fmt", - "json", -] - -includes = [ - "", - "", - "op-attrs/operator_type.dtg.h", - "op-attrs/ff_dim_t.dtg.h", - "op-attrs/activation.dtg.h", - "op-attrs/aggregate_op.dtg.h", - "op-attrs/regularizer_attrs.dtg.h", - "op-attrs/pool_op.dtg.h", - "op-attrs/tensor_shape.dtg.h", - "op-attrs/datatype.dtg.h", - "", - "utils/nonnegative_int/nonnegative_int.h", - "utils/positive_int/positive_int.h", -] - -src_includes = [ - "utils/fmt/optional.h", - "utils/json/optional.h", - "utils/fmt/vector.h", - "utils/hash/vector.h", -] - -[[values]] -type = "::FlexFlow::nonnegative_int" - -[[values]] -type = "::FlexFlow::positive_int" - -[[values]] -type = "bool" - -[[values]] -type = "float" - -[[values]] -type = "std::optional" - -[[values]] -type = "std::vector<::FlexFlow::positive_int>" - -[[values]] -type = "std::vector<::FlexFlow::ff_dim_t>" - -[[values]] -type = "::FlexFlow::OperatorType" - -[[values]] -type = "std::optional<::FlexFlow::Activation>" - -[[values]] -type = "::FlexFlow::ff_dim_t" - -[[values]] -type = "std::optional<::FlexFlow::AggregateOp>" - -[[values]] -type = "std::optional<::FlexFlow::RegularizerAttrs>" - -[[values]] -type = "::FlexFlow::PoolOp" - -[[values]] -type = "::FlexFlow::TensorShape" - -[[values]] -type = "::FlexFlow::TensorDims" - -[[values]] -type = "::FlexFlow::DataType" diff --git a/lib/substitutions/include/substitutions/output_graph/attr_constant.dtg.toml b/lib/substitutions/include/substitutions/output_graph/attr_constant.dtg.toml new file mode 100644 index 0000000000..f02730b653 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/attr_constant.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "AttrConstant" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "substitutions/operator_pattern/operator_attribute_value.dtg.h", +] + +[[fields]] +name = "value" +type = "::FlexFlow::OperatorAttributeValue" diff --git a/lib/substitutions/include/substitutions/output_graph/attr_constant.struct.toml b/lib/substitutions/include/substitutions/output_graph/attr_constant.struct.toml deleted file mode 100644 index 68973f9c0c..0000000000 --- a/lib/substitutions/include/substitutions/output_graph/attr_constant.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "AttrConstant" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "substitutions/operator_pattern/operator_attribute_value.dtg.h", -] - -[[fields]] -name = "value" -type = "::FlexFlow::OperatorAttributeValue" diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr.dtg.toml b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.dtg.toml new file mode 100644 index 0000000000..3eac8b0858 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "OutputGraphExpr" +type = "struct" +features = [] + +includes = [ + "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph.h", + "substitutions/output_graph/output_operator_attrs_assignment.dtg.h", + "", + "op-attrs/tensor_slot_name.dtg.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::LabelledOpenKwargDataflowGraph<::FlexFlow::OutputOperatorAttrsAssignment, std::monostate, int, ::FlexFlow::TensorSlotName>" diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr.h b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.h index 8c047fc44d..e5a897330f 100644 --- a/lib/substitutions/include/substitutions/output_graph/output_graph_expr.h +++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.h @@ -10,7 +10,7 @@ namespace FlexFlow { std::unordered_set get_nodes(OutputGraphExpr const &); -std::vector +std::unordered_map get_node_outputs(OutputGraphExpr const &, OutputGraphExprNode const &); std::unordered_set get_inputs(OutputGraphExpr const &); diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml deleted file mode 100644 index 9ad65369a9..0000000000 --- a/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml +++ /dev/null @@ -1,13 +0,0 @@ -namespace = "FlexFlow" -name = "OutputGraphExpr" -features = [] - -includes = [ - "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h", - "substitutions/output_graph/output_operator_attrs_assignment.dtg.h", - "", -] - -[[fields]] -name = "raw_graph" -type = "::FlexFlow::LabelledOpenDataflowGraph<::FlexFlow::OutputOperatorAttrsAssignment, std::monostate>" diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr_input.dtg.toml b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_input.dtg.toml new file mode 100644 index 0000000000..f00f4cd78f --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_input.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "OutputGraphExprInput" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_graph_input.dtg.h", +] + +[[fields]] +name = "raw_dataflow_graph_input" +type = "::FlexFlow::KwargDataflowGraphInput" diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr_input.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_input.struct.toml deleted file mode 100644 index fe7a861f0a..0000000000 --- a/lib/substitutions/include/substitutions/output_graph/output_graph_expr_input.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "OutputGraphExprInput" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", -] - -[[fields]] -name = "raw_dataflow_graph_input" -type = "::FlexFlow::DataflowGraphInput" diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node.dtg.toml b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node.dtg.toml new file mode 100644 index 0000000000..0c2ece3530 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "OutputGraphExprNode" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h" +] + +[[fields]] +name = "raw_graph_node" +type = "::FlexFlow::Node" diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node.struct.toml deleted file mode 100644 index 37c2a1f563..0000000000 --- a/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "OutputGraphExprNode" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/node/node.dtg.h" -] - -[[fields]] -name = "raw_graph_node" -type = "::FlexFlow::Node" diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node_output.dtg.toml b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node_output.dtg.toml new file mode 100644 index 0000000000..8cd881fb4c --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node_output.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "OutputGraphExprNodeOutput" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output.dtg.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +[[fields]] +name = "raw_dataflow_output" +type = "::FlexFlow::KwargDataflowOutput<::FlexFlow::TensorSlotName>" diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node_output.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node_output.struct.toml deleted file mode 100644 index 7a2072e385..0000000000 --- a/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node_output.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "OutputGraphExprNodeOutput" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/dataflow_graph/dataflow_output.dtg.h", -] - -[[fields]] -name = "raw_dataflow_output" -type = "::FlexFlow::DataflowOutput" diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr_value.dtg.toml b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_value.dtg.toml new file mode 100644 index 0000000000..40f4d999f3 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_value.dtg.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "OutputGraphExprValue" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "substitutions/output_graph/output_graph_expr_input.dtg.h", + "substitutions/output_graph/output_graph_expr_node_output.dtg.h", +] + +[[values]] +type = "::FlexFlow::OutputGraphExprNodeOutput" + +[[values]] +type = "::FlexFlow::OutputGraphExprInput" diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr_value.h b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_value.h index e172edb025..9a7bb05267 100644 --- a/lib/substitutions/include/substitutions/output_graph/output_graph_expr_value.h +++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_value.h @@ -2,14 +2,15 @@ #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_GRAPH_EXPR_VALUE_H #include "substitutions/output_graph/output_graph_expr_value.dtg.h" -#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_value.dtg.h" namespace FlexFlow { -OpenDataflowValue raw_open_dataflow_value_from_output_graph_expr_value( - OutputGraphExprValue const &); -OutputGraphExprValue output_graph_expr_value_from_raw_open_dataflow_value( - OpenDataflowValue const &); +OpenKwargDataflowValue + raw_open_kwarg_dataflow_value_from_output_graph_expr_value( + OutputGraphExprValue const &); +OutputGraphExprValue output_graph_expr_value_from_raw_open_kwarg_dataflow_value( + OpenKwargDataflowValue const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr_value.variant.toml b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_value.variant.toml deleted file mode 100644 index 641250e1f0..0000000000 --- a/lib/substitutions/include/substitutions/output_graph/output_graph_expr_value.variant.toml +++ /dev/null @@ -1,19 +0,0 @@ -namespace = "FlexFlow" -name = "OutputGraphExprValue" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "substitutions/output_graph/output_graph_expr_input.dtg.h", - "substitutions/output_graph/output_graph_expr_node_output.dtg.h", -] - -[[values]] -type = "::FlexFlow::OutputGraphExprNodeOutput" - -[[values]] -type = "::FlexFlow::OutputGraphExprInput" diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.dtg.toml b/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.dtg.toml new file mode 100644 index 0000000000..9c08bba621 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "OutputOperatorAttrAccess" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "substitutions/unlabelled/pattern_node.dtg.h", + "substitutions/operator_pattern/operator_attribute_expr.dtg.h", +] + +[[fields]] +name = "node" +type = "::FlexFlow::PatternNode" + +# NOTE(@wmdi) I am not sure whether these should be part of attribute expr. +[[fields]] +name = "attr_expr" +type = "::FlexFlow::OperatorAttributeExpr" + diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml deleted file mode 100644 index e856249e50..0000000000 --- a/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml +++ /dev/null @@ -1,23 +0,0 @@ -namespace = "FlexFlow" -name = "OutputOperatorAttrAccess" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "substitutions/unlabelled/pattern_node.dtg.h", - "substitutions/operator_pattern/operator_attribute_expr.dtg.h", -] - -[[fields]] -name = "node" -type = "::FlexFlow::PatternNode" - -# NOTE(@wmdi) I am not sure whether these should be part of attribute expr. -[[fields]] -name = "attr_expr" -type = "::FlexFlow::OperatorAttributeExpr" - diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.dtg.toml b/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.dtg.toml new file mode 100644 index 0000000000..dee7f2e1c2 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.dtg.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "OutputOperatorAttributeExpr" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "substitutions/output_graph/attr_constant.dtg.h", + "substitutions/output_graph/output_operator_attr_access.dtg.h", +] + +[[values]] +type = "::FlexFlow::OutputOperatorAttrAccess" +key = "attr_ref" + +[[values]] +type = "::FlexFlow::AttrConstant" +key = "constant" diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.variant.toml b/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.variant.toml deleted file mode 100644 index 19810a0151..0000000000 --- a/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.variant.toml +++ /dev/null @@ -1,21 +0,0 @@ -namespace = "FlexFlow" -name = "OutputOperatorAttributeExpr" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "substitutions/output_graph/attr_constant.dtg.h", - "substitutions/output_graph/output_operator_attr_access.dtg.h", -] - -[[values]] -type = "::FlexFlow::OutputOperatorAttrAccess" -key = "attr_ref" - -[[values]] -type = "::FlexFlow::AttrConstant" -key = "constant" diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.dtg.toml b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.dtg.toml new file mode 100644 index 0000000000..ca613dad91 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.dtg.toml @@ -0,0 +1,33 @@ +namespace = "FlexFlow" +name = "OutputOperatorAttrsAssignment" +type = "struct" +features = [ + "eq", + # "ord", + "hash", + # "json", + "fmt", +] + +includes = [ + "substitutions/operator_pattern/operator_attribute_key.dtg.h", + "substitutions/output_graph/output_operator_attribute_expr.dtg.h", + "substitutions/unlabelled/pattern_node.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/fmt/unordered_map.h", + "utils/fmt/optional.h", +] + +[[fields]] +name = "template_operator" +type = "std::optional<::FlexFlow::PatternNode>" + +# NOTE(@wmdi): Not sure if it aligns with other design. Or alternatively we can +# define the assignment for each operator type. +[[fields]] +name = "assignments" +type = "std::unordered_map<::FlexFlow::OperatorAttributeKey, ::FlexFlow::OutputOperatorAttributeExpr>" diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml deleted file mode 100644 index 483f27791a..0000000000 --- a/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml +++ /dev/null @@ -1,32 +0,0 @@ -namespace = "FlexFlow" -name = "OutputOperatorAttrsAssignment" -features = [ - "eq", - # "ord", - "hash", - # "json", - "fmt", -] - -includes = [ - "substitutions/operator_pattern/operator_attribute_key.dtg.h", - "substitutions/output_graph/output_operator_attribute_expr.dtg.h", - "substitutions/unlabelled/pattern_node.dtg.h", - "", -] - -src_includes = [ - "utils/hash/unordered_map.h", - "utils/fmt/unordered_map.h", - "utils/fmt/optional.h", -] - -[[fields]] -name = "template_operator" -type = "std::optional<::FlexFlow::PatternNode>" - -# NOTE(@wmdi): Not sure if it aligns with other design. Or alternatively we can -# define the assignment for each operator type. -[[fields]] -name = "assignments" -type = "std::unordered_map<::FlexFlow::OperatorAttributeKey, ::FlexFlow::OutputOperatorAttributeExpr>" diff --git a/lib/substitutions/include/substitutions/output_graph/output_pattern_value.dtg.toml b/lib/substitutions/include/substitutions/output_graph/output_pattern_value.dtg.toml new file mode 100644 index 0000000000..5d4753ac79 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_pattern_value.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "OutputPatternValue" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h", +] + +[[fields]] +name = "raw_dataflow_value" +type = "::FlexFlow::OpenDataflowValue" diff --git a/lib/substitutions/include/substitutions/output_graph/output_pattern_value.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_pattern_value.struct.toml deleted file mode 100644 index e29eef4cdd..0000000000 --- a/lib/substitutions/include/substitutions/output_graph/output_pattern_value.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "OutputPatternValue" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h", -] - -[[fields]] -name = "raw_dataflow_value" -type = "::FlexFlow::OpenDataflowValue" diff --git a/lib/substitutions/include/substitutions/pcg_pattern.dtg.toml b/lib/substitutions/include/substitutions/pcg_pattern.dtg.toml new file mode 100644 index 0000000000..c00e75d68b --- /dev/null +++ b/lib/substitutions/include/substitutions/pcg_pattern.dtg.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "PCGPattern" +type = "struct" +features = [] +includes = [ + "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph.h", + "substitutions/operator_pattern/operator_attribute_pattern.dtg.h", + "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::LabelledOpenKwargDataflowGraph<::FlexFlow::OperatorAttributePattern, ::FlexFlow::TensorAttributePattern, int, ::FlexFlow::TensorSlotName>" diff --git a/lib/substitutions/include/substitutions/pcg_pattern.h b/lib/substitutions/include/substitutions/pcg_pattern.h index f0962b15c2..d39fab0f7b 100644 --- a/lib/substitutions/include/substitutions/pcg_pattern.h +++ b/lib/substitutions/include/substitutions/pcg_pattern.h @@ -26,8 +26,8 @@ TensorAttributePattern get_tensor_pattern(PCGPattern const &, OperatorAttributePattern get_operator_pattern(PCGPattern const &, PatternNode const &); std::unordered_set get_inputs(PCGPattern const &); -std::vector get_pattern_node_outputs(PCGPattern const &, - PatternNode const &); +std::unordered_map + get_pattern_node_outputs(PCGPattern const &, PatternNode const &); bool assignment_satisfies(SubParallelComputationGraph const &, PCGPattern const &, diff --git a/lib/substitutions/include/substitutions/pcg_pattern.struct.toml b/lib/substitutions/include/substitutions/pcg_pattern.struct.toml deleted file mode 100644 index 31e8820b09..0000000000 --- a/lib/substitutions/include/substitutions/pcg_pattern.struct.toml +++ /dev/null @@ -1,12 +0,0 @@ -namespace = "FlexFlow" -name = "PCGPattern" -features = [] -includes = [ - "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h", - "substitutions/operator_pattern/operator_attribute_pattern.dtg.h", - "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h", -] - -[[fields]] -name = "raw_graph" -type = "::FlexFlow::LabelledOpenDataflowGraph<::FlexFlow::OperatorAttributePattern, ::FlexFlow::TensorAttributePattern>" diff --git a/lib/substitutions/include/substitutions/pcg_pattern_match.dtg.toml b/lib/substitutions/include/substitutions/pcg_pattern_match.dtg.toml new file mode 100644 index 0000000000..5e10f5963c --- /dev/null +++ b/lib/substitutions/include/substitutions/pcg_pattern_match.dtg.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "PCGPatternMatch" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/bidict/bidict.h", + "substitutions/unlabelled/pattern_node.dtg.h", + "substitutions/unlabelled/pattern_input.dtg.h", + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "substitutions/open_parallel_tensor_guid_t.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "node_assignment" +type = "::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::parallel_layer_guid_t>" + +[[fields]] +name = "input_assignment" +type = "std::unordered_map<::FlexFlow::PatternInput, ::FlexFlow::open_parallel_tensor_guid_t>" diff --git a/lib/substitutions/include/substitutions/pcg_pattern_match.h b/lib/substitutions/include/substitutions/pcg_pattern_match.h index b946173422..faf2e4d2a8 100644 --- a/lib/substitutions/include/substitutions/pcg_pattern_match.h +++ b/lib/substitutions/include/substitutions/pcg_pattern_match.h @@ -6,7 +6,7 @@ #include "substitutions/pcg_pattern_match.dtg.h" #include "substitutions/sub_parallel_computation_graph.dtg.h" #include "substitutions/unlabelled/pattern_node_output.dtg.h" -#include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h" +#include "substitutions/unlabelled/unlabelled_kwarg_dataflow_graph_pattern_match.h" namespace FlexFlow { @@ -16,9 +16,14 @@ bidict PCGPattern const &pattern, SubParallelComputationGraph const &spcg); -UnlabelledDataflowGraphPatternMatch +UnlabelledKwargDataflowGraphPatternMatch get_unlabelled_pattern_match(PCGPatternMatch const &match); +void assert_pcg_pattern_match_is_valid_for_pattern_and_subpcg( + PCGPatternMatch const &match, + PCGPattern const &pattern, + SubParallelComputationGraph const &spcg); + } // namespace FlexFlow #endif diff --git a/lib/substitutions/include/substitutions/pcg_pattern_match.struct.toml b/lib/substitutions/include/substitutions/pcg_pattern_match.struct.toml deleted file mode 100644 index f45bedd2be..0000000000 --- a/lib/substitutions/include/substitutions/pcg_pattern_match.struct.toml +++ /dev/null @@ -1,29 +0,0 @@ -namespace = "FlexFlow" -name = "PCGPatternMatch" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "utils/bidict/bidict.h", - "substitutions/unlabelled/pattern_node.dtg.h", - "substitutions/unlabelled/pattern_input.dtg.h", - "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", - "substitutions/open_parallel_tensor_guid_t.dtg.h", - "", -] - -src_includes = [ - "utils/fmt/unordered_map.h", - "utils/hash/unordered_map.h", -] - -[[fields]] -name = "node_assignment" -type = "::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::parallel_layer_guid_t>" - -[[fields]] -name = "input_assignment" -type = "std::unordered_map<::FlexFlow::PatternInput, ::FlexFlow::open_parallel_tensor_guid_t>" diff --git a/lib/substitutions/include/substitutions/serialization.h b/lib/substitutions/include/substitutions/serialization.h deleted file mode 100644 index ef23f2185d..0000000000 --- a/lib/substitutions/include/substitutions/serialization.h +++ /dev/null @@ -1,144 +0,0 @@ -#ifndef _FLEXFLOW_SUBSTITUTIONS_SERIALIZATION_H -#define _FLEXFLOW_SUBSTITUTIONS_SERIALIZATION_H - -#include "substitutions.h" - -NLOHMANN_JSON_SERIALIZE_ENUM(OperatorType, - {{OP_INVALID, nullptr}, - {OP_NOOP, "OP_NOOP"}, - {OP_CONV2D, "OP_CONV2D"}, - {OP_DROPOUT, "OP_DROPOUT"}, - {OP_LINEAR, "OP_LINEAR"}, - {OP_BATCHMATMUL, "OP_BATCHMATMUL"}, - {OP_POOL2D, "OP_POOL2D_MAX"}, - {OP_SCALAR_MULTIPLY, "OP_SCALAR_MULTIPLY"}, - {OP_SCALAR_ADD, "OP_SCALAR_ADD"}, - {OP_SCALAR_FLOOR_DIV, "OP_SCALAR_FLOOR_DIV"}, - {OP_SCALAR_TRUE_DIV, "OP_SCALAR_TRUE_DIV"}, - {OP_SCALAR_SUB, "OP_SCALAR_SUB"}, - {OP_RELU, "OP_RELU"}, - {OP_IDENTITY, "OP_IDENTITY"}, - {OP_SIGMOID, "OP_SIGMOID"}, - {OP_TANH, "OP_TANH"}, - {OP_ELU, "OP_ELU"}, - {OP_FLAT, "OP_FLAT"}, - {OP_SOFTMAX, "OP_SOFTMAX"}, - {OP_BATCHNORM, "OP_BATCHNORM"}, - {OP_CONCAT, "OP_CONCAT"}, - {OP_SPLIT, "OP_SPLIT"}, - {OP_EMBEDDING, "OP_EMBEDDING"}, - {OP_CACHE, "OP_CACHE"}, - {OP_RESHAPE, "OP_RESHAPE"}, - {OP_REVERSE, "OP_REVERSE"}, - {OP_TRANSPOSE, "OP_TRANSPOSE"}, - {OP_EW_ADD, "OP_EW_ADD"}, - {OP_EW_MUL, "OP_EW_MUL"}, - {OP_MATMUL, "OP_MATMUL"}, - {OP_MUL, "OP_MUL"}, - {OP_ENLARGE, "OP_ENLARGE"}, - {OP_MERGE_GCONV, "OP_MERGE_GCONV"}, - {OP_CONSTANT_IMM, "OP_CONSTANT_IMM"}, - {OP_CONSTANT_ICONV, "OP_CONSTANT_ICONV"}, - {OP_CONSTANT_ONE, "OP_CONSTANT_ONE"}, - {OP_CONSTANT_POOL, "OP_CONSTANT_POOL"}, - {OP_SQUEEZE, "OP_SQUEEZE"}, - {OP_UNSQUEEZE, "OP_UNSQUEEZE"}, - {OP_EW_SUB, "OP_EW_SUB"}, - {OP_EW_DIV, "OP_EW_DIV"}, - {OP_EW_EQUAL, "OP_EW_EQUAL"}, - {OP_EW_GREATER, "OP_EW_GREATER"}, - {OP_EW_LESS, "OP_EW_LESS"}, - {OP_EW_MAX, "OP_EW_MAX"}, - {OP_EW_MIN, "OP_EW_MIN"}, - {OP_REDUCE_ARGMAX, "OP_REDUCE_ARGMAX"}, - {OP_REDUCE_ARGMIN, "OP_REDUCE_ARGMIN"}, - {OP_REDUCE_MAX, "OP_REDUCE_MAX"}, - {OP_REDUCE_MEAN, "OP_REDUCE_MEAN"}, - {OP_REDUCE_MIN, "OP_REDUCE_MIN"}, - {OP_REDUCE_PROD, "OP_REDUCE_PROD"}, - {OP_REDUCE_SUM, "OP_REDUCE_SUM"}, - {OP_PAD, "OP_PAD"}, - {OP_SHAPE, "OP_SHAPE"}, - {OP_SIZE, "OP_SIZE"}, - {OP_TOPK, "OP_TOPK"}, - {OP_WHERE, "OP_WHERE"}, - {OP_CEIL, "OP_CEIL"}, - {OP_CAST, "OP_CAST"}, - {OP_EXP, "OP_EXP"}, - {OP_ROUND, "OP_ROUND"}, - {OP_LOG, "OP_LOG"}, - {OP_LOGICAL_NOT, "OP_LOGICAL_NOT"}, - {OP_SQRT, "OP_SQRT"}, - {OP_SIN, "OP_SIN"}, - {OP_COS, "OP_COS"}, - {OP_LEAKYRELU, "OP_LEAKYRELU"}, - {OP_SLICE, "OP_SLICE"}, - {OP_RESIZE, "OP_RESIZE"}, - {OP_PRELU, "OP_PRELU"}, - {OP_GELU, "OP_GELU"}, - {OP_MULTIHEAD_ATTENTION, - "OP_MULTIHEAD_ATTENTION"}, - {OP_FUSED, "OP_FUSED"}, - {OP_RSQRT, "OP_RSQRT"}, - {OP_POW, "OP_POW"}, - {OP_MEAN, "OP_MEAN"}, - {OP_LAYERNORM, "OP_LAYERNORM"}, - {OP_REPARTITION, "OP_PARTITION"}, - {OP_COMBINE, "OP_COMBINE"}, - {OP_REPLICATE, "OP_REPLICATE"}, - {OP_REDUCTION, "OP_REDUCE"}, - {OP_PIPELINE, "OP_PIPELINE"}, - {OP_FUSED_PARALLEL, "OP_FUSED_PARALLEL"}}) - -namespace FlexFlow { -namespace substitutions { - -NLOHMANN_JSON_SERIALIZE_ENUM( - ParameterAttribute, - {{ParameterAttribute::INVALID, nullptr}, - {ParameterAttribute::OP_TYPE, "PM_OP_TYPE"}, - {ParameterAttribute::NUM_INPUTS, "PM_NUM_INPUTS"}, - {ParameterAttribute::NUM_OUTPUTS, "PM_NUM_OUTPUTS"}, - {ParameterAttribute::GROUP, "PM_GROUP"}, - {ParameterAttribute::KERNEL_H, "PM_KERNEL_H"}, - {ParameterAttribute::KERNEL_W, "PM_KERNEL_W"}, - {ParameterAttribute::STRIDE_H, "PM_STRIDE_H"}, - {ParameterAttribute::STRIDE_W, "PM_STRIDE_W"}, - {ParameterAttribute::PADDING_H, "PM_PADDING_H"}, - {ParameterAttribute::PADDING_W, "PM_PADDING_W"}, - {ParameterAttribute::ACTIVATION, "PM_ACTIVATION"}, - {ParameterAttribute::NUMDIM, "PM_NUMDIM"}, - {ParameterAttribute::AXIS, "PM_AXIS"}, - {ParameterAttribute::PERM, "PM_PERM"}, - {ParameterAttribute::OUTSHUFFLE, "PM_OUTSHUFFLE"}, - {ParameterAttribute::MERGE_GCONV_COUNT, "PM_MERGE_GCONV_COUNT"}, - {ParameterAttribute::AXES, "PM_AXES"}, - {ParameterAttribute::KEEP_DIMS, "PM_KEEP_DIMS"}, - {ParameterAttribute::EPSILON, "PM_EPSILON"}, - {ParameterAttribute::REPARTITION_DIM, "PM_REPARTITION_DIM"}, - {ParameterAttribute::REPARTITION_DEGREE, "PM_REPARTITION_DEGREE"}, - {ParameterAttribute::REPLICATE_DIM, "PM_REPLICATE_DIM"}, - {ParameterAttribute::REPLICATE_DEGREE, "PM_REPLICATE_DEGREE"}, - {ParameterAttribute::COMBINE_DIM, "PM_COMBINE_DIM"}, - {ParameterAttribute::COMBINE_DEGREE, "PM_COMBINE_DEGREE"}, - {ParameterAttribute::REDUCTION_DIM, "PM_REDUCTION_DIM"}, - {ParameterAttribute::REDUCTION_DEGREE, "PM_REDUCTION_DEGREE"}, - {ParameterAttribute::SOFTMAX_DIM, "PM_SOFTMAX_DIM"}, - {ParameterAttribute::NUM_HEADS, "PM_NUM_HEADS"}, - {ParameterAttribute::PARALLEL_DIM, "PM_PARALLEL_DIM"}, - {ParameterAttribute::PARALLEL_DEGREE, "PM_PARALLEL_DEGREE"}, - {ParameterAttribute::PAD, "PM_PAD"}}) -void from_json(nlohmann::json const &j, OperatorAttributeConstraint &p); -void from_json(nlohmann::json const &j, Tensor &t); -void from_json(nlohmann::json const &j, OperatorConstraint &t); -void from_json(nlohmann::json const &j, MapOutput &t); -void from_json(nlohmann::json const &j, Rule &t); -void from_json(nlohmann::json const &j, RuleCollection &c); - -RuleCollection load_rule_collection(std::istream &s); -RuleCollection load_rule_collection_from_path(std::string const &path); - -} // namespace substitutions -} // namespace FlexFlow - -#endif diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.dtg.toml b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.dtg.toml new file mode 100644 index 0000000000..543db6f7df --- /dev/null +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "SubParallelComputationGraph" +type = "struct" +features = [ ] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h", + "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h", + "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph_view.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::LabelledOpenKwargDataflowGraphView<::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs, int, ::FlexFlow::TensorSlotName>" diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h index 2d76352ccf..0b4ac7238c 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h @@ -14,7 +14,9 @@ namespace FlexFlow { std::unordered_set - get_parallel_layers(SubParallelComputationGraph const &sub_pcg); + get_parallel_layers(SubParallelComputationGraph const &); +std::unordered_set + get_parallel_tensors(SubParallelComputationGraph const &); ParallelLayerAttrs get_parallel_layer_attrs(SubParallelComputationGraph const &, parallel_layer_guid_t const &); PCGOperatorAttrs get_operator_attrs(SubParallelComputationGraph const &, @@ -31,10 +33,10 @@ parallel_layer_guid_t get_parallel_layer_by_name(SubParallelComputationGraph const &pcg, std::string const &name); -std::vector +std::unordered_map get_layer_inputs(SubParallelComputationGraph const &, parallel_layer_guid_t const &); -std::vector +std::unordered_map get_layer_outputs(SubParallelComputationGraph const &, parallel_layer_guid_t const &); diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml deleted file mode 100644 index 38ce364b49..0000000000 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml +++ /dev/null @@ -1,13 +0,0 @@ -namespace = "FlexFlow" -name = "SubParallelComputationGraph" -features = [ ] - -includes = [ - "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h", - "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h", - "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h", -] - -[[fields]] -name = "raw_graph" -type = "::FlexFlow::LabelledOpenDataflowGraphView<::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs>" diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph_data.dtg.toml b/lib/substitutions/include/substitutions/sub_parallel_computation_graph_data.dtg.toml new file mode 100644 index 0000000000..8836d050b5 --- /dev/null +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph_data.dtg.toml @@ -0,0 +1,42 @@ +namespace = "FlexFlow" +name = "SubParallelComputationGraphData" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h", + "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h", + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "substitutions/open_parallel_tensor_guid_t.dtg.h", + "substitutions/input_parallel_tensor_guid_t.dtg.h", + "substitutions/sub_parallel_computation_graph_edge.dtg.h", + "", + "", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/hash/unordered_set.h", + "utils/fmt/unordered_map.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "node_data" +type = "std::unordered_map<::FlexFlow::parallel_layer_guid_t, ::FlexFlow::ParallelLayerAttrs>" + +[[fields]] +name = "edges" +type = "std::unordered_set<::FlexFlow::SubParallelComputationGraphEdge>" + +[[fields]] +name = "inputs" +type = "std::unordered_set<::FlexFlow::input_parallel_tensor_guid_t>" + +[[fields]] +name = "value_data" +type = "std::unordered_map<::FlexFlow::open_parallel_tensor_guid_t, ::FlexFlow::ParallelTensorAttrs>" diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph_data.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph_data.h new file mode 100644 index 0000000000..7911e12c1d --- /dev/null +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph_data.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_DATA_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_DATA_H + +#include "substitutions/sub_parallel_computation_graph_data.dtg.h" + +namespace FlexFlow { + +void require_sub_parallel_computation_graph_data_is_valid( + SubParallelComputationGraphData const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph_data.struct.toml b/lib/substitutions/include/substitutions/sub_parallel_computation_graph_data.struct.toml deleted file mode 100644 index 537af231bf..0000000000 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph_data.struct.toml +++ /dev/null @@ -1,41 +0,0 @@ -namespace = "FlexFlow" -name = "SubParallelComputationGraphData" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h", - "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h", - "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", - "substitutions/open_parallel_tensor_guid_t.dtg.h", - "substitutions/input_parallel_tensor_guid_t.dtg.h", - "substitutions/sub_parallel_computation_graph_edge.dtg.h", - "", - "", -] - -src_includes = [ - "utils/hash/unordered_map.h", - "utils/hash/unordered_set.h", - "utils/fmt/unordered_map.h", - "utils/fmt/unordered_set.h", -] - -[[fields]] -name = "node_data" -type = "std::unordered_map<::FlexFlow::parallel_layer_guid_t, ::FlexFlow::ParallelLayerAttrs>" - -[[fields]] -name = "edges" -type = "std::unordered_set<::FlexFlow::SubParallelComputationGraphEdge>" - -[[fields]] -name = "inputs" -type = "std::unordered_set<::FlexFlow::input_parallel_tensor_guid_t>" - -[[fields]] -name = "value_data" -type = "std::unordered_map<::FlexFlow::open_parallel_tensor_guid_t, ::FlexFlow::ParallelTensorAttrs>" diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph_edge.dtg.toml b/lib/substitutions/include/substitutions/sub_parallel_computation_graph_edge.dtg.toml new file mode 100644 index 0000000000..a2fcd22309 --- /dev/null +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph_edge.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "SubParallelComputationGraphEdge" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.dtg.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::OpenKwargDataflowEdge" diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph_edge.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph_edge.h index c0544abe1b..5b0ea7c7c3 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph_edge.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph_edge.h @@ -12,7 +12,7 @@ namespace FlexFlow { SubParallelComputationGraphEdge subpcg_edge_from_tensor_and_dst(parallel_tensor_guid_t const &tensor, parallel_layer_guid_t const &layer, - nonnegative_int input_idx); + TensorSlotName input_slot_name); SubParallelComputationGraphEdge subpcg_edge_from_tensor_and_use(open_parallel_tensor_guid_t const &tensor, parallel_tensor_use_t const &use); diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph_edge.struct.toml b/lib/substitutions/include/substitutions/sub_parallel_computation_graph_edge.struct.toml deleted file mode 100644 index 6d8f72bae8..0000000000 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph_edge.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "SubParallelComputationGraphEdge" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h", -] - -[[fields]] -name = "raw_edge" -type = "::FlexFlow::OpenDataflowEdge" diff --git a/lib/substitutions/include/substitutions/substitution.dtg.toml b/lib/substitutions/include/substitutions/substitution.dtg.toml new file mode 100644 index 0000000000..5daeaceded --- /dev/null +++ b/lib/substitutions/include/substitutions/substitution.dtg.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "Substitution" +type = "struct" +features = [] + +includes = [ + "substitutions/pcg_pattern.dtg.h", + "substitutions/output_graph/output_graph_expr.dtg.h", + "substitutions/output_graph/output_graph_expr_input.dtg.h", + "substitutions/output_graph/output_graph_expr_node_output.dtg.h", + "substitutions/unlabelled/pattern_input.dtg.h", + "substitutions/unlabelled/pattern_node_output.dtg.h", +] + +[[fields]] +name = "pcg_pattern" +type = "::FlexFlow::PCGPattern" + +[[fields]] +name = "output_graph_expr" +type = "::FlexFlow::OutputGraphExpr" + +[[fields]] +name = "inputs_mapping" +type = "::FlexFlow::bidict<::FlexFlow::PatternInput, ::FlexFlow::OutputGraphExprInput>" + +[[fields]] +name = "outputs_mapping" +type = "::FlexFlow::bidict<::FlexFlow::PatternNodeOutput, ::FlexFlow::OutputGraphExprNodeOutput>" diff --git a/lib/substitutions/include/substitutions/substitution.struct.toml b/lib/substitutions/include/substitutions/substitution.struct.toml deleted file mode 100644 index 49bef62747..0000000000 --- a/lib/substitutions/include/substitutions/substitution.struct.toml +++ /dev/null @@ -1,28 +0,0 @@ -namespace = "FlexFlow" -name = "Substitution" -features = [] - -includes = [ - "substitutions/pcg_pattern.dtg.h", - "substitutions/output_graph/output_graph_expr.dtg.h", - "substitutions/output_graph/output_graph_expr_input.dtg.h", - "substitutions/output_graph/output_graph_expr_node_output.dtg.h", - "substitutions/unlabelled/pattern_input.dtg.h", - "substitutions/unlabelled/pattern_node_output.dtg.h", -] - -[[fields]] -name = "pcg_pattern" -type = "::FlexFlow::PCGPattern" - -[[fields]] -name = "output_graph_expr" -type = "::FlexFlow::OutputGraphExpr" - -[[fields]] -name = "inputs_mapping" -type = "::FlexFlow::bidict<::FlexFlow::PatternInput, ::FlexFlow::OutputGraphExprInput>" - -[[fields]] -name = "outputs_mapping" -type = "::FlexFlow::bidict<::FlexFlow::PatternNodeOutput, ::FlexFlow::OutputGraphExprNodeOutput>" diff --git a/lib/substitutions/include/substitutions/substitution_builder.h b/lib/substitutions/include/substitutions/substitution_builder.h index 1548b2269b..248c23ecc8 100644 --- a/lib/substitutions/include/substitutions/substitution_builder.h +++ b/lib/substitutions/include/substitutions/substitution_builder.h @@ -17,16 +17,19 @@ struct SubstitutionBuilder { std::optional const &name = std::nullopt); void equate_outputs(PatternValue const &, OutputGraphExprValue const &); - std::vector add_pattern_node( + std::unordered_map add_pattern_node( OperatorAttributePattern const &node_pattern, - std::vector const &inputs, - std::vector const &output_patterns, + std::unordered_map const &inputs, + std::unordered_map const + &output_patterns, std::optional const &name = std::nullopt); - std::vector - add_output_graph_node(OutputOperatorAttrsAssignment const &node_expr, - std::vector const &inputs, - nonnegative_int num_outputs); + std::unordered_map + add_output_graph_node( + OutputOperatorAttrsAssignment const &node_expr, + std::unordered_map const + &inputs, + std::unordered_set const &output_slots); PatternNode pattern_node_named(std::string const &) const; PatternInput pattern_input_named(std::string const &) const; @@ -34,14 +37,24 @@ struct SubstitutionBuilder { Substitution get_substitution() const; private: - LabelledOpenDataflowGraph + int get_fresh_graph_input_name(); + +private: + LabelledOpenKwargDataflowGraph pattern_g; - LabelledOpenDataflowGraph + LabelledOpenKwargDataflowGraph output_g; bidict input_mapping; bidict pattern_node_names; bidict pattern_input_names; bidict output_mapping; + int next_graph_input_id; }; } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.toml new file mode 100644 index 0000000000..a8e1aca724 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "TensorAttributeConstraint" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "substitutions/constraint_type.dtg.h", + "substitutions/tensor_pattern/tensor_attribute_expr.dtg.h", + "substitutions/tensor_pattern/tensor_attribute_value.dtg.h", +] + +[[fields]] +name = "constraint_type" +type = "::FlexFlow::ConstraintType" + +[[fields]] +name = "attribute_expr" +type = "::FlexFlow::TensorAttributeExpr" + +[[fields]] +name = "attribute_value" +type = "::FlexFlow::TensorAttributeValue" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.struct.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.struct.toml deleted file mode 100644 index 6aba719e08..0000000000 --- a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.struct.toml +++ /dev/null @@ -1,28 +0,0 @@ -namespace = "FlexFlow" -name = "TensorAttributeConstraint" -features = [ - "eq", - "ord", - "hash", - "json", - # "rapidcheck", - "fmt", -] - -includes = [ - "substitutions/constraint_type.dtg.h", - "substitutions/tensor_pattern/tensor_attribute_expr.dtg.h", - "substitutions/tensor_pattern/tensor_attribute_value.dtg.h", -] - -[[fields]] -name = "constraint_type" -type = "::FlexFlow::ConstraintType" - -[[fields]] -name = "attribute_expr" -type = "::FlexFlow::TensorAttributeExpr" - -[[fields]] -name = "attribute_value" -type = "::FlexFlow::TensorAttributeValue" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.dtg.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.dtg.toml new file mode 100644 index 0000000000..5462e79b7e --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.dtg.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "TensorAttributeExpr" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "substitutions/tensor_pattern/tensor_attribute_key.dtg.h", + "substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h", + "substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h", +] + +[[values]] +type = "::FlexFlow::TensorAttributeKey" +key = "key" + +[[values]] +type = "::FlexFlow::TensorAttributeListSize" +key = "list_size" + +[[values]] +type = "::FlexFlow::TensorAttributeListIndexAccess" +key = "list_idx" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.variant.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.variant.toml deleted file mode 100644 index 03ec0eb624..0000000000 --- a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.variant.toml +++ /dev/null @@ -1,27 +0,0 @@ -namespace = "FlexFlow" -name = "TensorAttributeExpr" -features = [ - "eq", - "ord", - "hash", - "json", - "fmt", -] - -includes = [ - "substitutions/tensor_pattern/tensor_attribute_key.dtg.h", - "substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h", - "substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h", -] - -[[values]] -type = "::FlexFlow::TensorAttributeKey" -key = "key" - -[[values]] -type = "::FlexFlow::TensorAttributeListSize" -key = "list_size" - -[[values]] -type = "::FlexFlow::TensorAttributeListIndexAccess" -key = "list_idx" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.dtg.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.dtg.toml new file mode 100644 index 0000000000..e2fc03c166 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.dtg.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "TensorAttributeKey" +type = "enum" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "DIM_SIZES" + +[[values]] +name = "DIM_DEGREES" + +[[values]] +name = "DISCARD_COPY_DEGREE_DIM" + +[[values]] +name = "SUM_DEGREE_DIM" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml deleted file mode 100644 index 541888038b..0000000000 --- a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml +++ /dev/null @@ -1,20 +0,0 @@ -namespace = "FlexFlow" -name = "TensorAttributeKey" -features = [ - "hash", - "json", - "rapidcheck", - "fmt", -] - -[[values]] -name = "DIM_SIZES" - -[[values]] -name = "DIM_DEGREES" - -[[values]] -name = "DISCARD_COPY_DEGREE_DIM" - -[[values]] -name = "SUM_DEGREE_DIM" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.toml new file mode 100644 index 0000000000..ef74289e9d --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "TensorAttributeListIndexAccess" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "rapidcheck", + "json", + "fmt", +] + +includes = [ + "substitutions/tensor_pattern/tensor_attribute_key.dtg.h", + "utils/nonnegative_int/nonnegative_int.h", +] + +[[fields]] +name = "attribute_key" +type = "::FlexFlow::TensorAttributeKey" + +[[fields]] +name = "index" +type = "::FlexFlow::nonnegative_int" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.struct.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.struct.toml deleted file mode 100644 index 71e11a12d5..0000000000 --- a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.struct.toml +++ /dev/null @@ -1,23 +0,0 @@ -namespace = "FlexFlow" -name = "TensorAttributeListIndexAccess" -features = [ - "eq", - "ord", - "hash", - "rapidcheck", - "json", - "fmt", -] - -includes = [ - "substitutions/tensor_pattern/tensor_attribute_key.dtg.h", - "utils/nonnegative_int/nonnegative_int.h", -] - -[[fields]] -name = "attribute_key" -type = "::FlexFlow::TensorAttributeKey" - -[[fields]] -name = "index" -type = "::FlexFlow::nonnegative_int" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.toml new file mode 100644 index 0000000000..36f0530c44 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "TensorAttributeListSize" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "rapidcheck", + "json", + "fmt", +] + +includes = [ + "substitutions/tensor_pattern/tensor_attribute_key.dtg.h", +] + + +[[fields]] +name = "attribute_key" +type = "::FlexFlow::TensorAttributeKey" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.struct.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.struct.toml deleted file mode 100644 index c876696343..0000000000 --- a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.struct.toml +++ /dev/null @@ -1,19 +0,0 @@ -namespace = "FlexFlow" -name = "TensorAttributeListSize" -features = [ - "eq", - "ord", - "hash", - "rapidcheck", - "json", - "fmt", -] - -includes = [ - "substitutions/tensor_pattern/tensor_attribute_key.dtg.h", -] - - -[[fields]] -name = "attribute_key" -type = "::FlexFlow::TensorAttributeKey" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.toml new file mode 100644 index 0000000000..c81d28c8a0 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "TensorAttributePattern" +type = "struct" +features = [ + "eq", + # "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "", + "substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h", + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "attribute_constraints" +type = "std::unordered_set<::FlexFlow::TensorAttributeConstraint>" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.struct.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.struct.toml deleted file mode 100644 index 139774979e..0000000000 --- a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.struct.toml +++ /dev/null @@ -1,21 +0,0 @@ -namespace = "FlexFlow" -name = "TensorAttributePattern" -features = [ - "eq", - # "ord", - "hash", - "json", - # "rapidcheck", - "fmt", -] - -includes = [ - "", - "substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h", - "utils/hash/unordered_set.h", - "utils/fmt/unordered_set.h", -] - -[[fields]] -name = "attribute_constraints" -type = "std::unordered_set<::FlexFlow::TensorAttributeConstraint>" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.dtg.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.dtg.toml new file mode 100644 index 0000000000..ffacfafbdf --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.dtg.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "TensorAttributeValue" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "", + "utils/hash/vector.h", + "utils/fmt/vector.h", + "utils/nonnegative_int/nonnegative_int.h", +] + +[[values]] +type = "::FlexFlow::nonnegative_int" + +[[values]] +type = "std::vector<::FlexFlow::nonnegative_int>" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml deleted file mode 100644 index d2b931fb2d..0000000000 --- a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml +++ /dev/null @@ -1,22 +0,0 @@ -namespace = "FlexFlow" -name = "TensorAttributeValue" -features = [ - "eq", - "ord", - "hash", - "json", - "fmt", -] - -includes = [ - "", - "utils/hash/vector.h", - "utils/fmt/vector.h", - "utils/nonnegative_int/nonnegative_int.h", -] - -[[values]] -type = "::FlexFlow::nonnegative_int" - -[[values]] -type = "std::vector<::FlexFlow::nonnegative_int>" diff --git a/lib/substitutions/include/substitutions/unity_substitution_set.h b/lib/substitutions/include/substitutions/unity_substitution_set.h index 183f76ac8a..be1a2101d0 100644 --- a/lib/substitutions/include/substitutions/unity_substitution_set.h +++ b/lib/substitutions/include/substitutions/unity_substitution_set.h @@ -1,14 +1,14 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNITY_SUBSTITUTION_SET_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNITY_SUBSTITUTION_SET_H -#include "pcg/machine_specification.dtg.h" +#include "pcg/machine_compute_specification.dtg.h" #include "substitutions/substitution.dtg.h" #include "utils/fmt/vector.h" namespace FlexFlow { std::vector - get_substitution_set(MachineSpecification const &resources); + get_substitution_set(MachineComputeSpecification const &resources); Substitution create_combine_inception(nonnegative_int num_convs, nonnegative_int num_dims, diff --git a/lib/substitutions/include/substitutions/unlabelled/find_pattern_matches.h b/lib/substitutions/include/substitutions/unlabelled/find_pattern_matches.h index 42d45c1e0d..8b547c0ae1 100644 --- a/lib/substitutions/include/substitutions/unlabelled/find_pattern_matches.h +++ b/lib/substitutions/include/substitutions/unlabelled/find_pattern_matches.h @@ -2,15 +2,16 @@ #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_FIND_PATTERN_MATCHES_H #include "substitutions/unlabelled/match_additional_criterion.dtg.h" -#include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h" #include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" +#include "substitutions/unlabelled/unlabelled_kwarg_dataflow_graph_pattern_match.dtg.h" namespace FlexFlow { -std::vector - find_pattern_matches(UnlabelledGraphPattern const &pattern, - OpenDataflowGraphView const &graph, - MatchAdditionalCriterion const &additional_criterion); +std::vector + find_unlabelled_pattern_matches( + UnlabelledGraphPattern const &pattern, + OpenKwargDataflowGraphView const &graph, + MatchAdditionalCriterion const &additional_criterion); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.dtg.toml b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.dtg.toml new file mode 100644 index 0000000000..7484565da7 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "InputPatternEdge" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_input_edge.dtg.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::KwargDataflowInputEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.h index 8c58cb991c..4fe4213e7e 100644 --- a/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.h +++ b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.h @@ -9,7 +9,7 @@ namespace FlexFlow { PatternInput get_src_input(InputPatternEdge const &); PatternNode get_dst_node(InputPatternEdge const &); -nonnegative_int get_dst_idx(InputPatternEdge const &); +TensorSlotName get_dst_slot_name(InputPatternEdge const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml deleted file mode 100644 index e4203cf495..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "InputPatternEdge" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/open_dataflow_graph/dataflow_input_edge.dtg.h", -] - -[[fields]] -name = "raw_edge" -type = "::FlexFlow::DataflowInputEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.dtg.toml b/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.dtg.toml new file mode 100644 index 0000000000..680140c33f --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.dtg.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "MatchAdditionalCriterion" +type = "struct" +features = [] + +includes = [ + "", + "utils/graph/node/node.dtg.h", + "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_value.dtg.h", + "substitutions/unlabelled/pattern_node.dtg.h", + "substitutions/unlabelled/pattern_value.dtg.h", +] + +[[fields]] +name = "node_criterion" +type = "std::function" + +[[fields]] +name = "value_criterion" +type = "std::function const &)>" diff --git a/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.struct.toml b/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.struct.toml deleted file mode 100644 index 9eb62933f1..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.struct.toml +++ /dev/null @@ -1,19 +0,0 @@ -namespace = "FlexFlow" -name = "MatchAdditionalCriterion" -features = [] - -includes = [ - "", - "utils/graph/node/node.dtg.h", - "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h", - "substitutions/unlabelled/pattern_node.dtg.h", - "substitutions/unlabelled/pattern_value.dtg.h", -] - -[[fields]] -name = "node_criterion" -type = "std::function" - -[[fields]] -name = "value_criterion" -type = "std::function" diff --git a/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.h b/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.h deleted file mode 100644 index 1b30f274f9..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MULTIDIGRAPH_PATTERN_MATCH_H -#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MULTIDIGRAPH_PATTERN_MATCH_H - -// #include "substitutions/unlabelled/edge_splits.dtg.h" -// #include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" - -namespace FlexFlow { - -// MultiDiGraphPatternMatch empty_multidigraph_pattern_match(); -// std::optional -// unsplit_matches(MultiDiGraphPatternMatch const &prefix, -// MultiDiGraphPatternMatch const &postfix, -// UnlabelledPatternEdgeSplits const &edge_splits); - -} // namespace FlexFlow - -#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.dtg.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.dtg.toml new file mode 100644 index 0000000000..baccfa45a9 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.dtg.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "PatternEdge" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "substitutions/unlabelled/input_pattern_edge.dtg.h", + "substitutions/unlabelled/standard_pattern_edge.dtg.h", +] + +[[values]] +type = "::FlexFlow::InputPatternEdge" + +[[values]] +type = "::FlexFlow::StandardPatternEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h index 1d6f1302ed..13c6e36bc8 100644 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h @@ -6,6 +6,7 @@ #include "substitutions/unlabelled/pattern_node.dtg.h" #include "substitutions/unlabelled/standard_pattern_edge.dtg.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.dtg.h" #include namespace FlexFlow { @@ -22,7 +23,8 @@ InputPatternEdge require_input_edge(PatternEdge const &); PatternEdge pattern_edge_from_input_edge(InputPatternEdge const &); PatternEdge pattern_edge_from_standard_edge(StandardPatternEdge const &); -PatternEdge pattern_edge_from_raw_open_dataflow_edge(OpenDataflowEdge const &); +PatternEdge pattern_edge_from_raw_open_dataflow_edge( + OpenKwargDataflowEdge const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.variant.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.variant.toml deleted file mode 100644 index 143ea78ac1..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.variant.toml +++ /dev/null @@ -1,19 +0,0 @@ -namespace = "FlexFlow" -name = "PatternEdge" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "substitutions/unlabelled/input_pattern_edge.dtg.h", - "substitutions/unlabelled/standard_pattern_edge.dtg.h", -] - -[[values]] -type = "::FlexFlow::InputPatternEdge" - -[[values]] -type = "::FlexFlow::StandardPatternEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_input.dtg.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_input.dtg.toml new file mode 100644 index 0000000000..317fb97d56 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_input.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "PatternInput" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_graph_input.dtg.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +[[fields]] +name = "raw_dataflow_graph_input" +type = "::FlexFlow::KwargDataflowGraphInput" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_input.struct.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_input.struct.toml deleted file mode 100644 index e91e5673af..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_input.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "PatternInput" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", -] - -[[fields]] -name = "raw_dataflow_graph_input" -type = "::FlexFlow::DataflowGraphInput" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h b/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h index ce30b18f55..4999ddcfef 100644 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h @@ -2,29 +2,29 @@ #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_MATCHING_H #include "substitutions/unlabelled/match_additional_criterion.dtg.h" -#include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h" #include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" -#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_subgraph_result.dtg.h" -#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" +#include "substitutions/unlabelled/unlabelled_kwarg_dataflow_graph_pattern_match.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_subgraph_result.dtg.h" namespace FlexFlow { -OpenDataflowSubgraphResult - subgraph_matched(OpenDataflowGraphView const &graph, - UnlabelledDataflowGraphPatternMatch const &match); +OpenKwargDataflowSubgraphResult subgraph_matched( + OpenKwargDataflowGraphView const &graph, + UnlabelledKwargDataflowGraphPatternMatch const &match); bool pattern_matches_subgraph_under( UnlabelledGraphPattern const &pattern, - OpenDataflowGraphView const &subgraph, - bidict const + OpenKwargDataflowGraphView const &subgraph, + bidict, + KwargDataflowGraphInput> const &full_graph_values_to_subgraph_inputs, - UnlabelledDataflowGraphPatternMatch const &match, + UnlabelledKwargDataflowGraphPatternMatch const &match, MatchAdditionalCriterion const &additional_criterion); bool unlabelled_pattern_does_match( UnlabelledGraphPattern const &pattern, - OpenDataflowGraphView const &graph, - UnlabelledDataflowGraphPatternMatch const &match, + OpenKwargDataflowGraphView const &graph, + UnlabelledKwargDataflowGraphPatternMatch const &match, MatchAdditionalCriterion const &additional_criterion); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_node.dtg.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_node.dtg.toml new file mode 100644 index 0000000000..57029483f1 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_node.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "PatternNode" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "raw_node" +type = "::FlexFlow::Node" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_node.struct.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_node.struct.toml deleted file mode 100644 index a3bcc83249..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_node.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "PatternNode" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/node/node.dtg.h", -] - -[[fields]] -name = "raw_node" -type = "::FlexFlow::Node" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_node_output.dtg.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_node_output.dtg.toml new file mode 100644 index 0000000000..020e3ce531 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_node_output.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "PatternNodeOutput" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output.dtg.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +[[fields]] +name = "raw_dataflow_output" +type = "::FlexFlow::KwargDataflowOutput<::FlexFlow::TensorSlotName>" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_node_output.h b/lib/substitutions/include/substitutions/unlabelled/pattern_node_output.h index 67f513b8b1..d99fb34486 100644 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_node_output.h +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_node_output.h @@ -6,7 +6,7 @@ namespace FlexFlow { PatternNode get_src_node(PatternNodeOutput const &); -nonnegative_int get_idx(PatternNodeOutput const &); +TensorSlotName get_slot_name(PatternNodeOutput const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_node_output.struct.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_node_output.struct.toml deleted file mode 100644 index c2b85ae4fb..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_node_output.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "PatternNodeOutput" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/dataflow_graph/dataflow_output.dtg.h", -] - -[[fields]] -name = "raw_dataflow_output" -type = "::FlexFlow::DataflowOutput" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_split.dtg.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_split.dtg.toml new file mode 100644 index 0000000000..3358f320e6 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_split.dtg.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "PatternSplit" +type = "struct" +features = [ + "eq", + # "ord", + "hash", + # "json", + "fmt", +] + +includes = [ + "", + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", + "substitutions/unlabelled/pattern_node.dtg.h", +] + +[[fields]] +name = "first" +type = "std::unordered_set<::FlexFlow::PatternNode>" + +[[fields]] +name = "second" +type = "std::unordered_set<::FlexFlow::PatternNode>" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_split.struct.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_split.struct.toml deleted file mode 100644 index 1fbe8c241b..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_split.struct.toml +++ /dev/null @@ -1,24 +0,0 @@ -namespace = "FlexFlow" -name = "PatternSplit" -features = [ - "eq", - # "ord", - "hash", - # "json", - "fmt", -] - -includes = [ - "", - "utils/hash/unordered_set.h", - "utils/fmt/unordered_set.h", - "substitutions/unlabelled/pattern_node.dtg.h", -] - -[[fields]] -name = "first" -type = "std::unordered_set<::FlexFlow::PatternNode>" - -[[fields]] -name = "second" -type = "std::unordered_set<::FlexFlow::PatternNode>" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_split_result.dtg.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_split_result.dtg.toml new file mode 100644 index 0000000000..301a465a93 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_split_result.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "PatternSplitResult" +type = "struct" +features = [ ] + +includes = [ + "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h", + "substitutions/unlabelled/pattern_value.dtg.h", + "substitutions/unlabelled/pattern_input.dtg.h", + "utils/bidict/bidict.h", +] + +[[fields]] +name = "subpattern_1" +type = "::FlexFlow::UnlabelledGraphPattern" + +[[fields]] +name = "subpattern_2" +type = "::FlexFlow::UnlabelledGraphPattern" + +[[fields]] +name = "full_pattern_values_to_subpattern_1_inputs" +type = "::FlexFlow::bidict<::FlexFlow::PatternValue, ::FlexFlow::PatternInput>" + +[[fields]] +name = "full_pattern_values_to_subpattern_2_inputs" +type = "::FlexFlow::bidict<::FlexFlow::PatternValue, ::FlexFlow::PatternInput>" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_split_result.struct.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_split_result.struct.toml deleted file mode 100644 index d2e20343be..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_split_result.struct.toml +++ /dev/null @@ -1,26 +0,0 @@ -namespace = "FlexFlow" -name = "PatternSplitResult" -features = [ ] - -includes = [ - "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h", - "substitutions/unlabelled/pattern_value.dtg.h", - "substitutions/unlabelled/pattern_input.dtg.h", - "utils/bidict/bidict.h", -] - -[[fields]] -name = "subpattern_1" -type = "::FlexFlow::UnlabelledGraphPattern" - -[[fields]] -name = "subpattern_2" -type = "::FlexFlow::UnlabelledGraphPattern" - -[[fields]] -name = "full_pattern_values_to_subpattern_1_inputs" -type = "::FlexFlow::bidict<::FlexFlow::PatternValue, ::FlexFlow::PatternInput>" - -[[fields]] -name = "full_pattern_values_to_subpattern_2_inputs" -type = "::FlexFlow::bidict<::FlexFlow::PatternValue, ::FlexFlow::PatternInput>" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_value.dtg.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_value.dtg.toml new file mode 100644 index 0000000000..1b795636cc --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_value.dtg.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "PatternValue" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "substitutions/unlabelled/pattern_input.dtg.h", + "substitutions/unlabelled/pattern_node_output.dtg.h", +] + +[[values]] +type = "::FlexFlow::PatternNodeOutput" +key = "node_output" + +[[values]] +type = "::FlexFlow::PatternInput" +key = "pattern_input" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_value.h b/lib/substitutions/include/substitutions/unlabelled/pattern_value.h index 15dd299c6b..bb4dac43fd 100644 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_value.h +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_value.h @@ -3,13 +3,14 @@ #include "substitutions/unlabelled/pattern_value.dtg.h" #include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_value.dtg.h" namespace FlexFlow { -OpenDataflowValue +OpenKwargDataflowValue raw_open_dataflow_value_from_pattern_value(PatternValue const &); -PatternValue - pattern_value_from_raw_open_dataflow_value(OpenDataflowValue const &); +PatternValue pattern_value_from_raw_open_kwarg_dataflow_value( + OpenKwargDataflowValue const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_value.variant.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_value.variant.toml deleted file mode 100644 index f9abc85c4b..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_value.variant.toml +++ /dev/null @@ -1,19 +0,0 @@ -namespace = "FlexFlow" -name = "PatternValue" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "substitutions/unlabelled/pattern_input.dtg.h", - "substitutions/unlabelled/pattern_node_output.dtg.h", -] - -[[values]] -type = "::FlexFlow::PatternNodeOutput" - -[[values]] -type = "::FlexFlow::PatternInput" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_value_use.dtg.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_value_use.dtg.toml new file mode 100644 index 0000000000..c5b1c57de4 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_value_use.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "PatternValueUse" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/dataflow_graph/dataflow_input.dtg.h", +] + +[[fields]] +name = "raw_dataflow_input" +type = "::FlexFlow::DataflowInput" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_value_use.struct.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_value_use.struct.toml deleted file mode 100644 index 35630eac70..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_value_use.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "PatternValueUse" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/dataflow_graph/dataflow_input.dtg.h", -] - -[[fields]] -name = "raw_dataflow_input" -type = "::FlexFlow::DataflowInput" diff --git a/lib/substitutions/include/substitutions/unlabelled/standard_pattern_edge.dtg.toml b/lib/substitutions/include/substitutions/unlabelled/standard_pattern_edge.dtg.toml new file mode 100644 index 0000000000..1eeb2952f1 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/standard_pattern_edge.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "StandardPatternEdge" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge.dtg.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::KwargDataflowEdge<::FlexFlow::TensorSlotName>" diff --git a/lib/substitutions/include/substitutions/unlabelled/standard_pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/standard_pattern_edge.h index 817e829709..4119f5c47c 100644 --- a/lib/substitutions/include/substitutions/unlabelled/standard_pattern_edge.h +++ b/lib/substitutions/include/substitutions/unlabelled/standard_pattern_edge.h @@ -8,8 +8,8 @@ namespace FlexFlow { PatternNode get_src_node(StandardPatternEdge const &); PatternNode get_dst_node(StandardPatternEdge const &); -nonnegative_int get_src_idx(StandardPatternEdge const &); -nonnegative_int get_dst_idx(StandardPatternEdge const &); +TensorSlotName get_src_slot_name(StandardPatternEdge const &); +TensorSlotName get_dst_slot_name(StandardPatternEdge const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/unlabelled/standard_pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/standard_pattern_edge.struct.toml deleted file mode 100644 index 4a2e193544..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/standard_pattern_edge.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "StandardPatternEdge" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/dataflow_graph/dataflow_edge.dtg.h", -] - -[[fields]] -name = "raw_edge" -type = "::FlexFlow::DataflowEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h b/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h deleted file mode 100644 index 09d6a12716..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h +++ /dev/null @@ -1,32 +0,0 @@ -#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_DATAFLOW_GRAPH_PATTERN_MATCH_H -#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_DATAFLOW_GRAPH_PATTERN_MATCH_H - -#include "substitutions/pcg_pattern.dtg.h" -#include "substitutions/sub_parallel_computation_graph.dtg.h" -#include "substitutions/unlabelled/pattern_value.dtg.h" -#include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h" -#include -#include - -namespace FlexFlow { - -UnlabelledDataflowGraphPatternMatch empty_unlabelled_pattern_match(); -std::unordered_set - matched_nodes(UnlabelledDataflowGraphPatternMatch const &); -std::optional - merge_unlabelled_dataflow_graph_pattern_matches( - UnlabelledDataflowGraphPatternMatch const &subpattern_1, - UnlabelledDataflowGraphPatternMatch const &subpattern_2, - bidict const - &merged_graph_values_to_inputs_of_1, - bidict const - &merged_graph_values_to_inputs_of_2); - -std::unordered_map - get_output_assignment(SubParallelComputationGraph const &, - PCGPattern const &, - UnlabelledDataflowGraphPatternMatch const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.struct.toml b/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.struct.toml deleted file mode 100644 index 5e8538811c..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.struct.toml +++ /dev/null @@ -1,30 +0,0 @@ -namespace = "FlexFlow" -name = "UnlabelledDataflowGraphPatternMatch" -features = [ - "eq", - # "ord", - "hash", - "fmt", -] - -includes = [ - "utils/bidict/bidict.h", - "utils/graph/node/node.dtg.h", - "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h", - "substitutions/unlabelled/pattern_input.dtg.h", - "substitutions/unlabelled/pattern_node.dtg.h", - "", -] - -src_includes = [ - "utils/fmt/unordered_map.h", - "utils/hash/unordered_map.h", -] - -[[fields]] -name = "node_assignment" -type = "::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node>" - -[[fields]] -name = "input_assignment" -type = "std::unordered_map<::FlexFlow::PatternInput, ::FlexFlow::OpenDataflowValue>" diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.dtg.toml b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.dtg.toml new file mode 100644 index 0000000000..20971865d0 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.dtg.toml @@ -0,0 +1,12 @@ +namespace = "FlexFlow" +name = "UnlabelledGraphPattern" +type = "struct" +features = [] +includes = [ + "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::OpenKwargDataflowGraphView" diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h index 949fbf455b..716714e2d9 100644 --- a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h @@ -12,26 +12,29 @@ namespace FlexFlow { size_t num_nodes(UnlabelledGraphPattern const &); bool is_singleton_pattern(UnlabelledGraphPattern const &); -std::unordered_set get_nodes(UnlabelledGraphPattern const &); -std::unordered_set get_values(UnlabelledGraphPattern const &); +std::unordered_set + get_pattern_nodes(UnlabelledGraphPattern const &); +std::unordered_set + get_pattern_values(UnlabelledGraphPattern const &); std::vector get_topological_ordering(UnlabelledGraphPattern const &); std::unordered_set - get_graph_inputs(UnlabelledGraphPattern const &); + get_pattern_inputs(UnlabelledGraphPattern const &); -std::unordered_set get_edges(UnlabelledGraphPattern const &); +std::unordered_set + get_pattern_edges(UnlabelledGraphPattern const &); -std::vector +std::unordered_map get_inputs_to_pattern_node(UnlabelledGraphPattern const &, PatternNode const &); -std::vector +std::unordered_map get_outputs_from_pattern_node(UnlabelledGraphPattern const &, PatternNode const &); UnlabelledGraphPatternSubgraphResult - get_subgraph(UnlabelledGraphPattern const &, - std::unordered_set const &); + get_pattern_subgraph(UnlabelledGraphPattern const &, + std::unordered_set const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.struct.toml b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.struct.toml deleted file mode 100644 index 74371f21ef..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.struct.toml +++ /dev/null @@ -1,10 +0,0 @@ -namespace = "FlexFlow" -name = "UnlabelledGraphPattern" -features = [] -includes = [ - "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" -] - -[[fields]] -name = "raw_graph" -type = "::FlexFlow::OpenDataflowGraphView" diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern_subgraph_result.dtg.toml b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern_subgraph_result.dtg.toml new file mode 100644 index 0000000000..6c6ecfc2e7 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern_subgraph_result.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "UnlabelledGraphPatternSubgraphResult" +type = "struct" +features = [ ] + +includes = [ + "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h", + "substitutions/unlabelled/pattern_value.dtg.h", + "substitutions/unlabelled/pattern_input.dtg.h", + "utils/bidict/bidict.h", +] + +[[fields]] +name = "subpattern" +type = "::FlexFlow::UnlabelledGraphPattern" + +[[fields]] +name = "full_pattern_values_to_subpattern_inputs" +type = "::FlexFlow::bidict<::FlexFlow::PatternValue, ::FlexFlow::PatternInput>" diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern_subgraph_result.struct.toml b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern_subgraph_result.struct.toml deleted file mode 100644 index d718035f3e..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern_subgraph_result.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "UnlabelledGraphPatternSubgraphResult" -features = [ ] - -includes = [ - "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h", - "substitutions/unlabelled/pattern_value.dtg.h", - "substitutions/unlabelled/pattern_input.dtg.h", - "utils/bidict/bidict.h", -] - -[[fields]] -name = "subpattern" -type = "::FlexFlow::UnlabelledGraphPattern" - -[[fields]] -name = "full_pattern_values_to_subpattern_inputs" -type = "::FlexFlow::bidict<::FlexFlow::PatternValue, ::FlexFlow::PatternInput>" diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_kwarg_dataflow_graph_pattern_match.dtg.toml b/lib/substitutions/include/substitutions/unlabelled/unlabelled_kwarg_dataflow_graph_pattern_match.dtg.toml new file mode 100644 index 0000000000..c7f3b20394 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_kwarg_dataflow_graph_pattern_match.dtg.toml @@ -0,0 +1,32 @@ +namespace = "FlexFlow" +name = "UnlabelledKwargDataflowGraphPatternMatch" +type = "struct" +features = [ + "eq", + # "ord", + "hash", + "fmt", +] + +includes = [ + "utils/bidict/bidict.h", + "utils/graph/node/node.dtg.h", + "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_value.dtg.h", + "substitutions/unlabelled/pattern_input.dtg.h", + "substitutions/unlabelled/pattern_node.dtg.h", + "", + "op-attrs/tensor_slot_name.dtg.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "node_assignment" +type = "::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node>" + +[[fields]] +name = "input_assignment" +type = "std::unordered_map<::FlexFlow::PatternInput, ::FlexFlow::OpenKwargDataflowValue>" diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_kwarg_dataflow_graph_pattern_match.h b/lib/substitutions/include/substitutions/unlabelled/unlabelled_kwarg_dataflow_graph_pattern_match.h new file mode 100644 index 0000000000..d175747852 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_kwarg_dataflow_graph_pattern_match.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_KWARG_DATAFLOW_GRAPH_PATTERN_MATCH_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_KWARG_DATAFLOW_GRAPH_PATTERN_MATCH_H + +#include "substitutions/pcg_pattern.dtg.h" +#include "substitutions/sub_parallel_computation_graph.dtg.h" +#include "substitutions/unlabelled/pattern_value.dtg.h" +#include "substitutions/unlabelled/unlabelled_kwarg_dataflow_graph_pattern_match.dtg.h" +#include +#include + +namespace FlexFlow { + +UnlabelledKwargDataflowGraphPatternMatch empty_unlabelled_pattern_match(); +std::unordered_set + matched_nodes(UnlabelledKwargDataflowGraphPatternMatch const &); +std::optional + merge_unlabelled_dataflow_graph_pattern_matches( + UnlabelledKwargDataflowGraphPatternMatch const &subpattern_1, + UnlabelledKwargDataflowGraphPatternMatch const &subpattern_2, + bidict const + &merged_graph_values_to_inputs_of_1, + bidict const + &merged_graph_values_to_inputs_of_2); + +std::unordered_map, PatternValue> + get_output_assignment(SubParallelComputationGraph const &, + PCGPattern const &, + UnlabelledKwargDataflowGraphPatternMatch const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/src/substitution_loader.cc b/lib/substitutions/src/substitution_loader.cc deleted file mode 100644 index 49e8ff69ed..0000000000 --- a/lib/substitutions/src/substitution_loader.cc +++ /dev/null @@ -1,80 +0,0 @@ -/* #include "substitutions/substitution_loader.h" */ -/* #include */ -/* #include */ -/* #include */ - -/* using json = nlohmann::json; */ - -/* namespace FlexFlow { */ -/* namespace substitution_loader { */ - -/* void from_json(json const &j, Parameter &p) { */ -/* j.at("key").get_to(p.key); */ -/* j.at("value").get_to(p.value); */ -/* if (p.key == PM_INVALID) { */ -/* std::ostringstream oss; */ -/* oss << "Attempted to load invalid PMParameter: " << j.at("key"); */ -/* throw std::runtime_error(oss.str()); */ -/* } */ -/* } */ - -/* void from_json(json const &j, Tensor &t) { */ -/* j.at("opId").get_to(t.opId); */ -/* j.at("tsId").get_to(t.tsId); */ -/* } */ - -/* tl::optional Operator::at(PMParameter key) const { */ -/* tl::optional value = tl::nullopt; */ -/* for (Parameter const &p : this->para) { */ -/* if (p.key == key) { */ -/* assert(!value.has_value()); */ -/* value = p.key; */ -/* } */ -/* } */ - -/* return value; */ -/* } */ - -/* void from_json(json const &j, Operator &o) { */ -/* j.at("type").get_to(o.op_type); */ -/* j.at("input").get_to(o.input); */ -/* j.at("para").get_to(o.para); */ -/* if (o.op_type == OP_INVALID) { */ -/* std::ostringstream oss; */ -/* oss << "Attempted to load invalid OperatorType: " << j.at("type"); */ -/* throw std::runtime_error(oss.str()); */ -/* } */ -/* } */ - -/* void from_json(json const &j, MapOutput &m) { */ -/* j.at("dstOpId").get_to(m.dstOpId); */ -/* j.at("dstTsId").get_to(m.dstTsId); */ -/* j.at("srcOpId").get_to(m.srcOpId); */ -/* j.at("srcTsId").get_to(m.srcTsId); */ -/* } */ - -/* void from_json(json const &j, Rule &r) { */ -/* j.at("name").get_to(r.name); */ -/* j.at("srcOp").get_to(r.srcOp); */ -/* j.at("dstOp").get_to(r.dstOp); */ -/* j.at("mappedOutput").get_to(r.mappedOutput); */ -/* } */ - -/* void from_json(json const &j, RuleCollection &c) { */ -/* j.at("rule").get_to(c.rules); */ -/* } */ - -/* RuleCollection load_rule_collection(std::istream &s) { */ -/* json j; */ -/* s >> j; */ -/* RuleCollection rule_collection = j; */ -/* return rule_collection; */ -/* } */ - -/* RuleCollection load_rule_collection_from_path(std::string const &path) { */ -/* std::ifstream input(path); */ -/* return load_rule_collection(input); */ -/* } */ - -/* } */ -/* } */ diff --git a/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc b/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc index 61bfe15d7b..4c355acb4b 100644 --- a/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc +++ b/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc @@ -7,9 +7,10 @@ #include "substitutions/pcg_pattern_match.h" #include "substitutions/sub_parallel_computation_graph.h" #include "substitutions/sub_parallel_computation_graph_data.dtg.h" +#include "substitutions/sub_parallel_computation_graph_data.h" #include "substitutions/sub_parallel_computation_graph_edge.h" +#include "utils/containers/binary_merge_disjoint_maps.h" #include "utils/containers/keys.h" -#include "utils/containers/merge_maps.h" #include "utils/containers/restrict_keys.h" #include "utils/containers/set_minus.h" #include "utils/containers/values.h" @@ -20,6 +21,9 @@ SubParallelComputationGraph apply_substitution(SubParallelComputationGraph const &spcg, Substitution const &sub, PCGPatternMatch const &match) { + assert_pcg_pattern_match_is_valid_for_pattern_and_subpcg( + match, sub.pcg_pattern, spcg); + auto substitution_output_result = evaluate_substitution_output(spcg, sub, match); SubParallelComputationGraph substitution_output_graph = @@ -29,7 +33,10 @@ SubParallelComputationGraph SubParallelComputationGraphData output_graph_data = get_sub_pcg_data(substitution_output_graph); + require_sub_parallel_computation_graph_data_is_valid(output_graph_data); + SubParallelComputationGraphData pre_data = get_sub_pcg_data(spcg); + require_sub_parallel_computation_graph_data_is_valid(pre_data); std::unordered_set pre_nodes = keys(pre_data.node_data); @@ -46,17 +53,21 @@ SubParallelComputationGraph std::unordered_map post_node_data_from_sub = output_graph_data.node_data; - return merge_disjoint_maps(post_node_data_from_orig, - post_node_data_from_sub); + return binary_merge_disjoint_maps(post_node_data_from_orig, + post_node_data_from_sub); }(); + std::unordered_set post_inputs = + pre_data.inputs; + std::unordered_set post_edges = [&] { std::unordered_set post_edges_from_orig = filter(pre_data.edges, [&](SubParallelComputationGraphEdge const &e) { - if (e.raw_edge.has()) { + if (e.raw_edge.is_input_edge()) { return true; } else { - DataflowEdge dfe = e.raw_edge.get(); + KwargDataflowEdge dfe = + e.raw_edge.require_internal_edge(); parallel_layer_guid_t src = parallel_layer_guid_t{dfe.src.node}; parallel_layer_guid_t dst = parallel_layer_guid_t{dfe.dst.node}; return !(contains(matched_nodes, src) || @@ -67,7 +78,7 @@ SubParallelComputationGraph std::unordered_set post_edges_from_sub = filter(output_graph_data.edges, [&](SubParallelComputationGraphEdge const &e) { - return !e.raw_edge.has(); + return e.raw_edge.is_internal_edge(); }); bidict @@ -113,7 +124,7 @@ SubParallelComputationGraph subpcg_edge_from_tensor_and_dst( new_tensor, get_dst_layer(outgoing_edge), - get_dst_layer_input_idx(outgoing_edge)); + get_dst_layer_input_slot_name(outgoing_edge)); outgoing_from_sub_edges.insert(new_edge); } @@ -125,9 +136,6 @@ SubParallelComputationGraph }); }(); - std::unordered_set post_inputs = - pre_data.inputs; - std::unordered_map post_value_data = [&] { std::unordered_map @@ -148,8 +156,8 @@ SubParallelComputationGraph std::unordered_map post_value_data_from_sub = output_graph_data.value_data; - return merge_disjoint_maps(post_value_data_from_orig, - post_value_data_from_sub); + return binary_merge_disjoint_maps(post_value_data_from_orig, + post_value_data_from_sub); }(); SubParallelComputationGraphData post_data = SubParallelComputationGraphData{ diff --git a/lib/substitutions/src/substitutions/apply_substitution/evaluate_substitution_output.cc b/lib/substitutions/src/substitutions/apply_substitution/evaluate_substitution_output.cc index 272c5f2dd5..57a93daefc 100644 --- a/lib/substitutions/src/substitutions/apply_substitution/evaluate_substitution_output.cc +++ b/lib/substitutions/src/substitutions/apply_substitution/evaluate_substitution_output.cc @@ -2,16 +2,21 @@ #include "substitutions/apply_substitution/perform_shape_inference.h" #include "substitutions/output_graph/output_operator_attrs_assignment.h" #include "substitutions/sub_parallel_computation_graph.h" +#include "utils/bidict/algorithms/transform_keys.h" +#include "utils/bidict/algorithms/transform_values.h" +#include "utils/bidict/generate_bidict.h" #include "utils/containers/map_keys.h" #include "utils/containers/map_values.h" -#include "utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.h" -#include "utils/graph/labelled_open_dataflow_graph/algorithms/permute_node_ids.h" -#include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.h" -#include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_value_labels.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/permute_labelled_open_kwarg_dataflow_graph_input_ids.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/permute_labelled_open_kwarg_dataflow_graph_node_ids.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/rewrite_labelled_open_kwarg_dataflow_graph_node_labels.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/rewrite_labelled_open_kwarg_dataflow_graph_value_labels.h" #include "utils/graph/node/algorithms/generate_new_node_id_permutation.h" #include "utils/graph/node/algorithms/new_node.dtg.h" #include "utils/graph/open_dataflow_graph/algorithms/generate_new_input_id_permutation.h" #include "utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/generate_new_kwarg_dataflow_graph_input_id_permutation.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/new_kwarg_dataflow_graph_input.dtg.h" namespace FlexFlow { @@ -27,16 +32,34 @@ std::pair bidict new_node_id_permutation = generate_new_node_id_permutation(sub.output_graph_expr.raw_graph); - bidict new_input_id_permutation = - generate_new_input_id_permutation(sub.output_graph_expr.raw_graph); - LabelledOpenDataflowGraphView - permuted = - permute_input_ids(permute_node_ids(sub.output_graph_expr.raw_graph, - new_node_id_permutation), - new_input_id_permutation); - LabelledOpenDataflowGraphView - without_shapes = rewrite_node_labels( + int graph_input_ctr = 0; + auto graph_input_source = [&]() -> int { + int result = graph_input_ctr; + graph_input_ctr++; + return result; + }; + + bidict, KwargDataflowGraphInput> + new_input_id_permutation = + generate_new_kwarg_dataflow_graph_input_id_permutation( + sub.output_graph_expr.raw_graph, + std::function{graph_input_source}); + + LabelledOpenKwargDataflowGraphView + permuted = permute_labelled_open_kwarg_dataflow_graph_input_ids( + permute_labelled_open_kwarg_dataflow_graph_node_ids( + sub.output_graph_expr.raw_graph, new_node_id_permutation), + new_input_id_permutation); + + LabelledOpenKwargDataflowGraphView + without_shapes = rewrite_labelled_open_kwarg_dataflow_graph_node_labels( permuted, [&](Node const &n, OutputOperatorAttrsAssignment const &attrs) { return ParallelLayerAttrs{ @@ -47,35 +70,45 @@ std::pair }); bidict result_input_map = - map_keys(map_values(new_input_id_permutation, - [](DataflowGraphInput const &i) { - return OutputGraphExprInput{i}; - }), - [](NewDataflowGraphInput const &i) { - return input_parallel_tensor_guid_t{i.raw_input}; - }); + transform_keys( + transform_values(new_input_id_permutation, + [](KwargDataflowGraphInput const &i) { + return OutputGraphExprInput{i}; + }), + [](KwargDataflowGraphInput const &i) { + return input_parallel_tensor_guid_t{i}; + }); - bidict result_node_map = map_keys( - map_values(new_node_id_permutation, - [](Node const &n) { return OutputGraphExprNode{n}; }), - [](NewNode const &n) { return parallel_layer_guid_t{n.raw_node}; }); + bidict result_node_map = + transform_keys( + transform_values( + new_node_id_permutation, + [](Node const &n) { return OutputGraphExprNode{n}; }), + [](NewNode const &n) { return parallel_layer_guid_t{n.raw_node}; }); - std::unordered_map input_shapes = - map_values(map_keys(match.input_assignment, - [&](PatternInput const &i) { - return result_input_map - .at_r(sub.inputs_mapping.at_l(i)) - .raw_dataflow_graph_input; - }), - [&](open_parallel_tensor_guid_t const &v) { - return spcg.raw_graph.at(v.raw_open_dataflow_value).shape; - }); - LabelledOpenDataflowGraphView + std::unordered_map, ParallelTensorShape> + input_shapes = map_values( + map_keys(match.input_assignment, + [&](PatternInput const &i) { + return result_input_map.at_r(sub.inputs_mapping.at_l(i)) + .raw_dataflow_graph_input; + }), + [&](open_parallel_tensor_guid_t const &v) { + return spcg.raw_graph.at(v.raw_open_dataflow_value).shape; + }); + LabelledOpenKwargDataflowGraphView with_shapes = perform_shape_inference(without_shapes, input_shapes); - LabelledOpenDataflowGraphView - with_attrs = rewrite_value_labels( + LabelledOpenKwargDataflowGraphView + with_attrs = rewrite_labelled_open_kwarg_dataflow_graph_value_labels( with_shapes, - [](OpenDataflowValue const &, ParallelTensorShape const &s) { + [](OpenKwargDataflowValue const &, + ParallelTensorShape const &s) { return ParallelTensorAttrs{ s, CreateGrad::YES, diff --git a/lib/substitutions/src/substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.cc b/lib/substitutions/src/substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.cc index a5fc9a2e06..d263bb842a 100644 --- a/lib/substitutions/src/substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.cc +++ b/lib/substitutions/src/substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.cc @@ -1,8 +1,10 @@ #include "substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.h" #include "substitutions/output_graph/output_graph_expr.h" #include "substitutions/sub_parallel_computation_graph.h" -#include "utils/bidict/algorithms/bidict_from_keys_and_values.h" +#include "utils/bidict/algorithms/bidict_from_pairs.h" #include "utils/bidict/algorithms/merge_disjoint_bidicts.h" +#include "utils/containers/values.h" +#include "utils/containers/zip_values_strict.h" namespace FlexFlow { @@ -14,14 +16,15 @@ bidict bidict result; for (auto const &[parallel_layer, output_graph_expr_node] : m.node_mapping) { - std::vector layer_outputs = + std::unordered_map layer_outputs = get_layer_outputs(spcg, parallel_layer); - std::vector output_graph_expr_outputs = - get_node_outputs(output_graph_expr, output_graph_expr_node); + std::unordered_map + output_graph_expr_outputs = + get_node_outputs(output_graph_expr, output_graph_expr_node); bidict - mapping_for_layer = bidict_from_keys_and_values( - layer_outputs, output_graph_expr_outputs); + mapping_for_layer = bidict_from_pairs(values( + zip_values_strict(layer_outputs, output_graph_expr_outputs))); result = merge_disjoint_bidicts(result, mapping_for_layer); } diff --git a/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc b/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc index 87dd5e6cbd..e7dc926682 100644 --- a/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc +++ b/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc @@ -1,77 +1,96 @@ #include "substitutions/apply_substitution/perform_shape_inference.h" #include "op-attrs/get_incoming_tensor_roles.h" #include "op-attrs/shape_inference.h" +#include "utils/containers/filter_values.h" #include "utils/containers/filtrans.h" #include "utils/containers/map_keys.h" +#include "utils/containers/map_values.h" +#include "utils/containers/restrict_keys.h" #include "utils/containers/transform.h" +#include "utils/containers/values.h" #include "utils/containers/zip.h" +#include "utils/containers/zip_values_strict.h" #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/digraph/algorithms/get_topological_ordering.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_value_labels.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/rewrite_labelled_open_kwarg_dataflow_graph_value_labels.h" #include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_incoming_open_kwarg_dataflow_values_for_node.h" +#include "utils/nonnegative_int/num_elements.h" namespace FlexFlow { -LabelledOpenDataflowGraphView +LabelledOpenKwargDataflowGraphView perform_shape_inference( - LabelledOpenDataflowGraphView const - &g, - std::unordered_map const - &input_shapes) { + LabelledOpenKwargDataflowGraphView const &g, + std::unordered_map, + ParallelTensorShape> const &input_shapes) { - std::unordered_map inferred = - map_keys(input_shapes, [](DataflowGraphInput const &i) { - return OpenDataflowValue{i}; - }); + std::unordered_map, + ParallelTensorShape> + inferred = + map_keys(input_shapes, + [](KwargDataflowGraphInput const &i) + -> OpenKwargDataflowValue { + return OpenKwargDataflowValue{i}; + }); for (Node const &n : get_topological_ordering(g)) { - std::vector incoming_shapes = - transform(get_inputs(g, n), - [&](OpenDataflowValue const &v) { return inferred.at(v); }); + std::unordered_map incoming_shapes = + map_values(get_incoming_open_kwarg_dataflow_values_for_node(g, n), + [&](OpenKwargDataflowValue const &v) { + return inferred.at(v); + }); ParallelLayerAttrs n_attrs = g.at(n); - std::vector incoming_tensor_roles = - get_incoming_tensor_roles(n_attrs.op_attrs, incoming_shapes.size()); - - auto incoming_shapes_with_role = - [&](IncomingTensorRole role) -> std::vector { - return filtrans( - zip(incoming_shapes, incoming_tensor_roles), - [&](std::pair const &t) - -> std::optional { - if (t.second == role) { - return t.first; - } else { - return std::nullopt; - } - }); + std::unordered_map + incoming_tensor_roles = get_incoming_tensor_roles(n_attrs.op_attrs); + + auto incoming_shapes_with_role = [&](IncomingTensorRole role) + -> std::unordered_map { + std::unordered_set slots_with_desired_role = + keys(filter_values(incoming_tensor_roles, + [&](IncomingTensorRole r) { return r == role; })); + + return restrict_keys(incoming_shapes, slots_with_desired_role); }; - std::vector input_shapes = + std::unordered_map input_shapes = incoming_shapes_with_role(IncomingTensorRole::INPUT); - std::vector weight_shapes = + std::unordered_map weight_shapes = incoming_shapes_with_role(IncomingTensorRole::WEIGHT); - std::vector inferred_weight_shapes = - get_weight_shapes(n_attrs.op_attrs, input_shapes); + std::unordered_map + inferred_weight_shapes = + get_weight_shapes(n_attrs.op_attrs, input_shapes); - assert(weight_shapes == inferred_weight_shapes); + ASSERT(weight_shapes == inferred_weight_shapes); - std::vector output_shapes = + std::unordered_map output_shapes = get_output_shapes(n_attrs.op_attrs, input_shapes); - std::vector outputs = get_outputs(g, n); + std::unordered_map> + outputs = get_outgoing_kwarg_dataflow_outputs_for_node(g, n); - for (auto const &[output, shape] : zip(outputs, output_shapes)) { - inferred.insert({OpenDataflowValue{output}, shape}); + for (auto const &[output, shape] : + values(zip_values_strict(outputs, output_shapes))) { + inferred.insert( + {OpenKwargDataflowValue{output}, shape}); } } - return rewrite_value_labels( - g, [&](OpenDataflowValue const &v, std::monostate const &) { - return inferred.at(v); - }); + return rewrite_labelled_open_kwarg_dataflow_graph_value_labels( + g, + [&](OpenKwargDataflowValue const &v, + std::monostate const &) { return inferred.at(v); }); } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/open_parallel_tensor_guid_t.cc b/lib/substitutions/src/substitutions/open_parallel_tensor_guid_t.cc index 76329229a4..c8b939a17b 100644 --- a/lib/substitutions/src/substitutions/open_parallel_tensor_guid_t.cc +++ b/lib/substitutions/src/substitutions/open_parallel_tensor_guid_t.cc @@ -4,13 +4,14 @@ namespace FlexFlow { open_parallel_tensor_guid_t open_parallel_tensor_guid_from_closed(parallel_tensor_guid_t t) { - return open_parallel_tensor_guid_t{OpenDataflowValue{t.raw_graph_output}}; + return open_parallel_tensor_guid_t{ + OpenKwargDataflowValue{t.raw_graph_output}}; } open_parallel_tensor_guid_t open_parallel_tensor_guid_from_input(input_parallel_tensor_guid_t i) { return open_parallel_tensor_guid_t{ - OpenDataflowValue{i.raw_dataflow_graph_input}}; + OpenKwargDataflowValue{i.raw_dataflow_graph_input}}; } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/eval_list_access.cc b/lib/substitutions/src/substitutions/operator_pattern/eval_list_access.cc index 6f41772a9e..67b83c9569 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/eval_list_access.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/eval_list_access.cc @@ -1,9 +1,10 @@ #include "substitutions/operator_pattern/eval_list_access.h" #include "substitutions/operator_pattern/get_attribute.h" -#include "utils/containers/at_idx.h" #include "utils/containers/make.h" #include "utils/containers/transform.h" +#include "utils/containers/try_at_idx.h" #include "utils/overload.h" +#include namespace FlexFlow { @@ -22,13 +23,13 @@ std::optional using T = std::decay_t; if constexpr (std::is_same_v>) { - return transform(at_idx(v, acc.index), + return transform(try_at_idx(v, acc.index), make()); } else if constexpr (std::is_same_v>) { - return transform(at_idx(v, acc.index), + return transform(try_at_idx(v, acc.index), make()); } else { - throw mk_runtime_error("Invalid operand"); + PANIC("Invalid operand"); } }); } diff --git a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc index cb733e16ff..f7fce1aca7 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc @@ -384,7 +384,7 @@ std::optional get_attribute(TransposeAttrs const &p, case OperatorAttributeKey::OP_TYPE: return OperatorAttributeValue{get_op_type(p)}; case OperatorAttributeKey::PERMUTATION: - return OperatorAttributeValue{vector_of(p.perm)}; + return OperatorAttributeValue{p.permutation}; default: return std::nullopt; } diff --git a/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc b/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc index 194ae49255..96c33989fe 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc @@ -1,5 +1,6 @@ #include "substitutions/operator_pattern/satisfies_constraint.h" #include "substitutions/operator_pattern/operator_attribute_expr.h" +#include namespace FlexFlow { @@ -17,9 +18,7 @@ bool operator_satisfies_constraint( case ConstraintType::EQUAL: return expr_val.value() == constraint.attribute_value; default: - throw mk_runtime_error( - fmt::format("Unknown constraint type {}", - static_cast(constraint.constraint_type))); + PANIC("Unknown constraint type", constraint.constraint_type); } } diff --git a/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc b/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc index cf5a1e17f9..11ef85984c 100644 --- a/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc +++ b/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc @@ -1,6 +1,7 @@ #include "substitutions/output_graph/materialize_operator_from_attrs_map.h" #include "utils/containers/contains_key.h" #include "utils/fmt/unordered_map.h" +#include namespace FlexFlow { @@ -16,8 +17,7 @@ struct Accessor { if (contains_key(this->m, k)) { return this->m.at(k).get(); } else { - throw mk_runtime_error( - fmt::format("Could not find key {} in attrs map: {}", k, this->m)); + PANIC("Could not find key in attrs map", k, this->m); } } }; @@ -151,8 +151,7 @@ PCGOperatorAttrs materialize_operator_from_attrs_map( case OperatorType::PIPELINE: case OperatorType::FUSED_PARALLEL: default: - throw mk_runtime_error( - fmt::format("Unsupported operator type {}", op_type)); + PANIC("Unsupported operator type", op_type); } } diff --git a/lib/substitutions/src/substitutions/output_graph/output_graph_expr.cc b/lib/substitutions/src/substitutions/output_graph/output_graph_expr.cc index f6d1410a07..3bc1d04abc 100644 --- a/lib/substitutions/src/substitutions/output_graph/output_graph_expr.cc +++ b/lib/substitutions/src/substitutions/output_graph/output_graph_expr.cc @@ -1,8 +1,9 @@ #include "substitutions/output_graph/output_graph_expr.h" +#include "utils/containers/map_values.h" #include "utils/containers/transform.h" -#include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_graph_inputs.h" namespace FlexFlow { @@ -13,21 +14,23 @@ std::unordered_set get_nodes(OutputGraphExpr const &g) { [](Node const &n) { return OutputGraphExprNode{n}; }); } -std::vector +std::unordered_map get_node_outputs(OutputGraphExpr const &g, OutputGraphExprNode const &n) { - std::vector raw_outputs = - get_outputs(g.raw_graph, n.raw_graph_node); - - return transform(raw_outputs, [](DataflowOutput const &o) { - return OutputGraphExprNodeOutput{o}; - }); + std::unordered_map> + raw_outputs = get_outgoing_kwarg_dataflow_outputs_for_node( + g.raw_graph, n.raw_graph_node); + + return map_values(raw_outputs, + [](KwargDataflowOutput const &o) { + return OutputGraphExprNodeOutput{o}; + }); } std::unordered_set get_inputs(OutputGraphExpr const &g) { - std::unordered_set raw_inputs = - get_open_dataflow_graph_inputs(g.raw_graph); + std::unordered_set> raw_inputs = + get_all_kwarg_dataflow_graph_inputs(g.raw_graph); - return transform(raw_inputs, [](DataflowGraphInput const &i) { + return transform(raw_inputs, [](KwargDataflowGraphInput const &i) { return OutputGraphExprInput{i}; }); } diff --git a/lib/substitutions/src/substitutions/output_graph/output_graph_expr_value.cc b/lib/substitutions/src/substitutions/output_graph/output_graph_expr_value.cc index b35f3bbeae..c0fb979683 100644 --- a/lib/substitutions/src/substitutions/output_graph/output_graph_expr_value.cc +++ b/lib/substitutions/src/substitutions/output_graph/output_graph_expr_value.cc @@ -3,25 +3,28 @@ namespace FlexFlow { -OpenDataflowValue raw_open_dataflow_value_from_output_graph_expr_value( - OutputGraphExprValue const &v) { - return v.visit(overload{ +OpenKwargDataflowValue + raw_open_kwarg_dataflow_value_from_output_graph_expr_value( + OutputGraphExprValue const &v) { + return v.visit>(overload{ [](OutputGraphExprNodeOutput const &o) { - return OpenDataflowValue{o.raw_dataflow_output}; + return OpenKwargDataflowValue{ + o.raw_dataflow_output}; }, [](OutputGraphExprInput const &i) { - return OpenDataflowValue{i.raw_dataflow_graph_input}; + return OpenKwargDataflowValue{ + i.raw_dataflow_graph_input}; }, }); } -OutputGraphExprValue output_graph_expr_value_from_raw_open_dataflow_value( - OpenDataflowValue const &v) { +OutputGraphExprValue output_graph_expr_value_from_raw_open_kwarg_dataflow_value( + OpenKwargDataflowValue const &v) { return v.visit(overload{ - [](DataflowOutput const &o) { + [](KwargDataflowOutput const &o) { return OutputGraphExprValue{OutputGraphExprNodeOutput{o}}; }, - [](DataflowGraphInput const &i) { + [](KwargDataflowGraphInput const &i) { return OutputGraphExprValue{OutputGraphExprInput{i}}; }, }); diff --git a/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.cc b/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.cc index f6b90ef054..647362ee4d 100644 --- a/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.cc +++ b/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.cc @@ -2,8 +2,9 @@ #include "substitutions/operator_pattern/get_attribute_map.h" #include "substitutions/output_graph/materialize_operator_from_attrs_map.h" #include "substitutions/output_graph/output_operator_attribute_expr.h" +#include "utils/containers/binary_merge_maps_with_right_dominating.h" #include "utils/containers/map_values.h" -#include "utils/containers/merge_maps.h" +#include "utils/exception.h" namespace FlexFlow { @@ -35,8 +36,8 @@ PCGOperatorAttrs materialize_output_operator_from_attrs_assignment( }); std::unordered_map - joined_attrs_map = - merge_map_right_dominates(template_attrs_map, assignments_attrs_map); + joined_attrs_map = binary_merge_maps_with_right_dominating( + template_attrs_map, assignments_attrs_map); return materialize_operator_from_attrs_map(joined_attrs_map); } diff --git a/lib/substitutions/src/substitutions/pcg_pattern.cc b/lib/substitutions/src/substitutions/pcg_pattern.cc index a0af875848..15bce488ea 100644 --- a/lib/substitutions/src/substitutions/pcg_pattern.cc +++ b/lib/substitutions/src/substitutions/pcg_pattern.cc @@ -5,12 +5,12 @@ #include "substitutions/tensor_pattern/satisfies_pattern.h" #include "substitutions/unlabelled/find_pattern_matches.h" #include "substitutions/unlabelled/pattern_value.h" +#include "utils/bidict/algorithms/transform_values.h" #include "utils/containers/map_values.h" #include "utils/containers/transform.h" -#include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_graph_inputs.h" namespace FlexFlow { @@ -29,7 +29,8 @@ static MatchAdditionalCriterion get_operator_attrs(pcg, parallel_layer_guid_t{pcgNode}), get_operator_pattern(pattern, patternNode)); }, - [&](PatternValue const &patternValue, OpenDataflowValue const &pcgValue) { + [&](PatternValue const &patternValue, + OpenKwargDataflowValue const &pcgValue) { return parallel_tensor_satisfies_pattern( get_parallel_tensor_attrs(pcg, open_parallel_tensor_guid_t{pcgValue}), @@ -40,19 +41,21 @@ static MatchAdditionalCriterion std::vector find_pattern_matches(PCGPattern const &pattern, SubParallelComputationGraph const &pcg) { - std::vector unlabelled_matches = - find_pattern_matches(get_unlabelled_pattern(pattern), - pcg.raw_graph, - pcg_pattern_criteria(pattern, pcg)); + std::vector unlabelled_matches = + find_unlabelled_pattern_matches(get_unlabelled_pattern(pattern), + pcg.raw_graph, + pcg_pattern_criteria(pattern, pcg)); auto pcg_match_from_unlabelled_match = - [](UnlabelledDataflowGraphPatternMatch const &m) { + [](UnlabelledKwargDataflowGraphPatternMatch const &m) { return PCGPatternMatch{ - map_values(m.node_assignment, - [](Node const &n) { return parallel_layer_guid_t{n}; }), - map_values(m.input_assignment, - [](OpenDataflowValue const &i) { - return open_parallel_tensor_guid_t{i}; - }), + transform_values( + m.node_assignment, + [](Node const &n) { return parallel_layer_guid_t{n}; }), + map_values( + m.input_assignment, + [](OpenKwargDataflowValue const &i) { + return open_parallel_tensor_guid_t{i}; + }), }; }; @@ -74,22 +77,25 @@ OperatorAttributePattern get_operator_pattern(PCGPattern const &p, } std::unordered_set get_inputs(PCGPattern const &p) { - std::unordered_set raw_inputs = - get_open_dataflow_graph_inputs(p.raw_graph); + std::unordered_set> raw_inputs = + get_all_kwarg_dataflow_graph_inputs(p.raw_graph); - return transform(raw_inputs, - [](DataflowGraphInput const &i) { return PatternInput{i}; }); + return transform(raw_inputs, [](KwargDataflowGraphInput const &i) { + return PatternInput{i}; + }); } -std::vector +std::unordered_map get_pattern_node_outputs(PCGPattern const &pattern, PatternNode const &node) { - std::vector raw_outputs = - get_outputs(pattern.raw_graph, node.raw_node); + std::unordered_map> + raw_outputs = get_outgoing_kwarg_dataflow_outputs_for_node( + pattern.raw_graph, node.raw_node); - return transform(raw_outputs, [](DataflowOutput const &o) { - return PatternNodeOutput{o}; - }); + return map_values(raw_outputs, + [](KwargDataflowOutput const &o) { + return PatternNodeOutput{o}; + }); } bool assignment_satisfies(SubParallelComputationGraph const &pcg, diff --git a/lib/substitutions/src/substitutions/pcg_pattern_match.cc b/lib/substitutions/src/substitutions/pcg_pattern_match.cc index b701be65cf..0596764747 100644 --- a/lib/substitutions/src/substitutions/pcg_pattern_match.cc +++ b/lib/substitutions/src/substitutions/pcg_pattern_match.cc @@ -1,9 +1,15 @@ #include "substitutions/pcg_pattern_match.h" #include "substitutions/pcg_pattern.h" #include "substitutions/sub_parallel_computation_graph.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.h" #include "utils/bidict/algorithms/bidict_from_keys_and_values.h" +#include "utils/bidict/algorithms/bidict_from_map.h" +#include "utils/bidict/algorithms/exhaustive_relational_join.h" #include "utils/bidict/algorithms/merge_disjoint_bidicts.h" +#include "utils/bidict/algorithms/transform_values.h" +#include "utils/containers/is_subseteq_of.h" #include "utils/containers/map_values.h" +#include "utils/containers/values.h" #include "utils/containers/zip.h" namespace FlexFlow { @@ -16,16 +22,17 @@ bidict bidict result; for (auto const &[pattern_node, matched_layer] : match.node_assignment) { - std::vector matched_layer_output_tensors = - get_layer_outputs(spcg, matched_layer); - std::vector pattern_node_outputs = - get_pattern_node_outputs(pattern, pattern_node); + bidict + matched_layer_output_tensors = + bidict_from_map(get_layer_outputs(spcg, matched_layer)); + bidict pattern_node_outputs = + bidict_from_map(get_pattern_node_outputs(pattern, pattern_node)); assert(matched_layer_output_tensors.size() == pattern_node_outputs.size()); bidict mapping = - bidict_from_keys_and_values(pattern_node_outputs, - matched_layer_output_tensors); + exhaustive_relational_join(pattern_node_outputs.reversed(), + matched_layer_output_tensors); result = merge_disjoint_bidicts(result, mapping); } @@ -33,10 +40,10 @@ bidict return result; } -UnlabelledDataflowGraphPatternMatch +UnlabelledKwargDataflowGraphPatternMatch get_unlabelled_pattern_match(PCGPatternMatch const &match) { - return UnlabelledDataflowGraphPatternMatch{ - map_values( + return UnlabelledKwargDataflowGraphPatternMatch{ + transform_values( match.node_assignment, [](parallel_layer_guid_t const &l) { return l.raw_graph_node; }), map_values(match.input_assignment, @@ -46,4 +53,31 @@ UnlabelledDataflowGraphPatternMatch }; } +void assert_pcg_pattern_match_is_valid_for_pattern_and_subpcg( + PCGPatternMatch const &match, + PCGPattern const &pattern, + SubParallelComputationGraph const &spcg) { + std::unordered_set spcg_nodes = + get_parallel_layers(spcg); + std::unordered_set match_nodes = + match.node_assignment.right_values(); + ASSERT(is_subseteq_of(match_nodes, spcg_nodes)); + + std::unordered_set spcg_values = + get_parallel_tensors(spcg); + std::unordered_set match_values = + unordered_set_of(values(match.input_assignment)); + ASSERT(is_subseteq_of(match_values, spcg_values)); + + std::unordered_set pattern_nodes = get_nodes(pattern); + std::unordered_set match_pattern_nodes = + match.node_assignment.left_values(); + ASSERT(match_pattern_nodes == pattern_nodes); + + std::unordered_set pattern_inputs = get_inputs(pattern); + std::unordered_set match_pattern_inputs = + keys(match.input_assignment); + ASSERT(pattern_inputs == match_pattern_inputs); +} + } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc index 83df74f21b..12074bff33 100644 --- a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc @@ -2,18 +2,18 @@ #include "op-attrs/pcg_operator_attrs.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/containers/values.h" -#include "utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.h" -#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" -#include "utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.h" -#include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h" -#include "utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h" -#include "utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h" -#include "utils/graph/labelled_open_dataflow_graph/algorithms/from_labelled_open_dataflow_graph_data.h" -#include "utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h" -#include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.h" +#include "utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_subgraph_outgoing_edges.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/view_as_labelled_open_kwarg_dataflow_graph.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/find_isomorphism_between_labelled_open_kwarg_dataflow_graphs.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/get_labelled_open_kwarg_dataflow_graph_data.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/rewrite_labelled_open_kwarg_dataflow_graph_node_labels.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/view_from_labelled_open_kwarg_dataflow_graph_data.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_incoming_edges.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_incoming_open_kwarg_dataflow_values_for_node.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_subgraph_incoming_edges.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_value_uses.h" namespace FlexFlow { @@ -23,6 +23,15 @@ std::unordered_set [](Node const &n) { return parallel_layer_guid_t{n}; }); } +std::unordered_set + get_parallel_tensors(SubParallelComputationGraph const &sub_pcg) { + return transform(get_all_open_kwarg_dataflow_values(sub_pcg.raw_graph), + [](OpenKwargDataflowValue const &v) + -> open_parallel_tensor_guid_t { + return open_parallel_tensor_guid_t{v}; + }); +} + ParallelLayerAttrs get_parallel_layer_attrs(SubParallelComputationGraph const &spcg, parallel_layer_guid_t const &layer) { @@ -43,16 +52,24 @@ ParallelTensorAttrs SubParallelComputationGraph sub_pcg_from_full_pcg(ParallelComputationGraph const &pcg) { return SubParallelComputationGraph{ - view_as_labelled_open_dataflow_graph(pcg.raw_graph)}; + view_as_labelled_open_kwarg_dataflow_graph( + pcg.raw_graph)}; } ParallelComputationGraph pcg_from_sub_pcg_by_dropping_inputs( SubParallelComputationGraph const &sub_pcg) { return ParallelComputationGraph{ - LabelledDataflowGraph:: + LabelledKwargDataflowGraph:: create_copy_of< - UnorderedSetLabelledOpenDataflowGraph>( + UnorderedSetLabelledOpenKwargDataflowGraph>( sub_pcg.raw_graph)}; } @@ -63,31 +80,35 @@ parallel_layer_guid_t name); } -std::vector +std::unordered_map get_layer_inputs(SubParallelComputationGraph const &pcg, parallel_layer_guid_t const &layer) { - return transform(get_inputs(pcg.raw_graph, layer.raw_graph_node), - [](OpenDataflowValue const &v) { - return open_parallel_tensor_guid_t{v}; - }); + return map_values(get_incoming_open_kwarg_dataflow_values_for_node( + pcg.raw_graph, layer.raw_graph_node), + [](OpenKwargDataflowValue const &v) { + return open_parallel_tensor_guid_t{v}; + }); } -std::vector +std::unordered_map get_layer_outputs(SubParallelComputationGraph const &pcg, parallel_layer_guid_t const &layer) { - return transform( - get_outputs(pcg.raw_graph, layer.raw_graph_node), - [](DataflowOutput const &o) { return parallel_tensor_guid_t{o}; }); + return map_values(get_outgoing_kwarg_dataflow_outputs_for_node( + pcg.raw_graph, layer.raw_graph_node), + [](KwargDataflowOutput const &o) { + return parallel_tensor_guid_t{o}; + }); } std::unordered_set get_subgraph_outgoing_edges( SubParallelComputationGraph const &spcg, std::unordered_set const &layers) { - std::unordered_set raw_edges = get_subgraph_outgoing_edges( - spcg.raw_graph, transform(layers, [](parallel_layer_guid_t const &l) { - return l.raw_graph_node; - })); - return transform(raw_edges, [](DataflowEdge const &e) { + std::unordered_set> raw_edges = + get_kwarg_dataflow_subgraph_outgoing_edges( + spcg.raw_graph, transform(layers, [](parallel_layer_guid_t const &l) { + return l.raw_graph_node; + })); + return transform(raw_edges, [](KwargDataflowEdge const &e) { return ParallelComputationGraphEdge{e}; }); } @@ -99,42 +120,50 @@ std::unordered_set get_subgraph_incoming_edges( transform(subgraph, [](parallel_layer_guid_t const &l) { return l.raw_graph_node; }); - std::unordered_set raw_incoming_edges = - get_subgraph_incoming_edges(spcg.raw_graph, raw_subgraph); + std::unordered_set> + raw_incoming_edges = get_open_kwarg_dataflow_subgraph_incoming_edges( + spcg.raw_graph, raw_subgraph); - return transform(raw_incoming_edges, [](OpenDataflowEdge const &e) { - return SubParallelComputationGraphEdge{e}; - }); + return transform(raw_incoming_edges, + [](OpenKwargDataflowEdge const &e) { + return SubParallelComputationGraphEdge{e}; + }); } std::unordered_set get_parallel_tensor_uses(SubParallelComputationGraph const &spcg, open_parallel_tensor_guid_t const &t) { - std::unordered_set raw_uses = - get_open_dataflow_value_uses(spcg.raw_graph, t.raw_open_dataflow_value); - return transform(raw_uses, [](DataflowInput const &i) { + std::unordered_set> raw_uses = + get_open_kwarg_dataflow_value_uses(spcg.raw_graph, + t.raw_open_dataflow_value); + return transform(raw_uses, [](KwargDataflowInput const &i) { return parallel_tensor_use_t{i}; }); } SubParallelComputationGraphData get_sub_pcg_data(SubParallelComputationGraph const &pcg) { - LabelledOpenDataflowGraphData - raw_data = get_graph_data(pcg.raw_graph); + LabelledOpenKwargDataflowGraphData + raw_data = get_labelled_open_kwarg_dataflow_graph_data(pcg.raw_graph); + + require_labelled_open_kwarg_dataflow_graph_data_is_valid(raw_data); return SubParallelComputationGraphData{ map_keys(raw_data.node_data, [](Node const &n) { return parallel_layer_guid_t{n}; }), transform(raw_data.edges, - [](OpenDataflowEdge const &e) { + [](OpenKwargDataflowEdge const &e) { return SubParallelComputationGraphEdge{e}; }), transform(raw_data.inputs, - [](DataflowGraphInput const &i) { + [](KwargDataflowGraphInput const &i) { return input_parallel_tensor_guid_t{i}; }), map_keys(raw_data.value_data, - [](OpenDataflowValue const &v) { + [](OpenKwargDataflowValue const &v) { return open_parallel_tensor_guid_t{v}; }), }; @@ -142,9 +171,14 @@ SubParallelComputationGraphData SubParallelComputationGraph sub_pcg_from_graph_data(SubParallelComputationGraphData const &data) { - LabelledOpenDataflowGraphData - raw_data = LabelledOpenDataflowGraphData{ + LabelledOpenKwargDataflowGraphData + raw_data = LabelledOpenKwargDataflowGraphData{ map_keys( data.node_data, [](parallel_layer_guid_t const &l) { return l.raw_graph_node; }), @@ -162,15 +196,17 @@ SubParallelComputationGraph }), }; + require_labelled_open_kwarg_dataflow_graph_data_is_valid(raw_data); + return SubParallelComputationGraph{ - from_labelled_open_dataflow_graph_data(raw_data), + view_from_labelled_open_kwarg_dataflow_graph_data(raw_data), }; } SubParallelComputationGraph without_layer_names(SubParallelComputationGraph const &spcg) { return SubParallelComputationGraph{ - rewrite_node_labels( + rewrite_labelled_open_kwarg_dataflow_graph_node_labels( spcg.raw_graph, [](Node const &n, ParallelLayerAttrs const &old_attrs) { ParallelLayerAttrs new_attrs = old_attrs; @@ -182,8 +218,9 @@ SubParallelComputationGraph bool sub_pcgs_are_isomorphic(SubParallelComputationGraph const &lhs, SubParallelComputationGraph const &rhs) { - return find_isomorphism(without_layer_names(lhs).raw_graph, - without_layer_names(rhs).raw_graph) + return find_isomorphism_between_labelled_open_kwarg_dataflow_graphs( + without_layer_names(lhs).raw_graph, + without_layer_names(rhs).raw_graph) .has_value(); } diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph_data.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph_data.cc new file mode 100644 index 0000000000..8439ac9984 --- /dev/null +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph_data.cc @@ -0,0 +1,47 @@ +#include "substitutions/sub_parallel_computation_graph_data.h" +#include "utils/containers/map_keys.h" +#include "utils/containers/transform.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_data.dtg.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_data.h" + +namespace FlexFlow { + +void require_sub_parallel_computation_graph_data_is_valid( + SubParallelComputationGraphData const &d) { + LabelledOpenKwargDataflowGraphData + labelled_graph_data = + LabelledOpenKwargDataflowGraphData{ + /*node_data=*/map_keys(d.node_data, + [](parallel_layer_guid_t l) -> Node { + return l.raw_graph_node; + }), + /*edges=*/ + transform(d.edges, + [](SubParallelComputationGraphEdge const &e) + -> OpenKwargDataflowEdge { + return e.raw_edge; + }), + /*inputs=*/ + transform(d.inputs, + [](input_parallel_tensor_guid_t const &i) + -> KwargDataflowGraphInput { + return i.raw_dataflow_graph_input; + }), + /*value_data=*/ + map_keys(d.value_data, + [](open_parallel_tensor_guid_t t) + -> OpenKwargDataflowValue { + return t.raw_open_dataflow_value; + }), + }; + + require_labelled_open_kwarg_dataflow_graph_data_is_valid(labelled_graph_data); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph_edge.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph_edge.cc index 0d2b912049..03d6ff557a 100644 --- a/lib/substitutions/src/substitutions/sub_parallel_computation_graph_edge.cc +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph_edge.cc @@ -1,19 +1,20 @@ #include "substitutions/sub_parallel_computation_graph_edge.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.h" namespace FlexFlow { SubParallelComputationGraphEdge subpcg_edge_from_tensor_and_dst(parallel_tensor_guid_t const &tensor, parallel_layer_guid_t const &layer, - nonnegative_int input_idx) { + TensorSlotName input_slot_name) { return SubParallelComputationGraphEdge{ - OpenDataflowEdge{ - DataflowEdge{ + OpenKwargDataflowEdge{ + KwargDataflowEdge{ tensor.raw_graph_output, - DataflowInput{ + KwargDataflowInput{ layer.raw_graph_node, - input_idx, + input_slot_name, }, }, }, @@ -24,14 +25,15 @@ SubParallelComputationGraphEdge subpcg_edge_from_tensor_and_use(open_parallel_tensor_guid_t const &tensor, parallel_tensor_use_t const &use) { return SubParallelComputationGraphEdge{ - open_dataflow_edge_from_src_and_dst(tensor.raw_open_dataflow_value, - use.raw_dataflow_input), + mk_open_kwarg_dataflow_edge_from_src_val_and_dst( + tensor.raw_open_dataflow_value, use.raw_dataflow_input), }; } open_parallel_tensor_guid_t get_parallel_tensor(SubParallelComputationGraphEdge const &e) { - OpenDataflowValue raw_value = get_open_dataflow_edge_src(e.raw_edge); + OpenKwargDataflowValue raw_value = + get_src_of_open_kwarg_dataflow_edge(e.raw_edge); return open_parallel_tensor_guid_t{raw_value}; } diff --git a/lib/substitutions/src/substitutions/substitution.cc b/lib/substitutions/src/substitutions/substitution.cc index 874700d303..9bcfd054a7 100644 --- a/lib/substitutions/src/substitutions/substitution.cc +++ b/lib/substitutions/src/substitutions/substitution.cc @@ -3,18 +3,21 @@ #include "substitutions/pcg_pattern.h" #include "utils/bidict/algorithms/left_entries.h" #include "utils/bidict/algorithms/right_entries.h" +#include "utils/bidict/algorithms/transform.h" #include "utils/containers/map_values.h" -#include "utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h" -#include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.h" -#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.dtg.h" -#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/rewrite_labelled_open_kwarg_dataflow_graph_node_labels.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/find_isomorphisms_between_open_kwarg_dataflow_graphs.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_isomorphism.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_isomorphism.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/try_find_isomorphism_between_open_kwarg_dataflow_graphs.h" namespace FlexFlow { bool is_isomorphic_to(Substitution const &l, Substitution const &r) { - OpenDataflowGraphIsomorphism pcg_pattern_isomorphism = ({ - std::optional maybe_isomorphism = - find_isomorphism(l.pcg_pattern.raw_graph, r.pcg_pattern.raw_graph); + OpenKwargDataflowGraphIsomorphism pcg_pattern_isomorphism = ({ + std::optional> maybe_isomorphism = + try_find_isomorphism_between_open_kwarg_dataflow_graphs( + l.pcg_pattern.raw_graph, r.pcg_pattern.raw_graph); if (!maybe_isomorphism.has_value()) { return false; @@ -54,11 +57,11 @@ bool is_isomorphic_to(Substitution const &l, Substitution const &r) { }; }; - OpenDataflowGraphIsomorphism output_graph_expr_isomorphism = ({ - std::optional maybe_isomorphism = - find_isomorphism( + OpenKwargDataflowGraphIsomorphism output_graph_expr_isomorphism = ({ + std::optional> maybe_isomorphism = + try_find_isomorphism_between_open_kwarg_dataflow_graphs( l.output_graph_expr.raw_graph, - rewrite_node_labels( + rewrite_labelled_open_kwarg_dataflow_graph_node_labels( r.output_graph_expr.raw_graph, [&](Node const &, OutputOperatorAttrsAssignment const &a) { return l_from_r_output_attrs_assignment(a); @@ -86,15 +89,15 @@ bool is_isomorphic_to(Substitution const &l, Substitution const &r) { auto l_from_r_pattern_output = [&](PatternNodeOutput const &r_output) { return PatternNodeOutput{ - isomorphism_map_l_dataflow_output_from_r(pcg_pattern_isomorphism, - r_output.raw_dataflow_output), + isomorphism_map_l_kwarg_dataflow_output_from_r( + pcg_pattern_isomorphism, r_output.raw_dataflow_output), }; }; auto l_from_r_output_graph_output = [&](OutputGraphExprNodeOutput const &r_output) { return OutputGraphExprNodeOutput{ - isomorphism_map_l_dataflow_output_from_r( + isomorphism_map_l_kwarg_dataflow_output_from_r( output_graph_expr_isomorphism, r_output.raw_dataflow_output), }; }; diff --git a/lib/substitutions/src/substitutions/substitution_builder.cc b/lib/substitutions/src/substitutions/substitution_builder.cc index a267b8113f..f2860326ab 100644 --- a/lib/substitutions/src/substitutions/substitution_builder.cc +++ b/lib/substitutions/src/substitutions/substitution_builder.cc @@ -3,32 +3,43 @@ #include "substitutions/substitution.h" #include "substitutions/unlabelled/pattern_value.h" #include "utils/containers/repeat_element.h" -#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h" #include "utils/overload.h" namespace FlexFlow { SubstitutionBuilder::SubstitutionBuilder() - : pattern_g(LabelledOpenDataflowGraph:: - create:: + create>()), - output_g(LabelledOpenDataflowGraph:: - create>()), + output_g(LabelledOpenKwargDataflowGraph:: + create>()) {} + std::monostate, + int, + TensorSlotName>>()), + next_graph_input_id{0} {} std::pair SubstitutionBuilder::add_input( TensorAttributePattern const &input_tensor_pattern, std::optional const &name) { PatternInput pattern_input = PatternInput{ - this->pattern_g.add_input(input_tensor_pattern), + this->pattern_g.add_input(this->get_fresh_graph_input_name(), + input_tensor_pattern), }; OutputGraphExprInput output_graph_expr_input = OutputGraphExprInput{ - this->output_g.add_input(std::monostate{}), + this->output_g.add_input(this->get_fresh_graph_input_name(), + std::monostate{}), }; this->input_mapping.equate(pattern_input, output_graph_expr_input); @@ -43,14 +54,16 @@ std::pair SubstitutionBuilder::add_input( }; } -std::vector SubstitutionBuilder::add_pattern_node( - OperatorAttributePattern const &node_pattern, - std::vector const &inputs, - std::vector const &output_patterns, - std::optional const &maybe_name) { - NodeAddedResult node_added = this->pattern_g.add_node( +std::unordered_map + SubstitutionBuilder::add_pattern_node( + OperatorAttributePattern const &node_pattern, + std::unordered_map const &inputs, + std::unordered_map const + &output_patterns, + std::optional const &maybe_name) { + KwargNodeAddedResult node_added = this->pattern_g.add_node( node_pattern, - transform(inputs, raw_open_dataflow_value_from_pattern_value), + map_values(inputs, raw_open_dataflow_value_from_pattern_value), output_patterns); if (maybe_name.has_value()) { @@ -65,24 +78,30 @@ std::vector SubstitutionBuilder::add_pattern_node( this->pattern_node_names.equate(PatternNode{node_added.node}, name); } - return transform(node_added.outputs, [](DataflowOutput const &o) { - return pattern_value_from_raw_open_dataflow_value(OpenDataflowValue{o}); - }); + return map_values(node_added.outputs, + [](KwargDataflowOutput const &o) { + return pattern_value_from_raw_open_kwarg_dataflow_value( + OpenKwargDataflowValue{o}); + }); } -std::vector SubstitutionBuilder::add_output_graph_node( - OutputOperatorAttrsAssignment const &node_expr, - std::vector const &inputs, - nonnegative_int num_outputs) { - NodeAddedResult node_added = this->output_g.add_node( +std::unordered_map + SubstitutionBuilder::add_output_graph_node( + OutputOperatorAttrsAssignment const &node_expr, + std::unordered_map const &inputs, + std::unordered_set const &output_slots) { + KwargNodeAddedResult node_added = this->output_g.add_node( node_expr, - transform(inputs, raw_open_dataflow_value_from_output_graph_expr_value), - repeat_element(/*num_times=*/num_outputs, /*element=*/std::monostate{})); - - return transform(node_added.outputs, [](DataflowOutput const &o) { - return output_graph_expr_value_from_raw_open_dataflow_value( - OpenDataflowValue{o}); - }); + map_values(inputs, + raw_open_kwarg_dataflow_value_from_output_graph_expr_value), + generate_map(output_slots, + [](TensorSlotName) { return std::monostate{}; })); + + return map_values( + node_added.outputs, [](KwargDataflowOutput const &o) { + return output_graph_expr_value_from_raw_open_kwarg_dataflow_value( + OpenKwargDataflowValue{o}); + }); } void SubstitutionBuilder::equate_outputs( @@ -92,7 +111,7 @@ void SubstitutionBuilder::equate_outputs( maybe_pattern_output.visit(overload{ [](PatternNodeOutput const &o) { return o; }, [&](PatternInput const &) -> PatternNodeOutput { - throw mk_runtime_error(fmt::format( + PANIC(fmt::format( "SubstitutionBuilder::equate_outputs expected a PatternValue " "holding a PatternNodeOutput, but received {}", maybe_pattern_output)); @@ -103,16 +122,15 @@ void SubstitutionBuilder::equate_outputs( maybe_output_graph_expr_output.visit(overload{ [](OutputGraphExprNodeOutput const &o) { return o; }, [&](OutputGraphExprInput const &) -> OutputGraphExprNodeOutput { - throw mk_runtime_error( - fmt::format("SubstitutionBuilder::equate_outputs expected an " - "OutputGraphExprValue holding a " - "OutputGraphExprNodeOutput, but received {}", - maybe_output_graph_expr_output)); + PANIC(fmt::format("SubstitutionBuilder::equate_outputs expected an " + "OutputGraphExprValue holding a " + "OutputGraphExprNodeOutput, but received {}", + maybe_output_graph_expr_output)); }, }); if (this->output_mapping.contains_l(pattern_output)) { - throw mk_runtime_error( + PANIC( fmt::format("SubstitutionBuilder::equate_outputs expected a " "PatternValue holding a PatternValueOutput" "that is not contained in the output_mapping forward graph," @@ -120,7 +138,7 @@ void SubstitutionBuilder::equate_outputs( pattern_output)); } if (this->output_mapping.contains_r(output_graph_expr_output)) { - throw mk_runtime_error(fmt::format( + PANIC(fmt::format( "SubstitutionBuilder::output_graph_expr_output expected a " "OutputGraphExprValue holding a OutputGraphExprNodeOutput" "that is not contained in the output_mapping backward graph," @@ -149,13 +167,18 @@ Substitution SubstitutionBuilder::get_substitution() const { this->output_mapping, }; - if (!is_valid_substitution(result)) { - throw mk_runtime_error( - "get_substitution cannot return a Substitution, as the Substitution is " - "currently invalid. Ensure you have finished constructing the " - "Substitution and have mapped all of the outputs."); - } + ASSERT( + is_valid_substitution(result), + "get_substitution cannot return a Substitution, as the Substitution is " + "currently invalid. Ensure you have finished constructing the " + "Substitution and have mapped all of the outputs."); + + return result; +} +int SubstitutionBuilder::get_fresh_graph_input_name() { + int result = this->next_graph_input_id; + this->next_graph_input_id++; return result; } diff --git a/lib/substitutions/src/substitutions/tensor_pattern/eval_list_access.cc b/lib/substitutions/src/substitutions/tensor_pattern/eval_list_access.cc index 7bfb1f5e9e..d570f23313 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/eval_list_access.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/eval_list_access.cc @@ -2,6 +2,7 @@ #include "substitutions/tensor_pattern/get_attribute.h" #include "utils/containers/at_idx.h" #include "utils/overload.h" +#include namespace FlexFlow { @@ -12,11 +13,9 @@ TensorAttributeValue return from_attr.visit(overload{ [&](std::vector const &v) -> TensorAttributeValue { - return TensorAttributeValue{at_idx(v, acc.index).value()}; - }, - [](auto &&) -> TensorAttributeValue { - throw mk_runtime_error("Invalid operand"); + return TensorAttributeValue{at_idx(v, acc.index)}; }, + [](auto &&x) -> TensorAttributeValue { PANIC("Invalid operand", x); }, }); } diff --git a/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc index 974bfcabc0..e2f2e211fa 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc @@ -1,5 +1,6 @@ #include "substitutions/tensor_pattern/satisfies_constraint.h" #include "substitutions/tensor_pattern/tensor_attribute_expr.h" +#include namespace FlexFlow { @@ -13,9 +14,7 @@ bool parallel_tensor_satisfies_constraint( case ConstraintType::EQUAL: return expr_val == constraint.attribute_value; default: - throw mk_runtime_error( - fmt::format("Unknown constraint type {}", - static_cast(constraint.constraint_type))); + PANIC("Unknown constraint type", constraint.constraint_type); } } diff --git a/lib/substitutions/src/substitutions/unity_substitution_set.cc b/lib/substitutions/src/substitutions/unity_substitution_set.cc index 4b00cdd95f..469bc02799 100644 --- a/lib/substitutions/src/substitutions/unity_substitution_set.cc +++ b/lib/substitutions/src/substitutions/unity_substitution_set.cc @@ -1,17 +1,18 @@ #include "substitutions/unity_substitution_set.h" -#include "pcg/machine_specification.h" +#include "pcg/machine_compute_specification.h" #include "substitutions/operator_pattern/operator_attribute_constraint.h" #include "substitutions/output_graph/output_operator_attrs_assignment.h" #include "substitutions/substitution_builder.h" #include "substitutions/tensor_pattern/tensor_attribute_pattern.h" #include "utils/containers/get_only.h" +#include "utils/containers/require_only_key.h" #include "utils/nonnegative_int/nonnegative_int.h" #include "utils/nonnegative_int/nonnegative_range.h" namespace FlexFlow { std::vector - get_substitution_set(MachineSpecification const &resources) { + get_substitution_set(MachineComputeSpecification const &resources) { std::vector substitutions; for (nonnegative_int num_dims : nonnegative_range(1_n, nonnegative_int{MAX_TENSOR_DIM})) { @@ -49,13 +50,19 @@ Substitution create_replicate_linear_combine(nonnegative_int num_dims, auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all()); auto [p_weight, o_weight] = b.add_input(tensor_attribute_pattern_match_all()); - std::vector p_inputs = {p_input, p_weight}; + std::unordered_map p_inputs = { + {TensorSlotName::INPUT, p_input}, + {TensorSlotName::WEIGHT, p_weight}, + }; std::optional o_bias = std::nullopt; if (use_bias) { std::pair bias = b.add_input(tensor_attribute_pattern_match_all()); - p_inputs.push_back(bias.first); + p_inputs.insert({ + TensorSlotName::BIAS, + bias.first, + }); o_bias = bias.second; } @@ -67,11 +74,18 @@ Substitution create_replicate_linear_combine(nonnegative_int num_dims, nonnegative_int{degree}), }}; - PatternValue p_linear_output = get_only(b.add_pattern_node( - linear_pattern, - p_inputs, - {tensor_attr_pattern_require_num_dims(nonnegative_int{num_dims})}, - "linear")); + PatternValue p_linear_output = require_only_key( + b.add_pattern_node(linear_pattern, + p_inputs, + { + { + TensorSlotName::OUTPUT, + tensor_attr_pattern_require_num_dims( + nonnegative_int{num_dims}), + }, + }, + "linear"), + TensorSlotName::OUTPUT); OutputOperatorAttrsAssignment replicate_input_expr = OutputOperatorAttrsAssignment{ @@ -82,7 +96,20 @@ Substitution create_replicate_linear_combine(nonnegative_int num_dims, OperatorAttributeValue{degree}), }}; OutputGraphExprValue o_replicate_input_output = - get_only(b.add_output_graph_node(replicate_input_expr, {o_input}, 1_n)); + require_only_key(b.add_output_graph_node( + /*node_expr=*/replicate_input_expr, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + o_input, + }, + }, + /*output_slots=*/ + { + TensorSlotName::OUTPUT, + }), + TensorSlotName::OUTPUT); OutputOperatorAttrsAssignment partition_weights_expr = OutputOperatorAttrsAssignment{ @@ -94,11 +121,32 @@ Substitution create_replicate_linear_combine(nonnegative_int num_dims, set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, OperatorAttributeValue{ff_dim_t{1_n}}), }}; - OutputGraphExprValue o_partition_weights_output = get_only( - b.add_output_graph_node(partition_weights_expr, {o_weight}, 1_n)); - - std::vector o_linear_inputs = { - o_replicate_input_output, o_partition_weights_output}; + OutputGraphExprValue o_partition_weights_output = + require_only_key(b.add_output_graph_node( + /*node_expr=*/partition_weights_expr, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + o_weight, + }, + }, + /*output_slots=*/ + { + TensorSlotName::OUTPUT, + }), + TensorSlotName::OUTPUT); + + std::unordered_map o_linear_inputs = { + { + TensorSlotName::INPUT, + o_replicate_input_output, + }, + { + TensorSlotName::WEIGHT, + o_partition_weights_output, + }, + }; if (use_bias) { OutputOperatorAttrsAssignment partition_bias_expr = @@ -111,9 +159,25 @@ Substitution create_replicate_linear_combine(nonnegative_int num_dims, set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, OperatorAttributeValue{ff_dim_t{1_n}}), }}; - OutputGraphExprValue o_partition_bias_output = get_only( - b.add_output_graph_node(partition_bias_expr, {o_bias.value()}, 1_n)); - o_linear_inputs.push_back(o_partition_bias_output); + OutputGraphExprValue o_partition_bias_output = + require_only_key(b.add_output_graph_node( + /*node_expr=*/partition_bias_expr, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + o_bias.value(), + }, + }, + /*output_slots=*/ + { + TensorSlotName::OUTPUT, + }), + TensorSlotName::OUTPUT); + o_linear_inputs.insert({ + TensorSlotName::BIAS, + o_partition_bias_output, + }); } OutputOperatorAttrsAssignment linear_expr = OutputOperatorAttrsAssignment{ @@ -121,7 +185,14 @@ Substitution create_replicate_linear_combine(nonnegative_int num_dims, {}, }; OutputGraphExprValue o_linear_output = - get_only(b.add_output_graph_node(linear_expr, o_linear_inputs, 1_n)); + require_only_key(b.add_output_graph_node( + /*node_expr=*/linear_expr, + /*inputs=*/o_linear_inputs, + /*output_slots=*/ + { + TensorSlotName::OUTPUT, + }), + TensorSlotName::OUTPUT); OutputOperatorAttrsAssignment combine_expr = OutputOperatorAttrsAssignment{ std::nullopt, @@ -136,8 +207,22 @@ Substitution create_replicate_linear_combine(nonnegative_int num_dims, }}), }, }; + OutputGraphExprValue o_combine_output = - get_only(b.add_output_graph_node(combine_expr, {o_linear_output}, 1_n)); + require_only_key(b.add_output_graph_node( + /*node_expr=*/combine_expr, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + o_linear_output, + }, + }, + /*output_slots=*/ + { + TensorSlotName::OUTPUT, + }), + TensorSlotName::OUTPUT); b.equate_outputs(p_linear_output, o_combine_output); @@ -204,19 +289,51 @@ Substitution create_fuse_linear_activation(Activation activation) { OperatorAttributeValue{std::optional{std::nullopt}}), }}; PatternValue p_mm_output = - get_only(b.add_pattern_node(mm_pattern, - {p_input, p_weight}, - {tensor_attribute_pattern_match_all()}, - "mm")); + require_only_key(b.add_pattern_node( + /*node_expr=*/mm_pattern, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + p_input, + }, + { + TensorSlotName::WEIGHT, + p_weight, + }, + }, + /*output_patterns=*/ + { + { + TensorSlotName::OUTPUT, + tensor_attribute_pattern_match_all(), + }, + }, + /*name=*/"mm"), + TensorSlotName::OUTPUT); OperatorAttributePattern relu_pattern = OperatorAttributePattern{{ op_type_equals_constraint(OperatorType::RELU), }}; PatternValue p_relu_output = - get_only(b.add_pattern_node(relu_pattern, - {p_mm_output}, - {tensor_attribute_pattern_match_all()}, - "relu")); + require_only_key(b.add_pattern_node( + /*node_expr=*/relu_pattern, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + p_mm_output, + }, + }, + /*output_patterns=*/ + { + { + TensorSlotName::OUTPUT, + tensor_attribute_pattern_match_all(), + }, + }, + /*name=*/"relu"), + TensorSlotName::OUTPUT); OutputOperatorAttrsAssignment fused_node_expr = OutputOperatorAttrsAssignment{ b.pattern_node_named("mm"), @@ -224,8 +341,25 @@ Substitution create_fuse_linear_activation(Activation activation) { set_attr_to_constant(OperatorAttributeKey::ACTIVATION, OperatorAttributeValue{activation}), }}; - OutputGraphExprValue o_fused_node_output = get_only( - b.add_output_graph_node(fused_node_expr, {o_input, o_weight}, 1_n)); + OutputGraphExprValue o_fused_node_output = + require_only_key(b.add_output_graph_node( + /*node_expr=*/fused_node_expr, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + o_input, + }, + { + TensorSlotName::WEIGHT, + o_weight, + }, + }, + /*output_slots=*/ + { + TensorSlotName::OUTPUT, + }), + TensorSlotName::OUTPUT); b.equate_outputs(p_relu_output, o_fused_node_output); diff --git a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc index 9d8e4bc259..7d207d9c90 100644 --- a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc +++ b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc @@ -1,68 +1,91 @@ #include "substitutions/unlabelled/find_pattern_matches.h" #include "substitutions/unlabelled/match_additional_criterion.h" -#include "substitutions/unlabelled/multidigraph_pattern_match.h" #include "substitutions/unlabelled/pattern_matching.h" #include "substitutions/unlabelled/pattern_split.h" -#include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h" #include "substitutions/unlabelled/unlabelled_graph_pattern.h" +#include "substitutions/unlabelled/unlabelled_kwarg_dataflow_graph_pattern_match.h" #include "utils/containers/get_only.h" #include "utils/containers/transform.h" +#include "utils/containers/unstructured_exhaustive_relational_join.h" +#include "utils/containers/values.h" #include "utils/containers/zip.h" #include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_incoming_open_kwarg_dataflow_values_for_node.h" +#include "utils/many_to_one/invert_many_to_one.h" +#include "utils/many_to_one/many_to_one_from_map.h" +#include "utils/many_to_one/many_to_one_from_unstructured_relation.h" +#include "utils/many_to_one/unstructured_relation_from_many_to_one.h" +#include "utils/one_to_many/unstructured_relation_from_one_to_many.h" #include "utils/overload.h" namespace FlexFlow { -static std::optional - get_candidate_singleton_match(UnlabelledGraphPattern const &pattern, - OpenDataflowGraphView const &graph, - Node const &graph_node) { - assert(is_singleton_pattern(pattern)); +static std::optional + get_candidate_singleton_match( + UnlabelledGraphPattern const &pattern, + OpenKwargDataflowGraphView const &graph, + Node const &graph_node) { + ASSERT(is_singleton_pattern(pattern)); - PatternNode pattern_node = get_only(get_nodes(pattern)); + PatternNode pattern_node = get_only(get_pattern_nodes(pattern)); - UnlabelledDataflowGraphPatternMatch match = empty_unlabelled_pattern_match(); + UnlabelledKwargDataflowGraphPatternMatch match = + empty_unlabelled_pattern_match(); match.node_assignment.equate(pattern_node, graph_node); - std::vector pattern_outputs = + std::unordered_map pattern_outputs = get_outputs_from_pattern_node(pattern, pattern_node); - std::vector graph_outputs = - transform(get_outputs(graph, graph_node), - [](DataflowOutput const &o) { return OpenDataflowValue{o}; }); - - if (pattern_outputs.size() != graph_outputs.size()) { + std::unordered_map> + graph_outputs = map_values( + get_outgoing_kwarg_dataflow_outputs_for_node(graph, graph_node), + [](KwargDataflowOutput const &o) { + return OpenKwargDataflowValue{o}; + }); + + if (keys(pattern_outputs) != keys(graph_outputs)) { return std::nullopt; } - std::vector pattern_node_inputs = + std::unordered_map pattern_node_inputs = get_inputs_to_pattern_node(pattern, pattern_node); std::unordered_set pattern_graph_inputs = - get_graph_inputs(pattern); + get_pattern_inputs(pattern); - assert(unordered_set_of(pattern_node_inputs) == + ASSERT(unordered_set_of(values(pattern_node_inputs)) == transform(pattern_graph_inputs, [](PatternInput const &i) { return PatternValue{i}; })); - std::vector graph_node_inputs = - get_inputs(graph, graph_node); + std::unordered_map> + graph_node_inputs = + get_incoming_open_kwarg_dataflow_values_for_node(graph, graph_node); if (graph_node_inputs.size() != pattern_node_inputs.size()) { return std::nullopt; } - for (auto const &[pattern_node_input, graph_node_input] : - zip(pattern_node_inputs, graph_node_inputs)) { - assert(pattern_node_input.has()); + ManyToOne m_pattern_node_inputs = + many_to_one_from_map(map_values( + pattern_node_inputs, [](PatternValue const &v) -> PatternInput { + return v.require_pattern_input(); + })); + ManyToOne> + m_graph_node_inputs = many_to_one_from_map(graph_node_inputs); - match.input_assignment.insert({ - pattern_node_input.get(), - graph_node_input, - }); - } + ManyToOne> + input_assignment = many_to_one_from_unstructured_relation( + unstructured_exhaustive_relational_join( + unstructured_relation_from_one_to_many( + invert_many_to_one(m_pattern_node_inputs)), + unstructured_relation_from_many_to_one(m_graph_node_inputs))); + + match.input_assignment = input_assignment.l_to_r(); - assert(unlabelled_pattern_does_match( + ASSERT(unlabelled_pattern_does_match( pattern, graph, match, match_additional_crition_always_true())); return match; @@ -74,7 +97,8 @@ MatchAdditionalCriterion additional_criterion_for_subpattern( &full_pattern_values_to_subpattern_inputs) { return MatchAdditionalCriterion{ full_additional_criterion.node_criterion, - [&](PatternValue const &patternValue, OpenDataflowValue const &pcgValue) { + [&](PatternValue const &patternValue, + OpenKwargDataflowValue const &pcgValue) { return patternValue.visit( overload{[&](PatternNodeOutput const &) -> bool { return full_additional_criterion.value_criterion( @@ -89,14 +113,15 @@ MatchAdditionalCriterion additional_criterion_for_subpattern( }}; } -std::vector - find_pattern_matches(UnlabelledGraphPattern const &pattern, - OpenDataflowGraphView const &graph, - MatchAdditionalCriterion const &additional_criterion) { - std::vector matches; +std::vector + find_unlabelled_pattern_matches( + UnlabelledGraphPattern const &pattern, + OpenKwargDataflowGraphView const &graph, + MatchAdditionalCriterion const &additional_criterion) { + std::vector matches; if (is_singleton_pattern(pattern)) { for (Node const &graph_node : get_nodes(graph)) { - std::optional candidate = + std::optional candidate = get_candidate_singleton_match(pattern, graph, graph_node); if (candidate.has_value() && unlabelled_pattern_does_match( @@ -107,26 +132,26 @@ std::vector } else { PatternSplit split = find_even_split(pattern); PatternSplitResult subpatterns = apply_split(pattern, split); - std::vector prefix_matches = - find_pattern_matches( + std::vector prefix_matches = + find_unlabelled_pattern_matches( subpatterns.subpattern_1, graph, additional_criterion_for_subpattern( additional_criterion, subpatterns.full_pattern_values_to_subpattern_1_inputs)); - std::vector postfix_matches = - find_pattern_matches( + std::vector postfix_matches = + find_unlabelled_pattern_matches( subpatterns.subpattern_2, graph, additional_criterion_for_subpattern( additional_criterion, subpatterns.full_pattern_values_to_subpattern_2_inputs)); - for (UnlabelledDataflowGraphPatternMatch const &prefix_match : + for (UnlabelledKwargDataflowGraphPatternMatch const &prefix_match : prefix_matches) { - for (UnlabelledDataflowGraphPatternMatch const &postfix_match : + for (UnlabelledKwargDataflowGraphPatternMatch const &postfix_match : postfix_matches) { - std::optional unsplit = + std::optional unsplit = merge_unlabelled_dataflow_graph_pattern_matches( prefix_match, postfix_match, diff --git a/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc index dff600ecf0..9436187c90 100644 --- a/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc +++ b/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc @@ -11,8 +11,8 @@ PatternNode get_dst_node(InputPatternEdge const &e) { return PatternNode{e.raw_edge.dst.node}; } -nonnegative_int get_dst_idx(InputPatternEdge const &e) { - return e.raw_edge.dst.idx; +TensorSlotName get_dst_slot_name(InputPatternEdge const &e) { + return e.raw_edge.dst.slot_name; } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.cc b/lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.cc index 8e11932035..445bf5bb9a 100644 --- a/lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.cc +++ b/lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.cc @@ -5,7 +5,8 @@ namespace FlexFlow { MatchAdditionalCriterion match_additional_crition_always_true() { return MatchAdditionalCriterion{ [](PatternNode const &, Node const &) { return true; }, - [](PatternValue const &, OpenDataflowValue const &) { return true; }, + [](PatternValue const &, + OpenKwargDataflowValue const &) { return true; }, }; } diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc index 586c9d79c3..f70d20c7a3 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc @@ -46,13 +46,13 @@ PatternEdge pattern_edge_from_standard_edge(StandardPatternEdge const &e) { return PatternEdge{e}; } -PatternEdge - pattern_edge_from_raw_open_dataflow_edge(OpenDataflowEdge const &e) { +PatternEdge pattern_edge_from_raw_open_dataflow_edge( + OpenKwargDataflowEdge const &e) { return e.visit(overload{ - [](DataflowInputEdge const &ee) { + [](KwargDataflowInputEdge const &ee) { return PatternEdge{InputPatternEdge{ee}}; }, - [](DataflowEdge const &ee) { + [](KwargDataflowEdge const &ee) { return PatternEdge{StandardPatternEdge{ee}}; }, }); diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc index c7b03e24f2..703d651070 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc @@ -9,139 +9,165 @@ #include "utils/bidict/algorithms/right_entries.h" #include "utils/containers/is_subseteq_of.h" #include "utils/containers/keys.h" +#include "utils/containers/make_counter_func.h" #include "utils/containers/transform.h" #include "utils/containers/values.h" #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/open_dataflow_graph/algorithms/as_dot.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_edges.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h" -#include "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h" -#include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_open_kwarg_dataflow_edges.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_open_kwarg_dataflow_values.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_graph_subgraph.h" #include "utils/overload.h" #include #include namespace FlexFlow { -OpenDataflowSubgraphResult - subgraph_matched(OpenDataflowGraphView const &g, - UnlabelledDataflowGraphPatternMatch const &match) { +OpenKwargDataflowSubgraphResult + subgraph_matched(OpenKwargDataflowGraphView const &g, + UnlabelledKwargDataflowGraphPatternMatch const &match) { std::unordered_set matched_nodes = right_entries(match.node_assignment); - return get_subgraph(g, matched_nodes); + return get_open_kwarg_dataflow_graph_subgraph( + g, matched_nodes, make_counter_func()); } struct SubgraphConcreteFromPattern { SubgraphConcreteFromPattern( - UnlabelledDataflowGraphPatternMatch const &match, - bidict const + UnlabelledKwargDataflowGraphPatternMatch const &match, + bidict, + KwargDataflowGraphInput> const &full_graph_values_to_subgraph_inputs) : match(match), full_graph_values_to_subgraph_inputs( full_graph_values_to_subgraph_inputs) {} - UnlabelledDataflowGraphPatternMatch const &match; - bidict const + UnlabelledKwargDataflowGraphPatternMatch const &match; + bidict, + KwargDataflowGraphInput> const &full_graph_values_to_subgraph_inputs; Node operator()(PatternNode const &n) const { return match.node_assignment.at_l(n); } - OpenDataflowValue operator()(PatternInput const &i) const { - OpenDataflowValue mapped_input = match.input_assignment.at(i); + OpenKwargDataflowValue + operator()(PatternInput const &i) const { + OpenKwargDataflowValue mapped_input = + match.input_assignment.at(i); if (full_graph_values_to_subgraph_inputs.contains_l(mapped_input)) { - return OpenDataflowValue{ + return OpenKwargDataflowValue{ full_graph_values_to_subgraph_inputs.at_l(mapped_input)}; } else { return mapped_input; } } - OpenDataflowEdge operator()(InputPatternEdge const &e) const { - return open_dataflow_edge_from_src_and_dst( + OpenKwargDataflowEdge + operator()(InputPatternEdge const &e) const { + return mk_open_kwarg_dataflow_edge_from_src_val_and_dst( this->operator()(get_src_input(e)), - DataflowInput{ + KwargDataflowInput{ this->operator()(get_dst_node(e)), - get_dst_idx(e), + get_dst_slot_name(e), }); } - DataflowEdge operator()(StandardPatternEdge const &e) const { - return DataflowEdge{ - DataflowOutput{ + KwargDataflowEdge + operator()(StandardPatternEdge const &e) const { + return KwargDataflowEdge{ + KwargDataflowOutput{ this->operator()(get_src_node(e)), - get_src_idx(e), + get_src_slot_name(e), }, - DataflowInput{ + KwargDataflowInput{ this->operator()(get_dst_node(e)), - get_dst_idx(e), + get_dst_slot_name(e), }, }; } - OpenDataflowEdge operator()(PatternEdge const &pattern_e) const { - return pattern_e.visit( - [&](auto const &e) { return OpenDataflowEdge{this->operator()(e)}; }); + OpenKwargDataflowEdge + operator()(PatternEdge const &pattern_e) const { + return pattern_e.visit>( + [&](auto const &e) { + return OpenKwargDataflowEdge{ + this->operator()(e), + }; + }); } - OpenDataflowValue operator()(PatternValue const &pattern_v) const { - return pattern_v.visit( - [&](auto const &v) { return OpenDataflowValue{this->operator()(v)}; }); + OpenKwargDataflowValue + operator()(PatternValue const &pattern_v) const { + return pattern_v.visit>( + [&](auto const &v) { + return OpenKwargDataflowValue{ + this->operator()(v)}; + }); } - DataflowOutput operator()(PatternNodeOutput const &o) const { - return DataflowOutput{ + KwargDataflowOutput + operator()(PatternNodeOutput const &o) const { + return KwargDataflowOutput{ this->operator()(get_src_node(o)), - get_idx(o), + get_slot_name(o), }; } }; bool pattern_matches_subgraph_under( UnlabelledGraphPattern const &pattern, - OpenDataflowGraphView const &subgraph, - bidict const + OpenKwargDataflowGraphView const &subgraph, + bidict, + KwargDataflowGraphInput> const &full_graph_values_to_subgraph_inputs, - UnlabelledDataflowGraphPatternMatch const &match, + UnlabelledKwargDataflowGraphPatternMatch const &match, MatchAdditionalCriterion const &additional_criterion) { SubgraphConcreteFromPattern concrete_from_pattern{ match, full_graph_values_to_subgraph_inputs}; std::unordered_set concrete_nodes = get_nodes(subgraph); std::unordered_set concrete_nodes_from_match = - transform(get_nodes(pattern), concrete_from_pattern); + transform(get_pattern_nodes(pattern), concrete_from_pattern); if (concrete_nodes != concrete_nodes_from_match) { return false; } - for (PatternNode const &pattern_node : get_nodes(pattern)) { + for (PatternNode const &pattern_node : get_pattern_nodes(pattern)) { if (!additional_criterion.node_criterion( pattern_node, concrete_from_pattern(pattern_node))) { return false; } } - std::unordered_set concrete_edges = get_edges(subgraph); - std::unordered_set concrete_edge_from_match = - transform(get_edges(pattern), concrete_from_pattern); + std::unordered_set> + concrete_edges = get_all_open_kwarg_dataflow_edges(subgraph); + std::unordered_set> + concrete_edge_from_match = + transform(get_pattern_edges(pattern), + [&](PatternEdge const &e) + -> OpenKwargDataflowEdge { + return concrete_from_pattern(e); + }); if (concrete_edges != concrete_edge_from_match) { return false; } - std::unordered_set concrete_values = - get_open_dataflow_values(subgraph); - std::unordered_set concrete_values_from_match = - transform(get_values(pattern), concrete_from_pattern); + std::unordered_set> + concrete_values = get_all_open_kwarg_dataflow_values(subgraph); + std::unordered_set> + concrete_values_from_match = + transform(get_pattern_values(pattern), + [&](PatternValue const &v) + -> OpenKwargDataflowValue { + return concrete_from_pattern(v); + }); if (concrete_values != concrete_values_from_match) { return false; } - for (PatternValue const &pattern_value : get_values(pattern)) { + for (PatternValue const &pattern_value : get_pattern_values(pattern)) { if (!additional_criterion.value_criterion( pattern_value, concrete_from_pattern(pattern_value))) { return false; @@ -153,42 +179,47 @@ bool pattern_matches_subgraph_under( bool unlabelled_pattern_does_match( UnlabelledGraphPattern const &pattern, - OpenDataflowGraphView const &graph, - UnlabelledDataflowGraphPatternMatch const &match, + OpenKwargDataflowGraphView const &graph, + UnlabelledKwargDataflowGraphPatternMatch const &match, MatchAdditionalCriterion const &additional_criterion) { - std::unordered_set matched_by_pattern_inputs = - unordered_set_of(values(match.input_assignment)); + std::unordered_set> + matched_by_pattern_inputs = + unordered_set_of(values(match.input_assignment)); - ASSERT(left_entries(match.node_assignment) == get_nodes(pattern)); + ASSERT(left_entries(match.node_assignment) == get_pattern_nodes(pattern)); ASSERT( is_subseteq_of(right_entries(match.node_assignment), get_nodes(graph))); - ASSERT(keys(match.input_assignment) == get_graph_inputs(pattern)); + ASSERT(keys(match.input_assignment) == get_pattern_inputs(pattern)); ASSERT(is_subseteq_of(matched_by_pattern_inputs, - get_open_dataflow_values(graph))); + get_all_open_kwarg_dataflow_values(graph))); - OpenDataflowSubgraphResult subgraph_result = subgraph_matched(graph, match); - OpenDataflowGraphView matched_subgraph = subgraph_result.graph; + OpenKwargDataflowSubgraphResult subgraph_result = + subgraph_matched(graph, match); + OpenKwargDataflowGraphView matched_subgraph = + subgraph_result.graph; - std::unordered_set full_values_split_by_subgraph = - left_entries(subgraph_result.full_graph_values_to_subgraph_inputs); + std::unordered_set> + full_values_split_by_subgraph = + left_entries(subgraph_result.full_graph_values_to_subgraph_inputs); ASSERT(right_entries(match.node_assignment) == get_nodes(matched_subgraph)); ASSERT(is_subseteq_of(full_values_split_by_subgraph, - get_open_dataflow_values(graph)), + get_all_open_kwarg_dataflow_values(graph)), full_values_split_by_subgraph, - get_open_dataflow_values(graph)); + get_all_open_kwarg_dataflow_values(graph)); MatchAdditionalCriterion through_subgraph_operation = MatchAdditionalCriterion{ additional_criterion.node_criterion, - [&](PatternValue const &pv, OpenDataflowValue const &v) { + [&](PatternValue const &pv, + OpenKwargDataflowValue const &v) { return v.visit(overload{ - [&](DataflowOutput const &) { + [&](KwargDataflowOutput const &) { return additional_criterion.value_criterion(pv, v); }, - [&](DataflowGraphInput const &subgraph_input) { - OpenDataflowValue full_graph_value = + [&](KwargDataflowGraphInput const &subgraph_input) { + OpenKwargDataflowValue full_graph_value = subgraph_result.full_graph_values_to_subgraph_inputs.at_r( subgraph_input); return additional_criterion.value_criterion(pv, diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_node_output.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_node_output.cc index 24bbb6f4d1..fb5950ab34 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_node_output.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_node_output.cc @@ -6,8 +6,8 @@ PatternNode get_src_node(PatternNodeOutput const &o) { return PatternNode{o.raw_dataflow_output.node}; } -nonnegative_int get_idx(PatternNodeOutput const &o) { - return o.raw_dataflow_output.idx; +TensorSlotName get_slot_name(PatternNodeOutput const &o) { + return o.raw_dataflow_output.slot_name; } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_split.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_split.cc index de8cee8dd1..2c46944c99 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_split.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_split.cc @@ -21,15 +21,16 @@ PatternSplit find_even_split(UnlabelledGraphPattern const &pattern) { PatternSplitResult apply_split(UnlabelledGraphPattern const &p, PatternSplit const &s) { UnlabelledGraphPatternSubgraphResult first_subgraph_result = - get_subgraph(p, s.first); + get_pattern_subgraph(p, s.first); UnlabelledGraphPatternSubgraphResult second_subgraph_result = - get_subgraph(p, s.second); + get_pattern_subgraph(p, s.second); return PatternSplitResult{ first_subgraph_result.subpattern, second_subgraph_result.subpattern, first_subgraph_result.full_pattern_values_to_subpattern_inputs, - second_subgraph_result.full_pattern_values_to_subpattern_inputs}; + second_subgraph_result.full_pattern_values_to_subpattern_inputs, + }; } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_value.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_value.cc index 8ff72f07a6..037f90bb76 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_value.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_value.cc @@ -1,27 +1,32 @@ #include "substitutions/unlabelled/pattern_value.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_value.dtg.h" #include "utils/overload.h" namespace FlexFlow { -OpenDataflowValue +OpenKwargDataflowValue raw_open_dataflow_value_from_pattern_value(PatternValue const &v) { - return v.visit(overload{ + return v.visit>(overload{ [](PatternNodeOutput const &o) { - return OpenDataflowValue{o.raw_dataflow_output}; + return OpenKwargDataflowValue{ + o.raw_dataflow_output}; }, [](PatternInput const &i) { - return OpenDataflowValue{i.raw_dataflow_graph_input}; + return OpenKwargDataflowValue{ + i.raw_dataflow_graph_input}; }, }); } -PatternValue - pattern_value_from_raw_open_dataflow_value(OpenDataflowValue const &v) { +PatternValue pattern_value_from_raw_open_kwarg_dataflow_value( + OpenKwargDataflowValue const &v) { return v.visit(overload{ - [](DataflowOutput const &o) { + [](KwargDataflowOutput const &o) { return PatternValue{PatternNodeOutput{o}}; }, - [](DataflowGraphInput const &i) { return PatternValue{PatternInput{i}}; }, + [](KwargDataflowGraphInput const &i) { + return PatternValue{PatternInput{i}}; + }, }); } diff --git a/lib/substitutions/src/substitutions/unlabelled/standard_pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/standard_pattern_edge.cc index 17d05f1122..ce12dc1a4e 100644 --- a/lib/substitutions/src/substitutions/unlabelled/standard_pattern_edge.cc +++ b/lib/substitutions/src/substitutions/unlabelled/standard_pattern_edge.cc @@ -10,12 +10,12 @@ PatternNode get_dst_node(StandardPatternEdge const &e) { return PatternNode{e.raw_edge.dst.node}; } -nonnegative_int get_src_idx(StandardPatternEdge const &e) { - return e.raw_edge.src.idx; +TensorSlotName get_src_slot_name(StandardPatternEdge const &e) { + return e.raw_edge.src.slot_name; } -nonnegative_int get_dst_idx(StandardPatternEdge const &e) { - return e.raw_edge.dst.idx; +TensorSlotName get_dst_slot_name(StandardPatternEdge const &e) { + return e.raw_edge.dst.slot_name; } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.cc b/lib/substitutions/src/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.cc deleted file mode 100644 index 4abf40289f..0000000000 --- a/lib/substitutions/src/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.cc +++ /dev/null @@ -1,69 +0,0 @@ -#include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h" -#include "utils/bidict/try_merge_nondisjoint_bidicts.h" -#include "utils/containers/filtermap_keys.h" -#include "utils/containers/map_keys.h" -#include "utils/containers/try_merge_nondisjoint_unordered_maps.h" - -namespace FlexFlow { - -UnlabelledDataflowGraphPatternMatch empty_unlabelled_pattern_match() { - return UnlabelledDataflowGraphPatternMatch{ - bidict{}, - bidict{}, - }; -} - -std::optional - merge_unlabelled_dataflow_graph_pattern_matches( - UnlabelledDataflowGraphPatternMatch const &subpattern_1, - UnlabelledDataflowGraphPatternMatch const &subpattern_2, - bidict const - &merged_graph_values_to_inputs_of_1, - bidict const - &merged_graph_values_to_inputs_of_2) { - bidict merged_node_assignment = ({ - std::optional> result = - try_merge_nondisjoint_bidicts(subpattern_1.node_assignment, - subpattern_2.node_assignment); - if (!result.has_value()) { - return std::nullopt; - } - result.value(); - }); - - std::unordered_map merged_input_assignment = - ({ - std::unordered_map - lifted_input_assignment_1 = map_keys( - subpattern_1.input_assignment, [&](PatternInput const &pi1) { - return merged_graph_values_to_inputs_of_1.at_r(pi1); - }); - std::unordered_map - lifted_input_assignment_2 = map_keys( - subpattern_2.input_assignment, [&](PatternInput const &pi2) { - return merged_graph_values_to_inputs_of_2.at_r(pi2); - }); - std::optional> - merged = try_merge_nondisjoint_unordered_maps( - lifted_input_assignment_1, lifted_input_assignment_2); - if (!merged.has_value()) { - return std::nullopt; - } - filtermap_keys( - merged.value(), - [](PatternValue const &v) -> std::optional { - if (v.has()) { - return v.get(); - } else { - return std::nullopt; - } - }); - }); - - return UnlabelledDataflowGraphPatternMatch{ - merged_node_assignment, - merged_input_assignment, - }; -} - -} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc index 84e0d91fee..432d41ef1d 100644 --- a/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc +++ b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc @@ -1,16 +1,17 @@ #include "substitutions/unlabelled/unlabelled_graph_pattern.h" #include "substitutions/unlabelled/pattern_edge.h" #include "substitutions/unlabelled/pattern_value.h" +#include "utils/bidict/algorithms/transform.h" +#include "utils/containers/make_counter_func.h" #include "utils/containers/transform.h" -#include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/digraph/algorithms/get_topological_ordering.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_edges.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_graph_inputs.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_open_kwarg_dataflow_edges.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_open_kwarg_dataflow_values.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_incoming_open_kwarg_dataflow_values_for_node.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_graph_subgraph.h" namespace FlexFlow { @@ -22,24 +23,28 @@ bool is_singleton_pattern(UnlabelledGraphPattern const &pattern) { return num_nodes(pattern) == 1; } -std::unordered_set get_nodes(UnlabelledGraphPattern const &p) { +std::unordered_set + get_pattern_nodes(UnlabelledGraphPattern const &p) { return transform(get_nodes(p.raw_graph), [](Node const &n) { return PatternNode{n}; }); } -std::unordered_set get_values(UnlabelledGraphPattern const &p) { - return transform(get_open_dataflow_values(p.raw_graph), - pattern_value_from_raw_open_dataflow_value); +std::unordered_set + get_pattern_values(UnlabelledGraphPattern const &p) { + return transform(get_all_open_kwarg_dataflow_values(p.raw_graph), + pattern_value_from_raw_open_kwarg_dataflow_value); } std::unordered_set - get_graph_inputs(UnlabelledGraphPattern const &p) { - return transform(get_open_dataflow_graph_inputs(p.raw_graph), - [](DataflowGraphInput const &i) { return PatternInput{i}; }); + get_pattern_inputs(UnlabelledGraphPattern const &p) { + return transform( + get_all_kwarg_dataflow_graph_inputs(p.raw_graph), + [](KwargDataflowGraphInput const &i) { return PatternInput{i}; }); } -std::unordered_set get_edges(UnlabelledGraphPattern const &p) { - return transform(get_edges(p.raw_graph), +std::unordered_set + get_pattern_edges(UnlabelledGraphPattern const &p) { + return transform(get_all_open_kwarg_dataflow_edges(p.raw_graph), pattern_edge_from_raw_open_dataflow_edge); } @@ -49,33 +54,41 @@ std::vector [](Node const &n) { return PatternNode{n}; }); } -std::vector +std::unordered_map get_inputs_to_pattern_node(UnlabelledGraphPattern const &p, PatternNode const &n) { - return transform(get_inputs(p.raw_graph, n.raw_node), - pattern_value_from_raw_open_dataflow_value); + return map_values( + get_incoming_open_kwarg_dataflow_values_for_node(p.raw_graph, n.raw_node), + [](OpenKwargDataflowValue const &v) -> PatternValue { + return pattern_value_from_raw_open_kwarg_dataflow_value(v); + }); } -std::vector +std::unordered_map get_outputs_from_pattern_node(UnlabelledGraphPattern const &p, PatternNode const &n) { - return transform( - get_outputs(p.raw_graph, n.raw_node), [](DataflowOutput const &o) { - return pattern_value_from_raw_open_dataflow_value(OpenDataflowValue{o}); + return map_values( + get_outgoing_kwarg_dataflow_outputs_for_node(p.raw_graph, n.raw_node), + [](KwargDataflowOutput const &o) { + return pattern_value_from_raw_open_kwarg_dataflow_value( + OpenKwargDataflowValue{o}); }); } UnlabelledGraphPatternSubgraphResult - get_subgraph(UnlabelledGraphPattern const &p, - std::unordered_set const &n) { - OpenDataflowSubgraphResult raw_result = get_subgraph( - p.raw_graph, - transform(n, [](PatternNode const &pn) { return pn.raw_node; })); + get_pattern_subgraph(UnlabelledGraphPattern const &p, + std::unordered_set const &n) { + OpenKwargDataflowSubgraphResult raw_result = + get_open_kwarg_dataflow_graph_subgraph( + p.raw_graph, + transform(n, [](PatternNode const &pn) { return pn.raw_node; }), + make_counter_func()); bidict full_pattern_values_to_subpattern_inputs = transform(raw_result.full_graph_values_to_subgraph_inputs, - [](OpenDataflowValue const &v, DataflowGraphInput const &i) { + [](OpenKwargDataflowValue const &v, + KwargDataflowGraphInput const &i) { return std::make_pair( - pattern_value_from_raw_open_dataflow_value(v), + pattern_value_from_raw_open_kwarg_dataflow_value(v), PatternInput{i}); }); return UnlabelledGraphPatternSubgraphResult{ diff --git a/lib/substitutions/src/substitutions/unlabelled/unlabelled_kwarg_dataflow_graph_pattern_match.cc b/lib/substitutions/src/substitutions/unlabelled/unlabelled_kwarg_dataflow_graph_pattern_match.cc new file mode 100644 index 0000000000..8252f5ef02 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/unlabelled_kwarg_dataflow_graph_pattern_match.cc @@ -0,0 +1,73 @@ +#include "substitutions/unlabelled/unlabelled_kwarg_dataflow_graph_pattern_match.h" +#include "utils/bidict/try_merge_nondisjoint_bidicts.h" +#include "utils/containers/filtermap_keys.h" +#include "utils/containers/map_keys.h" +#include "utils/containers/try_merge_nondisjoint_unordered_maps.h" + +namespace FlexFlow { + +UnlabelledKwargDataflowGraphPatternMatch empty_unlabelled_pattern_match() { + return UnlabelledKwargDataflowGraphPatternMatch{ + bidict{}, + bidict>{}, + }; +} + +std::optional + merge_unlabelled_dataflow_graph_pattern_matches( + UnlabelledKwargDataflowGraphPatternMatch const &subpattern_1, + UnlabelledKwargDataflowGraphPatternMatch const &subpattern_2, + bidict const + &merged_graph_values_to_inputs_of_1, + bidict const + &merged_graph_values_to_inputs_of_2) { + bidict merged_node_assignment = ({ + std::optional> result = + try_merge_nondisjoint_bidicts(subpattern_1.node_assignment, + subpattern_2.node_assignment); + if (!result.has_value()) { + return std::nullopt; + } + result.value(); + }); + + std::unordered_map> + merged_input_assignment = ({ + std::unordered_map> + lifted_input_assignment_1 = map_keys( + subpattern_1.input_assignment, [&](PatternInput const &pi1) { + return merged_graph_values_to_inputs_of_1.at_r(pi1); + }); + std::unordered_map> + lifted_input_assignment_2 = map_keys( + subpattern_2.input_assignment, [&](PatternInput const &pi2) { + return merged_graph_values_to_inputs_of_2.at_r(pi2); + }); + std::optional< + std::unordered_map>> + merged = try_merge_nondisjoint_unordered_maps( + lifted_input_assignment_1, lifted_input_assignment_2); + if (!merged.has_value()) { + return std::nullopt; + } + filtermap_keys( + merged.value(), + [](PatternValue const &v) -> std::optional { + if (v.has()) { + return v.get(); + } else { + return std::nullopt; + } + }); + }); + + return UnlabelledKwargDataflowGraphPatternMatch{ + merged_node_assignment, + merged_input_assignment, + }; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/test/src/substitutions/apply_substitution/apply_substitution.cc b/lib/substitutions/test/src/substitutions/apply_substitution/apply_substitution.cc index 05fd1a3fc9..89b8eb820e 100644 --- a/lib/substitutions/test/src/substitutions/apply_substitution/apply_substitution.cc +++ b/lib/substitutions/test/src/substitutions/apply_substitution/apply_substitution.cc @@ -7,6 +7,7 @@ #include "substitutions/substitution_builder.h" #include "substitutions/tensor_pattern/tensor_attribute_pattern.h" #include "utils/containers/get_only.h" +#include "utils/containers/require_only_key.h" #include "utils/integer_conversions.h" #include @@ -16,10 +17,15 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("apply_substitution") { SubstitutionBuilder b; - auto [p_input, o_input] = + auto pair_input = b.add_input(tensor_attribute_pattern_match_all(), "input"); - auto [p_weight, o_weight] = + PatternValue p_input = pair_input.first; + OutputGraphExprValue o_input = pair_input.second; + + auto pair_weight = b.add_input(tensor_attribute_pattern_match_all(), "weight"); + PatternValue p_weight = pair_weight.first; + OutputGraphExprValue o_weight = pair_weight.second; PatternValue p_mm_output = [&] { auto pattern = OperatorAttributePattern{{ @@ -29,10 +35,28 @@ TEST_SUITE(FF_TEST_SUITE) { OperatorAttributeValue{std::optional{std::nullopt}}), }}; - return get_only(b.add_pattern_node(pattern, - {p_input, p_weight}, - {tensor_attribute_pattern_match_all()}, - "mm")); + return require_only_key(b.add_pattern_node( + /*node_pattern=*/pattern, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + p_input, + }, + { + TensorSlotName::WEIGHT, + p_weight, + }, + }, + /*output_patterns=*/ + { + { + TensorSlotName::OUTPUT, + tensor_attribute_pattern_match_all(), + }, + }, + /*name=*/"mm"), + TensorSlotName::OUTPUT); }(); PatternValue p_relu_output = [&] { @@ -40,10 +64,24 @@ TEST_SUITE(FF_TEST_SUITE) { op_type_equals_constraint(OperatorType::RELU), }}; - return get_only(b.add_pattern_node(pattern, - {p_mm_output}, - {tensor_attribute_pattern_match_all()}, - "relu")); + return require_only_key(b.add_pattern_node( + /*node_pattern=*/pattern, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + p_mm_output, + }, + }, + /*output_patterns=*/ + { + { + TensorSlotName::OUTPUT, + tensor_attribute_pattern_match_all(), + }, + }, + /*name=*/"relu"), + TensorSlotName::OUTPUT); }(); OutputGraphExprValue o_fused_output = [&] { @@ -54,8 +92,24 @@ TEST_SUITE(FF_TEST_SUITE) { OperatorAttributeValue{Activation::RELU}), }}; - return get_only( - b.add_output_graph_node(node_expr, {o_input, o_weight}, 1_n)); + return require_only_key(b.add_output_graph_node( + /*node_expr=*/node_expr, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + o_input, + }, + { + TensorSlotName::WEIGHT, + o_weight, + }, + }, + /*output_slots=*/ + { + TensorSlotName::OUTPUT, + }), + TensorSlotName::OUTPUT); }(); b.equate_outputs(p_relu_output, o_fused_output); @@ -110,9 +164,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t relu_match_layer = get_parallel_layer_by_name(pcg, relu_match); open_parallel_tensor_guid_t mm_match_layer_input_activations = - get_layer_inputs(pcg, mm_match_layer).at(0); + get_layer_inputs(pcg, mm_match_layer).at(TensorSlotName::INPUT); open_parallel_tensor_guid_t mm_match_layer_input_weights = - get_layer_inputs(pcg, mm_match_layer).at(1); + get_layer_inputs(pcg, mm_match_layer).at(TensorSlotName::WEIGHT); return PCGPatternMatch{ bidict{ diff --git a/lib/substitutions/test/src/substitutions/apply_substitution/evaluate_substitution_output.cc b/lib/substitutions/test/src/substitutions/apply_substitution/evaluate_substitution_output.cc index 7419c62965..efebedb5df 100644 --- a/lib/substitutions/test/src/substitutions/apply_substitution/evaluate_substitution_output.cc +++ b/lib/substitutions/test/src/substitutions/apply_substitution/evaluate_substitution_output.cc @@ -6,7 +6,8 @@ #include "substitutions/sub_parallel_computation_graph.h" #include "substitutions/tensor_pattern/tensor_attribute_pattern.h" #include "utils/containers/get_only.h" -#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/containers/require_only_key.h" +#include "utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h" #include "utils/integer_conversions.h" #include @@ -17,15 +18,20 @@ TEST_SUITE(FF_TEST_SUITE) { // Currently Substitution creation is very verbose. // This is being addressed in // https://github.com/flexflow/FlexFlow/issues/1473. - auto pattern_g = LabelledOpenDataflowGraph:: - create>(); + auto pattern_g = LabelledOpenKwargDataflowGraph:: + create< + UnorderedSetLabelledOpenKwargDataflowGraph>(); - PatternInput pattern_i_activation = - PatternInput{pattern_g.add_input(tensor_attribute_pattern_match_all())}; - PatternInput pattern_i_weights = - PatternInput{pattern_g.add_input(tensor_attribute_pattern_match_all())}; + PatternInput pattern_i_activation = PatternInput{ + pattern_g.add_input(0, tensor_attribute_pattern_match_all())}; + PatternInput pattern_i_weights = PatternInput{ + pattern_g.add_input(1, tensor_attribute_pattern_match_all())}; OperatorAttributePattern mm_pattern = OperatorAttributePattern{{ op_type_equals_constraint(OperatorType::LINEAR), @@ -33,35 +39,75 @@ TEST_SUITE(FF_TEST_SUITE) { OperatorAttributeKey::ACTIVATION, OperatorAttributeValue{std::optional{std::nullopt}}), }}; - NodeAddedResult mm_added = pattern_g.add_node( - mm_pattern, - {OpenDataflowValue{pattern_i_activation.raw_dataflow_graph_input}, - OpenDataflowValue{pattern_i_weights.raw_dataflow_graph_input}}, - {tensor_attribute_pattern_match_all()}); + KwargNodeAddedResult mm_added = pattern_g.add_node( + /*node_label=*/mm_pattern, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + OpenKwargDataflowValue{ + pattern_i_activation.raw_dataflow_graph_input, + }, + }, + { + TensorSlotName::WEIGHT, + OpenKwargDataflowValue{ + pattern_i_weights.raw_dataflow_graph_input, + }, + }, + }, + /*output_labels=*/ + { + { + TensorSlotName::OUTPUT, + tensor_attribute_pattern_match_all(), + }, + }); PatternNode pattern_mm_node = PatternNode{mm_added.node}; - DataflowOutput mm_output = get_only(mm_added.outputs); + KwargDataflowOutput mm_output = + require_only_key(mm_added.outputs, TensorSlotName::OUTPUT); OperatorAttributePattern relu_pattern = OperatorAttributePattern{{ op_type_equals_constraint(OperatorType::RELU), }}; - NodeAddedResult relu_added = - pattern_g.add_node(relu_pattern, - {OpenDataflowValue{mm_output}}, - {tensor_attribute_pattern_match_all()}); + KwargNodeAddedResult relu_added = pattern_g.add_node( + /*node_label=*/relu_pattern, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + OpenKwargDataflowValue{mm_output}, + }, + }, + /*output_labels=*/ + { + { + TensorSlotName::OUTPUT, + tensor_attribute_pattern_match_all(), + }, + }); PatternNode pattern_relu_node = PatternNode{relu_added.node}; - DataflowOutput relu_output = get_only(relu_added.outputs); + KwargDataflowOutput relu_output = + require_only_key(relu_added.outputs, TensorSlotName::OUTPUT); - LabelledOpenDataflowGraph - output_g = LabelledOpenDataflowGraph:: - create + output_g = LabelledOpenKwargDataflowGraph:: + create>(); + std::monostate, + int, + TensorSlotName>>(); OutputGraphExprInput output_i_activation = - OutputGraphExprInput{output_g.add_input({})}; + OutputGraphExprInput{output_g.add_input(0, std::monostate{})}; OutputGraphExprInput output_i_weights = - OutputGraphExprInput{output_g.add_input({})}; + OutputGraphExprInput{output_g.add_input(1, std::monostate{})}; OutputOperatorAttrsAssignment fused_mm_relu_attrs_assignment = OutputOperatorAttrsAssignment{ @@ -81,14 +127,34 @@ TEST_SUITE(FF_TEST_SUITE) { copy_attr_from_pattern_node(OperatorAttributeKey::REGULARIZER, pattern_mm_node), }}; - NodeAddedResult fused_mm_relu_added = output_g.add_node( - fused_mm_relu_attrs_assignment, - {OpenDataflowValue{output_i_activation.raw_dataflow_graph_input}, - OpenDataflowValue{output_i_weights.raw_dataflow_graph_input}}, - {{}}); + KwargNodeAddedResult fused_mm_relu_added = output_g.add_node( + /*node_label=*/fused_mm_relu_attrs_assignment, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + OpenKwargDataflowValue{ + output_i_activation.raw_dataflow_graph_input, + }, + }, + { + TensorSlotName::WEIGHT, + OpenKwargDataflowValue{ + output_i_weights.raw_dataflow_graph_input, + }, + }, + }, + /*output_labels=*/ + { + { + TensorSlotName::OUTPUT, + std::monostate{}, + }, + }); OutputGraphExprNode fused_mm_relu_node = OutputGraphExprNode{fused_mm_relu_added.node}; - DataflowOutput fused_mm_relu_output = get_only(fused_mm_relu_added.outputs); + KwargDataflowOutput fused_mm_relu_output = + require_only_key(fused_mm_relu_added.outputs, TensorSlotName::OUTPUT); Substitution sub = Substitution{ PCGPattern{pattern_g}, @@ -158,9 +224,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t relu_match_layer = get_parallel_layer_by_name(pcg, relu_match); open_parallel_tensor_guid_t mm_match_layer_input_activations = - get_layer_inputs(pcg, mm_match_layer).at(0); + get_layer_inputs(pcg, mm_match_layer).at(TensorSlotName::INPUT); open_parallel_tensor_guid_t mm_match_layer_input_weights = - get_layer_inputs(pcg, mm_match_layer).at(1); + get_layer_inputs(pcg, mm_match_layer).at(TensorSlotName::WEIGHT); PCGPatternMatch match = PCGPatternMatch{ bidict{ @@ -204,12 +270,14 @@ TEST_SUITE(FF_TEST_SUITE) { get_parallel_tensor_attrs( pcg, open_parallel_tensor_guid_from_closed( - get_only(get_layer_outputs(pcg, relu_match_layer)))); + require_only_key(get_layer_outputs(pcg, relu_match_layer), + TensorSlotName::OUTPUT))); parallel_layer_guid_t result_fused_mm_relu_node = result_node_map.at_r(fused_mm_relu_node); - parallel_tensor_guid_t result_fused_mm_relu_output = - get_only(get_layer_outputs(result_graph, result_fused_mm_relu_node)); + parallel_tensor_guid_t result_fused_mm_relu_output = require_only_key( + get_layer_outputs(result_graph, result_fused_mm_relu_node), + TensorSlotName::OUTPUT); input_parallel_tensor_guid_t result_i_activation = result_input_map.at_r(output_i_activation); input_parallel_tensor_guid_t result_i_weights = @@ -226,23 +294,23 @@ TEST_SUITE(FF_TEST_SUITE) { }}, std::unordered_set{ SubParallelComputationGraphEdge{ - OpenDataflowEdge{ - DataflowInputEdge{ + OpenKwargDataflowEdge{ + KwargDataflowInputEdge{ result_i_activation.raw_dataflow_graph_input, - DataflowInput{ + KwargDataflowInput{ result_fused_mm_relu_node.raw_graph_node, - 0_n, + TensorSlotName::INPUT, }, }, }, }, SubParallelComputationGraphEdge{ - OpenDataflowEdge{ - DataflowInputEdge{ + OpenKwargDataflowEdge{ + KwargDataflowInputEdge{ result_i_weights.raw_dataflow_graph_input, - DataflowInput{ + KwargDataflowInput{ result_fused_mm_relu_node.raw_graph_node, - 1_n, + TensorSlotName::WEIGHT, }, }, }, diff --git a/lib/substitutions/test/src/substitutions/apply_substitution/perform_shape_inference.cc b/lib/substitutions/test/src/substitutions/apply_substitution/perform_shape_inference.cc index 2bf72d3224..1c0f46bb3f 100644 --- a/lib/substitutions/test/src/substitutions/apply_substitution/perform_shape_inference.cc +++ b/lib/substitutions/test/src/substitutions/apply_substitution/perform_shape_inference.cc @@ -3,9 +3,11 @@ #include "op-attrs/ops/linear.h" #include "op-attrs/parallel_tensor_shape.h" #include "utils/containers/get_only.h" -#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" -#include "utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h" -#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h" +#include "utils/containers/require_only_key.h" +#include "utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/get_labelled_open_kwarg_dataflow_graph_data.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_data.dtg.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph.h" #include "utils/integer_conversions.h" #include @@ -13,17 +15,21 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("perform_shape_inference") { - auto g = - LabelledOpenDataflowGraph::create< - UnorderedSetLabelledOpenDataflowGraph>(); + auto g = LabelledOpenKwargDataflowGraph:: + create>(); positive_int in_channels = 24_p; positive_int out_channels = 16_p; positive_int batch_size = 4_p; positive_int batch_degree = 2_p; - DataflowGraphInput i0 = g.add_input({}); + KwargDataflowGraphInput i0 = g.add_input(0, std::monostate{}); ParallelTensorShape i0_shape = ParallelTensorShape{ ParallelTensorDims{ FFOrdered{ @@ -69,7 +75,7 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelTensorShape n1_weight_shape = throw_if_unexpected(get_projection_shape(n1_op_attrs, i0_shape)); ParallelTensorShape n2_output_shape = - throw_if_unexpected(get_output_shape(n2_op_attrs, n1_output_shape)); + get_output_shape(n2_op_attrs, n1_output_shape); ParallelLayerAttrs n1_weight_attrs = ParallelLayerAttrs{ PCGOperatorAttrs{ @@ -88,42 +94,111 @@ TEST_SUITE(FF_TEST_SUITE) { std::nullopt, }; - NodeAddedResult n1_weight_added_result = - g.add_node(n1_weight_attrs, {}, {{}}); + KwargNodeAddedResult n1_weight_added_result = g.add_node( + /*node_labels=*/n1_weight_attrs, + /*inputs=*/{}, + /*output_labels=*/ + { + { + TensorSlotName::OUTPUT, + {}, + }, + }); Node n1_weight_node = n1_weight_added_result.node; - DataflowOutput n1_weight = get_only(n1_weight_added_result.outputs); + KwargDataflowOutput n1_weight = require_only_key( + n1_weight_added_result.outputs, TensorSlotName::OUTPUT); - NodeAddedResult n1_weight_replicate_added_result = g.add_node( - n1_weight_replicate_attrs, {OpenDataflowValue{n1_weight}}, {{}}); + KwargNodeAddedResult n1_weight_replicate_added_result = + g.add_node( + /*node_label=*/n1_weight_replicate_attrs, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + OpenKwargDataflowValue{n1_weight}, + }, + }, + /*outupt_labels=*/ + { + { + TensorSlotName::OUTPUT, + std::monostate{}, + }, + }); Node n1_weight_replicate_node = n1_weight_replicate_added_result.node; - DataflowOutput n1_weight_replicated = - get_only(n1_weight_replicate_added_result.outputs); + KwargDataflowOutput n1_weight_replicated = require_only_key( + n1_weight_replicate_added_result.outputs, TensorSlotName::OUTPUT); - NodeAddedResult n1_added_result = g.add_node( - n1_attrs, - {OpenDataflowValue{i0}, OpenDataflowValue{n1_weight_replicated}}, - {{}}); + KwargNodeAddedResult n1_added_result = g.add_node( + /*node_label=*/n1_attrs, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + OpenKwargDataflowValue{i0}, + }, + { + TensorSlotName::WEIGHT, + OpenKwargDataflowValue{ + n1_weight_replicated}, + }, + }, + /*output_labels=*/ + { + { + TensorSlotName::OUTPUT, + std::monostate{}, + }, + }); Node n1 = n1_added_result.node; - DataflowOutput o1 = get_only(n1_added_result.outputs); + KwargDataflowOutput o1 = + require_only_key(n1_added_result.outputs, TensorSlotName::OUTPUT); - NodeAddedResult n2_added_result = - g.add_node(n2_attrs, {OpenDataflowValue{o1}}, {{}}); + KwargNodeAddedResult n2_added_result = g.add_node( + /*node_labels=*/n2_attrs, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + OpenKwargDataflowValue{o1}, + }, + }, + /*output_labels=*/ + { + { + TensorSlotName::OUTPUT, + {}, + }, + }); Node n2 = n2_added_result.node; - DataflowOutput o2 = get_only(n2_added_result.outputs); + KwargDataflowOutput o2 = + require_only_key(n2_added_result.outputs, TensorSlotName::OUTPUT); - std::unordered_map input_shapes = { - {i0, i0_shape}, - }; + std::unordered_map, ParallelTensorShape> + input_shapes = { + {i0, i0_shape}, + }; - LabelledOpenDataflowGraphView + LabelledOpenKwargDataflowGraphView result = perform_shape_inference(g, input_shapes); - LabelledOpenDataflowGraphData - result_data = get_graph_data(result); + LabelledOpenKwargDataflowGraphData + result_data = get_labelled_open_kwarg_dataflow_graph_data(result); - LabelledOpenDataflowGraphData - correct_data = LabelledOpenDataflowGraphData{ + LabelledOpenKwargDataflowGraphData + correct_data = LabelledOpenKwargDataflowGraphData{ { {n1, n1_attrs}, {n2, n2_attrs}, @@ -131,49 +206,94 @@ TEST_SUITE(FF_TEST_SUITE) { {n1_weight_replicate_node, n1_weight_replicate_attrs}, }, { - OpenDataflowEdge{ - DataflowInputEdge{ + OpenKwargDataflowEdge{ + KwargDataflowInputEdge{ i0, - DataflowInput{n1, 0_n}, + KwargDataflowInput{ + n1, + TensorSlotName::INPUT, + }, + }, + }, + OpenKwargDataflowEdge{ + KwargDataflowEdge{ + KwargDataflowOutput{ + n1_weight_node, + TensorSlotName::OUTPUT, + }, + KwargDataflowInput{ + n1_weight_replicate_node, + TensorSlotName::INPUT, + }, + }, + }, + OpenKwargDataflowEdge{ + KwargDataflowEdge{ + KwargDataflowOutput{ + n1_weight_replicate_node, + TensorSlotName::OUTPUT, + }, + KwargDataflowInput{ + n1, + TensorSlotName::WEIGHT, + }, }, }, - OpenDataflowEdge{DataflowEdge{ - DataflowOutput{n1_weight_node, 0_n}, - DataflowInput{n1_weight_replicate_node, 0_n}, - }}, - OpenDataflowEdge{ - DataflowEdge{ - DataflowOutput{n1_weight_replicate_node, 0_n}, - DataflowInput{n1, 1_n}, + OpenKwargDataflowEdge{ + KwargDataflowEdge{ + KwargDataflowOutput{ + n1, + TensorSlotName::OUTPUT, + }, + KwargDataflowInput{ + n2, + TensorSlotName::INPUT, + }, }, }, - OpenDataflowEdge{DataflowEdge{ - DataflowOutput{n1, 0_n}, - DataflowInput{n2, 0_n}, - }}, }, {i0}, - {{ - OpenDataflowValue{i0}, - i0_shape, - }, - { - OpenDataflowValue{DataflowOutput{n1_weight_node, 0_n}}, - lift_to_parallel(get_reduced_shape(n1_weight_shape)), - }, - { - OpenDataflowValue{ - DataflowOutput{n1_weight_replicate_node, 0_n}}, - n1_weight_shape, - }, - { - OpenDataflowValue{DataflowOutput{n1, 0_n}}, - n1_output_shape, - }, - { - OpenDataflowValue{DataflowOutput{n2, 0_n}}, - n2_output_shape, - }}}; + { + { + OpenKwargDataflowValue{i0}, + i0_shape, + }, + { + OpenKwargDataflowValue{ + KwargDataflowOutput{ + n1_weight_node, + TensorSlotName::OUTPUT, + }, + }, + lift_to_parallel(get_reduced_shape(n1_weight_shape)), + }, + { + OpenKwargDataflowValue{ + KwargDataflowOutput{ + n1_weight_replicate_node, + TensorSlotName::OUTPUT, + }}, + n1_weight_shape, + }, + { + OpenKwargDataflowValue{ + KwargDataflowOutput{ + n1, + TensorSlotName::OUTPUT, + }, + }, + n1_output_shape, + }, + { + OpenKwargDataflowValue{ + KwargDataflowOutput{ + n2, + TensorSlotName::OUTPUT, + }, + }, + n2_output_shape, + }, + }}; CHECK(result_data == correct_data); } diff --git a/lib/substitutions/test/src/substitutions/pcg_pattern.cc b/lib/substitutions/test/src/substitutions/pcg_pattern.cc index f4d430077f..b36a5f1d82 100644 --- a/lib/substitutions/test/src/substitutions/pcg_pattern.cc +++ b/lib/substitutions/test/src/substitutions/pcg_pattern.cc @@ -6,7 +6,8 @@ #include "substitutions/sub_parallel_computation_graph.h" #include "substitutions/tensor_pattern/tensor_attribute_pattern.h" #include "utils/containers/get_only.h" -#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/containers/require_only_key.h" +#include "utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h" #include using namespace ::FlexFlow; @@ -64,22 +65,29 @@ TEST_SUITE(FF_TEST_SUITE) { get_parallel_layer_by_name(pcg, x_matmul_name); parallel_layer_guid_t y_matmul = get_parallel_layer_by_name(pcg, y_matmul_name); - std::vector x_incoming = + std::unordered_map x_incoming = get_incoming_tensors(pcg, x_matmul); REQUIRE(x_incoming.size() == 2); - parallel_tensor_guid_t x_weights = x_incoming.at(1); - std::vector y_incoming = + + parallel_tensor_guid_t x_weights = x_incoming.at(TensorSlotName::WEIGHT); + std::unordered_map y_incoming = get_incoming_tensors(pcg, y_matmul); REQUIRE(y_incoming.size() == 2); - parallel_tensor_guid_t y_weights = y_incoming.at(1); - - LabelledOpenDataflowGraph - g = LabelledOpenDataflowGraph:: - create + g = LabelledOpenKwargDataflowGraph:: + create>(); + TensorAttributePattern, + int, + TensorSlotName>>(); TensorAttributePattern pattern_tensor_a = tensor_attribute_pattern_match_all(); @@ -98,25 +106,63 @@ TEST_SUITE(FF_TEST_SUITE) { OperatorAttributePattern op_pattern_2 = op_pattern_1; - DataflowGraphInput pt_a = g.add_input(pattern_tensor_a); - DataflowGraphInput pt_b = g.add_input(pattern_tensor_b); - DataflowGraphInput pt_c = g.add_input(pattern_tensor_c); - - NodeAddedResult op_pattern_1_added = - g.add_node(op_pattern_1, - {OpenDataflowValue{pt_a}, OpenDataflowValue{pt_b}}, - {pattern_tensor_x}); + KwargDataflowGraphInput pt_a = g.add_input(0, pattern_tensor_a); + KwargDataflowGraphInput pt_b = g.add_input(1, pattern_tensor_b); + KwargDataflowGraphInput pt_c = g.add_input(2, pattern_tensor_c); + + KwargNodeAddedResult op_pattern_1_added = g.add_node( + /*node_label=*/op_pattern_1, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + OpenKwargDataflowValue{pt_a}, + }, + { + TensorSlotName::WEIGHT, + OpenKwargDataflowValue{pt_b}, + }, + }, + /*output_labels=*/ + { + { + TensorSlotName::OUTPUT, + pattern_tensor_x, + }, + }); PatternNode op_pattern_1_node = PatternNode{op_pattern_1_added.node}; - OpenDataflowValue pt_x = - OpenDataflowValue{get_only(op_pattern_1_added.outputs)}; - - NodeAddedResult op_pattern_2_added = - g.add_node(op_pattern_2, - {OpenDataflowValue{pt_a}, OpenDataflowValue{pt_c}}, - {pattern_tensor_y}); + OpenKwargDataflowValue pt_x = + OpenKwargDataflowValue{ + require_only_key(op_pattern_1_added.outputs, + TensorSlotName::OUTPUT), + }; + + KwargNodeAddedResult op_pattern_2_added = g.add_node( + /*node_label=*/op_pattern_2, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + OpenKwargDataflowValue{pt_a}, + }, + { + TensorSlotName::WEIGHT, + OpenKwargDataflowValue{pt_c}, + }, + }, + /*outputs_labels=*/ + { + { + TensorSlotName::OUTPUT, + pattern_tensor_y, + }, + }); PatternNode op_pattern_2_node = PatternNode{op_pattern_2_added.node}; - OpenDataflowValue pt_y = - OpenDataflowValue{get_only(op_pattern_2_added.outputs)}; + OpenKwargDataflowValue pt_y = + OpenKwargDataflowValue{ + require_only_key(op_pattern_2_added.outputs, + TensorSlotName::OUTPUT), + }; PCGPattern pattern = PCGPattern{g}; @@ -218,13 +264,19 @@ TEST_SUITE(FF_TEST_SUITE) { /*bias_initializer=*/std::nullopt); ParallelComputationGraph pcg = builder.pcg; - LabelledOpenDataflowGraph - g = LabelledOpenDataflowGraph:: - create + g = LabelledOpenKwargDataflowGraph:: + create>(); + TensorAttributePattern, + int, + TensorSlotName>>(); TensorAttributePattern pattern_tensor_a = tensor_attribute_pattern_match_all(); @@ -243,22 +295,57 @@ TEST_SUITE(FF_TEST_SUITE) { OperatorAttributePattern op_pattern_2 = op_pattern_1; - DataflowGraphInput pt_a = g.add_input(pattern_tensor_a); - DataflowGraphInput pt_b = g.add_input(pattern_tensor_b); - DataflowGraphInput pt_c = g.add_input(pattern_tensor_c); - - NodeAddedResult op_pattern_1_added = - g.add_node(op_pattern_1, - {OpenDataflowValue{pt_a}, OpenDataflowValue{pt_b}}, - {pattern_tensor_x}); + KwargDataflowGraphInput pt_a = g.add_input(0, pattern_tensor_a); + KwargDataflowGraphInput pt_b = g.add_input(1, pattern_tensor_b); + KwargDataflowGraphInput pt_c = g.add_input(2, pattern_tensor_c); + + KwargNodeAddedResult op_pattern_1_added = g.add_node( + /*node_label=*/op_pattern_1, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + OpenKwargDataflowValue{pt_a}, + }, + { + TensorSlotName::WEIGHT, + OpenKwargDataflowValue{pt_b}, + }, + }, + /*output_labels=*/ + { + { + TensorSlotName::OUTPUT, + pattern_tensor_x, + }, + }); PatternNode op_pattern_1_node = PatternNode{op_pattern_1_added.node}; - OpenDataflowValue pt_x = - OpenDataflowValue{get_only(op_pattern_1_added.outputs)}; - - NodeAddedResult op_pattern_2_added = - g.add_node(op_pattern_2, - {OpenDataflowValue{pt_x}, OpenDataflowValue{pt_c}}, - {pattern_tensor_y}); + OpenKwargDataflowValue pt_x = + OpenKwargDataflowValue{ + require_only_key(op_pattern_1_added.outputs, + TensorSlotName::OUTPUT), + }; + + KwargNodeAddedResult op_pattern_2_added = g.add_node( + /*node_label=*/op_pattern_2, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + OpenKwargDataflowValue{pt_x}, + }, + { + TensorSlotName::WEIGHT, + OpenKwargDataflowValue{pt_c}, + }, + }, + /*output_labels=*/ + { + { + TensorSlotName::OUTPUT, + pattern_tensor_y, + }, + }); PatternNode op_pattern_2_node = PatternNode{op_pattern_2_added.node}; PCGPattern pattern = PCGPattern{g}; diff --git a/lib/substitutions/test/src/substitutions/substitution.cc b/lib/substitutions/test/src/substitutions/substitution.cc index ef27cb7606..c8c4fe69e8 100644 --- a/lib/substitutions/test/src/substitutions/substitution.cc +++ b/lib/substitutions/test/src/substitutions/substitution.cc @@ -10,7 +10,9 @@ #include "substitutions/substitution_builder.h" #include "substitutions/tensor_pattern/tensor_attribute_pattern.h" #include "utils/containers/get_only.h" +#include "utils/containers/require_only_key.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h" #include "utils/graph/open_dataflow_graph/algorithms/are_isomorphic.h" #include "utils/integer_conversions.h" @@ -23,10 +25,13 @@ TEST_SUITE(FF_TEST_SUITE) { auto make_substitution = [] { SubstitutionBuilder b; - auto [p_input, o_input] = - b.add_input(tensor_attribute_pattern_match_all()); - auto [p_weight, o_weight] = - b.add_input(tensor_attribute_pattern_match_all()); + auto pair_input = b.add_input(tensor_attribute_pattern_match_all()); + PatternValue p_input = pair_input.first; + OutputGraphExprValue o_input = pair_input.second; + + auto pair_weight = b.add_input(tensor_attribute_pattern_match_all()); + PatternValue p_weight = pair_weight.first; + OutputGraphExprValue o_weight = pair_weight.second; PatternValue p_mm_output = [&] { auto pattern = OperatorAttributePattern{{ @@ -36,11 +41,29 @@ TEST_SUITE(FF_TEST_SUITE) { std::optional{std::nullopt}}), }}; - return get_only( - b.add_pattern_node(pattern, - {p_input, p_weight}, - {tensor_attribute_pattern_match_all()}, - "mm")); + return require_only_key( + b.add_pattern_node( + /*node_pattern=*/pattern, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + p_input, + }, + { + TensorSlotName::WEIGHT, + p_weight, + }, + }, + /*output_patterns=*/ + { + { + TensorSlotName::OUTPUT, + tensor_attribute_pattern_match_all(), + }, + }, + /*name=*/"mm"), + TensorSlotName::OUTPUT); }(); PatternValue p_relu_output = [&] { @@ -48,11 +71,25 @@ TEST_SUITE(FF_TEST_SUITE) { op_type_equals_constraint(OperatorType::RELU), }}; - return get_only( - b.add_pattern_node(pattern, - {p_mm_output}, - {tensor_attribute_pattern_match_all()}, - "relu")); + return require_only_key( + b.add_pattern_node( + /*node_pattern=*/pattern, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + p_mm_output, + }, + }, + /*output_patterns=*/ + { + { + TensorSlotName::OUTPUT, + tensor_attribute_pattern_match_all(), + }, + }, + /*name=*/"relu"), + TensorSlotName::OUTPUT); }(); OutputGraphExprValue o_fused_output = [&] { @@ -63,8 +100,24 @@ TEST_SUITE(FF_TEST_SUITE) { OperatorAttributeValue{Activation::RELU}), }}; - return get_only(b.add_output_graph_node( - node_expr, {o_input, o_weight}, nonnegative_int{1})); + return require_only_key(b.add_output_graph_node( + /*node_expr=*/node_expr, + /*input=*/ + { + { + TensorSlotName::INPUT, + o_input, + }, + { + TensorSlotName::WEIGHT, + o_weight, + }, + }, + /*output_slots=*/ + { + TensorSlotName::OUTPUT, + }), + TensorSlotName::OUTPUT); }(); b.equate_outputs(p_relu_output, o_fused_output); @@ -82,9 +135,13 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("is_valid_substitution") { SubstitutionBuilder b; - auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all()); - auto [p_weight, o_weight] = - b.add_input(tensor_attribute_pattern_match_all()); + auto pair_input = b.add_input(tensor_attribute_pattern_match_all()); + PatternValue p_input = pair_input.first; + OutputGraphExprValue o_input = pair_input.second; + + auto pair_weight = b.add_input(tensor_attribute_pattern_match_all()); + PatternValue p_weight = pair_weight.first; + OutputGraphExprValue o_weight = pair_weight.second; PatternValue p_mm_output = [&] { auto pattern = OperatorAttributePattern{{ @@ -94,10 +151,28 @@ TEST_SUITE(FF_TEST_SUITE) { OperatorAttributeValue{std::optional{std::nullopt}}), }}; - return get_only(b.add_pattern_node(pattern, - {p_input, p_weight}, - {tensor_attribute_pattern_match_all()}, - "mm")); + return require_only_key(b.add_pattern_node( + /*node_pattern=*/pattern, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + p_input, + }, + { + TensorSlotName::WEIGHT, + p_weight, + }, + }, + /*output_patterns=*/ + { + { + TensorSlotName::OUTPUT, + tensor_attribute_pattern_match_all(), + }, + }, + /*name=*/"mm"), + TensorSlotName::OUTPUT); }(); PatternValue p_relu_output = [&] { @@ -105,10 +180,24 @@ TEST_SUITE(FF_TEST_SUITE) { op_type_equals_constraint(OperatorType::RELU), }}; - return get_only(b.add_pattern_node(pattern, - {p_mm_output}, - {tensor_attribute_pattern_match_all()}, - "relu")); + return require_only_key(b.add_pattern_node( + /*node_pattern=*/pattern, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + p_mm_output, + }, + }, + /*output_patterns=*/ + { + { + TensorSlotName::OUTPUT, + tensor_attribute_pattern_match_all(), + }, + }, + /*name=*/"relu"), + TensorSlotName::OUTPUT); }(); OutputGraphExprValue o_fused_output = [&] { @@ -119,21 +208,35 @@ TEST_SUITE(FF_TEST_SUITE) { OperatorAttributeValue{Activation::RELU}), }}; - return get_only(b.add_output_graph_node( - node_expr, {o_input, o_weight}, nonnegative_int{1})); + return require_only_key(b.add_output_graph_node( + /*node_expr=*/node_expr, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + o_input, + }, + { + TensorSlotName::OUTPUT, + o_weight, + }, + }, + /*output_slots=*/{TensorSlotName::OUTPUT}), + TensorSlotName::OUTPUT); }(); b.equate_outputs(p_relu_output, o_fused_output); SUBCASE("pattern inputs != mapped inputs") { Substitution sub = b.get_substitution(); - sub.pcg_pattern.raw_graph.add_input(tensor_attribute_pattern_match_all()); + sub.pcg_pattern.raw_graph.add_input(13, + tensor_attribute_pattern_match_all()); CHECK_FALSE(is_valid_substitution(sub)); } SUBCASE("output graph inputs != mapped inputs") { Substitution sub = b.get_substitution(); - sub.output_graph_expr.raw_graph.add_input(std::monostate{}); + sub.output_graph_expr.raw_graph.add_input(0, std::monostate{}); CHECK_FALSE(is_valid_substitution(sub)); } @@ -141,14 +244,20 @@ TEST_SUITE(FF_TEST_SUITE) { // Could revamp this test to only trigger the // get_nodes(sub.pcg_pattern).empty() case Substitution sub = b.get_substitution(); - LabelledOpenDataflowGraph + LabelledOpenKwargDataflowGraph zero_node_pattern = - LabelledOpenDataflowGraph:: - create:: + create>(); + TensorAttributePattern, + int, + TensorSlotName>>(); sub.pcg_pattern = PCGPattern{zero_node_pattern}; CHECK_FALSE(is_valid_substitution(sub)); } @@ -157,13 +266,20 @@ TEST_SUITE(FF_TEST_SUITE) { // Could revamp this test to only trigger the // get_nodes(sub.output_graph_expr).empty() case Substitution sub = b.get_substitution(); - LabelledOpenDataflowGraph + LabelledOpenKwargDataflowGraph zero_node_pattern = - LabelledOpenDataflowGraph:: - create:: + create>(); + std::monostate, + int, + TensorSlotName>>(); sub.output_graph_expr = OutputGraphExpr{zero_node_pattern}; CHECK_FALSE(is_valid_substitution(sub)); } diff --git a/lib/substitutions/test/src/substitutions/substitution_builder.cc b/lib/substitutions/test/src/substitutions/substitution_builder.cc index 028a4e59c9..10e08b09e7 100644 --- a/lib/substitutions/test/src/substitutions/substitution_builder.cc +++ b/lib/substitutions/test/src/substitutions/substitution_builder.cc @@ -5,7 +5,9 @@ #include "substitutions/substitution.h" #include "substitutions/tensor_pattern/tensor_attribute_pattern.h" #include "utils/containers/get_only.h" +#include "utils/containers/require_only_key.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h" #include using namespace ::FlexFlow; @@ -30,58 +32,122 @@ TEST_SUITE(FF_TEST_SUITE) { }; Substitution correct = [&] { - auto pattern_g = LabelledOpenDataflowGraph:: - create< - UnorderedSetLabelledOpenDataflowGraph>(); + auto pattern_g = LabelledOpenKwargDataflowGraph:: + create>(); PatternInput pattern_i_activation = PatternInput{ - pattern_g.add_input(tensor_attribute_pattern_match_all())}; + pattern_g.add_input(0, tensor_attribute_pattern_match_all()), + }; PatternInput pattern_i_weights = PatternInput{ - pattern_g.add_input(tensor_attribute_pattern_match_all())}; + pattern_g.add_input(1, tensor_attribute_pattern_match_all()), + }; - NodeAddedResult mm_added = pattern_g.add_node( - mm_pattern, - {OpenDataflowValue{pattern_i_activation.raw_dataflow_graph_input}, - OpenDataflowValue{pattern_i_weights.raw_dataflow_graph_input}}, - {tensor_attribute_pattern_match_all()}); + KwargNodeAddedResult mm_added = pattern_g.add_node( + /*node_label=*/mm_pattern, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + OpenKwargDataflowValue{ + pattern_i_activation.raw_dataflow_graph_input, + }, + }, + { + TensorSlotName::WEIGHT, + OpenKwargDataflowValue{ + pattern_i_weights.raw_dataflow_graph_input, + }, + }, + }, + /*output_labels=*/ + { + { + TensorSlotName::OUTPUT, + tensor_attribute_pattern_match_all(), + }, + }); PatternNode pattern_mm_node = PatternNode{mm_added.node}; - DataflowOutput mm_output = get_only(mm_added.outputs); + KwargDataflowOutput mm_output = + require_only_key(mm_added.outputs, TensorSlotName::OUTPUT); - NodeAddedResult relu_added = - pattern_g.add_node(relu_pattern, - {OpenDataflowValue{mm_output}}, - {tensor_attribute_pattern_match_all()}); + KwargNodeAddedResult relu_added = pattern_g.add_node( + /*node_label=*/relu_pattern, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + OpenKwargDataflowValue{mm_output}, + }, + }, + /*output_labels=*/ + { + { + TensorSlotName::OUTPUT, + tensor_attribute_pattern_match_all(), + }, + }); PatternNode pattern_relu_node = PatternNode{relu_added.node}; - DataflowOutput relu_output = get_only(relu_added.outputs); + KwargDataflowOutput relu_output = + require_only_key(relu_added.outputs, TensorSlotName::OUTPUT); - LabelledOpenDataflowGraph - output_g = LabelledOpenDataflowGraph:: - create>(); + LabelledOpenKwargDataflowGraph + output_g = + LabelledOpenKwargDataflowGraph:: + create>(); OutputGraphExprInput output_i_activation = - OutputGraphExprInput{output_g.add_input({})}; + OutputGraphExprInput{output_g.add_input(0, {})}; OutputGraphExprInput output_i_weights = - OutputGraphExprInput{output_g.add_input({})}; + OutputGraphExprInput{output_g.add_input(1, {})}; OutputOperatorAttrsAssignment fused_mm_relu_attrs_assignment = OutputOperatorAttrsAssignment{ pattern_mm_node, fused_mm_relu_attr_assignments, }; - NodeAddedResult fused_mm_relu_added = output_g.add_node( - fused_mm_relu_attrs_assignment, - {OpenDataflowValue{output_i_activation.raw_dataflow_graph_input}, - OpenDataflowValue{output_i_weights.raw_dataflow_graph_input}}, - {{}}); + KwargNodeAddedResult fused_mm_relu_added = output_g.add_node( + /*node_label=*/fused_mm_relu_attrs_assignment, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + OpenKwargDataflowValue{ + output_i_activation.raw_dataflow_graph_input, + }, + }, + { + TensorSlotName::WEIGHT, + OpenKwargDataflowValue{ + output_i_weights.raw_dataflow_graph_input, + }, + }, + }, + /*output_labels=*/ + {{ + TensorSlotName::OUTPUT, + std::monostate{}, + }}); OutputGraphExprNode fused_mm_relu_node = OutputGraphExprNode{fused_mm_relu_added.node}; - DataflowOutput fused_mm_relu_output = - get_only(fused_mm_relu_added.outputs); + KwargDataflowOutput fused_mm_relu_output = + require_only_key(fused_mm_relu_added.outputs, TensorSlotName::OUTPUT); return Substitution{ PCGPattern{pattern_g}, @@ -114,16 +180,48 @@ TEST_SUITE(FF_TEST_SUITE) { b.add_input(tensor_attribute_pattern_match_all()); PatternValue p_mm_output = - get_only(b.add_pattern_node(mm_pattern, - {p_input, p_weight}, - {tensor_attribute_pattern_match_all()}, - "mm")); + require_only_key(b.add_pattern_node( + /*node_pattern=*/mm_pattern, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + p_input, + }, + { + TensorSlotName::WEIGHT, + p_weight, + }, + }, + /*output_patterns=*/ + { + { + TensorSlotName::OUTPUT, + tensor_attribute_pattern_match_all(), + }, + }, + /*name=*/"mm"), + TensorSlotName::OUTPUT); PatternValue p_relu_output = - get_only(b.add_pattern_node(relu_pattern, - {p_mm_output}, - {tensor_attribute_pattern_match_all()}, - "relu")); + require_only_key(b.add_pattern_node( + /*node_pattern=*/relu_pattern, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + p_mm_output, + }, + }, + /*output_patterns=*/ + { + { + TensorSlotName::OUTPUT, + tensor_attribute_pattern_match_all(), + }, + }, + /*name=*/"relu"), + TensorSlotName::OUTPUT); OutputOperatorAttrsAssignment fused_mm_relu_attrs_assignment = OutputOperatorAttrsAssignment{ @@ -131,9 +229,24 @@ TEST_SUITE(FF_TEST_SUITE) { fused_mm_relu_attr_assignments, }; OutputGraphExprValue o_fused_output = - get_only(b.add_output_graph_node(fused_mm_relu_attrs_assignment, - {o_input, o_weight}, - nonnegative_int{1})); + require_only_key(b.add_output_graph_node( + /*node_expr=*/fused_mm_relu_attrs_assignment, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + o_input, + }, + { + TensorSlotName::WEIGHT, + o_weight, + }, + }, + /*output_slots=*/ + { + TensorSlotName::OUTPUT, + }), + TensorSlotName::OUTPUT); b.equate_outputs(p_relu_output, o_fused_output); diff --git a/lib/substitutions/test/src/substitutions/unity_substitution_set.cc b/lib/substitutions/test/src/substitutions/unity_substitution_set.cc index c86cb7e51f..ea8c8529ba 100644 --- a/lib/substitutions/test/src/substitutions/unity_substitution_set.cc +++ b/lib/substitutions/test/src/substitutions/unity_substitution_set.cc @@ -5,12 +5,10 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_substitution_set") { - MachineSpecification machine_spec = MachineSpecification{ + MachineComputeSpecification machine_spec = MachineComputeSpecification{ /*num_nodes=*/2_p, /*num_cpus_per_node=*/8_p, /*num_gpus_per_node=*/4_p, - /*inter_node_bandwidth=*/0.0, - /*intra_node_bandwidth=*/0.0, }; std::vector result = get_substitution_set(machine_spec); diff --git a/lib/substitutions/test/src/substitutions/unlabelled/find_pattern_matches.cc b/lib/substitutions/test/src/substitutions/unlabelled/find_pattern_matches.cc index ab79ad6ff6..9cb16f2923 100644 --- a/lib/substitutions/test/src/substitutions/unlabelled/find_pattern_matches.cc +++ b/lib/substitutions/test/src/substitutions/unlabelled/find_pattern_matches.cc @@ -1,185 +1,215 @@ #include "substitutions/unlabelled/find_pattern_matches.h" #include "substitutions/unlabelled/match_additional_criterion.h" #include "substitutions/unlabelled/pattern_matching.h" +#include "test/utils/doctest/fmt/vector.h" #include "utils/containers/get_only.h" -#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/containers/make_counter_func.h" +#include "utils/containers/require_only_key.h" +#include "utils/graph/instances/unordered_set_open_kwarg_dataflow_graph.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h" -#include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_open_kwarg_dataflow_values.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_incoming_open_kwarg_dataflow_edges_for_node.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_graph_subgraph.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_subgraph_inputs.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph.h" #include using namespace FlexFlow; -namespace rc { - -// template <> -// struct Arbitrary { -// static int const MAX_GRAPH_SIZE = 200; -// static int const MAX_EDGE_SIZE = 1000; -// -// static Gen arbitrary() { -// return gen::exec([&] { -// int num_nodes = *gen::inRange(1, MAX_GRAPH_SIZE + 1); -// MultiDiGraph g = MultiDiGraph::template -// create(); -// -// std::vector nodes; -// for (int i = 0; i < num_nodes; ++i) { -// nodes.push_back(g.add_node()); -// } -// -// int num_edges = *gen::inRange(1, MAX_GRAPH_SIZE + 1); -// for (int i = 0; i < num_edges; ++i) { -// int src_id = *gen::inRange(0, num_nodes); -// int dst_id = *gen::inRange(0, num_nodes); -// if (src_id > dst_id) { -// std::swap(src_id, dst_id); -// } -// -// g.add_edge(MultiDiEdge{nodes[dst_id], -// g.add_node_port(), -// nodes[src_id], -// g.add_node_port()}); -// } -// -// return g; -// }); -// } -// }; - -} // namespace rc - -// TEST_CASE("find_pattern_matches") { -// RC_SUBCASE([](MultiDiGraph const &g) { -// std::unordered_set subgraph_nodes = *rc::subset_of(get_nodes(g)); -// OpenMultiDiGraphView subgraph = -// get_subgraph(as_openmultidigraph(g), -// subgraph_nodes); -// -// std::vector matches = -// find_pattern_matches(subgraph, as_openmultidigraph(g), AlwaysTrue{}); -// -// RC_ASSERT(!matches.empty()); -// -// for (MultiDiGraphPatternMatch const &match : matches) { -// RC_ASSERT(pattern_matches(subgraph, as_openmultidigraph(g), match, -// AlwaysTrue{})); -// } -// }); - TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("find_pattern_matches") { - OpenDataflowGraph pattern_graph = - OpenDataflowGraph::create(); + OpenKwargDataflowGraph pattern_graph = + OpenKwargDataflowGraph::create< + UnorderedSetOpenKwargDataflowGraph>(); - NodeAddedResult pattern_n0_added = pattern_graph.add_node({}, 1_n); + KwargNodeAddedResult pattern_n0_added = pattern_graph.add_node( + /*inputs=*/{}, + /*outputs=*/{ + TensorSlotName::OUTPUT, + }); Node pattern_n0 = pattern_n0_added.node; - OpenDataflowValue pattern_v0 = - OpenDataflowValue{get_only(pattern_n0_added.outputs)}; + OpenKwargDataflowValue pattern_v0 = + OpenKwargDataflowValue{ + require_only_key(pattern_n0_added.outputs, TensorSlotName::OUTPUT), + }; - NodeAddedResult pattern_n1_added = - pattern_graph.add_node({pattern_v0}, 1_n); + KwargNodeAddedResult pattern_n1_added = pattern_graph.add_node( + /*inputs=*/ + { + { + TensorSlotName::INPUT, + pattern_v0, + }, + }, + /*outputs=*/{TensorSlotName::OUTPUT}); Node pattern_n1 = pattern_n1_added.node; - OpenDataflowValue pattern_v1 = - OpenDataflowValue{get_only(pattern_n1_added.outputs)}; + OpenKwargDataflowValue pattern_v1 = + OpenKwargDataflowValue{ + require_only_key(pattern_n1_added.outputs, TensorSlotName::OUTPUT), + }; UnlabelledGraphPattern pattern = UnlabelledGraphPattern{pattern_graph}; PatternNode p0 = PatternNode{pattern_n0}; PatternNode p1 = PatternNode{pattern_n1}; - OpenDataflowGraph graph = - OpenDataflowGraph::create(); + OpenKwargDataflowGraph graph = + OpenKwargDataflowGraph::create< + UnorderedSetOpenKwargDataflowGraph>(); - NodeAddedResult n0_added = graph.add_node({}, 1_n); + KwargNodeAddedResult n0_added = graph.add_node( + /*inputs=*/{}, + /*outputs=*/{ + TensorSlotName::OUTPUT, + }); Node n0 = n0_added.node; - OpenDataflowValue v0 = OpenDataflowValue{get_only(n0_added.outputs)}; + OpenKwargDataflowValue v0 = + OpenKwargDataflowValue{ + require_only_key(n0_added.outputs, TensorSlotName::OUTPUT), + }; - NodeAddedResult n1_added = graph.add_node({v0}, 1_n); + KwargNodeAddedResult n1_added = graph.add_node( + /*inputs=*/ + { + { + TensorSlotName::INPUT, + v0, + }, + }, + /*outputs=*/{ + TensorSlotName::OUTPUT, + }); Node n1 = n1_added.node; - OpenDataflowValue v1 = OpenDataflowValue{get_only(n1_added.outputs)}; + OpenKwargDataflowValue v1 = + OpenKwargDataflowValue{ + require_only_key(n1_added.outputs, TensorSlotName::OUTPUT), + }; - NodeAddedResult n2_added = graph.add_node({v1}, 1_n); + KwargNodeAddedResult n2_added = graph.add_node( + /*inputs=*/ + { + { + TensorSlotName::INPUT, + v1, + }, + }, + /*outputs=*/{ + TensorSlotName::OUTPUT, + }); Node n2 = n2_added.node; - OpenDataflowValue v2 = OpenDataflowValue{get_only(n2_added.outputs)}; + OpenKwargDataflowValue v2 = + OpenKwargDataflowValue{ + require_only_key(n2_added.outputs, TensorSlotName::OUTPUT), + }; - NodeAddedResult n3_added = graph.add_node({v2}, 1_n); + KwargNodeAddedResult n3_added = graph.add_node( + /*inputs=*/ + { + { + TensorSlotName::INPUT, + v2, + }, + }, + /*outputs=*/{ + TensorSlotName::OUTPUT, + }); Node n3 = n3_added.node; - OpenDataflowValue v3 = OpenDataflowValue{get_only(n3_added.outputs)}; + OpenKwargDataflowValue v3 = + OpenKwargDataflowValue{ + require_only_key(n3_added.outputs, TensorSlotName::OUTPUT), + }; - UnlabelledDataflowGraphPatternMatch match = - UnlabelledDataflowGraphPatternMatch{ + UnlabelledKwargDataflowGraphPatternMatch match = + UnlabelledKwargDataflowGraphPatternMatch{ bidict{ {p0, n0}, {p1, n1}, }, - bidict{}}; + bidict>{}}; - UnlabelledDataflowGraphPatternMatch invalid_match = - UnlabelledDataflowGraphPatternMatch{ + UnlabelledKwargDataflowGraphPatternMatch invalid_match = + UnlabelledKwargDataflowGraphPatternMatch{ bidict{ {p0, n1}, {p1, n2}, }, - bidict{}}; + bidict>{}}; - std::vector n1_incoming = {OpenDataflowEdge{ - DataflowEdge{ - DataflowOutput{n0, 0_n}, - DataflowInput{n1, 0_n}, - }, - }}; + std::unordered_map> + n1_incoming = { + { + TensorSlotName::INPUT, + OpenKwargDataflowEdge{ + KwargDataflowEdge{ + KwargDataflowOutput{n0, TensorSlotName::OUTPUT}, + KwargDataflowInput{n1, TensorSlotName::INPUT}, + }, + }, + }, + }; SUBCASE("get_incoming_edges") { SUBCASE("n0") { - std::vector result = get_incoming_edges(graph, n0); - std::vector correct = {}; + std::unordered_map> + result = get_incoming_open_kwarg_dataflow_edges_for_node(graph, n0); + std::unordered_map> + correct = {}; CHECK(result == correct); } + SUBCASE("n1") { - std::vector result = get_incoming_edges(graph, n1); - std::vector correct = n1_incoming; - CHECK(result == correct); - } - SUBCASE("both") { - std::unordered_map> result = - get_incoming_edges(graph, {n0, n1}); - std::unordered_map> correct = { - {n0, {}}, {n1, n1_incoming}}; + std::unordered_map> + result = get_incoming_open_kwarg_dataflow_edges_for_node(graph, n1); + std::unordered_map> + correct = n1_incoming; CHECK(result == correct); } } - SUBCASE("get_subgraph_inputs") { - std::unordered_set result = - get_subgraph_inputs(graph, {n0, n1}); - std::unordered_set correct = {}; + SUBCASE("get_open_kwarg_dataflow_subgraph_inputs") { + std::unordered_set> result = + get_open_kwarg_dataflow_subgraph_inputs(graph, {n0, n1}); + std::unordered_set> correct = + {}; CHECK(result == correct); } - SUBCASE("get_subgraph") { - OpenDataflowGraphView g = get_subgraph(graph, {n0, n1}).graph; + SUBCASE("get_open_kwarg_dataflow_graph_subgraph") { + int graph_input_ctr = 0; + OpenKwargDataflowGraphView g = + get_open_kwarg_dataflow_graph_subgraph( + graph, {n0, n1}, make_counter_func()) + .graph; + SUBCASE("nodes") { std::unordered_set result = get_nodes(g); std::unordered_set correct = {n0, n1}; CHECK(result == correct); } + SUBCASE("inputs") { - std::unordered_set result = g.get_inputs(); - std::unordered_set correct = {}; + std::unordered_set> result = + g.get_inputs(); + std::unordered_set> correct = {}; CHECK(result == correct); } - SUBCASE("get_open_dataflow_values") { - std::unordered_set values = - get_open_dataflow_values(g); + + SUBCASE("get_all_open_kwarg_dataflow_values") { + std::unordered_set> values = + get_all_open_kwarg_dataflow_values(g); CHECK(values.size() == 2); } } SUBCASE("subgraph_matched") { - OpenDataflowGraphView result = subgraph_matched(graph, match).graph; + OpenKwargDataflowGraphView result = + subgraph_matched(graph, match).graph; std::unordered_set result_nodes = get_nodes(result); std::unordered_set correct_nodes = {n0, n1}; CHECK(result_nodes == correct_nodes); @@ -196,24 +226,38 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("unlabelled_pattern_does_match (open)") { - OpenDataflowGraph g = - OpenDataflowGraph::create(); - DataflowGraphInput i0 = g.add_input(); + OpenKwargDataflowGraph g = + OpenKwargDataflowGraph::create< + UnorderedSetOpenKwargDataflowGraph>(); + KwargDataflowGraphInput i0 = g.add_input(0); - NodeAddedResult g_n0_added = g.add_node({OpenDataflowValue{i0}}, 1_n); + KwargNodeAddedResult g_n0_added = g.add_node( + /*inputs=*/ + { + { + TensorSlotName::INPUT, + OpenKwargDataflowValue{i0}, + }, + }, + /*outputs=*/{ + TensorSlotName::OUTPUT, + }); Node g_n0 = g_n0_added.node; - OpenDataflowValue g_v0 = OpenDataflowValue{get_only(g_n0_added.outputs)}; + OpenKwargDataflowValue g_v0 = + OpenKwargDataflowValue{ + require_only_key(g_n0_added.outputs, TensorSlotName::OUTPUT), + }; PatternNode g_p0 = PatternNode{g_n0}; PatternInput g_pi0 = PatternInput{i0}; UnlabelledGraphPattern open_pattern = UnlabelledGraphPattern{g}; - UnlabelledDataflowGraphPatternMatch open_match = - UnlabelledDataflowGraphPatternMatch{ + UnlabelledKwargDataflowGraphPatternMatch open_match = + UnlabelledKwargDataflowGraphPatternMatch{ bidict{ {g_p0, n1}, }, - bidict{ + bidict>{ {g_pi0, v0}, }}; CHECK(unlabelled_pattern_does_match( @@ -224,10 +268,10 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("find_pattern_matches") { - std::vector matches = - find_pattern_matches( + std::vector matches = + find_unlabelled_pattern_matches( pattern, graph, match_additional_crition_always_true()); - std::vector correct = {match}; + std::vector correct = {match}; CHECK(matches == correct); } diff --git a/lib/substitutions/test/src/substitutions/unlabelled/pattern_matching.cc b/lib/substitutions/test/src/substitutions/unlabelled/pattern_matching.cc index 8fd468d186..ef4650c8d5 100644 --- a/lib/substitutions/test/src/substitutions/unlabelled/pattern_matching.cc +++ b/lib/substitutions/test/src/substitutions/unlabelled/pattern_matching.cc @@ -1,14 +1,15 @@ #include "substitutions/unlabelled/pattern_matching.h" #include "substitutions/unlabelled/find_pattern_matches.h" #include "substitutions/unlabelled/match_additional_criterion.h" -#include "utils/containers/get_only.h" -#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/containers/make_counter_func.h" +#include "utils/containers/require_only_key.h" +#include "utils/graph/instances/unordered_set_open_kwarg_dataflow_graph.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h" -#include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_open_kwarg_dataflow_values.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_incoming_open_kwarg_dataflow_edges_for_node.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_graph_subgraph.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_subgraph_inputs.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph.h" #include "utils/overload.h" #include @@ -55,114 +56,194 @@ namespace rc { TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("find_pattern_matches") { - OpenDataflowGraph pattern_graph = - OpenDataflowGraph::create(); + OpenKwargDataflowGraph pattern_graph = + OpenKwargDataflowGraph::create< + UnorderedSetOpenKwargDataflowGraph>(); - NodeAddedResult pattern_n0_added = pattern_graph.add_node({}, 1_n); + KwargNodeAddedResult pattern_n0_added = pattern_graph.add_node( + /*inputs=*/{}, + /*outputs=*/{TensorSlotName::OUTPUT}); Node pattern_n0 = pattern_n0_added.node; - OpenDataflowValue pattern_v0 = - OpenDataflowValue{get_only(pattern_n0_added.outputs)}; + OpenKwargDataflowValue pattern_v0 = + OpenKwargDataflowValue{ + require_only_key(pattern_n0_added.outputs, TensorSlotName::OUTPUT), + }; - NodeAddedResult pattern_n1_added = - pattern_graph.add_node({pattern_v0}, 1_n); + KwargNodeAddedResult pattern_n1_added = pattern_graph.add_node( + /*inputs=*/ + { + { + TensorSlotName::INPUT, + pattern_v0, + }, + }, + /*outputs=*/{ + TensorSlotName::OUTPUT, + }); Node pattern_n1 = pattern_n1_added.node; - OpenDataflowValue pattern_v1 = - OpenDataflowValue{get_only(pattern_n1_added.outputs)}; + OpenKwargDataflowValue pattern_v1 = + OpenKwargDataflowValue{ + require_only_key(pattern_n1_added.outputs, TensorSlotName::OUTPUT), + }; UnlabelledGraphPattern pattern = UnlabelledGraphPattern{pattern_graph}; PatternNode p0 = PatternNode{pattern_n0}; PatternNode p1 = PatternNode{pattern_n1}; - OpenDataflowGraph graph = - OpenDataflowGraph::create(); + OpenKwargDataflowGraph graph = + OpenKwargDataflowGraph::create< + UnorderedSetOpenKwargDataflowGraph>(); - NodeAddedResult n0_added = graph.add_node({}, 1_n); + KwargNodeAddedResult n0_added = + graph.add_node({}, {TensorSlotName::OUTPUT}); Node n0 = n0_added.node; - OpenDataflowValue v0 = OpenDataflowValue{get_only(n0_added.outputs)}; + OpenKwargDataflowValue v0 = + OpenKwargDataflowValue{ + require_only_key(n0_added.outputs, TensorSlotName::OUTPUT), + }; - NodeAddedResult n1_added = graph.add_node({v0}, 1_n); + KwargNodeAddedResult n1_added = graph.add_node( + /*inputs=*/ + { + { + TensorSlotName::INPUT, + v0, + }, + }, + /*outputs=*/{ + TensorSlotName::OUTPUT, + }); Node n1 = n1_added.node; - OpenDataflowValue v1 = OpenDataflowValue{get_only(n1_added.outputs)}; + OpenKwargDataflowValue v1 = + OpenKwargDataflowValue{ + require_only_key(n1_added.outputs, TensorSlotName::OUTPUT), + }; - NodeAddedResult n2_added = graph.add_node({v1}, 1_n); + KwargNodeAddedResult n2_added = graph.add_node( + /*inputs=*/ + { + { + TensorSlotName::INPUT, + v1, + }, + }, + /*outputs=*/{ + TensorSlotName::OUTPUT, + }); Node n2 = n2_added.node; - OpenDataflowValue v2 = OpenDataflowValue{get_only(n2_added.outputs)}; + OpenKwargDataflowValue v2 = + OpenKwargDataflowValue{ + require_only_key(n2_added.outputs, TensorSlotName::OUTPUT), + }; - NodeAddedResult n3_added = graph.add_node({v2}, 1_n); + KwargNodeAddedResult n3_added = graph.add_node( + /*inputs=*/ + { + { + TensorSlotName::INPUT, + v2, + }, + }, + /*outputs=*/{ + TensorSlotName::OUTPUT, + }); Node n3 = n3_added.node; - OpenDataflowValue v3 = OpenDataflowValue{get_only(n3_added.outputs)}; + OpenKwargDataflowValue v3 = + OpenKwargDataflowValue{ + require_only_key(n3_added.outputs, TensorSlotName::OUTPUT), + }; - UnlabelledDataflowGraphPatternMatch match = - UnlabelledDataflowGraphPatternMatch{ + UnlabelledKwargDataflowGraphPatternMatch match = + UnlabelledKwargDataflowGraphPatternMatch{ bidict{ {p0, n0}, {p1, n1}, }, - bidict{}}; + bidict>{}}; - UnlabelledDataflowGraphPatternMatch invalid_match = - UnlabelledDataflowGraphPatternMatch{ + UnlabelledKwargDataflowGraphPatternMatch invalid_match = + UnlabelledKwargDataflowGraphPatternMatch{ bidict{ {p0, n1}, {p1, n2}, }, - bidict{}}; + bidict>{}}; - std::vector n1_incoming = {OpenDataflowEdge{ - DataflowEdge{ - DataflowOutput{n0, 0_n}, - DataflowInput{n1, 0_n}, - }, - }}; + std::unordered_map> + n1_incoming = { + { + TensorSlotName::INPUT, + OpenKwargDataflowEdge{ + KwargDataflowEdge{ + KwargDataflowOutput{n0, TensorSlotName::OUTPUT}, + KwargDataflowInput{n1, TensorSlotName::INPUT}, + }, + }, + }, + }; - SUBCASE("get_incoming_edges") { + SUBCASE("get_incoming_open_kwarg_dataflow_edges_for_node") { SUBCASE("n0") { - std::vector result = get_incoming_edges(graph, n0); - std::vector correct = {}; + std::unordered_map> + result = get_incoming_open_kwarg_dataflow_edges_for_node(graph, n0); + std::unordered_map> + correct = {}; CHECK(result == correct); } + SUBCASE("n1") { - std::vector result = get_incoming_edges(graph, n1); - std::vector correct = n1_incoming; - CHECK(result == correct); - } - SUBCASE("both") { - std::unordered_map> result = - get_incoming_edges(graph, {n0, n1}); - std::unordered_map> correct = { - {n0, {}}, {n1, n1_incoming}}; + std::unordered_map> + result = get_incoming_open_kwarg_dataflow_edges_for_node(graph, n1); + std::unordered_map> + correct = n1_incoming; CHECK(result == correct); } } - SUBCASE("get_subgraph_inputs") { - std::unordered_set result = - get_subgraph_inputs(graph, {n0, n1}); - std::unordered_set correct = {}; + SUBCASE("get_open_kwarg_dataflow_subgraph_inputs") { + std::unordered_set> result = + get_open_kwarg_dataflow_subgraph_inputs(graph, {n0, n1}); + std::unordered_set> correct = + {}; CHECK(result == correct); } - SUBCASE("get_subgraph") { - OpenDataflowGraphView g = get_subgraph(graph, {n0, n1}).graph; + SUBCASE("get_open_kwarg_dataflow_graph_subgraph") { + OpenKwargDataflowGraphView g = + get_open_kwarg_dataflow_graph_subgraph( + graph, {n0, n1}, make_counter_func()) + .graph; + SUBCASE("nodes") { std::unordered_set result = get_nodes(g); std::unordered_set correct = {n0, n1}; CHECK(result == correct); } + SUBCASE("inputs") { - std::unordered_set result = g.get_inputs(); - std::unordered_set correct = {}; + std::unordered_set> result = + g.get_inputs(); + std::unordered_set> correct = {}; CHECK(result == correct); } - SUBCASE("get_open_dataflow_values") { - std::unordered_set values = - get_open_dataflow_values(g); + + SUBCASE("get_all_open_kwarg_dataflow_values") { + std::unordered_set> values = + get_all_open_kwarg_dataflow_values(g); CHECK(values.size() == 2); } } SUBCASE("subgraph_matched") { - OpenDataflowGraphView result = subgraph_matched(graph, match).graph; + OpenKwargDataflowGraphView result = + subgraph_matched(graph, match).graph; std::unordered_set result_nodes = get_nodes(result); std::unordered_set correct_nodes = {n0, n1}; CHECK(result_nodes == correct_nodes); @@ -179,24 +260,38 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("unlabelled_pattern_does_match") { - OpenDataflowGraph g = - OpenDataflowGraph::create(); - DataflowGraphInput i0 = g.add_input(); + OpenKwargDataflowGraph g = + OpenKwargDataflowGraph::create< + UnorderedSetOpenKwargDataflowGraph>(); + KwargDataflowGraphInput i0 = g.add_input(0); - NodeAddedResult g_n0_added = g.add_node({OpenDataflowValue{i0}}, 1_n); + KwargNodeAddedResult g_n0_added = g.add_node( + /*inputs=*/ + { + { + TensorSlotName::INPUT, + OpenKwargDataflowValue{i0}, + }, + }, + /*outputs=*/{ + TensorSlotName::OUTPUT, + }); Node g_n0 = g_n0_added.node; - OpenDataflowValue g_v0 = OpenDataflowValue{get_only(g_n0_added.outputs)}; + OpenKwargDataflowValue g_v0 = + OpenKwargDataflowValue{ + require_only_key(g_n0_added.outputs, TensorSlotName::OUTPUT), + }; PatternNode g_p0 = PatternNode{g_n0}; PatternInput g_pi0 = PatternInput{i0}; UnlabelledGraphPattern open_pattern = UnlabelledGraphPattern{g}; - UnlabelledDataflowGraphPatternMatch open_match = - UnlabelledDataflowGraphPatternMatch{ + UnlabelledKwargDataflowGraphPatternMatch open_match = + UnlabelledKwargDataflowGraphPatternMatch{ bidict{ {g_p0, n1}, }, - bidict{ + bidict>{ {g_pi0, v0}, }}; diff --git a/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc b/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc index 1bddb9f680..458ab8a811 100644 --- a/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc +++ b/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc @@ -2,30 +2,37 @@ #include "substitutions/unlabelled/pattern_value.h" #include "substitutions/unlabelled/unlabelled_graph_pattern.h" #include "utils/containers/get_only.h" -#include "utils/graph/instances/unordered_set_dataflow_graph.h" -#include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include "utils/containers/require_only_key.h" +#include "utils/graph/instances/unordered_set_open_kwarg_dataflow_graph.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("pattern_split (sequential)") { - OpenDataflowGraph g = - OpenDataflowGraph::create(); + OpenKwargDataflowGraph g = + OpenKwargDataflowGraph::create< + UnorderedSetOpenKwargDataflowGraph>(); - NodeAddedResult n0_added = g.add_node({}, 1_n); + KwargNodeAddedResult n0_added = g.add_node({}, {TensorSlotName::OUTPUT}); Node n0 = n0_added.node; - OpenDataflowValue v0 = OpenDataflowValue{get_only(n0_added.outputs)}; + OpenKwargDataflowValue v0 = OpenKwargDataflowValue{ + require_only_key(n0_added.outputs, TensorSlotName::OUTPUT), + }; - NodeAddedResult n1_added = g.add_node({v0}, 1_n); + KwargNodeAddedResult n1_added = + g.add_node({{TensorSlotName::INPUT, v0}}, {TensorSlotName::OUTPUT}); Node n1 = n1_added.node; - OpenDataflowValue v1 = OpenDataflowValue{get_only(n1_added.outputs)}; + OpenKwargDataflowValue v1 = OpenKwargDataflowValue{ + require_only_key(n1_added.outputs, TensorSlotName::OUTPUT), + }; UnlabelledGraphPattern pattern = UnlabelledGraphPattern{g}; PatternNode p0 = PatternNode{n0}; PatternNode p1 = PatternNode{n1}; - PatternValue pv0 = pattern_value_from_raw_open_dataflow_value(v0); - PatternValue pv1 = pattern_value_from_raw_open_dataflow_value(v1); + PatternValue pv0 = pattern_value_from_raw_open_kwarg_dataflow_value(v0); + PatternValue pv1 = pattern_value_from_raw_open_kwarg_dataflow_value(v1); PatternSplit even_split = PatternSplit{ std::unordered_set{p0}, @@ -42,13 +49,13 @@ TEST_SUITE(FF_TEST_SUITE) { PatternSplitResult split_result = apply_split(pattern, even_split); SUBCASE("subpattern_1") { std::unordered_set result = - get_nodes(split_result.subpattern_1); + get_pattern_nodes(split_result.subpattern_1); std::unordered_set correct = even_split.first; CHECK(result == correct); } SUBCASE("subpattern_2") { std::unordered_set result = - get_nodes(split_result.subpattern_2); + get_pattern_nodes(split_result.subpattern_2); std::unordered_set correct = even_split.second; CHECK(result == correct); } @@ -61,7 +68,8 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("full_pattern_values_to_subpattern_2_inputs") { bidict result = split_result.full_pattern_values_to_subpattern_2_inputs; - PatternInput i0 = get_only(get_graph_inputs(split_result.subpattern_2)); + PatternInput i0 = + get_only(get_pattern_inputs(split_result.subpattern_2)); bidict correct = { {pv0, i0}, }; @@ -71,27 +79,38 @@ TEST_SUITE(FF_TEST_SUITE) { } TEST_CASE("pattern split (parallel)") { - OpenDataflowGraph g = - OpenDataflowGraph::create(); + OpenKwargDataflowGraph g = + OpenKwargDataflowGraph::create< + UnorderedSetOpenKwargDataflowGraph>(); - DataflowGraphInput i0 = g.add_input(); - DataflowGraphInput i1 = g.add_input(); + KwargDataflowGraphInput i0 = g.add_input(0); + KwargDataflowGraphInput i1 = g.add_input(1); - NodeAddedResult n0_added = g.add_node({OpenDataflowValue{i0}}, 1_n); + KwargNodeAddedResult n0_added = + g.add_node({{TensorSlotName::INPUT, + OpenKwargDataflowValue{i0}}}, + {TensorSlotName::OUTPUT}); Node n0 = n0_added.node; - OpenDataflowValue v0 = OpenDataflowValue{get_only(n0_added.outputs)}; + OpenKwargDataflowValue v0 = OpenKwargDataflowValue{ + require_only_key(n0_added.outputs, TensorSlotName::OUTPUT), + }; - NodeAddedResult n1_added = g.add_node({OpenDataflowValue{i1}}, 1_n); + KwargNodeAddedResult n1_added = + g.add_node({{TensorSlotName::INPUT, + OpenKwargDataflowValue{i1}}}, + {TensorSlotName::OUTPUT}); Node n1 = n1_added.node; - OpenDataflowValue v1 = OpenDataflowValue{get_only(n1_added.outputs)}; + OpenKwargDataflowValue v1 = OpenKwargDataflowValue{ + require_only_key(n1_added.outputs, TensorSlotName::OUTPUT), + }; UnlabelledGraphPattern pattern = UnlabelledGraphPattern{g}; PatternInput pi0 = PatternInput{i0}; PatternInput pi1 = PatternInput{i1}; PatternNode p0 = PatternNode{n0}; PatternNode p1 = PatternNode{n1}; - PatternValue pv0 = pattern_value_from_raw_open_dataflow_value(v0); - PatternValue pv1 = pattern_value_from_raw_open_dataflow_value(v1); + PatternValue pv0 = pattern_value_from_raw_open_kwarg_dataflow_value(v0); + PatternValue pv1 = pattern_value_from_raw_open_kwarg_dataflow_value(v1); PatternSplit even_split = PatternSplit{ std::unordered_set{p0}, @@ -102,13 +121,13 @@ TEST_SUITE(FF_TEST_SUITE) { PatternSplitResult split_result = apply_split(pattern, even_split); SUBCASE("subpattern_1") { std::unordered_set result = - get_nodes(split_result.subpattern_1); + get_pattern_nodes(split_result.subpattern_1); std::unordered_set correct = even_split.first; CHECK(result == correct); } SUBCASE("subpattern_2") { std::unordered_set result = - get_nodes(split_result.subpattern_2); + get_pattern_nodes(split_result.subpattern_2); std::unordered_set correct = even_split.second; CHECK(result == correct); } @@ -117,7 +136,7 @@ TEST_SUITE(FF_TEST_SUITE) { split_result.full_pattern_values_to_subpattern_1_inputs; bidict correct = { {PatternValue{pi0}, - get_only(get_graph_inputs(split_result.subpattern_1))}, + get_only(get_pattern_inputs(split_result.subpattern_1))}, }; CHECK(result == correct); } @@ -126,7 +145,7 @@ TEST_SUITE(FF_TEST_SUITE) { split_result.full_pattern_values_to_subpattern_2_inputs; bidict correct = { {PatternValue{pi1}, - get_only(get_graph_inputs(split_result.subpattern_2))}, + get_only(get_pattern_inputs(split_result.subpattern_2))}, }; CHECK(result == correct); } diff --git a/lib/substitutions/test/src/substitutions/unlabelled/unlabelled_graph_pattern.cc b/lib/substitutions/test/src/substitutions/unlabelled/unlabelled_graph_pattern.cc index 22d1b8a2a5..b0dc6c7e4b 100644 --- a/lib/substitutions/test/src/substitutions/unlabelled/unlabelled_graph_pattern.cc +++ b/lib/substitutions/test/src/substitutions/unlabelled/unlabelled_graph_pattern.cc @@ -1,15 +1,16 @@ #include "substitutions/unlabelled/unlabelled_graph_pattern.h" -#include "utils/containers/get_only.h" -#include "utils/graph/instances/unordered_set_dataflow_graph.h" -#include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include "utils/containers/require_only_key.h" +#include "utils/graph/instances/unordered_set_open_kwarg_dataflow_graph.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("is_singleton_pattern") { - OpenDataflowGraph g = - OpenDataflowGraph::create(); + OpenKwargDataflowGraph g = + OpenKwargDataflowGraph::create< + UnorderedSetOpenKwargDataflowGraph>(); SUBCASE("0 nodes") { UnlabelledGraphPattern pattern = UnlabelledGraphPattern{g}; @@ -17,8 +18,10 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK_FALSE(is_singleton_pattern(pattern)); } - NodeAddedResult n0_added = g.add_node({}, 1_n); - OpenDataflowValue v0 = OpenDataflowValue{get_only(n0_added.outputs)}; + KwargNodeAddedResult n0_added = g.add_node({}, {TensorSlotName::OUTPUT}); + OpenKwargDataflowValue v0 = OpenKwargDataflowValue{ + require_only_key(n0_added.outputs, TensorSlotName::OUTPUT), + }; SUBCASE("1 node") { UnlabelledGraphPattern pattern = UnlabelledGraphPattern{g}; @@ -26,8 +29,12 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(is_singleton_pattern(pattern)); } - NodeAddedResult n1_added = g.add_node({v0}, 1_n); - OpenDataflowValue v1 = OpenDataflowValue{get_only(n1_added.outputs)}; + KwargNodeAddedResult n1_added = + g.add_node({{TensorSlotName::INPUT, v0}}, {TensorSlotName::OUTPUT}); + OpenKwargDataflowValue v1 = + OpenKwargDataflowValue{ + require_only_key(n1_added.outputs, TensorSlotName::OUTPUT), + }; SUBCASE("more than 1 node") { UnlabelledGraphPattern pattern = UnlabelledGraphPattern{g}; diff --git a/lib/task-spec/CMakeLists.txt b/lib/task-spec/CMakeLists.txt index 8ccd8312cb..3c7c91af67 100644 --- a/lib/task-spec/CMakeLists.txt +++ b/lib/task-spec/CMakeLists.txt @@ -13,6 +13,7 @@ ff_add_library( kernels pcg spdlog + compiler ) add_subdirectory(test) diff --git a/lib/task-spec/include/task-spec/arg_ref.h b/lib/task-spec/include/task-spec/arg_ref.h index 8d3402c578..a0b4717f3a 100644 --- a/lib/task-spec/include/task-spec/arg_ref.h +++ b/lib/task-spec/include/task-spec/arg_ref.h @@ -1,10 +1,5 @@ -#ifndef _FLEXFLOW_LOCAL_EXECUTION_ARG_REF_H -#define _FLEXFLOW_LOCAL_EXECUTION_ARG_REF_H - -#include "kernels/ff_handle.h" -// #include "task-spec/serialization.h -#include "utils/type_index.h" -#include "utils/visitable.h" +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_ARG_REF_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_ARG_REF_H namespace FlexFlow { @@ -13,80 +8,6 @@ struct ArgRef { LABEL_TYPE ref_type; }; -template -struct ArgRefSpec { -public: - ArgRefSpec() = delete; - - template - bool holds() const { - return matches(this->type_idx); - } - - LABEL_TYPE const &get_ref_type() const { - return this->ref_type; - } - - std::type_index get_type_index() const { - return this->type_idx; - } - - bool operator==(ArgRefSpec const &other) const { - return this->tie() == other.tie(); - } - - bool operator!=(ArgRefSpec const &other) const { - return this->tie() != other.tie(); - } - - template - static ArgRefSpec create(ArgRef const &r) { - // static_assert(is_serializable::value, "Type must be serializeable"); - - return ArgRefSpec(get_type_index_for_type(), r.ref_type); - } - -private: - ArgRefSpec(std::type_index const &type_index, LABEL_TYPE ref_type) - : type_idx(type_index), ref_type(ref_type) {} - - std::type_index type_idx; - LABEL_TYPE ref_type; - - std::tuple - tie() const { - return std::tie(this->type_idx, this->ref_type); - } - friend struct std::hash>; -}; - -template -std::string format_as(ArgRefSpec const &x) { - std::ostringstream oss; - oss << ""; - return oss.str(); -} - -template -std::ostream &operator<<(std::ostream &s, ArgRefSpec const &x) { - return (s << fmt::to_string(x)); -} - } // namespace FlexFlow -namespace std { - -template -struct hash<::FlexFlow::ArgRefSpec> { - size_t operator()(::FlexFlow::ArgRefSpec const &s) const { - size_t result = 0; - ::FlexFlow::hash_combine(result, s.type_idx, s.get_ref_type()); - return result; - } -}; - -} // namespace std - #endif diff --git a/lib/task-spec/include/task-spec/arg_ref_spec.h b/lib/task-spec/include/task-spec/arg_ref_spec.h new file mode 100644 index 0000000000..59a81180e9 --- /dev/null +++ b/lib/task-spec/include/task-spec/arg_ref_spec.h @@ -0,0 +1,88 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_ARG_REF_SPEC_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_ARG_REF_SPEC_H + +#include "task-spec/arg_ref.h" +#include "utils/hash-utils.h" +#include "utils/type_index.h" + +namespace FlexFlow { + +template +struct ArgRefSpec { +public: + ArgRefSpec() = delete; + + template + bool holds() const { + return matches(this->type_idx); + } + + LABEL_TYPE const &get_ref_type() const { + return this->ref_type; + } + + std::type_index get_type_index() const { + return this->type_idx; + } + + bool operator==(ArgRefSpec const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(ArgRefSpec const &other) const { + return this->tie() != other.tie(); + } + + template + static ArgRefSpec create(ArgRef const &r) { + // static_assert(is_serializable::value, "Type must be serializeable"); + + return ArgRefSpec(get_type_index_for_type(), r.ref_type); + } + +private: + ArgRefSpec(std::type_index const &type_index, LABEL_TYPE ref_type) + : type_idx(type_index), ref_type(ref_type) {} + +private: + std::type_index type_idx; + LABEL_TYPE ref_type; + +private: + std::tuple + tie() const { + return std::tie(this->type_idx, this->ref_type); + } + friend struct std::hash>; +}; + +template +std::string format_as(ArgRefSpec const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} + +template +std::ostream &operator<<(std::ostream &s, ArgRefSpec const &x) { + return (s << fmt::to_string(x)); +} + +} // namespace FlexFlow + +namespace std { + +template +struct hash<::FlexFlow::ArgRefSpec> { + size_t operator()(::FlexFlow::ArgRefSpec const &s) const { + size_t result = 0; + ::FlexFlow::hash_combine(result, s.type_idx, s.get_ref_type()); + return result; + } +}; + +} // namespace std + +#endif diff --git a/lib/task-spec/include/task-spec/concrete_arg_spec.h b/lib/task-spec/include/task-spec/concrete_arg_spec.h index 24a96e9f78..45bbd6ba6b 100644 --- a/lib/task-spec/include/task-spec/concrete_arg_spec.h +++ b/lib/task-spec/include/task-spec/concrete_arg_spec.h @@ -1,10 +1,10 @@ #ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_CONCRETE_ARG_SPEC_H #define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_CONCRETE_ARG_SPEC_H -#include "fmt/format.h" #include "task-spec/serialization.h" #include "utils/hash-utils.h" #include "utils/type_index.h" +#include #include namespace FlexFlow { diff --git a/lib/task-spec/include/task-spec/config.h b/lib/task-spec/include/task-spec/config.h deleted file mode 100644 index ff7c4af5a5..0000000000 --- a/lib/task-spec/include/task-spec/config.h +++ /dev/null @@ -1,179 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef _FLEXFLOW_LOCAL_EXECUTION_CONFIG_H_ -#define _FLEXFLOW_LOCAL_EXECUTION_CONFIG_H_ - -#include "utils/fmt.h" -#include "utils/visitable.h" -#include - -namespace FlexFlow { - -enum class ComputationMode { - TRAINING, - INFERENCE, -}; - -// ======================================================== -// Define Runtime Constants -// ======================================================== -// Pre-assigned const flags -#define MAP_TO_FB_MEMORY 0xABCD0000 -#define MAP_TO_ZC_MEMORY 0xABCE0000 - -struct FFInitInfo : public use_visitable_cmp { - size_t workSpaceSize; - bool allowTensorOpMathConversion; -}; - -using legion_mapping_tag_id_t = unsigned long; - -struct FFConfig : public use_visitable_cmp { -public: - enum PreservedIDs { - InvalidID = 0, - DataParallelism_GPU = 1, - // DataParallelism_GPU_2D = 2, - // DataParallelism_GPU_3D = 3, - // DataParallelism_GPU_4D = 4, - // DataParallelism_GPU_5D = 5, - DataParallelism_CPU = 11, - // DataParallelism_CPU_2D = 12, - // DataParallelism_CPU_3D = 13, - // DataParallelism_CPU_4D = 14, - // DataParallelism_CPU_5D = 15, - }; - - FFConfig() = default; - static legion_mapping_tag_id_t get_hash_id(std::string const &pcname); - -public: - int epochs = 1; - int batchSize = 64; - int numNodes = 1; - int cpusPerNode = 0; - int workersPerNode = 0; - float learningRate = 0.01f; - float weightDecay = 0.0001f; - size_t workSpaceSize = (size_t)1 * 1024 * 1024 * 1024; // 2GB - bool profiling = false; - bool perform_fusion = false; - size_t simulator_work_space_size = (size_t)2 * 1024 * 1024 * 1024; // 2GB - size_t search_budget = -1; - float search_alpha = 1.2f; - bool search_overlap_backward_update = false; - ComputationMode computationMode = ComputationMode::TRAINING; - // Control parallelizable dimensions - bool only_data_parallel = false; - bool enable_parameter_parallel = false; - bool enable_inplace_optimizations = false; - // Control Tensor Op Math Conversion - bool allow_tensor_op_math_conversion = false; - std::optional dataset_path = std::nullopt; - std::optional export_strategy_computation_graph_file = - std::nullopt; - bool include_costs_dot_graph = false; - std::optional substitution_json_path = std::nullopt; - int machine_model_version = 0; - std::optional machine_model_file = std::nullopt; - int simulator_segment_size = 16777216; // 16 MB - int simulator_max_num_segments = 1; - std::optional search_num_nodes = std::nullopt; - std::optional search_num_workers = std::nullopt; - int base_optimize_threshold = 10; - bool enable_control_replication = true; - // The default python data loader type is 2 to enable control replication - int python_data_loader_type = 2; -}; - -struct FFIterationConfig { - FFIterationConfig() = delete; - void reset(); - int seq_length; -}; - -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(FFIterationConfig, seq_length); - -enum FieldIDs { - FID_DATA, -}; - -} // namespace FlexFlow - -VISITABLE_STRUCT(::FlexFlow::FFInitInfo, - workSpaceSize, - allowTensorOpMathConversion); -MAKE_VISIT_HASHABLE(::FlexFlow::FFInitInfo); - -VISITABLE_STRUCT(::FlexFlow::FFConfig, - epochs, - batchSize, - numNodes, - cpusPerNode, - workersPerNode, - learningRate, - weightDecay, - workSpaceSize, - profiling, - perform_fusion, - simulator_work_space_size, - search_budget, - search_alpha, - search_overlap_backward_update, - computationMode, - only_data_parallel, - enable_parameter_parallel, - enable_inplace_optimizations, - allow_tensor_op_math_conversion, - dataset_path, - export_strategy_computation_graph_file, - include_costs_dot_graph, - substitution_json_path, - machine_model_version, - machine_model_file, - simulator_segment_size, - simulator_max_num_segments, - search_num_nodes, - search_num_workers, - base_optimize_threshold, - enable_control_replication, - python_data_loader_type); - -namespace fmt { - -template <> -struct formatter<::FlexFlow::ComputationMode> : formatter { - template - auto format(::FlexFlow::ComputationMode m, FormatContext &ctx) const - -> decltype(ctx.out()) { - using namespace FlexFlow; - - string_view name = "unknown"; - switch (m) { - case ComputationMode::TRAINING: - name = "Training"; - break; - case ComputationMode::INFERENCE: - name = "Inference"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt - -#endif diff --git a/lib/task-spec/include/task-spec/device_specific.h b/lib/task-spec/include/task-spec/device_specific.h index 3ef017f704..2055888b1b 100644 --- a/lib/task-spec/include/task-spec/device_specific.h +++ b/lib/task-spec/include/task-spec/device_specific.h @@ -1,19 +1,18 @@ #ifndef _FLEXFLOW_LOCAL_EXECUTION_DEVICE_SPECIFIC_H #define _FLEXFLOW_LOCAL_EXECUTION_DEVICE_SPECIFIC_H +#include "pcg/device_id_t.dtg.h" #include "task-spec/serialization.h" -#include "utils/exception.h" +#include "utils/hash/tuple.h" namespace FlexFlow { template struct DeviceSpecific { - DeviceSpecific() = delete; template - static DeviceSpecific create(Args &&...args) { - size_t device_idx = 0; + static DeviceSpecific create(device_id_t device_idx, Args &&...args) { return DeviceSpecific(std::make_shared(std::forward(args)...), device_idx); } @@ -26,31 +25,38 @@ struct DeviceSpecific { return this->tie() != other.tie(); } - T const *get(size_t curr_device_idx) const { - if (curr_device_idx != this->device_idx) { - throw mk_runtime_error( - fmt::format("Invalid access to DeviceSpecific: attempted " - "device_idx {} != correct device_idx {})", - curr_device_idx, - this->device_idx)); - } + T const *get(device_id_t curr_device_idx) const { + ASSERT(curr_device_idx == this->device_idx); return (T const *)this->ptr.get(); } - // TODO: can modify ptr - private: - DeviceSpecific(std::shared_ptr ptr, size_t device_idx) + DeviceSpecific(std::shared_ptr ptr, device_id_t device_idx) : ptr(ptr), device_idx(device_idx) {} +private: std::shared_ptr ptr; - size_t device_idx; + device_id_t device_idx; +private: std::tuple tie() const { return std::tie(this->ptr, this->device_idx); } + + friend struct ::std::hash>; + + friend std::string format_as(DeviceSpecific const &d) { + return fmt::format("DeviceSpecific({:p}, {})", + static_cast(d.ptr.get()), + d.device_idx); + } }; +template +std::ostream &operator<<(std::ostream &s, DeviceSpecific const &d) { + return (s << fmt::to_string(d)); +} + // manually force serialization to make DeviceSpecific trivially // serializable // template @@ -58,4 +64,15 @@ struct DeviceSpecific { } // namespace FlexFlow +namespace std { + +template +struct hash<::FlexFlow::DeviceSpecific> { + size_t operator()(::FlexFlow::DeviceSpecific const &x) const { + return get_std_hash(x.tie()); + } +}; + +} // namespace std + #endif diff --git a/lib/task-spec/include/task-spec/device_specific_device_states.variant.toml b/lib/task-spec/include/task-spec/device_specific_device_states.variant.toml deleted file mode 100644 index b77850c50d..0000000000 --- a/lib/task-spec/include/task-spec/device_specific_device_states.variant.toml +++ /dev/null @@ -1,75 +0,0 @@ -namespace = "FlexFlow" -name = "DeviceSpecificDeviceStates" -features = [ - "eq", -] - -includes = [ - "kernels/mha_per_device_state.dtg.h", - "kernels/batch_norm_per_device_state.dtg.h", - "kernels/conv_2d_per_device_state.dtg.h", - "kernels/dropout_per_device_state.dtg.h", - "kernels/element_binary_per_device_state.dtg.h", - "kernels/element_unary_per_device_state.dtg.h", - "kernels/gather_per_device_state.dtg.h", - "kernels/layer_norm_per_device_state.dtg.h", - "kernels/linear_per_device_state.dtg.h", - "kernels/partition_per_device_state.dtg.h", - "kernels/pool_2d_per_device_state.dtg.h", - "kernels/reduce_per_device_state.dtg.h", - "kernels/softmax_per_device_state.dtg.h", - "task-spec/device_specific.h", - "", -] - -[[values]] -type = "::FlexFlow::DeviceSpecific>" -key = "device_specific_mha_per_device_state" - -[[values]] -type = "::FlexFlow::DeviceSpecific>" -key = "device_specific_batch_norm_per_device_state" - -[[values]] -type = "::FlexFlow::DeviceSpecific>" -key = "device_specific_conv2d_per_device_state" - -[[values]] -type = "::FlexFlow::DeviceSpecific>" -key = "device_specific_dropout_per_device_state" - -[[values]] -type = "::FlexFlow::DeviceSpecific>" -key = "device_specific_element_binary_per_device_state" - -[[values]] -type = "::FlexFlow::DeviceSpecific>" -key = "device_specific_element_unary_per_device_state" - -[[values]] -type = "::FlexFlow::DeviceSpecific>" -key = "device_specific_gather_per_device_state" - -[[values]] -type = "::FlexFlow::DeviceSpecific>" -key = "device_specific_layer_norm_per_device_state" - -[[values]] -type = "::FlexFlow::DeviceSpecific>" -key = "device_specific_linear_per_device_state" - -[[values]] -type = "::FlexFlow::DeviceSpecific>" -key = "device_specific_pool_2d_per_device_state" - -[[values]] -type = "::FlexFlow::DeviceSpecific>" -key = "device_specific_reduce_per_device_state" - -[[values]] -type = "::FlexFlow::DeviceSpecific>" -key = "device_specific_repartition_per_device_state" - -[[values]] -type = "::FlexFlow::DeviceSpecific>" -key = "device_specific_softmax_per_device_state" diff --git a/lib/task-spec/include/task-spec/device_specific_per_device_op_state.dtg.toml b/lib/task-spec/include/task-spec/device_specific_per_device_op_state.dtg.toml new file mode 100644 index 0000000000..4435a472ce --- /dev/null +++ b/lib/task-spec/include/task-spec/device_specific_per_device_op_state.dtg.toml @@ -0,0 +1,82 @@ +namespace = "FlexFlow" +name = "DeviceSpecificPerDeviceOpState" +type = "variant" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "kernels/mha_per_device_state.dtg.h", + "kernels/batch_norm_per_device_state.dtg.h", + "kernels/conv_2d_per_device_state.dtg.h", + "kernels/dropout_per_device_state.dtg.h", + "kernels/element_binary_per_device_state.dtg.h", + "kernels/element_unary_per_device_state.dtg.h", + "kernels/gather_per_device_state.dtg.h", + "kernels/layer_norm_per_device_state.dtg.h", + "kernels/linear_per_device_state.dtg.h", + "kernels/partition_per_device_state.dtg.h", + "kernels/pool_2d_per_device_state.dtg.h", + "kernels/reduce_per_device_state.dtg.h", + "kernels/softmax_per_device_state.dtg.h", + "task-spec/device_specific.h", + "", +] + +src_includes = [ + "utils/fmt/optional.h", +] + +[[values]] +type = "::FlexFlow::DeviceSpecific>" +key = "device_specific_mha_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific>" +key = "device_specific_batch_norm_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific>" +key = "device_specific_conv2d_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific>" +key = "device_specific_dropout_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific>" +key = "device_specific_element_binary_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific>" +key = "device_specific_element_unary_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific>" +key = "device_specific_gather_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific>" +key = "device_specific_layer_norm_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific>" +key = "device_specific_linear_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific>" +key = "device_specific_pool_2d_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific>" +key = "device_specific_reduce_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific>" +key = "device_specific_repartition_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific>" +key = "device_specific_softmax_per_device_state" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_loss_config.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_loss_config.dtg.toml new file mode 100644 index 0000000000..6195b37301 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_loss_config.dtg.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "DynamicLossConfig" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "op-attrs/ops/loss_functions/loss_attrs.dtg.h", + "task-spec/dynamic_graph/dynamic_value_guid_t.dtg.h" +] + +[[fields]] +name = "loss_attrs" +type = "::FlexFlow::LossAttrs" + +[[fields]] +name = "logit_tensor" +type = "::FlexFlow::dynamic_value_guid_t" + +[[fields]] +name = "logit_grad_tensor" +type = "::FlexFlow::dynamic_value_guid_t" + +[[fields]] +name = "label_tensor" +type = "::FlexFlow::dynamic_value_guid_t" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_attr_keys.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_attr_keys.dtg.toml new file mode 100644 index 0000000000..3132d47d4e --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_attr_keys.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "DynamicNodeAttrKeys" +type = "enum" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "DEVICE_ID" + +[[values]] +name = "PASS" + +[[values]] +name = "MACHINE_VIEW" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_attrs.dtg.toml new file mode 100644 index 0000000000..14f361ca75 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_attrs.dtg.toml @@ -0,0 +1,46 @@ +namespace = "FlexFlow" +name = "DynamicNodeAttrs" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "task-spec/dynamic_graph/dynamic_task_type.dtg.h", + "pcg/machine_space_coordinate.dtg.h", + "pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h", + "op-attrs/pcg_operator_attrs.dtg.h", + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "task-spec/device_specific_per_device_op_state.dtg.h", +] + +src_includes = [ + "utils/fmt/optional.h", +] + +[[fields]] +name = "task_type" +type = "std::optional<::FlexFlow::DynamicTaskType>" + +[[fields]] +name = "device_coord" +type = "std::optional<::FlexFlow::MachineSpaceCoordinate>" + +[[fields]] +name = "mapping" +type = "std::optional<::FlexFlow::MappedOperatorTaskGroup>" + +[[fields]] +name = "op_attrs" +type = "std::optional<::FlexFlow::PCGOperatorAttrs>" + +[[fields]] +name = "pcg_layer_guid" +type = "::FlexFlow::parallel_layer_guid_t" + +[[fields]] +name = "per_device_op_state" +type = "std::optional<::FlexFlow::DeviceSpecificPerDeviceOpState>" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_guid_t.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_guid_t.dtg.toml new file mode 100644 index 0000000000..7bfe4d14ba --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_guid_t.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "dynamic_node_guid_t" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "raw_node" +type = "::FlexFlow::Node" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation.dtg.toml new file mode 100644 index 0000000000..07060106c0 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation.dtg.toml @@ -0,0 +1,32 @@ +namespace = "FlexFlow" +name = "DynamicNodeInvocation" +type = "struct" +features = [ + "eq", + "fmt", + "hash", +] + +includes = [ + "", + "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h", + "task-spec/dynamic_graph/dynamic_tensor_slot.dtg.h", + "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/fmt/unordered_map.h", +] + +[[fields]] +name = "inputs" +type = "std::unordered_map<::FlexFlow::DynamicTensorSlot, ::FlexFlow::DynamicValueAttrs>" + +[[fields]] +name = "node_attrs" +type = "::FlexFlow::DynamicNodeAttrs" + +[[fields]] +name = "outputs" +type = "std::unordered_map<::FlexFlow::DynamicTensorSlot, ::FlexFlow::DynamicValueAttrs>" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation.h b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation.h new file mode 100644 index 0000000000..94a4886b49 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_DYNAMIC_NODE_INVOCATION_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_DYNAMIC_NODE_INVOCATION_H + +#include "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h" + +namespace FlexFlow { + +bool invocation_fully_satisfies_expansion_conditions( + std::function const &node_condition, + std::function const &slot_condition, + std::function const &) { + +] + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.toml new file mode 100644 index 0000000000..ba16732364 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "DynamicOpenDataflowGraph" +type = "struct" +features = [ + "eq", +] + +includes = [ + "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h", + "", +] + +[[fields]] +name = "invocations" +type = "std::unordered_set<::FlexFlow::DynamicNodeInvocation>" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_open_dataflow_graph.h b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_open_dataflow_graph.h new file mode 100644 index 0000000000..a3bbba592f --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_open_dataflow_graph.h @@ -0,0 +1,52 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_DYNAMIC_OPEN_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_DYNAMIC_OPEN_DATAFLOW_GRAPH_H + +#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" + +namespace FlexFlow { + +DynamicOpenDataflowGraph make_empty_dynamic_open_dataflow_graph(); + +nonnegative_int dynamic_graph_num_nodes(DynamicOpenDataflowGraph const &); + +bool full_dynamic_graph_satisfies( + DynamicOpenDataflowGraph const &, + std::function const &, + std::function const &, + std::function const &); + +bool no_part_of_dynamic_graph_satisfies( + DynamicOpenDataflowGraph const &, + std::function const &, + std::function const &, + std::function const &); + +std::unordered_multiset + get_dynamic_nodes(DynamicOpenDataflowGraph const &); +std::unordered_multiset + get_dynamic_values(DynamicOpenDataflowGraph const &); +std::unordered_multiset + get_dynamic_tensor_slots(DynamicOpenDataflowGraph const &); +std::unordered_set + get_dynamic_invocation_set(DynamicOpenDataflowGraph const &); + +DynamicOpenDataflowGraph transform_dynamic_invocation_set( + DynamicOpenDataflowGraph const &, + std::function const + &); + +DynamicOpenDataflowGraph flatmap_dynamic_invocation_set( + DynamicOpenDataflowGraph const &, + std::function( + DynamicNodeInvocation const &)> const &); + +DynamicOpenDataflowGraph dynamic_open_dataflow_graph_from_invocation_set( + std::unordered_set const &); + +bool dynamic_open_dataflow_graphs_are_isomorphic( + DynamicOpenDataflowGraph const &, DynamicOpenDataflowGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_optimizer_tensor_role.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_optimizer_tensor_role.dtg.toml new file mode 100644 index 0000000000..4dd1d406e8 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_optimizer_tensor_role.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "DynamicOptimizerTensorRole" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", + "rapidcheck", +] + +includes = [ + "pcg/optimizer_slot_name.dtg.h", +] + +[[fields]] +name = "optimizer_slot_name" +type = "::FlexFlow::OptimizerSlotName" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_parameter_guid_t.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_parameter_guid_t.dtg.toml new file mode 100644 index 0000000000..952f4cc248 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_parameter_guid_t.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "dynamic_parameter_guid_t" +type = "struct" +features = [ + "eq", + "ord", + "hash", +] + +includes = [ + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", +] + +[[fields]] +name = "raw_value" +type = "::FlexFlow::DataflowGraphInput" + diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_task_type.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_task_type.dtg.toml new file mode 100644 index 0000000000..2885d7d0d3 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_task_type.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "DynamicTaskType" +type = "enum" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "FWD" + +[[values]] +name = "BWD" + +[[values]] +name = "UPD" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_role.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_role.dtg.toml new file mode 100644 index 0000000000..91d05dbc2d --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_role.dtg.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "DynamicTensorRole" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "task-spec/fwb_tensor_type.dtg.h", + "task-spec/dynamic_graph/dynamic_optimizer_tensor_role.dtg.h", +] + +[[values]] +type = "::FlexFlow::FwbTensorType" +key = "fwb_tensor" + +[[values]] +type = "::FlexFlow::DynamicOptimizerTensorRole" +key = "optimizer_tensor" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_role.h b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_role.h new file mode 100644 index 0000000000..374230bd0d --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_role.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_DYNAMIC_TENSOR_ROLE_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_DYNAMIC_TENSOR_ROLE_H + +#include "pcg/optimizer_slot_name.dtg.h" +#include "task-spec/dynamic_graph/dynamic_tensor_role.dtg.h" + +namespace FlexFlow { + +DynamicTensorRole dynamic_tensor_role_from_fwb_tensor_type(FwbTensorType); + +DynamicTensorRole mk_dynamic_tensor_role_fwd(); +DynamicTensorRole mk_dynamic_tensor_role_bwd(); +DynamicTensorRole mk_dynamic_tensor_role_opt(OptimizerSlotName); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_slot.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_slot.dtg.toml new file mode 100644 index 0000000000..378582f428 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_slot.dtg.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "DynamicTensorSlot" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "op-attrs/tensor_slot_name.dtg.h", + "task-spec/dynamic_graph/dynamic_tensor_role.dtg.h", + "", +] + +src_includes = [ + "utils/json/optional.h", + "utils/fmt/optional.h", +] + +[[fields]] +name = "slot_name" +type = "::FlexFlow::TensorSlotName" + +[[fields]] +name = "slot_tensor_role" +type = "std::optional<::FlexFlow::DynamicTensorRole>" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_slot.h b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_slot.h new file mode 100644 index 0000000000..129f1a2eae --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_slot.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_DYNAMIC_TENSOR_SLOT_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_DYNAMIC_TENSOR_SLOT_H + +#include "task-spec/dynamic_graph/dynamic_tensor_slot.dtg.h" + +namespace FlexFlow { + +DynamicTensorSlot decide_tensor_slot_role(DynamicTensorSlot const &, + DynamicTensorRole); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attr_keys.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attr_keys.dtg.toml new file mode 100644 index 0000000000..4cca1b93c1 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attr_keys.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "DynamicValueAttrKeys" +type = "enum" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "PARALLEL_TENSOR_SHAPE" + +[[values]] +name = "ACCESSOR" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml new file mode 100644 index 0000000000..2332b2c93b --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml @@ -0,0 +1,41 @@ +namespace = "FlexFlow" +name = "DynamicValueAttrs" +type = "struct" +features = [ + "eq", + "fmt", + "hash", +] + +includes = [ + "", + "op-attrs/parallel_tensor_shape.dtg.h", + "kernels/accessor.h", + "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h", + "op-attrs/parallel_tensor_space_coordinate.dtg.h", + "task-spec/dynamic_graph/dynamic_tensor_role.dtg.h", +] + +src_includes = [ + "utils/fmt/optional.h", +] + +[[fields]] +name = "pcg_tensor_guid" +type = "::FlexFlow::parallel_tensor_guid_t" + +[[fields]] +name = "parallel_tensor_shape" +type = "std::optional<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "shard_coord" +type = "std::optional<::FlexFlow::ParallelTensorSpaceCoordinate>" + +[[fields]] +name = "accessor" +type = "std::optional<::FlexFlow::GenericTensorAccessorW>" + +[[fields]] +name = "role" +type = "std::optional<::FlexFlow::DynamicTensorRole>" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.h b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.h new file mode 100644 index 0000000000..9cccc565cc --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_DYNAMIC_VALUE_ATTRS_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_DYNAMIC_VALUE_ATTRS_H + +#include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" + +namespace FlexFlow { + +DynamicValueAttrs decide_dynamic_value_attrs_role(DynamicValueAttrs const &, + DynamicTensorRole); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_guid_t.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_guid_t.dtg.toml new file mode 100644 index 0000000000..8351c47faa --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_guid_t.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "dynamic_value_guid_t" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h", +] + +[[fields]] +name = "raw_value" +type = "::FlexFlow::OpenDataflowValue" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/machine_slicing.h b/lib/task-spec/include/task-spec/dynamic_graph/machine_slicing.h new file mode 100644 index 0000000000..823f962c25 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/machine_slicing.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_MACHINE_SLICING_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_MACHINE_SLICING_H + +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" + +namespace FlexFlow { + +std::unordered_set + perform_machine_slicing_for_invocation(DynamicNodeInvocation const &, + MachineSpaceCoordinate const &); + +DynamicOpenDataflowGraph + perform_machine_slicing(DynamicOpenDataflowGraph const &, + MachineSpaceCoordinate const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/dynamic_graph/pass_expansion.h b/lib/task-spec/include/task-spec/dynamic_graph/pass_expansion.h new file mode 100644 index 0000000000..6dce8ad514 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/pass_expansion.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_PASS_EXPANSION_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_PASS_EXPANSION_H + +#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" + +namespace FlexFlow { + +bool node_is_pass_expanded(DynamicNodeAttrs const &); +bool value_is_pass_expanded(DynamicValueAttrs const &); +bool slot_is_pass_expanded(DynamicTensorSlot const &); + +bool no_part_of_graph_is_pass_expanded(DynamicOpenDataflowGraph const &); +bool graph_is_fully_pass_expanded(DynamicOpenDataflowGraph const &); + +DynamicNodeInvocation + perform_fwd_pass_expansion_for_invocation(DynamicNodeInvocation const &); +DynamicNodeInvocation + perform_bwd_pass_expansion_for_invocation(DynamicNodeInvocation const &); + +DynamicOpenDataflowGraph + perform_pass_expansion(DynamicOpenDataflowGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/dynamic_graph/shard_expansion.h b/lib/task-spec/include/task-spec/dynamic_graph/shard_expansion.h new file mode 100644 index 0000000000..4e0db1cd7e --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/shard_expansion.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_SHARD_EXPANSION_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_SHARD_EXPANSION_H + +#include "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h" +#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" + +namespace FlexFlow { + +bool node_is_shard_expanded(DynamicNodeAttrs const &); +bool value_is_shard_expanded(DynamicValueAttrs const &); + +bool no_part_of_graph_is_shard_expanded(DynamicOpenDataflowGraph const &); +bool graph_is_fully_shard_expanded(DynamicOpenDataflowGraph const &); + +std::unordered_set + perform_shard_expansion_for_invocation(DynamicNodeInvocation const &); + +DynamicOpenDataflowGraph + perform_shard_expansion(DynamicOpenDataflowGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/dynamic_graph/update_insertion.h b/lib/task-spec/include/task-spec/dynamic_graph/update_insertion.h new file mode 100644 index 0000000000..23fb7050a0 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/update_insertion.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_UPDATE_INSERTION_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_UPDATE_INSERTION_H + +#include "pcg/optimizer_attrs.dtg.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" + +namespace FlexFlow { + +std::unordered_set + perform_update_insertion_for_invocation(DynamicNodeInvocation const &, + OptimizerAttrs const &); + +DynamicOpenDataflowGraph + perform_update_insertion(DynamicOpenDataflowGraph const &, + OptimizerAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ff_config.dtg.toml b/lib/task-spec/include/task-spec/ff_config.dtg.toml new file mode 100644 index 0000000000..3d2b6a6b1d --- /dev/null +++ b/lib/task-spec/include/task-spec/ff_config.dtg.toml @@ -0,0 +1,116 @@ +namespace = "FlexFlow" +name = "FFConfig" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", + "rapidcheck", + "json", +] + +includes = [ + "utils/positive_int/positive_int.h", + "utils/nonnegative_int/nonnegative_int.h", + "", + "", +] + +src_includes = [ + "utils/rapidcheck/optional.h", + "utils/json/optional.h", + "utils/fmt/optional.h", +] + +[[fields]] +name = "epochs" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "batch_size" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "num_nodes" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "cpus_per_node" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "gpus_per_node" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "learning_rate" +type = "float" + +[[fields]] +name = "weight_decay" +type = "float" + +[[fields]] +name = "workspace_size" +type = "::FlexFlow::nonnegative_int" + +[[fields]] +name = "enable_profiling" +type = "bool" + +[[fields]] +name = "perform_fusion" +type = "bool" + +[[fields]] +name = "simulator_workspace_size" +type = "::FlexFlow::nonnegative_int" + +[[fields]] +name = "search_budget" +type = "std::optional<::FlexFlow::positive_int>" + +[[fields]] +name = "search_alpha" +type = "float" + +[[fields]] +name = "search_overlap_backward_update" +type = "bool" + +[[fields]] +name = "only_data_parallel" +type = "bool" + +[[fields]] +name = "enable_parameter_parallel" +type = "bool" + +[[fields]] +name = "enable_inplace_optimizations" +type = "bool" + +[[fields]] +name = "allow_tensor_op_math_conversion" +type = "bool" + +[[fields]] +name = "dataset_path" +type = "std::optional" + +[[fields]] +name = "export_strategy_computation_graph_file" +type = "std::optional" + +[[fields]] +name = "include_costs_dot_graph" +type = "bool" + +[[fields]] +name = "substitution_json_path" +type = "std::optional" + +[[fields]] +name = "base_optimize_threshold" +type = "::FlexFlow::positive_int" diff --git a/lib/task-spec/include/task-spec/ff_init_info.dtg.toml b/lib/task-spec/include/task-spec/ff_init_info.dtg.toml new file mode 100644 index 0000000000..9f6c3f9342 --- /dev/null +++ b/lib/task-spec/include/task-spec/ff_init_info.dtg.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "FFInitInfo" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", + "rapidcheck", + "json", +] + +includes = [ + "utils/nonnegative_int/nonnegative_int.h", +] + +[[fields]] +name = "workspace_size" +type = "::FlexFlow::nonnegative_int" + +[[fields]] +name = "allow_tensor_op_math_conversion" +type = "bool" diff --git a/lib/task-spec/include/task-spec/ff_iteration_config.dtg.toml b/lib/task-spec/include/task-spec/ff_iteration_config.dtg.toml new file mode 100644 index 0000000000..c8a8bb61b0 --- /dev/null +++ b/lib/task-spec/include/task-spec/ff_iteration_config.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "FFIterationConfig" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", + "rapidcheck", + "json", +] + +includes = [ + "utils/positive_int/positive_int.h", +] + +[[fields]] +name = "seq_length" +type = "::FlexFlow::positive_int" diff --git a/lib/task-spec/include/task-spec/forward_tensor_guid_t.struct.toml b/lib/task-spec/include/task-spec/forward_tensor_guid_t.struct.toml deleted file mode 100644 index 68fc4b6815..0000000000 --- a/lib/task-spec/include/task-spec/forward_tensor_guid_t.struct.toml +++ /dev/null @@ -1,13 +0,0 @@ -namespace = "FlexFlow" -name = "forward_tensor_guid_t" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - - -[[fields]] -name = "raw_index" -type = "int" diff --git a/lib/task-spec/include/task-spec/forward_tensor_source.h b/lib/task-spec/include/task-spec/forward_tensor_source.h deleted file mode 100644 index 7adde6e145..0000000000 --- a/lib/task-spec/include/task-spec/forward_tensor_source.h +++ /dev/null @@ -1,22 +0,0 @@ -#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_FORWARD_TENSOR_SOURCE_H -#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_FORWARD_TENSOR_SOURCE_H - -#include "task-spec/forward_tensor_guid_t.dtg.h" - -namespace FlexFlow { - -struct ForwardTensorSource { -public: - ForwardTensorSource(); - - forward_tensor_guid_t new_forward_tensor(); - - void reset(); - -private: - static int next_available_forward_tensor_id; -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/fwb_tensor_type.dtg.toml b/lib/task-spec/include/task-spec/fwb_tensor_type.dtg.toml new file mode 100644 index 0000000000..b494d88561 --- /dev/null +++ b/lib/task-spec/include/task-spec/fwb_tensor_type.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "FwbTensorType" +type = "enum" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "FORWARD" + +[[values]] +name = "GRADIENT" diff --git a/lib/task-spec/include/task-spec/fwd_bwd_op_task_impl_function.h b/lib/task-spec/include/task-spec/fwd_bwd_op_task_impl_function.h index 3620ff87cb..fddad49ddf 100644 --- a/lib/task-spec/include/task-spec/fwd_bwd_op_task_impl_function.h +++ b/lib/task-spec/include/task-spec/fwd_bwd_op_task_impl_function.h @@ -1,13 +1,14 @@ #ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_FWD_BWD_OP_TASK_IMPL_FUNCTION_H #define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_FWD_BWD_OP_TASK_IMPL_FUNCTION_H -#include "task-spec/task_argument_accessor.h" +#include "task-spec/task_argument_accessor/task_argument_accessor.h" +#include "utils/units/milliseconds_t.h" namespace FlexFlow { struct FwdBwdOpTaskImplFunction { - std::optional (*function_ptr)(TaskArgumentAccessor const &); + std::optional (*function_ptr)(TaskArgumentAccessor const &); bool operator==(FwdBwdOpTaskImplFunction const &) const; bool operator!=(FwdBwdOpTaskImplFunction const &) const; diff --git a/lib/task-spec/include/task-spec/generic_task_impl_function.h b/lib/task-spec/include/task-spec/generic_task_impl_function.h index 31bf132e4f..a4707a2f6f 100644 --- a/lib/task-spec/include/task-spec/generic_task_impl_function.h +++ b/lib/task-spec/include/task-spec/generic_task_impl_function.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_GENERIC_TASK_IMPL_FUNCTION_H #define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_GENERIC_TASK_IMPL_FUNCTION_H -#include "task-spec/device_specific_device_states.dtg.h" -#include "task-spec/task_argument_accessor.h" +#include "task-spec/device_specific_per_device_op_state.dtg.h" +#include "task-spec/task_argument_accessor/task_argument_accessor.h" namespace FlexFlow { diff --git a/lib/task-spec/include/task-spec/gradient_tensor_guid_t.struct.toml b/lib/task-spec/include/task-spec/gradient_tensor_guid_t.struct.toml deleted file mode 100644 index b75e27a9d2..0000000000 --- a/lib/task-spec/include/task-spec/gradient_tensor_guid_t.struct.toml +++ /dev/null @@ -1,13 +0,0 @@ -namespace = "FlexFlow" -name = "gradient_tensor_guid_t" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - - -[[fields]] -name = "raw_index" -type = "int" diff --git a/lib/task-spec/include/task-spec/gradient_tensor_source.h b/lib/task-spec/include/task-spec/gradient_tensor_source.h deleted file mode 100644 index 14ebf05d43..0000000000 --- a/lib/task-spec/include/task-spec/gradient_tensor_source.h +++ /dev/null @@ -1,22 +0,0 @@ -#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_GRADIENT_TENSOR_SOURCE_H -#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_GRADIENT_TENSOR_SOURCE_H - -#include "task-spec/gradient_tensor_guid_t.dtg.h" - -namespace FlexFlow { - -struct GradientTensorSource { -public: - GradientTensorSource(); - - gradient_tensor_guid_t new_gradient_tensor(); - - void reset(); - -private: - static int next_available_gradient_tensor_id; -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/init_op_task_impl_function.h b/lib/task-spec/include/task-spec/init_op_task_impl_function.h index 97daa7ef56..0fbdc95ac9 100644 --- a/lib/task-spec/include/task-spec/init_op_task_impl_function.h +++ b/lib/task-spec/include/task-spec/init_op_task_impl_function.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_INIT_OP_TASK_IMPL_FUNCTION_H #define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_INIT_OP_TASK_IMPL_FUNCTION_H -#include "task-spec/device_specific_device_states.dtg.h" -#include "task-spec/task_argument_accessor.h" +#include "task-spec/device_specific_per_device_op_state.dtg.h" +#include "task-spec/task_argument_accessor/task_argument_accessor.h" namespace FlexFlow { @@ -16,7 +16,7 @@ struct InitOpTaskImplFunction { bool operator>=(InitOpTaskImplFunction const &) const; public: - DeviceSpecificDeviceStates (*function_ptr)(TaskArgumentAccessor const &); + DeviceSpecificPerDeviceOpState (*function_ptr)(TaskArgumentAccessor const &); }; std::string format_as(InitOpTaskImplFunction const &x); diff --git a/lib/task-spec/include/task-spec/is_grad.dtg.toml b/lib/task-spec/include/task-spec/is_grad.dtg.toml new file mode 100644 index 0000000000..49cbe85ffd --- /dev/null +++ b/lib/task-spec/include/task-spec/is_grad.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "IsGrad" +type = "enum" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "YES" + +[[values]] +name = "NO" diff --git a/lib/task-spec/include/task-spec/is_grad.enum.toml b/lib/task-spec/include/task-spec/is_grad.enum.toml deleted file mode 100644 index f955b7749d..0000000000 --- a/lib/task-spec/include/task-spec/is_grad.enum.toml +++ /dev/null @@ -1,14 +0,0 @@ -namespace = "FlexFlow" -name = "IsGrad" -features = [ - "hash", - "json", - "rapidcheck", - "fmt", -] - -[[values]] -name = "YES" - -[[values]] -name = "NO" diff --git a/lib/task-spec/include/task-spec/is_trainable.dtg.toml b/lib/task-spec/include/task-spec/is_trainable.dtg.toml new file mode 100644 index 0000000000..eef3fed693 --- /dev/null +++ b/lib/task-spec/include/task-spec/is_trainable.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "IsTrainable" +type = "enum" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "YES" + +[[values]] +name = "NO" diff --git a/lib/task-spec/include/task-spec/is_trainable.enum.toml b/lib/task-spec/include/task-spec/is_trainable.enum.toml deleted file mode 100644 index 57ad9b6976..0000000000 --- a/lib/task-spec/include/task-spec/is_trainable.enum.toml +++ /dev/null @@ -1,14 +0,0 @@ -namespace = "FlexFlow" -name = "IsTrainable" -features = [ - "hash", - "fmt", - "rapidcheck", - "json", -] - -[[values]] -name = "YES" - -[[values]] -name = "NO" diff --git a/lib/task-spec/include/task-spec/itask_argument_accessor.h b/lib/task-spec/include/task-spec/itask_argument_accessor.h deleted file mode 100644 index 2e693e7983..0000000000 --- a/lib/task-spec/include/task-spec/itask_argument_accessor.h +++ /dev/null @@ -1,32 +0,0 @@ -#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_ITASK_ARGUMENT_ACCESSOR_H -#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_ITASK_ARGUMENT_ACCESSOR_H - -#include "kernels/allocation.h" -#include "task-spec/concrete_arg_spec.h" -#include "task-spec/op_task_signature.h" -#include "task-spec/privilege_tensor_accessor.h" -#include "task-spec/tensor_type.dtg.h" - -namespace FlexFlow { - -struct ITaskArgumentAccessor { - ITaskArgumentAccessor &operator=(ITaskArgumentAccessor const &) = delete; - - virtual ~ITaskArgumentAccessor() = default; - - virtual ConcreteArgSpec const &get_concrete_arg(slot_id_t) const = 0; - - virtual GenericTensorAccessor get_tensor(slot_id_t slot, - Permissions priv, - TensorType tensor_type) const = 0; - virtual VariadicGenericTensorAccessor get_variadic_tensor( - slot_id_t slot, Permissions priv, TensorType tensor_type) const = 0; - - virtual Allocator get_allocator() const = 0; - virtual size_t get_device_idx() const = 0; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(ITaskArgumentAccessor); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/loss_functions.h b/lib/task-spec/include/task-spec/loss_functions.h index a5f5886caa..97f8f15b26 100644 --- a/lib/task-spec/include/task-spec/loss_functions.h +++ b/lib/task-spec/include/task-spec/loss_functions.h @@ -17,21 +17,11 @@ #define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_LOSS_FUNCTIONS_H #include "op-attrs/ops/loss_functions.h" -#include "task-spec/forward_tensor_guid_t.dtg.h" -#include "task-spec/gradient_tensor_guid_t.dtg.h" -#include "task-spec/loss_tensor_guid_t.dtg.h" #include "task-spec/task_impl_function.dtg.h" -#include "task-spec/task_invocation.dtg.h" -#include "task-spec/task_signature.h" namespace FlexFlow { TaskImplFunction get_loss_bwd_task_impl(); -TaskSignature get_loss_bwd_signature(); -TaskInvocation backward(LossAttrs const &, - forward_tensor_guid_t logit, - gradient_tensor_guid_t logit_grad, - loss_tensor_guid_t label); } // namespace FlexFlow diff --git a/lib/task-spec/include/task-spec/loss_tensor_guid_t.struct.toml b/lib/task-spec/include/task-spec/loss_tensor_guid_t.struct.toml deleted file mode 100644 index c00ccbb0f2..0000000000 --- a/lib/task-spec/include/task-spec/loss_tensor_guid_t.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "loss_tensor_guid_t" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/nonnegative_int/nonnegative_int.h" -] - -[[fields]] -name = "raw_index" -type = "::FlexFlow::nonnegative_int" diff --git a/lib/task-spec/include/task-spec/loss_tensor_source.h b/lib/task-spec/include/task-spec/loss_tensor_source.h deleted file mode 100644 index 21091109e5..0000000000 --- a/lib/task-spec/include/task-spec/loss_tensor_source.h +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_LOSS_TENSOR_SOURCE_H -#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_LOSS_TENSOR_SOURCE_H - -#include "task-spec/loss_tensor_guid_t.dtg.h" -#include "utils/nonnegative_int/nonnegative_int.h" - -namespace FlexFlow { - -struct LossTensorSource { -public: - LossTensorSource(); - - loss_tensor_guid_t new_loss_tensor(); - -private: - static nonnegative_int next_available_loss_tensor_id; -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/op_arg_ref_type.dtg.toml b/lib/task-spec/include/task-spec/op_arg_ref_type.dtg.toml new file mode 100644 index 0000000000..5335a51dc9 --- /dev/null +++ b/lib/task-spec/include/task-spec/op_arg_ref_type.dtg.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "OpArgRefType" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "task-spec/per_device_op_state_ref_type.dtg.h", + "task-spec/parallel_tensor_shape_ref_type.dtg.h", +] + +[[values]] +type = "::FlexFlow::PerDeviceOpStateRefType" +key = "per_device_op_state_ref_type" + +[[values]] +type = "::FlexFlow::ParallelTensorShapeRefType" +key = "parallel_tensor_shape_ref_type" diff --git a/lib/task-spec/include/task-spec/op_arg_ref_type.variant.toml b/lib/task-spec/include/task-spec/op_arg_ref_type.variant.toml deleted file mode 100644 index e0452c6ce2..0000000000 --- a/lib/task-spec/include/task-spec/op_arg_ref_type.variant.toml +++ /dev/null @@ -1,22 +0,0 @@ -namespace = "FlexFlow" -name = "OpArgRefType" -features = [ - "eq", - "ord", - "hash", - "json", - "fmt", -] - -includes = [ - "task-spec/per_device_op_state_ref_type.dtg.h", - "task-spec/parallel_tensor_shape_ref_type.dtg.h", -] - -[[values]] -type = "::FlexFlow::PerDeviceOpStateRefType" -key = "per_device_op_state_ref_type" - -[[values]] -type = "::FlexFlow::ParallelTensorShapeRefType" -key = "parallel_tensor_shape_ref_type" diff --git a/lib/task-spec/include/task-spec/op_arg_spec.h b/lib/task-spec/include/task-spec/op_arg_spec.h deleted file mode 100644 index 1dc4efcdd1..0000000000 --- a/lib/task-spec/include/task-spec/op_arg_spec.h +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_OP_ARG_SPEC_H -#define _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_OP_ARG_SPEC_H - -#include "task-spec/op_arg_spec.dtg.h" - -namespace FlexFlow { - -std::type_index get_op_arg_spec_type_index(OpArgSpec const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/op_arg_spec.variant.toml b/lib/task-spec/include/task-spec/op_arg_spec.variant.toml deleted file mode 100644 index a03bc222e8..0000000000 --- a/lib/task-spec/include/task-spec/op_arg_spec.variant.toml +++ /dev/null @@ -1,28 +0,0 @@ -namespace = "FlexFlow" -name = "OpArgSpec" -features = [ - "eq", - # "ord", - # "hash", - # "json", - # "fmt", - # "rapidcheck", -] - -includes = [ - "task-spec/concrete_arg_spec.h", - "task-spec/op_arg_ref.h", - "task-spec/runtime_arg_ref.h", -] - -[[values]] -type = "::FlexFlow::ConcreteArgSpec" -key = "concrete_arg" - -[[values]] -type = "::FlexFlow::OpArgRefSpec" -key = "op_arg_ref" - -[[values]] -type = "::FlexFlow::RuntimeArgRefSpec" -key = "runtime_arg_ref" diff --git a/lib/task-spec/include/task-spec/op_slot_options.enum.toml b/lib/task-spec/include/task-spec/op_slot_options.enum.toml deleted file mode 100644 index 69867d3236..0000000000 --- a/lib/task-spec/include/task-spec/op_slot_options.enum.toml +++ /dev/null @@ -1,20 +0,0 @@ -namespace = "FlexFlow" -name = "OpSlotOptions" -features = [ - "hash", - "fmt", - "rapidcheck", - "json", -] - -[[values]] -name = "OPTIONAL" - -[[values]] -name = "UNTRAINABLE" - -[[values]] -name = "OPTIONAL_UNTRAINABLE" - -[[values]] -name = "NECESSARY" diff --git a/lib/task-spec/include/task-spec/op_task_binding.h b/lib/task-spec/include/task-spec/op_task_binding.h deleted file mode 100644 index bcfea33877..0000000000 --- a/lib/task-spec/include/task-spec/op_task_binding.h +++ /dev/null @@ -1,97 +0,0 @@ -#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OP_TASK_BINDING_H -#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OP_TASK_BINDING_H - -#include "task-spec/op_arg_ref.h" -#include "task-spec/op_arg_spec.dtg.h" -#include "task-spec/op_tensor_spec.h" -#include "task-spec/slot_grad_id.dtg.h" -#include "task-spec/slot_id_t.dtg.h" -#include "task-spec/variadic_tensor_ref.h" - -namespace FlexFlow { - -struct OpTaskBinding { - OpTaskBinding() = default; - - void bind(int, VariadicTensorRef const &); - void bind(slot_id_t, VariadicTensorRef const &); - - void bind(int, OpTensorSpec const &); - void bind(slot_id_t, OpTensorSpec const &); - - void bind_grad(int, OpTensorSpec const &); - void bind_grad(slot_id_t, OpTensorSpec const &); - - template - void bind_device_specific_arg(int name, T const &t) { - this->bind_device_specific_arg(slot_id_t{name}, t); - } - - template - void bind_device_specific_arg(slot_id_t name, T const &t) { - NOT_IMPLEMENTED(); - } - - template - void bind_device_specific_arg(int name, OpArgRef const &t) { - this->bind_device_specific_arg(slot_id_t{name}, t); - } - - template - void bind_device_specific_arg(slot_id_t name, OpArgRef const &t) { - NOT_IMPLEMENTED(); - } - - template - void bind_arg(int name, T const &t) { - this->bind_arg(slot_id_t{name}, t); - } - - template - void bind_arg(slot_id_t name, T const &t) { - this->insert_arg_spec(name, OpArgSpec{ConcreteArgSpec::create(t)}); - } - - template - void bind_arg(int name, RuntimeArgRef const &t) { - this->bind_arg(slot_id_t{name}, t); - } - - template - void bind_arg(slot_id_t name, RuntimeArgRef const &ref) { - this->insert_arg_spec(name, OpArgSpec{RuntimeArgRefSpec::create(ref)}); - } - - template - void bind_arg(int name, OpArgRef const &t) { - this->bind_arg(slot_id_t{name}, t); - } - - template - void bind_arg(slot_id_t name, OpArgRef const &ref) { - this->insert_arg_spec(name, OpArgSpec{OpArgRefSpec::create(ref)}); - } - bool operator==(OpTaskBinding const &other) const; - bool operator!=(OpTaskBinding const &other) const; - - std::unordered_map const & - get_tensor_bindings() const; - std::unordered_map const &get_arg_bindings() const; - - void bind_from_forward(OpTaskBinding const &fwd); - -private: - std::unordered_map tensor_bindings; - std::unordered_map arg_bindings; - -private: - void insert_arg_spec(slot_id_t name, OpArgSpec const &arg_spec); - std::tuple - tie() const; -}; - -OpTaskBinding infer_bwd_binding(OpTaskBinding const &fwd); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/op_task_invocation.h b/lib/task-spec/include/task-spec/op_task_invocation.h deleted file mode 100644 index 88e9e9bf26..0000000000 --- a/lib/task-spec/include/task-spec/op_task_invocation.h +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef _FLEXFLOW_LOCAL_EXECUTION_OP_TASK_INVOCATION_H -#define _FLEXFLOW_LOCAL_EXECUTION_OP_TASK_INVOCATION_H - -#include "task-spec/op_task_invocation.dtg.h" -#include "task-spec/op_task_signature.h" - -namespace FlexFlow { - -bool is_invocation_valid(OpTaskSignature const &sig, - OpTaskInvocation const &inv); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/op_task_invocation.struct.toml b/lib/task-spec/include/task-spec/op_task_invocation.struct.toml deleted file mode 100644 index 465fa5f1ff..0000000000 --- a/lib/task-spec/include/task-spec/op_task_invocation.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "OpTaskInvocation" -features = [] - -includes = [ - "task-spec/op_task_binding.h", - "task-spec/task_id_t.dtg.h", -] - -[[fields]] -name = "task_id" -type = "::FlexFlow::task_id_t" - -[[fields]] -name = "binding" -type = "::FlexFlow::OpTaskBinding" diff --git a/lib/task-spec/include/task-spec/op_task_signature.h b/lib/task-spec/include/task-spec/op_task_signature.h deleted file mode 100644 index eba0023906..0000000000 --- a/lib/task-spec/include/task-spec/op_task_signature.h +++ /dev/null @@ -1,107 +0,0 @@ -#ifndef _FLEXFLOW_LOCAL_EXECUTION_OP_TASK_SIGNATURE_H -#define _FLEXFLOW_LOCAL_EXECUTION_OP_TASK_SIGNATURE_H - -#include "task-spec/is_grad.dtg.h" -#include "task-spec/op_task_type.dtg.h" -#include "task-spec/op_tensor_slot_spec.dtg.h" -#include "task-spec/serialization.h" -#include "task-spec/slot_id_t.dtg.h" -#include "task-spec/slot_type.dtg.h" -#include "task-spec/task_id_t.dtg.h" -#include "utils/hash/unordered_map.h" -#include "utils/hash/unordered_set.h" -#include "utils/type_index.h" -#include "utils/visitable.h" - -namespace FlexFlow { - -struct OpTaskSignature { - OpTaskSignature() = delete; - explicit OpTaskSignature(OpTaskType); - - OpTaskType get_task_type() const { - return this->type; - } - - void add_input_slot(int, SlotType slot_type = SlotType::TENSOR); - void add_input_slot(slot_id_t, SlotType slot_type = SlotType::TENSOR); - - void add_optional_input_slot(int, SlotType slot_type = SlotType::TENSOR); - void add_optional_input_slot(slot_id_t, - SlotType slot_type = SlotType::TENSOR); - - void add_untrainable_input_slot(int, SlotType slot_type = SlotType::TENSOR); - void add_untrainable_input_slot(slot_id_t, - SlotType slot_type = SlotType::TENSOR); - - void add_optional_untrainable_input_slot( - int, SlotType slot_type = SlotType::TENSOR); - void add_optional_untrainable_input_slot( - slot_id_t, SlotType slot_type = SlotType::TENSOR); - - void add_output_slot(int, SlotType slot_type = SlotType::TENSOR); - void add_output_slot(slot_id_t, SlotType slot_type = SlotType::TENSOR); - - void add_bwd_optional_output_slot(int, SlotType slot_type = SlotType::TENSOR); - void add_bwd_optional_output_slot(slot_id_t, - SlotType slot_type = SlotType::TENSOR); - - void add_weight_slot(int, SlotType slot_type = SlotType::TENSOR); - void add_weight_slot(slot_id_t, SlotType slot_type = SlotType::TENSOR); - - void add_optional_weight_slot(int, SlotType slot_type = SlotType::TENSOR); - void add_optional_weight_slot(slot_id_t, - SlotType slot_type = SlotType::TENSOR); - - void add_from_slot_spec(OpTensorSlotSpec const &spec); - - template - void add_arg_slot(int name) { - this->add_arg_slot(slot_id_t{name}); - } - - template - void add_arg_slot(slot_id_t name) { - // static_assert(is_serializable::value, "Type must be serializable"); - this->task_arg_types.insert({name, get_type_index_for_type()}); - } - - template - void add_return_value() { - this->return_value = get_type_index_for_type(); - } - - // adds arg_slot without checking is_serializable, used for arguments that are - // deviceSpecific - template - void add_unchecked_arg_slot(int name) { - this->add_unchecked_arg_slot(slot_id_t{name}); - } - - // adds arg_slot without checking is_serializable, used for arguments that are - // deviceSpecific - template - void add_unchecked_arg_slot(slot_id_t name) { - this->task_arg_types.insert({name, get_type_index_for_type()}); - } - - std::unordered_set get_tensor_slots() const; - void set_arg_types(std::unordered_map const &); - std::unordered_map get_arg_types() const; - - OpTaskType type; - std::optional return_value; - std::unordered_map task_arg_types; - std::unordered_set op_tensor_slots; -}; -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION( - OpTaskSignature, type, return_value, task_arg_types, op_tensor_slots); - -std::string format_as(OpTaskSignature const &x); -std::ostream &operator<<(std::ostream &s, OpTaskSignature const &x); - -OpTaskSignature infer_bwd_signature(OpTaskSignature const &fwd); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/op_task_to_task_invocation.h b/lib/task-spec/include/task-spec/op_task_to_task_invocation.h deleted file mode 100644 index 3208e9d049..0000000000 --- a/lib/task-spec/include/task-spec/op_task_to_task_invocation.h +++ /dev/null @@ -1,45 +0,0 @@ -#ifndef _FLEXFLOW_LOCAL_EXECUTION_OP_TASK_TO_TASK_INVOCATION_H -#define _FLEXFLOW_LOCAL_EXECUTION_OP_TASK_TO_TASK_INVOCATION_H - -#include "pcg/cg_operator_tensor_shape_signature.dtg.h" -#include "pcg/computation_graph.dtg.h" -#include "pcg/layer_guid_t.dtg.h" -#include "task-spec/device_specific_device_states.dtg.h" -#include "task-spec/op_task_invocation.h" -#include "task-spec/runtime_arg_config.dtg.h" -#include "task-spec/task_invocation.dtg.h" -#include "task-spec/training_layer_plus_context.dtg.h" -#include "task-spec/training_layer_tensor_group_signature.dtg.h" - -namespace FlexFlow { - -TaskInvocation - lower_to_task_invocation(OpTaskInvocation const &op_task_invocation, - TrainingLayerPlusContext const &training_layer, - std::optional const - &device_specific_device_states); - -std::pair lower_tensor_binding( - TrainingLayerTensorGroupSignature const &training_layer_signature, - SlotGradId const &slot_grad_id, - OpTensorSpec const &op_tensor_spec); - -TaskArgSpec lower_to_task_arg_spec( - OpArgSpec const &op_arg_spec, - CGOperatorTensorShapeSignature const &op_shape_signature, - layer_guid_t const &layer_guid, - std::optional const - &device_specific_device_states); - -ConcreteArgSpec lower_to_concrete_arg_spec(RuntimeArgRefSpec const &, - RuntimeArgConfig const &); - -ConcreteArgSpec lower_to_concrete_arg_spec( - OpArgRefSpec const &, - CGOperatorTensorShapeSignature const &, - layer_guid_t const &, - std::optional const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/op_task_type.enum.toml b/lib/task-spec/include/task-spec/op_task_type.enum.toml deleted file mode 100644 index c336476f50..0000000000 --- a/lib/task-spec/include/task-spec/op_task_type.enum.toml +++ /dev/null @@ -1,17 +0,0 @@ -namespace = "FlexFlow" -name = "OpTaskType" -features = [ - "hash", - "fmt", - "rapidcheck", - "json", -] - -[[values]] -name = "INIT" - -[[values]] -name = "FWD" - -[[values]] -name = "BWD" diff --git a/lib/task-spec/include/task-spec/op_tensor_slot_spec.struct.toml b/lib/task-spec/include/task-spec/op_tensor_slot_spec.struct.toml deleted file mode 100644 index 3a388b8559..0000000000 --- a/lib/task-spec/include/task-spec/op_tensor_slot_spec.struct.toml +++ /dev/null @@ -1,36 +0,0 @@ -namespace = "FlexFlow" -name = "OpTensorSlotSpec" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "task-spec/slot_id_t.dtg.h", - "task-spec/slot_type.dtg.h", - "pcg/tensor_role.dtg.h", - "task-spec/is_grad.dtg.h", - "task-spec/op_slot_options.dtg.h", -] - -[[fields]] -name = "name" -type = "::FlexFlow::slot_id_t" - -[[fields]] -name = "slot_type" -type = "::FlexFlow::SlotType" - -[[fields]] -name = "tensor_role" -type = "::FlexFlow::TensorRole" - -[[fields]] -name = "is_grad" -type = "::FlexFlow::IsGrad" - -[[fields]] -name = "slot_option" -type = "::FlexFlow::OpSlotOptions" diff --git a/lib/task-spec/include/task-spec/op_tensor_spec.h b/lib/task-spec/include/task-spec/op_tensor_spec.h deleted file mode 100644 index 6f00a2e38d..0000000000 --- a/lib/task-spec/include/task-spec/op_tensor_spec.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef _FLEXFLOW_LOCAL_EXECUTION_OP_TENSOR_SPEC_REF_H -#define _FLEXFLOW_LOCAL_EXECUTION_OP_TENSOR_SPEC_REF_H - -#include "task-spec/op_tensor_spec.dtg.h" - -namespace FlexFlow { - -OpTensorSpec input_tensor(nonnegative_int idx, - OpSlotOptions option = OpSlotOptions::NECESSARY); -OpTensorSpec output_tensor(nonnegative_int idx, - OpSlotOptions option = OpSlotOptions::NECESSARY); -OpTensorSpec weight_tensor(nonnegative_int idx, - OpSlotOptions option = OpSlotOptions::NECESSARY); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/op_tensor_spec.struct.toml b/lib/task-spec/include/task-spec/op_tensor_spec.struct.toml deleted file mode 100644 index 3e790c7e08..0000000000 --- a/lib/task-spec/include/task-spec/op_tensor_spec.struct.toml +++ /dev/null @@ -1,28 +0,0 @@ -namespace = "FlexFlow" -name = "OpTensorSpec" -features = [ - "eq", - "ord", - "hash", - "json", - "fmt", - "rapidcheck", -] - -includes = [ - "pcg/tensor_role.dtg.h", - "task-spec/op_slot_options.dtg.h", - "utils/nonnegative_int/nonnegative_int.h", -] - -[[fields]] -name = "role" -type = "::FlexFlow::TensorRole" - -[[fields]] -name = "slot_option" -type = "::FlexFlow::OpSlotOptions" - -[[fields]] -name = "idx" -type = "::FlexFlow::nonnegative_int" diff --git a/lib/task-spec/include/task-spec/op_training_tensor_type.dtg.toml b/lib/task-spec/include/task-spec/op_training_tensor_type.dtg.toml new file mode 100644 index 0000000000..9e7af2dbfd --- /dev/null +++ b/lib/task-spec/include/task-spec/op_training_tensor_type.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "OpTrainingTensorType" +type = "enum" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "FORWARD" + +[[values]] +name = "GRADIENT" + +[[values]] +name = "OPTIMIZER" diff --git a/lib/task-spec/include/task-spec/ops/arg_slot_id_t.dtg.toml b/lib/task-spec/include/task-spec/ops/arg_slot_id_t.dtg.toml new file mode 100644 index 0000000000..8dc08aa3f5 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/arg_slot_id_t.dtg.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "arg_slot_id_t" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +[[fields]] +name = "raw_id" +type = "int" diff --git a/lib/task-spec/include/task-spec/ops/attention.h b/lib/task-spec/include/task-spec/ops/attention.h deleted file mode 100644 index a8a444c9bf..0000000000 --- a/lib/task-spec/include/task-spec/ops/attention.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_ATTENTION_H -#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_ATTENTION_H - -#include "op-attrs/ops/attention.h" -#include "task-spec/op_task_invocation.h" -#include "task-spec/task_impl_function.dtg.h" - -namespace FlexFlow { - -std::vector get_task_ids(MultiHeadAttentionAttrs const &); - -TaskImplFunction get_attention_init_task_impl(); -TaskImplFunction get_attention_fwd_task_impl(); -TaskImplFunction get_attention_bwd_task_impl(); - -OpTaskSignature get_attention_init_signature(); -OpTaskSignature get_attention_fwd_signature(); -OpTaskSignature get_attention_bwd_signature(); - -OpTaskInvocation init(MultiHeadAttentionAttrs const &); -OpTaskInvocation forward(MultiHeadAttentionAttrs const &); -OpTaskInvocation backward(MultiHeadAttentionAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/ops/batch_matmul.h b/lib/task-spec/include/task-spec/ops/batch_matmul.h deleted file mode 100644 index a50d1889e1..0000000000 --- a/lib/task-spec/include/task-spec/ops/batch_matmul.h +++ /dev/null @@ -1,24 +0,0 @@ -#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_BATCH_MATMUL_H -#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_BATCH_MATMUL_H - -#include "op-attrs/ops/batch_matmul_attrs.dtg.h" -#include "task-spec/op_task_invocation.h" -#include "task-spec/op_task_signature.h" -#include "task-spec/task_impl_function.dtg.h" - -namespace FlexFlow { - -std::vector get_task_ids(BatchMatmulAttrs const &); - -TaskImplFunction get_batch_matmul_fwd_task_impl(); -TaskImplFunction get_batch_matmul_bwd_task_impl(); - -OpTaskSignature get_batch_matmul_fwd_signature(); -OpTaskSignature get_batch_matmul_bwd_signature(); - -OpTaskInvocation forward(BatchMatmulAttrs const &); -OpTaskInvocation backward(BatchMatmulAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/ops/batch_norm.h b/lib/task-spec/include/task-spec/ops/batch_norm.h deleted file mode 100644 index bab6a4404a..0000000000 --- a/lib/task-spec/include/task-spec/ops/batch_norm.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef _FLEXFLOW_BATCH_NORM_H -#define _FLEXFLOW_BATCH_NORM_H - -#include "op-attrs/ops/batch_norm_attrs.dtg.h" -#include "task-spec/op_task_invocation.h" -#include "task-spec/task_impl_function.dtg.h" - -namespace FlexFlow { - -std::vector get_task_ids(BatchNormAttrs const &); - -TaskImplFunction get_batch_norm_init_task_impl(); -TaskImplFunction get_batch_norm_fwd_task_impl(); -TaskImplFunction get_batch_norm_bwd_task_impl(); - -OpTaskSignature get_batch_norm_init_signature(); -OpTaskSignature get_batch_norm_fwd_signature(); -OpTaskSignature get_batch_norm_bwd_signature(); - -OpTaskInvocation init(BatchNormAttrs const &); -OpTaskInvocation forward(BatchNormAttrs const &); -OpTaskInvocation backward(BatchNormAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/ops/cast.h b/lib/task-spec/include/task-spec/ops/cast.h deleted file mode 100644 index dadc8f8c74..0000000000 --- a/lib/task-spec/include/task-spec/ops/cast.h +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef _FLEXFLOW_CAST_H -#define _FLEXFLOW_CAST_H - -#include "op-attrs/ops/cast_attrs.dtg.h" -#include "task-spec/op_task_invocation.h" -#include "task-spec/task_impl_function.dtg.h" - -namespace FlexFlow { - -std::vector get_task_ids(CastAttrs const &); - -TaskImplFunction get_cast_fwd_task_impl(); -TaskImplFunction get_cast_bwd_task_impl(); - -OpTaskSignature get_cast_fwd_signature(); -OpTaskSignature get_cast_bwd_signature(); - -OpTaskInvocation forward(CastAttrs const &); -OpTaskInvocation backward(CastAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/ops/concat.h b/lib/task-spec/include/task-spec/ops/concat.h deleted file mode 100644 index 4e7cfef629..0000000000 --- a/lib/task-spec/include/task-spec/ops/concat.h +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef _FLEXFLOW_CONCAT_H -#define _FLEXFLOW_CONCAT_H - -#include "op-attrs/ops/concat_attrs.dtg.h" -#include "task-spec/op_task_invocation.h" -#include "task-spec/task_impl_function.dtg.h" - -namespace FlexFlow { - -std::vector get_task_ids(ConcatAttrs const &); - -TaskImplFunction get_concat_fwd_task_impl(); -TaskImplFunction get_concat_bwd_task_impl(); - -OpTaskSignature get_concat_fwd_signature(); -OpTaskSignature get_concat_bwd_signature(); - -OpTaskInvocation forward(ConcatAttrs const &); -OpTaskInvocation backward(ConcatAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/ops/conv_2d.h b/lib/task-spec/include/task-spec/ops/conv_2d.h deleted file mode 100644 index 1efb165d55..0000000000 --- a/lib/task-spec/include/task-spec/ops/conv_2d.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef _FLEXFLOW_CONV_2D_H -#define _FLEXFLOW_CONV_2D_H - -#include "op-attrs/ops/conv_2d_attrs.dtg.h" -#include "task-spec/op_task_invocation.h" -#include "task-spec/task_impl_function.dtg.h" - -namespace FlexFlow { - -std::vector get_task_ids(Conv2DAttrs const &); - -TaskImplFunction get_conv_2d_init_task_impl(); -TaskImplFunction get_conv_2d_fwd_task_impl(); -TaskImplFunction get_conv_2d_bwd_task_impl(); - -OpTaskSignature get_conv_2d_init_signature(); -OpTaskSignature get_conv_2d_fwd_signature(); -OpTaskSignature get_conv_2d_bwd_signature(); - -OpTaskInvocation init(Conv2DAttrs const &); -OpTaskInvocation forward(Conv2DAttrs const &); -OpTaskInvocation backward(Conv2DAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/ops/dropout.h b/lib/task-spec/include/task-spec/ops/dropout.h deleted file mode 100644 index 931e3e591e..0000000000 --- a/lib/task-spec/include/task-spec/ops/dropout.h +++ /dev/null @@ -1,27 +0,0 @@ -#ifndef _FLEXFLOW_DROPOUT_H -#define _FLEXFLOW_DROPOUT_H - -#include "op-attrs/ops/dropout_attrs.dtg.h" -#include "task-spec/op_task_invocation.h" -#include "task-spec/task_id_t.dtg.h" -#include "task-spec/task_impl_function.dtg.h" - -namespace FlexFlow { - -std::vector get_task_ids(DropoutAttrs const &); - -TaskImplFunction get_dropout_init_task_impl(); -TaskImplFunction get_dropout_fwd_task_impl(); -TaskImplFunction get_dropout_bwd_task_impl(); - -OpTaskSignature get_dropout_init_signature(); -OpTaskSignature get_dropout_fwd_signature(); -OpTaskSignature get_dropout_bwd_signature(); - -OpTaskInvocation init(DropoutAttrs const &); -OpTaskInvocation forward(DropoutAttrs const &); -OpTaskInvocation backward(DropoutAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/ops/element_binary.h b/lib/task-spec/include/task-spec/ops/element_binary.h deleted file mode 100644 index 2bd8c5dde7..0000000000 --- a/lib/task-spec/include/task-spec/ops/element_binary.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_ELEMENT_BINARY_H -#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_ELEMENT_BINARY_H - -#include "op-attrs/ops/element_binary_attrs.dtg.h" -#include "task-spec/task_impl_function.dtg.h" -#include "task-spec/task_signature_impl.h" - -namespace FlexFlow { - -std::vector get_task_ids(ElementBinaryAttrs const &); - -OpTaskInvocation init(ElementBinaryAttrs const &); -OpTaskInvocation forward(ElementBinaryAttrs const &); -OpTaskInvocation backward(ElementBinaryAttrs const &); - -TaskImplFunction get_element_binary_init_task_impl(); -TaskImplFunction get_element_binary_fwd_task_impl(); -TaskImplFunction get_element_binary_bwd_task_impl(); - -OpTaskSignature get_element_binary_init_signature(); -OpTaskSignature get_element_binary_fwd_signature(); -OpTaskSignature get_element_binary_bwd_signature(); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/ops/element_unary.h b/lib/task-spec/include/task-spec/ops/element_unary.h deleted file mode 100644 index 5c88871ee7..0000000000 --- a/lib/task-spec/include/task-spec/ops/element_unary.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef _ELEMENT_UNARY_H -#define _ELEMENT_UNARY_H - -#include "op-attrs/ops/element_unary_attrs.dtg.h" -#include "task-spec/op_task_invocation.h" -#include "task-spec/task_impl_function.dtg.h" - -namespace FlexFlow { - -std::vector get_task_ids(ElementUnaryAttrs const &); - -TaskImplFunction get_element_unary_init_task_impl(); -TaskImplFunction get_element_unary_fwd_task_impl(); -TaskImplFunction get_element_unary_bwd_task_impl(); - -OpTaskSignature get_element_unary_init_signature(); -OpTaskSignature get_element_unary_fwd_signature(); -OpTaskSignature get_element_unary_bwd_signature(); - -OpTaskInvocation init(ElementUnaryAttrs const &); -OpTaskInvocation forward(ElementUnaryAttrs const &); -OpTaskInvocation backward(ElementUnaryAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/ops/embedding.h b/lib/task-spec/include/task-spec/ops/embedding.h deleted file mode 100644 index 27ade01cfa..0000000000 --- a/lib/task-spec/include/task-spec/ops/embedding.h +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef _FLEXFLOW_EMBEDDING_H -#define _FLEXFLOW_EMBEDDING_H - -#include "op-attrs/ops/embedding_attrs.dtg.h" -#include "task-spec/op_task_invocation.h" -#include "task-spec/task_impl_function.dtg.h" - -namespace FlexFlow { - -std::vector get_task_ids(EmbeddingAttrs const &); - -TaskImplFunction get_embedding_fwd_task_impl(); -TaskImplFunction get_embedding_bwd_task_impl(); - -OpTaskSignature get_embedding_fwd_signature(); -OpTaskSignature get_embedding_bwd_signature(); - -OpTaskInvocation forward(EmbeddingAttrs const &); -OpTaskInvocation backward(EmbeddingAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/ops/flat.h b/lib/task-spec/include/task-spec/ops/flat.h deleted file mode 100644 index 3a02965d3b..0000000000 --- a/lib/task-spec/include/task-spec/ops/flat.h +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef _FLEXFLOW_FLAT_H -#define _FLEXFLOW_FLAT_H - -#include "op-attrs/ops/flat_attrs.dtg.h" -#include "task-spec/op_task_invocation.h" -#include "task-spec/task_impl_function.dtg.h" - -namespace FlexFlow { - -std::vector get_task_ids(FlatAttrs const &); - -TaskImplFunction get_flat_fwd_task_impl(); -TaskImplFunction get_flat_bwd_task_impl(); - -OpTaskSignature get_flat_fwd_signature(); -OpTaskSignature get_flat_bwd_signature(); - -OpTaskInvocation forward(FlatAttrs const &); -OpTaskInvocation backward(FlatAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/ops/gather.h b/lib/task-spec/include/task-spec/ops/gather.h deleted file mode 100644 index f800173f20..0000000000 --- a/lib/task-spec/include/task-spec/ops/gather.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef _FLEXFLOW_GATHER_H -#define _FLEXFLOW_GATHER_H - -#include "op-attrs/ops/gather_attrs.dtg.h" -#include "task-spec/op_task_invocation.h" -#include "task-spec/task_impl_function.dtg.h" - -namespace FlexFlow { - -std::vector get_task_ids(GatherAttrs const &); - -TaskImplFunction get_gather_init_task_impl(); -TaskImplFunction get_gather_fwd_task_impl(); -TaskImplFunction get_gather_bwd_task_impl(); - -OpTaskSignature get_gather_init_signature(); -OpTaskSignature get_gather_fwd_signature(); -OpTaskSignature get_gather_bwd_signature(); - -OpTaskInvocation init(GatherAttrs const &); -OpTaskInvocation forward(GatherAttrs const &); -OpTaskInvocation backward(GatherAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/ops/impl/attention.h b/lib/task-spec/include/task-spec/ops/impl/attention.h new file mode 100644 index 0000000000..ae951dda58 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/impl/attention.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_ATTENTION_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_ATTENTION_H + +#include "op-attrs/ops/attention.h" +#include "task-spec/task_impl_function.dtg.h" + +namespace FlexFlow { + +TaskImplFunction get_attention_init_task_impl(); +TaskImplFunction get_attention_fwd_task_impl(); +TaskImplFunction get_attention_bwd_task_impl(); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ops/impl/batch_matmul.h b/lib/task-spec/include/task-spec/ops/impl/batch_matmul.h new file mode 100644 index 0000000000..6184e194f2 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/impl/batch_matmul.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_BATCH_MATMUL_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_BATCH_MATMUL_H + +#include "op-attrs/ops/batch_matmul_attrs.dtg.h" +#include "task-spec/task_impl_function.dtg.h" + +namespace FlexFlow { + +TaskImplFunction get_batch_matmul_fwd_task_impl(); +TaskImplFunction get_batch_matmul_bwd_task_impl(); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ops/impl/batch_norm.h b/lib/task-spec/include/task-spec/ops/impl/batch_norm.h new file mode 100644 index 0000000000..791d3edf34 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/impl/batch_norm.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_BATCH_NORM_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_BATCH_NORM_H + +#include "op-attrs/ops/batch_norm_attrs.dtg.h" +#include "task-spec/task_impl_function.dtg.h" + +namespace FlexFlow { + +TaskImplFunction get_batch_norm_init_task_impl(); +TaskImplFunction get_batch_norm_fwd_task_impl(); +TaskImplFunction get_batch_norm_bwd_task_impl(); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ops/impl/broadcast.h b/lib/task-spec/include/task-spec/ops/impl/broadcast.h new file mode 100644 index 0000000000..35bf73c665 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/impl/broadcast.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_BROADCAST_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_BROADCAST_H + +#include "op-attrs/ops/broadcast_attrs.dtg.h" +#include "task-spec/task_impl_function.dtg.h" + +namespace FlexFlow { + +TaskImplFunction get_broadcast_fwd_task_impl(); +TaskImplFunction get_broadcast_bwd_task_impl(); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ops/impl/cast.h b/lib/task-spec/include/task-spec/ops/impl/cast.h new file mode 100644 index 0000000000..c31a92a923 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/impl/cast.h @@ -0,0 +1,28 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_CAST_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_CAST_H + +#include "op-attrs/ops/cast_attrs.dtg.h" +#include "task-spec/task_impl_function.dtg.h" + +namespace FlexFlow { + +TaskImplFunction get_cast_fwd_task_impl(); +TaskImplFunction get_cast_bwd_task_impl(); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ops/impl/concat.h b/lib/task-spec/include/task-spec/ops/impl/concat.h new file mode 100644 index 0000000000..26cbc49956 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/impl/concat.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_CONCAT_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_CONCAT_H + +#include "op-attrs/ops/concat_attrs.dtg.h" +#include "task-spec/task_impl_function.dtg.h" + +namespace FlexFlow { + +TaskImplFunction get_concat_fwd_task_impl(); +TaskImplFunction get_concat_bwd_task_impl(); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ops/impl/conv_2d.h b/lib/task-spec/include/task-spec/ops/impl/conv_2d.h new file mode 100644 index 0000000000..34b3863fd8 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/impl/conv_2d.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_CONV_2D_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_CONV_2D_H + +#include "op-attrs/ops/conv_2d_attrs.dtg.h" +#include "task-spec/task_impl_function.dtg.h" + +namespace FlexFlow { + +TaskImplFunction get_conv_2d_init_task_impl(); +TaskImplFunction get_conv_2d_fwd_task_impl(); +TaskImplFunction get_conv_2d_bwd_task_impl(); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ops/impl/dropout.h b/lib/task-spec/include/task-spec/ops/impl/dropout.h new file mode 100644 index 0000000000..a7b382ce62 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/impl/dropout.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_DROPOUT_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_DROPOUT_H + +#include "op-attrs/ops/dropout_attrs.dtg.h" +#include "task-spec/task_id_t.dtg.h" +#include "task-spec/task_impl_function.dtg.h" + +namespace FlexFlow { + +TaskImplFunction get_dropout_init_task_impl(); +TaskImplFunction get_dropout_fwd_task_impl(); +TaskImplFunction get_dropout_bwd_task_impl(); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ops/impl/element_binary.h b/lib/task-spec/include/task-spec/ops/impl/element_binary.h new file mode 100644 index 0000000000..8808656085 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/impl/element_binary.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_ELEMENT_BINARY_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_ELEMENT_BINARY_H + +#include "op-attrs/ops/element_binary_attrs.dtg.h" +#include "task-spec/task_impl_function.dtg.h" + +namespace FlexFlow { + +TaskImplFunction get_element_binary_init_task_impl(); +TaskImplFunction get_element_binary_fwd_task_impl(); +TaskImplFunction get_element_binary_bwd_task_impl(); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ops/impl/element_unary.h b/lib/task-spec/include/task-spec/ops/impl/element_unary.h new file mode 100644 index 0000000000..9d2bcdfa76 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/impl/element_unary.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_ELEMENT_UNARY_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_ELEMENT_UNARY_H + +#include "op-attrs/ops/element_unary_attrs.dtg.h" +#include "task-spec/task_impl_function.dtg.h" + +namespace FlexFlow { + +TaskImplFunction get_element_unary_init_task_impl(); +TaskImplFunction get_element_unary_fwd_task_impl(); +TaskImplFunction get_element_unary_bwd_task_impl(); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ops/impl/embedding.h b/lib/task-spec/include/task-spec/ops/impl/embedding.h new file mode 100644 index 0000000000..daf105d298 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/impl/embedding.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_EMBEDDING_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_EMBEDDING_H + +#include "op-attrs/ops/embedding_attrs.dtg.h" +#include "task-spec/task_impl_function.dtg.h" + +namespace FlexFlow { + +TaskImplFunction get_embedding_fwd_task_impl(); +TaskImplFunction get_embedding_bwd_task_impl(); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ops/impl/flat.h b/lib/task-spec/include/task-spec/ops/impl/flat.h new file mode 100644 index 0000000000..61a51f90f6 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/impl/flat.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_FLAT_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_FLAT_H + +#include "op-attrs/ops/flat_attrs.dtg.h" +#include "task-spec/task_impl_function.dtg.h" + +namespace FlexFlow { + +TaskImplFunction get_flat_fwd_task_impl(); +TaskImplFunction get_flat_bwd_task_impl(); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ops/impl/gather.h b/lib/task-spec/include/task-spec/ops/impl/gather.h new file mode 100644 index 0000000000..9957e91298 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/impl/gather.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_GATHER_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_GATHER_H + +#include "op-attrs/ops/gather_attrs.dtg.h" +#include "task-spec/task_impl_function.dtg.h" + +namespace FlexFlow { + +TaskImplFunction get_gather_init_task_impl(); +TaskImplFunction get_gather_fwd_task_impl(); +TaskImplFunction get_gather_bwd_task_impl(); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ops/impl/layer_norm.h b/lib/task-spec/include/task-spec/ops/impl/layer_norm.h new file mode 100644 index 0000000000..6ac835d693 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/impl/layer_norm.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_LAYER_NORM_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_LAYER_NORM_H + +#include "op-attrs/ops/layer_norm_attrs.dtg.h" +#include "task-spec/task_impl_function.dtg.h" + +namespace FlexFlow { + +TaskImplFunction get_layer_norm_init_task_impl(); +TaskImplFunction get_layer_norm_fwd_task_impl(); +TaskImplFunction get_layer_norm_bwd_task_impl(); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ops/impl/linear.h b/lib/task-spec/include/task-spec/ops/impl/linear.h new file mode 100644 index 0000000000..12d265697a --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/impl/linear.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_LINEAR_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_LINEAR_H + +#include "op-attrs/ops/linear_attrs.dtg.h" +#include "task-spec/task_impl_function.dtg.h" + +namespace FlexFlow { + +TaskImplFunction get_linear_init_task_impl(); +TaskImplFunction get_linear_fwd_task_impl(); +TaskImplFunction get_linear_bwd_task_impl(); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ops/impl/noop.h b/lib/task-spec/include/task-spec/ops/impl/noop.h new file mode 100644 index 0000000000..e7af654537 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/impl/noop.h @@ -0,0 +1,8 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_NOOP_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_NOOP_H + +#include "op-attrs/ops/noop_attrs.dtg.h" + +namespace FlexFlow {} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ops/parallel_op.h b/lib/task-spec/include/task-spec/ops/impl/parallel_op.h similarity index 89% rename from lib/task-spec/include/task-spec/ops/parallel_op.h rename to lib/task-spec/include/task-spec/ops/impl/parallel_op.h index e7bd98b8a8..7061821b62 100644 --- a/lib/task-spec/include/task-spec/ops/parallel_op.h +++ b/lib/task-spec/include/task-spec/ops/impl/parallel_op.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_PARALLEL_OP_H -#define _FLEXFLOW_PARALLEL_OP_H +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_PARALLEL_OP_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_PARALLEL_OP_H #include "parallel_op_info.h" #include "utils/optional.h" diff --git a/lib/task-spec/include/task-spec/ops/impl/pool_2d.h b/lib/task-spec/include/task-spec/ops/impl/pool_2d.h new file mode 100644 index 0000000000..27d8783377 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/impl/pool_2d.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_POOL_2D_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_POOL_2D_H + +#include "op-attrs/ops/pool_2d_attrs.dtg.h" +#include "task-spec/task_impl_function.dtg.h" + +namespace FlexFlow { + +TaskImplFunction get_pool_2d_init_task_impl(); +TaskImplFunction get_pool_2d_fwd_task_impl(); +TaskImplFunction get_pool_2d_bwd_task_impl(); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ops/impl/reduce.h b/lib/task-spec/include/task-spec/ops/impl/reduce.h new file mode 100644 index 0000000000..835479dcfb --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/impl/reduce.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_REDUCE_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_REDUCE_H + +#include "op-attrs/ops/reduce_attrs.dtg.h" +#include "task-spec/task_impl_function.dtg.h" + +namespace FlexFlow { + +TaskImplFunction get_reduce_init_task_impl(); +TaskImplFunction get_reduce_fwd_task_impl(); +TaskImplFunction get_reduce_bwd_task_impl(); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ops/impl/reshape.h b/lib/task-spec/include/task-spec/ops/impl/reshape.h new file mode 100644 index 0000000000..e5ebdb30c4 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/impl/reshape.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_RESHAPE_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_RESHAPE_H + +#include "op-attrs/ops/reshape_attrs.dtg.h" +#include "task-spec/task_impl_function.dtg.h" + +namespace FlexFlow { + +TaskImplFunction get_reshape_fwd_task_impl(); +TaskImplFunction get_reshape_bwd_task_impl(); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ops/impl/reverse.h b/lib/task-spec/include/task-spec/ops/impl/reverse.h new file mode 100644 index 0000000000..c74567a370 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/impl/reverse.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_REVERSE_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_REVERSE_H + +#include "op-attrs/ops/reverse_attrs.dtg.h" +#include "task-spec/task_impl_function.dtg.h" + +namespace FlexFlow { + +TaskImplFunction get_reverse_fwd_task_impl(); +TaskImplFunction get_reverse_bwd_task_impl(); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ops/impl/softmax.h b/lib/task-spec/include/task-spec/ops/impl/softmax.h new file mode 100644 index 0000000000..3f2a55d243 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/impl/softmax.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_SOFTMAX_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_SOFTMAX_H + +#include "op-attrs/ops/softmax_attrs.dtg.h" +#include "task-spec/task_impl_function.dtg.h" + +namespace FlexFlow { + +TaskImplFunction get_softmax_init_task_impl(); +TaskImplFunction get_softmax_fwd_task_impl(); +TaskImplFunction get_softmax_bwd_task_impl(); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ops/impl/split.h b/lib/task-spec/include/task-spec/ops/impl/split.h new file mode 100644 index 0000000000..7cf24e10d9 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/impl/split.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_SPLIT_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_SPLIT_H + +#include "op-attrs/ops/split_attrs.dtg.h" +#include "task-spec/task_impl_function.dtg.h" + +namespace FlexFlow { + +TaskImplFunction get_split_fwd_task_impl(); +TaskImplFunction get_split_bwd_task_impl(); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ops/impl/topk.h b/lib/task-spec/include/task-spec/ops/impl/topk.h new file mode 100644 index 0000000000..565166aa53 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/impl/topk.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_TOPK_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_TOPK_H + +#include "op-attrs/ops/topk_attrs.dtg.h" +#include "task-spec/task_impl_function.dtg.h" + +namespace FlexFlow { + +TaskImplFunction get_topk_fwd_task_impl(); +TaskImplFunction get_topk_bwd_task_impl(); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ops/impl/transpose.h b/lib/task-spec/include/task-spec/ops/impl/transpose.h new file mode 100644 index 0000000000..7ba83a2648 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/impl/transpose.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_TRANSPOSE_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_TRANSPOSE_H + +#include "op-attrs/ops/transpose_attrs.dtg.h" +#include "task-spec/task_impl_function.dtg.h" + +namespace FlexFlow { + +TaskImplFunction get_transpose_fwd_task_impl(); +TaskImplFunction get_transpose_bwd_task_impl(); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ops/input.h b/lib/task-spec/include/task-spec/ops/input.h deleted file mode 100644 index 9181478363..0000000000 --- a/lib/task-spec/include/task-spec/ops/input.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef _FLEXFLOW_INPUT_H -#define _FLEXFLOW_INPUT_H - -#include "op-attrs/ops/input_attrs.dtg.h" -#include "task-spec/op_task_invocation.h" - -namespace FlexFlow { - -std::vector get_task_ids(InputAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/ops/layer_norm.h b/lib/task-spec/include/task-spec/ops/layer_norm.h deleted file mode 100644 index ad418826f2..0000000000 --- a/lib/task-spec/include/task-spec/ops/layer_norm.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef _FLEXFLOW_RUNTIME_SRC_OPS_LAYER_NORM_H -#define _FLEXFLOW_RUNTIME_SRC_OPS_LAYER_NORM_H - -#include "op-attrs/ops/layer_norm_attrs.dtg.h" -#include "task-spec/op_task_invocation.h" -#include "task-spec/task_impl_function.dtg.h" - -namespace FlexFlow { - -std::vector get_task_ids(LayerNormAttrs const &); - -TaskImplFunction get_layer_norm_init_task_impl(); -TaskImplFunction get_layer_norm_fwd_task_impl(); -TaskImplFunction get_layer_norm_bwd_task_impl(); - -OpTaskSignature get_layer_norm_init_signature(); -OpTaskSignature get_layer_norm_fwd_signature(); -OpTaskSignature get_layer_norm_bwd_signature(); - -OpTaskInvocation init(LayerNormAttrs const &); -OpTaskInvocation forward(LayerNormAttrs const &); -OpTaskInvocation backward(LayerNormAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/ops/linear.h b/lib/task-spec/include/task-spec/ops/linear.h deleted file mode 100644 index d3c188a2c4..0000000000 --- a/lib/task-spec/include/task-spec/ops/linear.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef _FLEXFLOW_LINEAR_H -#define _FLEXFLOW_LINEAR_H - -#include "op-attrs/ops/linear_attrs.dtg.h" -#include "task-spec/op_task_invocation.h" -#include "task-spec/task_impl_function.dtg.h" - -namespace FlexFlow { - -std::vector get_task_ids(LinearAttrs const &); - -OpTaskInvocation init(LinearAttrs const &); -OpTaskInvocation forward(LinearAttrs const &); -OpTaskInvocation backward(LinearAttrs const &); - -TaskImplFunction get_linear_init_task_impl(); -TaskImplFunction get_linear_fwd_task_impl(); -TaskImplFunction get_linear_bwd_task_impl(); - -OpTaskSignature get_linear_init_signature(); -OpTaskSignature get_linear_fwd_signature(); -OpTaskSignature get_linear_bwd_signature(); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/ops/noop.h b/lib/task-spec/include/task-spec/ops/noop.h deleted file mode 100644 index adbc15cd3b..0000000000 --- a/lib/task-spec/include/task-spec/ops/noop.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef _FLEXFLOW_NOOP_H -#define _FLEXFLOW_NOOP_H - -#include "op-attrs/ops/noop_attrs.dtg.h" -#include "task-spec/op_task_invocation.h" - -namespace FlexFlow { - -std::vector get_task_ids(NoopAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/op_arg_ref.h b/lib/task-spec/include/task-spec/ops/op_arg_ref.h similarity index 85% rename from lib/task-spec/include/task-spec/op_arg_ref.h rename to lib/task-spec/include/task-spec/ops/op_arg_ref.h index 88882abd46..41857b78bc 100644 --- a/lib/task-spec/include/task-spec/op_arg_ref.h +++ b/lib/task-spec/include/task-spec/ops/op_arg_ref.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LOCAL_EXECUTION_OP_ARG_REF_H -#define _FLEXFLOW_LOCAL_EXECUTION_OP_ARG_REF_H +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_OP_ARG_REF_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_OP_ARG_REF_H #include "op-attrs/parallel_tensor_shape.dtg.h" #include "task-spec/arg_ref.h" @@ -12,8 +12,6 @@ namespace FlexFlow { template using OpArgRef = ArgRef; -using OpArgRefSpec = ArgRefSpec; - template OpArgRef per_device_op_state() { OpArgRefType op_arg_ref_type = OpArgRefType{PerDeviceOpStateRefType{}}; diff --git a/lib/task-spec/include/task-spec/ops/op_arg_ref_spec.h b/lib/task-spec/include/task-spec/ops/op_arg_ref_spec.h new file mode 100644 index 0000000000..14c7c0f013 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/op_arg_ref_spec.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_OP_ARG_REF_SPEC_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_OP_ARG_REF_SPEC_H + +#include "task-spec/arg_ref_spec.h" +#include "task-spec/op_arg_ref_type.dtg.h" + +namespace FlexFlow { + +using OpArgRefSpec = ArgRefSpec; + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ops/op_slot_options.dtg.toml b/lib/task-spec/include/task-spec/ops/op_slot_options.dtg.toml new file mode 100644 index 0000000000..bb984e8ac0 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/op_slot_options.dtg.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "OpSlotOptions" +type = "enum" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "OPTIONAL" + +[[values]] +name = "UNTRAINABLE" + +[[values]] +name = "OPTIONAL_UNTRAINABLE" + +[[values]] +name = "NECESSARY" diff --git a/lib/task-spec/include/task-spec/ops/op_task_id_t.dtg.toml b/lib/task-spec/include/task-spec/ops/op_task_id_t.dtg.toml new file mode 100644 index 0000000000..557da6cf4c --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/op_task_id_t.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "op_task_id_t" +type = "enum" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "INIT" + +[[values]] +name = "FWD" + +[[values]] +name = "BWD" diff --git a/lib/task-spec/include/task-spec/ops/op_task_type.dtg.toml b/lib/task-spec/include/task-spec/ops/op_task_type.dtg.toml new file mode 100644 index 0000000000..582d51b657 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/op_task_type.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "OpTaskType" +type = "enum" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "INIT" + +[[values]] +name = "FWD" + +[[values]] +name = "BWD" diff --git a/lib/task-spec/include/task-spec/ops/op_tensor_slot_spec.dtg.toml b/lib/task-spec/include/task-spec/ops/op_tensor_slot_spec.dtg.toml new file mode 100644 index 0000000000..75410432c2 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/op_tensor_slot_spec.dtg.toml @@ -0,0 +1,32 @@ +namespace = "FlexFlow" +name = "OpTensorSlotSpec" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "op-attrs/tensor_slot_name.dtg.h", + "op-attrs/tensor_role.dtg.h", + "task-spec/is_grad.dtg.h", + "task-spec/ops/op_slot_options.dtg.h", +] + +[[fields]] +name = "name" +type = "::FlexFlow::TensorSlotName" + +[[fields]] +name = "tensor_role" +type = "::FlexFlow::TensorRole" + +[[fields]] +name = "is_grad" +type = "::FlexFlow::IsGrad" + +[[fields]] +name = "slot_option" +type = "::FlexFlow::OpSlotOptions" diff --git a/lib/task-spec/include/task-spec/ops/op_tensor_spec.dtg.toml b/lib/task-spec/include/task-spec/ops/op_tensor_spec.dtg.toml new file mode 100644 index 0000000000..45587019ec --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/op_tensor_spec.dtg.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "OpTensorSpec" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", + "rapidcheck", +] + +includes = [ + "op-attrs/tensor_role.dtg.h", + "task-spec/ops/op_slot_options.dtg.h", + "utils/nonnegative_int/nonnegative_int.h", +] + +[[fields]] +name = "role" +type = "::FlexFlow::TensorRole" + +[[fields]] +name = "slot_option" +type = "::FlexFlow::OpSlotOptions" + +[[fields]] +name = "idx" +type = "::FlexFlow::nonnegative_int" diff --git a/lib/task-spec/include/task-spec/ops/op_tensor_spec.h b/lib/task-spec/include/task-spec/ops/op_tensor_spec.h new file mode 100644 index 0000000000..c7af5df2d9 --- /dev/null +++ b/lib/task-spec/include/task-spec/ops/op_tensor_spec.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_OP_TENSOR_SPEC_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_OP_TENSOR_SPEC_H + +#include "task-spec/ops/op_tensor_spec.dtg.h" + +namespace FlexFlow { + +OpTensorSpec input_tensor(nonnegative_int idx, + OpSlotOptions option = OpSlotOptions::NECESSARY); +OpTensorSpec output_tensor(nonnegative_int idx, + OpSlotOptions option = OpSlotOptions::NECESSARY); +OpTensorSpec weight_tensor(nonnegative_int idx, + OpSlotOptions option = OpSlotOptions::NECESSARY); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/ops/pool_2d.h b/lib/task-spec/include/task-spec/ops/pool_2d.h deleted file mode 100644 index fbecd0e96f..0000000000 --- a/lib/task-spec/include/task-spec/ops/pool_2d.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef _FLEXFLOW_POOL_2D_H -#define _FLEXFLOW_POOL_2D_H - -#include "op-attrs/ops/pool_2d_attrs.dtg.h" -#include "task-spec/op_task_invocation.h" -#include "task-spec/task_impl_function.dtg.h" - -namespace FlexFlow { - -std::vector get_task_ids(Pool2DAttrs const &); - -TaskImplFunction get_pool_2d_init_task_impl(); -TaskImplFunction get_pool_2d_fwd_task_impl(); -TaskImplFunction get_pool_2d_bwd_task_impl(); - -OpTaskSignature get_pool_2d_init_signature(); -OpTaskSignature get_pool_2d_fwd_signature(); -OpTaskSignature get_pool_2d_bwd_signature(); - -OpTaskInvocation init(Pool2DAttrs const &); -OpTaskInvocation forward(Pool2DAttrs const &); -OpTaskInvocation backward(Pool2DAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/ops/reduce.h b/lib/task-spec/include/task-spec/ops/reduce.h deleted file mode 100644 index ffcf66e752..0000000000 --- a/lib/task-spec/include/task-spec/ops/reduce.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef _FLEXFLOW_RUNTIME_SRC_OPS_REDUCE_H -#define _FLEXFLOW_RUNTIME_SRC_OPS_REDUCE_H - -#include "op-attrs/ops/reduce_attrs.dtg.h" -#include "task-spec/op_task_invocation.h" -#include "task-spec/task_impl_function.dtg.h" - -namespace FlexFlow { - -std::vector get_task_ids(ReduceAttrs const &); - -TaskImplFunction get_reduce_init_task_impl(); -TaskImplFunction get_reduce_fwd_task_impl(); -TaskImplFunction get_reduce_bwd_task_impl(); - -OpTaskSignature get_reduce_init_signature(); -OpTaskSignature get_reduce_fwd_signature(); -OpTaskSignature get_reduce_bwd_signature(); - -OpTaskInvocation init(ReduceAttrs const &); -OpTaskInvocation forward(ReduceAttrs const &); -OpTaskInvocation backward(ReduceAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/ops/reshape.h b/lib/task-spec/include/task-spec/ops/reshape.h deleted file mode 100644 index e5bf7170fb..0000000000 --- a/lib/task-spec/include/task-spec/ops/reshape.h +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef _FLEXFLOW_RESHAPE_H -#define _FLEXFLOW_RESHAPE_H - -#include "op-attrs/ops/reshape_attrs.dtg.h" -#include "task-spec/op_task_invocation.h" -#include "task-spec/task_impl_function.dtg.h" - -namespace FlexFlow { - -std::vector get_task_ids(ReshapeAttrs const &); - -TaskImplFunction get_reshape_fwd_task_impl(); -TaskImplFunction get_reshape_bwd_task_impl(); - -OpTaskSignature get_reshape_fwd_signature(); -OpTaskSignature get_reshape_bwd_signature(); - -OpTaskInvocation forward(ReshapeAttrs const &); -OpTaskInvocation backward(ReshapeAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/ops/reverse.h b/lib/task-spec/include/task-spec/ops/reverse.h deleted file mode 100644 index 7c91f91c0b..0000000000 --- a/lib/task-spec/include/task-spec/ops/reverse.h +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef _FLEXFLOW_REVERSE_H_ -#define _FLEXFLOW_REVERSE_H_ - -#include "op-attrs/ops/reverse_attrs.dtg.h" -#include "task-spec/op_task_invocation.h" -#include "task-spec/task_impl_function.dtg.h" - -namespace FlexFlow { - -std::vector get_task_ids(ReverseAttrs const &); - -TaskImplFunction get_reverse_fwd_task_impl(); -TaskImplFunction get_reverse_bwd_task_impl(); - -OpTaskSignature get_reverse_fwd_signature(); -OpTaskSignature get_reverse_bwd_signature(); - -OpTaskInvocation forward(ReverseAttrs const &); -OpTaskInvocation backward(ReverseAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/ops/softmax.h b/lib/task-spec/include/task-spec/ops/softmax.h deleted file mode 100644 index 8f99c2658a..0000000000 --- a/lib/task-spec/include/task-spec/ops/softmax.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_SOFTMAX_H -#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_SOFTMAX_H - -#include "op-attrs/ops/softmax_attrs.dtg.h" -#include "task-spec/op_task_invocation.h" -#include "task-spec/task_impl_function.dtg.h" - -namespace FlexFlow { - -std::vector get_task_ids(SoftmaxAttrs const &); - -TaskImplFunction get_softmax_init_task_impl(); -TaskImplFunction get_softmax_fwd_task_impl(); -TaskImplFunction get_softmax_bwd_task_impl(); - -OpTaskSignature get_softmax_init_signature(); -OpTaskSignature get_softmax_fwd_signature(); -OpTaskSignature get_softmax_bwd_signature(); - -OpTaskInvocation init(SoftmaxAttrs const &); -OpTaskInvocation forward(SoftmaxAttrs const &); -OpTaskInvocation backward(SoftmaxAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/ops/split.h b/lib/task-spec/include/task-spec/ops/split.h deleted file mode 100644 index 1aa8609011..0000000000 --- a/lib/task-spec/include/task-spec/ops/split.h +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef _FLEXFLOW_SPLIT_H -#define _FLEXFLOW_SPLIT_H - -#include "op-attrs/ops/split_attrs.dtg.h" -#include "task-spec/op_task_invocation.h" -#include "task-spec/task_impl_function.dtg.h" - -namespace FlexFlow { - -std::vector get_task_ids(SplitAttrs const &); - -TaskImplFunction get_split_fwd_task_impl(); -TaskImplFunction get_split_bwd_task_impl(); - -OpTaskSignature get_split_fwd_signature(); -OpTaskSignature get_split_bwd_signature(); - -OpTaskInvocation forward(SplitAttrs const &); -OpTaskInvocation backward(SplitAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/ops/topk.h b/lib/task-spec/include/task-spec/ops/topk.h deleted file mode 100644 index ca1d43c2ee..0000000000 --- a/lib/task-spec/include/task-spec/ops/topk.h +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_TOPK_H -#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_TOPK_H - -#include "op-attrs/ops/topk_attrs.dtg.h" -#include "task-spec/op_task_invocation.h" -#include "task-spec/task_impl_function.dtg.h" - -namespace FlexFlow { - -std::vector get_task_ids(TopKAttrs const &); - -TaskImplFunction get_topk_fwd_task_impl(); -TaskImplFunction get_topk_bwd_task_impl(); - -OpTaskSignature get_topk_fwd_signature(); -OpTaskSignature get_topk_bwd_signature(); - -OpTaskInvocation forward(TopKAttrs const &); -OpTaskInvocation backward(TopKAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/ops/transpose.h b/lib/task-spec/include/task-spec/ops/transpose.h deleted file mode 100644 index 7762f440cd..0000000000 --- a/lib/task-spec/include/task-spec/ops/transpose.h +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef _FLEXFLOW_TRANSPOSE_H_ -#define _FLEXFLOW_TRANSPOSE_H_ - -#include "op-attrs/ops/transpose_attrs.dtg.h" -#include "task-spec/op_task_invocation.h" -#include "task-spec/task_impl_function.dtg.h" - -namespace FlexFlow { - -std::vector get_task_ids(TransposeAttrs const &); - -TaskImplFunction get_transpose_fwd_task_impl(); -TaskImplFunction get_transpose_bwd_task_impl(); - -OpTaskSignature get_transpose_fwd_signature(); -OpTaskSignature get_transpose_bwd_signature(); - -OpTaskInvocation forward(TransposeAttrs const &); -OpTaskInvocation backward(TransposeAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/ops/weight.h b/lib/task-spec/include/task-spec/ops/weight.h deleted file mode 100644 index 162236e41e..0000000000 --- a/lib/task-spec/include/task-spec/ops/weight.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef _FLEXFLOW_WEIGHT_H -#define _FLEXFLOW_WEIGHT_H - -#include "op-attrs/ops/weight_attrs.dtg.h" -#include "task-spec/op_task_invocation.h" - -namespace FlexFlow { - -std::vector get_task_ids(WeightAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/optimizer.h b/lib/task-spec/include/task-spec/optimizer.h index 5b898d8699..6de9b9183b 100644 --- a/lib/task-spec/include/task-spec/optimizer.h +++ b/lib/task-spec/include/task-spec/optimizer.h @@ -5,32 +5,11 @@ #include "pcg/optimizers/adam_optimizer_attrs.dtg.h" #include "pcg/optimizers/sgd_optimizer_attrs.dtg.h" #include "task-spec/task_impl_function.dtg.h" -#include "task-spec/task_invocation.dtg.h" -#include "task-spec/task_signature.h" namespace FlexFlow { -TaskSignature get_update_signature(OptimizerAttrs const &); -TaskInvocation get_update_invocation( - OptimizerAttrs const &, - forward_tensor_guid_t const &weight, - gradient_tensor_guid_t const &weight_grad, - std::vector const &grad_buffer_tensors); TaskImplFunction get_update_task_impl(OptimizerAttrs const &); - -TaskSignature get_sgd_update_signature(); -TaskInvocation sgd_update(SGDOptimizerAttrs const &, - forward_tensor_guid_t const &weight, - gradient_tensor_guid_t const &weight_grad, - optimizer_tensor_guid_t const &sgd_v); TaskImplFunction get_sgd_update_task_impl(); - -TaskSignature get_adam_update_signature(); -TaskInvocation adam_update(AdamOptimizerAttrs const &, - forward_tensor_guid_t const &weight, - gradient_tensor_guid_t const &weight_grad, - optimizer_tensor_guid_t const &adam_v, - optimizer_tensor_guid_t const &adam_m); TaskImplFunction get_adam_update_task_impl(); } // namespace FlexFlow diff --git a/lib/task-spec/include/task-spec/optimizer_tensor_guid_t.struct.toml b/lib/task-spec/include/task-spec/optimizer_tensor_guid_t.struct.toml deleted file mode 100644 index dc5f98886f..0000000000 --- a/lib/task-spec/include/task-spec/optimizer_tensor_guid_t.struct.toml +++ /dev/null @@ -1,13 +0,0 @@ -namespace = "FlexFlow" -name = "optimizer_tensor_guid_t" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - - -[[fields]] -name = "raw_index" -type = "int" diff --git a/lib/task-spec/include/task-spec/optimizer_tensor_source.h b/lib/task-spec/include/task-spec/optimizer_tensor_source.h deleted file mode 100644 index 2f10c5c35b..0000000000 --- a/lib/task-spec/include/task-spec/optimizer_tensor_source.h +++ /dev/null @@ -1,22 +0,0 @@ -#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPTIMIZER_TENSOR_SOURCE_H -#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPTIMIZER_TENSOR_SOURCE_H - -#include "task-spec/optimizer_tensor_guid_t.dtg.h" - -namespace FlexFlow { - -struct OptimizerTensorSource { -public: - OptimizerTensorSource(); - - optimizer_tensor_guid_t new_optimizer_tensor(); - - void reset(); - -private: - static int next_available_optimizer_tensor_id; -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/parallel_tensor_shape_ref_type.dtg.toml b/lib/task-spec/include/task-spec/parallel_tensor_shape_ref_type.dtg.toml new file mode 100644 index 0000000000..a6684653a0 --- /dev/null +++ b/lib/task-spec/include/task-spec/parallel_tensor_shape_ref_type.dtg.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "ParallelTensorShapeRefType" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "utils/nonnegative_int/nonnegative_int.h", + "op-attrs/tensor_role.dtg.h", +] + +[[fields]] +name = "tensor_role" +type = "::FlexFlow::TensorRole" + +[[fields]] +name = "idx" +type = "::FlexFlow::nonnegative_int" diff --git a/lib/task-spec/include/task-spec/parallel_tensor_shape_ref_type.struct.toml b/lib/task-spec/include/task-spec/parallel_tensor_shape_ref_type.struct.toml deleted file mode 100644 index 4ff411d17b..0000000000 --- a/lib/task-spec/include/task-spec/parallel_tensor_shape_ref_type.struct.toml +++ /dev/null @@ -1,22 +0,0 @@ -namespace = "FlexFlow" -name = "ParallelTensorShapeRefType" -features = [ - "eq", - "ord", - "hash", - "json", - "fmt", -] - -includes = [ - "utils/nonnegative_int/nonnegative_int.h", - "pcg/tensor_role.dtg.h", -] - -[[fields]] -name = "tensor_role" -type = "::FlexFlow::TensorRole" - -[[fields]] -name = "idx" -type = "::FlexFlow::nonnegative_int" diff --git a/lib/task-spec/include/task-spec/per_device_op_state.dtg.toml b/lib/task-spec/include/task-spec/per_device_op_state.dtg.toml new file mode 100644 index 0000000000..2bb7a2dce9 --- /dev/null +++ b/lib/task-spec/include/task-spec/per_device_op_state.dtg.toml @@ -0,0 +1,73 @@ +namespace = "FlexFlow" +name = "PerDeviceOpState" +type = "variant" +features = [] + +includes = [ + "kernels/mha_per_device_state.dtg.h", + "kernels/batch_norm_per_device_state.dtg.h", + "kernels/conv_2d_per_device_state.dtg.h", + "kernels/dropout_per_device_state.dtg.h", + "kernels/element_binary_per_device_state.dtg.h", + "kernels/element_unary_per_device_state.dtg.h", + "kernels/gather_per_device_state.dtg.h", + "kernels/layer_norm_per_device_state.dtg.h", + "kernels/linear_per_device_state.dtg.h", + "kernels/partition_per_device_state.dtg.h", + "kernels/pool_2d_per_device_state.dtg.h", + "kernels/reduce_per_device_state.dtg.h", + "kernels/softmax_per_device_state.dtg.h", + "", +] + +[[values]] +type = "std::optional<::FlexFlow::MHAPerDeviceState>" +key = "mha" + +[[values]] +type = "std::optional<::FlexFlow::BatchNormPerDeviceState>" +key = "batch_norm" + +[[values]] +type = "std::optional<::FlexFlow::Conv2DPerDeviceState>" +key = "conv2d" + +[[values]] +type = "std::optional<::FlexFlow::DropoutPerDeviceState>" +key = "dropout" + +[[values]] +type = "std::optional<::FlexFlow::ElementBinaryPerDeviceState>" +key = "element_binary" + +[[values]] +type = "std::optional<::FlexFlow::ElementUnaryPerDeviceState>" +key = "element_unary" + +[[values]] +type = "std::optional<::FlexFlow::GatherPerDeviceState>" +key = "gather" + +[[values]] +type = "std::optional<::FlexFlow::LayerNormPerDeviceState>" +key = "layer_norm" + +[[values]] +type = "std::optional<::FlexFlow::LinearPerDeviceState>" +key = "linear" + +[[values]] +type = "std::optional<::FlexFlow::Pool2DPerDeviceState>" +key = "pool_2d" + +[[values]] +type = "std::optional<::FlexFlow::ReducePerDeviceState>" +key = "reduce" + +[[values]] +type = "std::optional<::FlexFlow::RepartitionPerDeviceState>" +key = "repartition" + +[[values]] +type = "std::optional<::FlexFlow::SoftmaxPerDeviceState>" +key = "softmax" diff --git a/lib/task-spec/include/task-spec/per_device_op_state.h b/lib/task-spec/include/task-spec/per_device_op_state.h index ae6c93807c..68d3f98ebf 100644 --- a/lib/task-spec/include/task-spec/per_device_op_state.h +++ b/lib/task-spec/include/task-spec/per_device_op_state.h @@ -2,15 +2,14 @@ #define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_PER_DEVICE_OP_STATE_H #include "task-spec/concrete_arg_spec.h" -#include "task-spec/device_specific_device_states.dtg.h" +#include "task-spec/device_specific_per_device_op_state.dtg.h" #include "task-spec/per_device_op_state.dtg.h" #include "utils/type_index.h" namespace FlexFlow { -PerDeviceOpState - get_device_state_from_device_specific(DeviceSpecificDeviceStates const &, - size_t device_idx); +PerDeviceOpState get_device_state_from_device_specific( + DeviceSpecificPerDeviceOpState const &, device_id_t device_idx); } diff --git a/lib/task-spec/include/task-spec/per_device_op_state.variant.toml b/lib/task-spec/include/task-spec/per_device_op_state.variant.toml deleted file mode 100644 index 7c340447f9..0000000000 --- a/lib/task-spec/include/task-spec/per_device_op_state.variant.toml +++ /dev/null @@ -1,72 +0,0 @@ -namespace = "FlexFlow" -name = "PerDeviceOpState" -features = [] - -includes = [ - "kernels/mha_per_device_state.dtg.h", - "kernels/batch_norm_per_device_state.dtg.h", - "kernels/conv_2d_per_device_state.dtg.h", - "kernels/dropout_per_device_state.dtg.h", - "kernels/element_binary_per_device_state.dtg.h", - "kernels/element_unary_per_device_state.dtg.h", - "kernels/gather_per_device_state.dtg.h", - "kernels/layer_norm_per_device_state.dtg.h", - "kernels/linear_per_device_state.dtg.h", - "kernels/partition_per_device_state.dtg.h", - "kernels/pool_2d_per_device_state.dtg.h", - "kernels/reduce_per_device_state.dtg.h", - "kernels/softmax_per_device_state.dtg.h", - "", -] - -[[values]] -type = "std::optional<::FlexFlow::MHAPerDeviceState>" -key = "mha_per_device_state" - -[[values]] -type = "std::optional<::FlexFlow::BatchNormPerDeviceState>" -key = "batch_norm_per_device_state" - -[[values]] -type = "std::optional<::FlexFlow::Conv2DPerDeviceState>" -key = "conv2d_per_device_state" - -[[values]] -type = "std::optional<::FlexFlow::DropoutPerDeviceState>" -key = "dropout_per_device_state" - -[[values]] -type = "std::optional<::FlexFlow::ElementBinaryPerDeviceState>" -key = "element_binary_per_device_state" - -[[values]] -type = "std::optional<::FlexFlow::ElementUnaryPerDeviceState>" -key = "element_unary_per_device_state" - -[[values]] -type = "std::optional<::FlexFlow::GatherPerDeviceState>" -key = "gather_per_device_state" - -[[values]] -type = "std::optional<::FlexFlow::LayerNormPerDeviceState>" -key = "layer_norm_per_device_state" - -[[values]] -type = "std::optional<::FlexFlow::LinearPerDeviceState>" -key = "linear_per_device_state" - -[[values]] -type = "std::optional<::FlexFlow::Pool2DPerDeviceState>" -key = "pool_2d_per_device_state" - -[[values]] -type = "std::optional<::FlexFlow::ReducePerDeviceState>" -key = "reduce_per_device_state" - -[[values]] -type = "std::optional<::FlexFlow::RepartitionPerDeviceState>" -key = "repartition_per_device_state" - -[[values]] -type = "std::optional<::FlexFlow::SoftmaxPerDeviceState>" -key = "softmax_per_device_state" diff --git a/lib/task-spec/include/task-spec/per_device_op_state_ref_type.dtg.toml b/lib/task-spec/include/task-spec/per_device_op_state_ref_type.dtg.toml new file mode 100644 index 0000000000..3b37e4a383 --- /dev/null +++ b/lib/task-spec/include/task-spec/per_device_op_state_ref_type.dtg.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "PerDeviceOpStateRefType" +type = "struct" + +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +fields = [] diff --git a/lib/task-spec/include/task-spec/per_device_op_state_ref_type.struct.toml b/lib/task-spec/include/task-spec/per_device_op_state_ref_type.struct.toml deleted file mode 100644 index e3d48a02ee..0000000000 --- a/lib/task-spec/include/task-spec/per_device_op_state_ref_type.struct.toml +++ /dev/null @@ -1,12 +0,0 @@ -namespace = "FlexFlow" -name = "PerDeviceOpStateRefType" - -features = [ - "eq", - "ord", - "hash", - "json", - "fmt", -] - -fields = [] diff --git a/lib/task-spec/include/task-spec/profiling.h b/lib/task-spec/include/task-spec/profiling.h index 91774f69ef..760d23240d 100644 --- a/lib/task-spec/include/task-spec/profiling.h +++ b/lib/task-spec/include/task-spec/profiling.h @@ -1,20 +1,20 @@ -#ifndef _FLEXFLOW_LOCAL_EXECUTION_PROFILING_H -#define _FLEXFLOW_LOCAL_EXECUTION_PROFILING_H +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_PROFILING_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_PROFILING_H #include "kernels/profiling.h" -#include "spdlog/spdlog.h" +#include namespace FlexFlow { enum class EnableProfiling { YES, NO }; template -std::optional profile(F const &f, - ProfilingSettings profiling, - DeviceType device_type, - Str s, - Ts &&...ts) { - std::optional elapsed = profiling_wrapper( +std::optional profile(F const &f, + ProfilingSettings profiling, + DeviceType device_type, + Str s, + Ts &&...ts) { + std::optional elapsed = profiling_wrapper( f, profiling, device_type, std::forward(ts)...); if (elapsed.has_value()) { spdlog::debug(s, elapsed.value()); diff --git a/lib/task-spec/include/task-spec/runtime_arg_config.h b/lib/task-spec/include/task-spec/runtime_arg_config.h deleted file mode 100644 index 5358caf331..0000000000 --- a/lib/task-spec/include/task-spec/runtime_arg_config.h +++ /dev/null @@ -1,18 +0,0 @@ -#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_RUNTIME_ARG_CONFIG_H -#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_RUNTIME_ARG_CONFIG_H - -#include "task-spec/runtime_arg_config.dtg.h" - -namespace FlexFlow { - -RuntimeArgConfig - cpu_make_runtime_arg_config(EnableProfiling enable_profiling, - ProfilingSettings profiling_settings); -RuntimeArgConfig - gpu_make_runtime_arg_config(PerDeviceFFHandle const &ff_handle, - EnableProfiling enable_profiling, - ProfilingSettings profiling_settings); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/runtime_arg_config.struct.toml b/lib/task-spec/include/task-spec/runtime_arg_config.struct.toml deleted file mode 100644 index 9d77616306..0000000000 --- a/lib/task-spec/include/task-spec/runtime_arg_config.struct.toml +++ /dev/null @@ -1,25 +0,0 @@ -namespace = "FlexFlow" -name = "RuntimeArgConfig" -features = [] - -includes = [ - "kernels/device_handle_t.dtg.h", - "task-spec/device_specific.h", - "task-spec/profiling.h", -] - -[[fields]] -name = "ff_handle" -type = "::FlexFlow::DeviceSpecific<::FlexFlow::device_handle_t>" - -[[fields]] -name = "enable_profiling" -type = "::FlexFlow::EnableProfiling" - -[[fields]] -name = "profiling_settings" -type = "::FlexFlow::ProfilingSettings" - -[[fields]] -name = "kernel_device_type" -type = "::FlexFlow::DeviceType" diff --git a/lib/task-spec/include/task-spec/runtime_arg_ref.h b/lib/task-spec/include/task-spec/runtime_arg_ref.h deleted file mode 100644 index 532482f89e..0000000000 --- a/lib/task-spec/include/task-spec/runtime_arg_ref.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_RUNTIME_ARG_REF_H -#define _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_RUNTIME_ARG_REF_H - -#include "kernels/device_handle_t.dtg.h" -#include "kernels/profiling_settings.dtg.h" -#include "pcg/device_type.dtg.h" -#include "task-spec/arg_ref.h" -#include "task-spec/config.h" -#include "task-spec/device_specific.h" -#include "task-spec/runtime_arg_ref_type.dtg.h" - -namespace FlexFlow { - -template -using RuntimeArgRef = ArgRef; - -using RuntimeArgRefSpec = ArgRefSpec; - -RuntimeArgRef profiling_settings(); -RuntimeArgRef> ff_handle(); -RuntimeArgRef iteration_config(); -RuntimeArgRef kernel_device_type(); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/runtime_arg_ref_type.enum.toml b/lib/task-spec/include/task-spec/runtime_arg_ref_type.enum.toml deleted file mode 100644 index e33eeebc56..0000000000 --- a/lib/task-spec/include/task-spec/runtime_arg_ref_type.enum.toml +++ /dev/null @@ -1,17 +0,0 @@ -namespace = "FlexFlow" -name = "RuntimeArgRefType" -features = [ - "fmt", -] - -[[values]] -name = "FF_HANDLE" - -[[values]] -name = "PROFILING_SETTINGS" - -[[values]] -name = "FF_ITERATION_CONFIG" - -[[values]] -name = "KERNEL_DEVICE_TYPE" diff --git a/lib/task-spec/include/task-spec/serialization.h b/lib/task-spec/include/task-spec/serialization.h index 2fc4b4b706..29f9144a3b 100644 --- a/lib/task-spec/include/task-spec/serialization.h +++ b/lib/task-spec/include/task-spec/serialization.h @@ -1,25 +1,12 @@ -#ifndef _FLEXFLOW_LOCAL_EXECUTION_SERIALIZATION_H -#define _FLEXFLOW_LOCAL_EXECUTION_SERIALIZATION_H +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_SERIALIZATION_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_SERIALIZATION_H #include "kernels/device.h" #include "kernels/nccl.h" -#include "op-attrs/dim_ordered/dim_ordered.h" +#include "op-attrs/ff_ordered/ff_ordered.h" #include "utils/required.h" -#include "utils/strong_typedef.h" #include "utils/type_traits.h" #include "utils/variant.h" -#include "utils/visitable.h" - -namespace FlexFlow { - -struct InternalTestType { - int x; - float y; -}; - -} // namespace FlexFlow - -VISITABLE_STRUCT(::FlexFlow::InternalTestType, x, y); namespace FlexFlow { @@ -46,26 +33,12 @@ struct visit_trivially_serializable> { template <> struct visit_trivially_serializable<> : std::true_type {}; -template -struct is_trivially_serializable< - T, - typename std::enable_if< - visit_trivially_serializable>::value>::type> - : std::true_type {}; - template struct is_trivially_serializable< T, typename std::enable_if::value>::type> : std::true_type {}; -template -struct is_trivially_serializable>> - : is_trivially_serializable> {}; - -template -struct is_trivially_serializable> : is_trivially_serializable {}; - template <> struct is_trivially_serializable : std::true_type {}; template <> @@ -86,9 +59,9 @@ template struct is_trivially_serializable> : is_trivially_serializable {}; -template -struct is_trivially_serializable> - : is_trivially_serializable {}; +template +struct is_trivially_serializable> : is_trivially_serializable { +}; template struct is_trivially_serializable> @@ -134,11 +107,6 @@ static_assert(is_trivially_serializable::value, ""); static_assert(is_trivially_serializable::value, ""); static_assert(is_trivially_serializable>::value, ""); -static_assert(std::is_same, - std::tuple>::value, - ""); -static_assert(visit_trivially_serializable::value, ""); -static_assert(is_trivially_serializable::value, ""); } // namespace FlexFlow diff --git a/lib/task-spec/include/task-spec/slot_grad_id.struct.toml b/lib/task-spec/include/task-spec/slot_grad_id.struct.toml deleted file mode 100644 index a6533ea884..0000000000 --- a/lib/task-spec/include/task-spec/slot_grad_id.struct.toml +++ /dev/null @@ -1,21 +0,0 @@ -namespace = "FlexFlow" -name = "SlotGradId" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "task-spec/is_grad.dtg.h", - "task-spec/slot_id_t.dtg.h", -] - -[[fields]] -name = "slot_id" -type = "::FlexFlow::slot_id_t" - -[[fields]] -name = "is_grad" -type = "::FlexFlow::IsGrad" diff --git a/lib/task-spec/include/task-spec/slot_id_t.struct.toml b/lib/task-spec/include/task-spec/slot_id_t.struct.toml deleted file mode 100644 index 0a5f360638..0000000000 --- a/lib/task-spec/include/task-spec/slot_id_t.struct.toml +++ /dev/null @@ -1,12 +0,0 @@ -namespace = "FlexFlow" -name = "slot_id_t" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -[[fields]] -name = "raw_id" -type = "int" diff --git a/lib/task-spec/include/task-spec/slot_type.enum.toml b/lib/task-spec/include/task-spec/slot_type.enum.toml deleted file mode 100644 index 0871a0bae4..0000000000 --- a/lib/task-spec/include/task-spec/slot_type.enum.toml +++ /dev/null @@ -1,14 +0,0 @@ -namespace = "FlexFlow" -name = "SlotType" -features = [ - "hash", - "fmt", - "rapidcheck", - "json", -] - -[[values]] -name = "TENSOR" - -[[values]] -name = "VARIADIC" diff --git a/lib/task-spec/include/task-spec/task_arg_spec.h b/lib/task-spec/include/task-spec/task_arg_spec.h deleted file mode 100644 index 38879ecab9..0000000000 --- a/lib/task-spec/include/task-spec/task_arg_spec.h +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_TASK_ARG_SPEC_H -#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_TASK_ARG_SPEC_H - -#include "task-spec/task_arg_spec.dtg.h" - -namespace FlexFlow { - -std::type_index get_type_index(TaskArgSpec const &); - -} - -#endif diff --git a/lib/task-spec/include/task-spec/task_arg_spec.variant.toml b/lib/task-spec/include/task-spec/task_arg_spec.variant.toml deleted file mode 100644 index 4829a50ff6..0000000000 --- a/lib/task-spec/include/task-spec/task_arg_spec.variant.toml +++ /dev/null @@ -1,20 +0,0 @@ -namespace = "FlexFlow" -name = "TaskArgSpec" -features = [ - "eq", - "fmt", - "hash" -] - -includes = [ - "task-spec/concrete_arg_spec.h", - "task-spec/runtime_arg_ref.h" -] - -[[values]] -type = "::FlexFlow::ConcreteArgSpec" -key = "concrete_arg_spec" - -[[values]] -type = "::FlexFlow::RuntimeArgRefSpec" -key = "runtime_arg_ref" diff --git a/lib/task-spec/include/task-spec/task_argument_accessor.h b/lib/task-spec/include/task-spec/task_argument_accessor.h deleted file mode 100644 index a6d71b6b70..0000000000 --- a/lib/task-spec/include/task-spec/task_argument_accessor.h +++ /dev/null @@ -1,141 +0,0 @@ -#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_TASK_ARGUMENT_ACCESSOR_H -#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_TASK_ARGUMENT_ACCESSOR_H - -#include "task-spec/device_specific.h" -#include "task-spec/itask_argument_accessor.h" -#include "task-spec/per_device_op_state.dtg.h" - -namespace FlexFlow { - -struct TaskArgumentAccessor { - // arguments - template - T const &get_argument(slot_id_t slot) const { - return this->ptr->get_concrete_arg(slot).get(); - } - - template - T const &get_argument(int slot) const { - return this->get_argument(slot_id_t{slot}); - } - - // tensors - template - privilege_mode_to_accessor get_tensor(int slot) const { - return this->get_tensor(slot_id_t{slot}); - } - - template - privilege_mode_to_accessor get_tensor(slot_id_t slot) const { - return std::get>( - this->ptr->get_tensor(slot, PRIV, TensorType::FORWARD)); - } - - template - privilege_mode_to_accessor get_tensor_grad(int slot) const { - return this->get_tensor_grad(slot_id_t{slot}); - } - - template - privilege_mode_to_accessor get_tensor_grad(slot_id_t slot) const { - return std::get>( - this->ptr->get_tensor(slot, PRIV, TensorType::GRADIENT)); - } - - template - privilege_mode_to_accessor get_optimizer_tensor(int slot) const { - return this->get_optimizer_tensor(slot_id_t{slot}); - } - - template - privilege_mode_to_accessor get_optimizer_tensor(slot_id_t slot) const { - return std::get>( - this->ptr->get_tensor(slot, PRIV, TensorType::OPTIMIZER)); - } - - template - privilege_mode_to_accessor get_loss_tensor(int slot) const { - return this->get_loss_tensor(slot_id_t{slot}); - } - - template - privilege_mode_to_accessor get_loss_tensor(slot_id_t slot) const { - return std::get>( - this->ptr->get_tensor(slot, PRIV, TensorType::LOSS)); - } - - // variadic tensors - template - std::vector> - get_variadic_tensor(int slot) const { - return this->get_variadic_tensor(slot_id_t{slot}); - } - - template - std::vector> - get_variadic_tensor(slot_id_t slot) const { - return std::get>>( - this->ptr->get_variadic_tensor(slot, PRIV, TensorType::FORWARD)); - } - - template - std::vector> - get_variadic_tensor_grad(int slot) const { - return this->get_variadic_tensor_grad(slot_id_t{slot}); - } - - template - std::vector> - get_variadic_tensor_grad(slot_id_t slot) const { - return std::get>>( - this->ptr->get_variadic_tensor(slot, PRIV, TensorType::GRADIENT)); - } - - template - std::vector> - get_variadic_optimizer_tensor(int slot) const { - return this->get_variadic_optimizer_tensor(slot_id_t{slot}); - } - - template - std::vector> - get_variadic_optimizer_tensor(slot_id_t slot) const { - return std::get>>( - this->ptr->get_variadic_tensor(slot, PRIV, TensorType::OPTIMIZER)); - } - - template - std::vector> - get_variadic_loss_tensor(int slot) const { - return this->get_variadic_loss_tensor(slot_id_t{slot}); - } - - template - std::vector> - get_variadic_loss_tensor(slot_id_t slot) const { - return std::get>>( - this->ptr->get_variadic_tensor(slot, PRIV, TensorType::LOSS)); - } - - Allocator get_allocator() const { - return this->ptr->get_allocator(); - } - - template - static - typename std::enable_if::value, - TaskArgumentAccessor>::type - create(Args &&...args) { - return TaskArgumentAccessor( - std::make_shared(std::forward(args)...)); - } - -private: - TaskArgumentAccessor(std::shared_ptr ptr) - : ptr(ptr) {} - std::shared_ptr ptr; -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/task_argument_accessor/itask_argument_accessor.h b/lib/task-spec/include/task-spec/task_argument_accessor/itask_argument_accessor.h new file mode 100644 index 0000000000..8a8d741d90 --- /dev/null +++ b/lib/task-spec/include/task-spec/task_argument_accessor/itask_argument_accessor.h @@ -0,0 +1,48 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_TASK_ARGUMENT_ACCESSOR_ITASK_ARGUMENT_ACCESSOR_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_TASK_ARGUMENT_ACCESSOR_ITASK_ARGUMENT_ACCESSOR_H + +#include "kernels/allocation.h" +#include "kernels/device_handle_t.dtg.h" +#include "kernels/profiling_settings.dtg.h" +#include "op-attrs/ops/loss_functions/loss_attrs.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" +#include "pcg/device_id_t.dtg.h" +#include "pcg/optimizer_attrs.dtg.h" +#include "task-spec/concrete_arg_spec.h" +#include "task-spec/ff_iteration_config.dtg.h" +#include "task-spec/ops/arg_slot_id_t.dtg.h" +#include "task-spec/per_device_op_state.dtg.h" +#include "task-spec/privilege_tensor_accessor.h" +#include "task-spec/task_argument_accessor/task_tensor_parameter.dtg.h" +#include "task-spec/training_tensor_type.dtg.h" + +namespace FlexFlow { + +struct ITaskArgumentAccessor { + ITaskArgumentAccessor &operator=(ITaskArgumentAccessor const &) = delete; + + virtual ~ITaskArgumentAccessor() = default; + + virtual ConcreteArgSpec const &get_concrete_arg(arg_slot_id_t) const = 0; + + virtual GenericTensorAccessor get_tensor(TaskTensorParameter, + Permissions priv) const = 0; + + virtual ProfilingSettings get_profiling_settings() const = 0; + virtual device_handle_t get_ff_handle() const = 0; + virtual DeviceType get_kernel_device_type() const = 0; + virtual PCGOperatorAttrs get_op_attrs() const = 0; + virtual LossAttrs get_loss_attrs() const = 0; + virtual PerDeviceOpState get_per_device_op_state() const = 0; + virtual FFIterationConfig get_iteration_config() const = 0; + virtual OptimizerAttrs get_optimizer_attrs() const = 0; + + virtual Allocator get_allocator() const = 0; + virtual device_id_t get_device_idx() const = 0; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(ITaskArgumentAccessor); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/task_argument_accessor/task_argument_accessor.h b/lib/task-spec/include/task-spec/task_argument_accessor/task_argument_accessor.h new file mode 100644 index 0000000000..e350387684 --- /dev/null +++ b/lib/task-spec/include/task-spec/task_argument_accessor/task_argument_accessor.h @@ -0,0 +1,89 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_TASK_ARGUMENT_ACCESSOR_TASK_ARGUMENT_ACCESSOR_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_TASK_ARGUMENT_ACCESSOR_TASK_ARGUMENT_ACCESSOR_H + +#include "kernels/device_handle_t.dtg.h" +#include "kernels/profiling_settings.dtg.h" +#include "op-attrs/ops/loss_functions/loss_attrs.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" +#include "pcg/optimizer_attrs.dtg.h" +#include "pcg/optimizer_slot_name.dtg.h" +#include "task-spec/device_specific.h" +#include "task-spec/ff_iteration_config.dtg.h" +#include "task-spec/per_device_op_state.dtg.h" +#include "task-spec/task_argument_accessor/itask_argument_accessor.h" +#include "task-spec/task_argument_accessor/task_tensor_parameter.h" + +namespace FlexFlow { + +struct TaskArgumentAccessor { + ProfilingSettings get_profiling_settings() const; + device_handle_t get_ff_handle() const; + DeviceType get_kernel_device_type() const; + PCGOperatorAttrs get_op_attrs() const; + LossAttrs get_loss_attrs() const; + PerDeviceOpState get_per_device_op_state() const; + FFIterationConfig get_iteration_config() const; + OptimizerAttrs get_optimizer_attrs() const; + + TensorShape get_tensor_shape(TensorSlotName slot) const { + NOT_IMPLEMENTED(); + } + + template + privilege_mode_to_accessor get_tensor(TensorSlotName slot) const { + return std::get>( + this->ptr->get_tensor(make_task_tensor_parameter_fwd(slot), PRIV)); + } + + template + privilege_mode_to_accessor get_tensor_grad(TensorSlotName slot) const { + return std::get>( + this->ptr->get_tensor(make_task_tensor_parameter_grad(slot), PRIV)); + } + + template + privilege_mode_to_accessor + get_optimizer_tensor(TensorSlotName slot, + OptimizerSlotName opt_slot) const { + return std::get>(this->ptr->get_tensor( + make_task_tensor_parameter_opt(slot, opt_slot), PRIV)); + } + + template + privilege_mode_to_accessor get_loss_tensor() const { + return std::get>( + this->ptr->get_tensor(make_task_tensor_parameter_loss(), PRIV)); + } + + Allocator get_allocator() const { + return this->ptr->get_allocator(); + } + + device_id_t get_device_idx() const { + return this->ptr->get_device_idx(); + } + + template + DeviceSpecific make_device_specific(T const &t) const { + return DeviceSpecific::create(this->get_device_idx(), t); + } + + template + static + typename std::enable_if::value, + TaskArgumentAccessor>::type + create(Args &&...args) { + return TaskArgumentAccessor( + std::make_shared(std::forward(args)...)); + } + +private: + TaskArgumentAccessor(std::shared_ptr ptr) + : ptr(ptr) {} + std::shared_ptr ptr; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/task_argument_accessor/task_forward_tensor_parameter.dtg.toml b/lib/task-spec/include/task-spec/task_argument_accessor/task_forward_tensor_parameter.dtg.toml new file mode 100644 index 0000000000..230ce250fd --- /dev/null +++ b/lib/task-spec/include/task-spec/task_argument_accessor/task_forward_tensor_parameter.dtg.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "TaskForwardTensorParameter" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", + "rapidcheck", +] + +includes = [ + "op-attrs/tensor_slot_name.dtg.h", +] +src_includes = [] + +[[fields]] +name = "name" +type = "::FlexFlow::TensorSlotName" diff --git a/lib/task-spec/include/task-spec/task_argument_accessor/task_gradient_tensor_parameter.dtg.toml b/lib/task-spec/include/task-spec/task_argument_accessor/task_gradient_tensor_parameter.dtg.toml new file mode 100644 index 0000000000..c3ebb0ee96 --- /dev/null +++ b/lib/task-spec/include/task-spec/task_argument_accessor/task_gradient_tensor_parameter.dtg.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "TaskGradientTensorParameter" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", + "rapidcheck", +] + +includes = [ + "op-attrs/tensor_slot_name.dtg.h", +] + +src_includes = [] + +[[fields]] +name = "name" +type = "::FlexFlow::TensorSlotName" diff --git a/lib/task-spec/include/task-spec/task_argument_accessor/task_loss_tensor_parameter.dtg.toml b/lib/task-spec/include/task-spec/task_argument_accessor/task_loss_tensor_parameter.dtg.toml new file mode 100644 index 0000000000..6bf0e728ce --- /dev/null +++ b/lib/task-spec/include/task-spec/task_argument_accessor/task_loss_tensor_parameter.dtg.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "TaskLossTensorParameter" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", + "rapidcheck", +] + +includes = [] +src_includes = [] + +fields = [] diff --git a/lib/task-spec/include/task-spec/task_argument_accessor/task_optimizer_tensor_parameter.dtg.toml b/lib/task-spec/include/task-spec/task_argument_accessor/task_optimizer_tensor_parameter.dtg.toml new file mode 100644 index 0000000000..2957b00efe --- /dev/null +++ b/lib/task-spec/include/task-spec/task_argument_accessor/task_optimizer_tensor_parameter.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "TaskOptimizerTensorParameter" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", + "rapidcheck", +] + +includes = [ + "op-attrs/tensor_slot_name.dtg.h", + "pcg/optimizer_slot_name.dtg.h", +] + +src_includes = [ +] + +[[fields]] +name = "name" +type = "::FlexFlow::TensorSlotName" + +[[fields]] +name = "optimizer_slot" +type = "::FlexFlow::OptimizerSlotName" diff --git a/lib/task-spec/include/task-spec/task_argument_accessor/task_tensor_parameter.dtg.toml b/lib/task-spec/include/task-spec/task_argument_accessor/task_tensor_parameter.dtg.toml new file mode 100644 index 0000000000..4dd8374d19 --- /dev/null +++ b/lib/task-spec/include/task-spec/task_argument_accessor/task_tensor_parameter.dtg.toml @@ -0,0 +1,33 @@ +namespace = "FlexFlow" +name = "TaskTensorParameter" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "task-spec/task_argument_accessor/task_forward_tensor_parameter.dtg.h", + "task-spec/task_argument_accessor/task_gradient_tensor_parameter.dtg.h", + "task-spec/task_argument_accessor/task_optimizer_tensor_parameter.dtg.h", + "task-spec/task_argument_accessor/task_loss_tensor_parameter.dtg.h", +] + +[[values]] +type = "::FlexFlow::TaskForwardTensorParameter" +key = "forward" + +[[values]] +type = "::FlexFlow::TaskGradientTensorParameter" +key = "gradient" + +[[values]] +type = "::FlexFlow::TaskOptimizerTensorParameter" +key = "optimizer" + +[[values]] +type = "::FlexFlow::TaskLossTensorParameter" +key = "loss" diff --git a/lib/task-spec/include/task-spec/task_argument_accessor/task_tensor_parameter.h b/lib/task-spec/include/task-spec/task_argument_accessor/task_tensor_parameter.h new file mode 100644 index 0000000000..6f1e15c92c --- /dev/null +++ b/lib/task-spec/include/task-spec/task_argument_accessor/task_tensor_parameter.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_TASK_ARGUMENT_ACCESSOR_TASK_TENSOR_PARAMETER_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_TASK_ARGUMENT_ACCESSOR_TASK_TENSOR_PARAMETER_H + +#include "task-spec/task_argument_accessor/task_tensor_parameter.dtg.h" + +namespace FlexFlow { + +TaskTensorParameter make_task_tensor_parameter_fwd(TensorSlotName); +TaskTensorParameter make_task_tensor_parameter_grad(TensorSlotName); +TaskTensorParameter make_task_tensor_parameter_opt(TensorSlotName, + OptimizerSlotName); +TaskTensorParameter make_task_tensor_parameter_loss(); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/task_binding.h b/lib/task-spec/include/task-spec/task_binding.h deleted file mode 100644 index 4cc286e104..0000000000 --- a/lib/task-spec/include/task-spec/task_binding.h +++ /dev/null @@ -1,89 +0,0 @@ -#ifndef _FLEXFLOW_LOCAL_EXECUTION_TASK_BINDING_H -#define _FLEXFLOW_LOCAL_EXECUTION_TASK_BINDING_H - -#include "task-spec/loss_tensor_guid_t.dtg.h" -#include "task-spec/optimizer_tensor_guid_t.dtg.h" -#include "task-spec/slot_id_t.dtg.h" -#include "task-spec/task_arg_spec.dtg.h" -#include "task-spec/task_id_t.dtg.h" -#include "task-spec/task_signature.dtg.h" -#include "task-spec/tensor_sub_slot_id_t.dtg.h" -#include "task-spec/training_tensor_guid_t.dtg.h" - -namespace FlexFlow { - -struct TaskBinding { - TaskBinding(); - - explicit TaskBinding( - std::unordered_map const - &tensor_bindings, - std::unordered_map const &arg_bindings); - - void bind(int, forward_tensor_guid_t const &); - void bind(slot_id_t, forward_tensor_guid_t const &); - - void bind_grad(int, gradient_tensor_guid_t const &); - void bind_grad(slot_id_t, gradient_tensor_guid_t const &); - - void bind_optimizer(int, optimizer_tensor_guid_t const &); - void bind_optimizer(slot_id_t, optimizer_tensor_guid_t const &); - - void bind_loss(int, loss_tensor_guid_t const &); - void bind_loss(slot_id_t, loss_tensor_guid_t const &); - - template - void bind_arg(int name, T const &t) { - this->bind_arg(slot_id_t{name}, t); - } - - template - void bind_arg(slot_id_t name, T const &t) { - this->insert_arg_spec(name, TaskArgSpec{ConcreteArgSpec::create(t)}); - } - - template - void bind_arg(int name, RuntimeArgRef const &t) { - this->bind_arg(slot_id_t{name}, t); - } - - template - void bind_arg(slot_id_t name, RuntimeArgRef const &ref) { - this->insert_arg_spec(name, TaskArgSpec{RuntimeArgRefSpec::create(ref)}); - } - - bool operator==(TaskBinding const &other) const; - bool operator!=(TaskBinding const &other) const; - - std::unordered_map const & - get_tensor_bindings() const; - std::unordered_map const &get_arg_bindings() const; - void insert_arg_spec(slot_id_t name, TaskArgSpec const &arg_spec); - -private: - std::unordered_map - tensor_bindings; - std::unordered_map arg_bindings; - -private: - std::tuple - tie() const; - - friend ::std::hash; -}; - -std::string format_as(TaskBinding const &x); -std::ostream &operator<<(std::ostream &s, TaskBinding const &x); - -} // namespace FlexFlow - -namespace std { - -template <> -struct hash<::FlexFlow::TaskBinding> { - size_t operator()(::FlexFlow::TaskBinding const &s) const; -}; - -} // namespace std - -#endif diff --git a/lib/task-spec/include/task-spec/task_id_t.dtg.toml b/lib/task-spec/include/task-spec/task_id_t.dtg.toml new file mode 100644 index 0000000000..ce2de52d40 --- /dev/null +++ b/lib/task-spec/include/task-spec/task_id_t.dtg.toml @@ -0,0 +1,405 @@ +namespace = "FlexFlow" +name = "task_id_t" +type = "enum" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "TOP_LEVEL_TASK_ID" + +[[values]] +name = "FF_INIT_TASK_ID" + +[[values]] +name = "IMAGE_INIT_TASK_ID" + +[[values]] +name = "LABEL_INIT_TASK_ID" + +[[values]] +name = "LOAD_IMAGES_TASK_ID" + +[[values]] +name = "NORMALIZE_IMAGES_TASK_ID" + +[[values]] +name = "ELEMENTBINARY_INIT_TASK_ID" + +[[values]] +name = "ELEMENTBINARY_FWD_TASK_ID" + +[[values]] +name = "ELEMENTBINARY_BWD_TASK_ID" + +[[values]] +name = "ELEMENTUNARY_INIT_TASK_ID" + +[[values]] +name = "ELEMENTUNARY_FWD_TASK_ID" + +[[values]] +name = "ELEMENTUNARY_BWD_TASK_ID" + +[[values]] +name = "CONV2D_INIT_TASK_ID" + +[[values]] +name = "CONV2D_FWD_TASK_ID" + +[[values]] +name = "CONV2D_BWD_TASK_ID" + +[[values]] +name = "DROPOUT_INIT_TASK_ID" + +[[values]] +name = "DROPOUT_FWD_TASK_ID" + +[[values]] +name = "DROPOUT_BWD_TASK_ID" + +[[values]] +name = "EMBED_FWD_TASK_ID" + +[[values]] +name = "EMBED_BWD_TASK_ID" + +[[values]] +name = "GATHER_INIT_TASK_ID" + +[[values]] +name = "GATHER_FWD_TASK_ID" + +[[values]] +name = "GATHER_BWD_TASK_ID" + +[[values]] +name = "CAST_FWD_TASK_ID" + +[[values]] +name = "CAST_BWD_TASK_ID" + +[[values]] +name = "POOL2D_INIT_TASK_ID" + +[[values]] +name = "POOL2D_FWD_TASK_ID" + +[[values]] +name = "POOL2D_BWD_TASK_ID" + +[[values]] +name = "BATCHNORM_INIT_TASK_ID" + +[[values]] +name = "BATCHNORM_FWD_TASK_ID" + +[[values]] +name = "BATCHNORM_BWD_TASK_ID" + +[[values]] +name = "BATCHMATMUL_FWD_TASK_ID" + +[[values]] +name = "BATCHMATMUL_BWD_TASK_ID" + +[[values]] +name = "LAYERNORM_INIT_TASK_ID" + +[[values]] +name = "LAYERNORM_FWD_TASK_ID" + +[[values]] +name = "LAYERNORM_BWD_TASK_ID" + +[[values]] +name = "LINEAR_INIT_TASK_ID" + +[[values]] +name = "LINEAR_FWD_TASK_ID" + +[[values]] +name = "LINEAR_BWD_TASK_ID" + +[[values]] +name = "FLAT_FWD_TASK_ID" + +[[values]] +name = "FLAT_BWD_TASK_ID" + +[[values]] +name = "SOFTMAX_INIT_TASK_ID" + +[[values]] +name = "SOFTMAX_FWD_TASK_ID" + +[[values]] +name = "SOFTMAX_BWD_TASK_ID" + +[[values]] +name = "CONCAT_FWD_TASK_ID" + +[[values]] +name = "CONCAT_BWD_TASK_ID" + +[[values]] +name = "SPLIT_FWD_TASK_ID" + +[[values]] +name = "SPLIT_BWD_TASK_ID" + +[[values]] +name = "REDUCE_INIT_TASK_ID" + +[[values]] +name = "REDUCE_FWD_TASK_ID" + +[[values]] +name = "REDUCE_BWD_TASK_ID" + +[[values]] +name = "RESHAPE_FWD_TASK_ID" + +[[values]] +name = "RESHAPE_BWD_TASK_ID" + +[[values]] +name = "REVERSE_FWD_TASK_ID" + +[[values]] +name = "REVERSE_BWD_TASK_ID" + +[[values]] +name = "TOPK_FWD_TASK_ID" + +[[values]] +name = "TOPK_BWD_TASK_ID" + +[[values]] +name = "TRANSPOSE_FWD_TASK_ID" + +[[values]] +name = "TRANSPOSE_BWD_TASK_ID" + +[[values]] +name = "ATTENTION_INIT_TASK_ID" + +[[values]] +name = "ATTENTION_FWD_TASK_ID" + +[[values]] +name = "ATTENTION_BWD_TASK_ID" + +[[values]] +name = "BROADCAST_FWD_TASK_ID" + +[[values]] +name = "BROADCAST_BWD_TASK_ID" + +[[values]] +name = "MSELOSS_BWD_TASK_ID" + +[[values]] +name = "FUSEDOP_INIT_TASK_ID" + +[[values]] +name = "FUSEDOP_FWD_TASK_ID" + +[[values]] +name = "FUSEDOP_BWD_TASK_ID" + +[[values]] +name = "METRICS_COMP_TASK_ID" + +[[values]] +name = "UPDATE_METRICS_TASK_ID" + +[[values]] +name = "PS_PREFETCH_TASK_ID" + +[[values]] +name = "LOSS_BWD_TASK_ID" + +[[values]] +name = "SGD_UPD_PS_TASK_ID" + +[[values]] +name = "ADAM_UPD_PS_TASK_ID" + +[[values]] +name = "SGD_UPD_NCCL_TASK_ID" + +[[values]] +name = "ADAM_UPD_NCCL_TASK_ID" + +[[values]] +name = "GLOROT_INIT_TASK_ID" + +[[values]] +name = "ZERO_INIT_TASK_ID" + +[[values]] +name = "CONSTANT_INIT_TASK_ID" + +[[values]] +name = "UNIFORM_INIT_TASK_ID" + +[[values]] +name = "NORMAL_INIT_TASK_ID" + +[[values]] +name = "NCCL_GETUNIQUEID_TASK_ID" + +[[values]] +name = "NCCL_INIT_COMMS_TASK_ID" + +[[values]] +name = "STRATEGY_SEARCH_TASK_ID" + +[[values]] +name = "GRAPH_OPTIMIZE_TASK_ID" + +[[values]] +name = "PY_DL_FLOAT_LOAD_ENTIRE_CPU_TASK_ID" + +[[values]] +name = "PY_DL_INT32_LOAD_ENTIRE_CPU_TASK_ID" + +[[values]] +name = "PY_DL_INT64_LOAD_ENTIRE_CPU_TASK_ID" + +[[values]] +name = "PY_DL_FLOAT_INDEX_LOAD_ENTIRE_CPU_TASK_ID" + +[[values]] +name = "PY_DL_INT32_INDEX_LOAD_ENTIRE_CPU_TASK_ID" + +[[values]] +name = "PY_DL_INT64_INDEX_LOAD_ENTIRE_CPU_TASK_ID" + +[[values]] +name = "PY_DL_FLOAT_LOAD_BATCH_GPU_TASK_ID" + +[[values]] +name = "PY_DL_INT32_LOAD_BATCH_GPU_TASK_ID" + +[[values]] +name = "PY_DL_INT64_LOAD_BATCH_GPU_TASK_ID" + +[[values]] +name = "REPARTITION_INIT_TASK_ID" + +[[values]] +name = "REPARTITION_FWD_TASK_ID" + +[[values]] +name = "REPARTITION_BWD_TASK_ID" + +[[values]] +name = "COMBINE_INIT_TASK_ID" + +[[values]] +name = "COMBINE_FWD_TASK_ID" + +[[values]] +name = "COMBINE_BWD_TASK_ID" + +[[values]] +name = "REPLICATE_INIT_TASK_ID" + +[[values]] +name = "REPLICATE_FWD_TASK_ID" + +[[values]] +name = "REPLICATE_BWD_TASK_ID" + +[[values]] +name = "REDUCTION_INIT_TASK_ID" + +[[values]] +name = "REDUCTION_FWD_TASK_ID" + +[[values]] +name = "REDUCTION_BWD_TASK_ID" + +[[values]] +name = "PIPELINE_INIT_TASK_ID" + +[[values]] +name = "PIPELINE_FWD_TASK_ID" + +[[values]] +name = "PIPELINE_BWD_TASK_ID" + +[[values]] +name = "FUSED_PARALLELOP_INIT_TASK_ID" + +[[values]] +name = "FUSED_PARALLELOP_FWD_TASK_ID" + +[[values]] +name = "FUSED_PARALLELOP_BWD_TASK_ID" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_FIRST" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_1" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_2" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_3" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_4" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_5" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_6" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_7" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_8" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_LAST" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_FIRST" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_1" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_2" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_3" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_4" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_5" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_6" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_7" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_LAST" + +[[values]] +name = "PYTHON_TOP_LEVEL_TASK_ID" diff --git a/lib/task-spec/include/task-spec/task_id_t.enum.toml b/lib/task-spec/include/task-spec/task_id_t.enum.toml deleted file mode 100644 index 2e8f0a0046..0000000000 --- a/lib/task-spec/include/task-spec/task_id_t.enum.toml +++ /dev/null @@ -1,419 +0,0 @@ -namespace = "FlexFlow" -name = "task_id_t" -features = [ - "hash", - "fmt", - "rapidcheck", - "json", -] - -[[values]] -name = "TOP_LEVEL_TASK_ID" - -[[values]] -name = "FF_INIT_TASK_ID" - -[[values]] -name = "IMAGE_INIT_TASK_ID" - -[[values]] -name = "LABEL_INIT_TASK_ID" - -[[values]] -name = "LOAD_IMAGES_TASK_ID" - -[[values]] -name = "NORMALIZE_IMAGES_TASK_ID" - -[[values]] -name = "ELEMENTBINARY_INIT_TASK_ID" - -[[values]] -name = "ELEMENTBINARY_FWD_TASK_ID" - -[[values]] -name = "ELEMENTBINARY_BWD_TASK_ID" - -[[values]] -name = "ELEMENTUNARY_INIT_TASK_ID" - -[[values]] -name = "ELEMENTUNARY_FWD_TASK_ID" - -[[values]] -name = "ELEMENTUNARY_BWD_TASK_ID" - -[[values]] -name = "CONV2D_INIT_TASK_ID" - -[[values]] -name = "CONV2D_FWD_TASK_ID" - -[[values]] -name = "CONV2D_BWD_TASK_ID" - -[[values]] -name = "DROPOUT_INIT_TASK_ID" - -[[values]] -name = "DROPOUT_FWD_TASK_ID" - -[[values]] -name = "DROPOUT_BWD_TASK_ID" - -[[values]] -name = "EMBED_INIT_TASK_ID" - -[[values]] -name = "EMBED_FWD_TASK_ID" - -[[values]] -name = "EMBED_BWD_TASK_ID" - -[[values]] -name = "GATHER_INIT_TASK_ID" - -[[values]] -name = "GATHER_FWD_TASK_ID" - -[[values]] -name = "GATHER_BWD_TASK_ID" - -[[values]] -name = "CAST_INIT_TASK_ID" - -[[values]] -name = "CAST_FWD_TASK_ID" - -[[values]] -name = "CAST_BWD_TASK_ID" - -[[values]] -name = "POOL2D_INIT_TASK_ID" - -[[values]] -name = "POOL2D_FWD_TASK_ID" - -[[values]] -name = "POOL2D_BWD_TASK_ID" - -[[values]] -name = "BATCHNORM_INIT_TASK_ID" - -[[values]] -name = "BATCHNORM_FWD_TASK_ID" - -[[values]] -name = "BATCHNORM_BWD_TASK_ID" - -[[values]] -name = "BATCHMATMUL_FWD_TASK_ID" - -[[values]] -name = "BATCHMATMUL_BWD_TASK_ID" - -[[values]] -name = "LAYERNORM_INIT_TASK_ID" - -[[values]] -name = "LAYERNORM_FWD_TASK_ID" - -[[values]] -name = "LAYERNORM_BWD_TASK_ID" - -[[values]] -name = "LINEAR_INIT_TASK_ID" - -[[values]] -name = "LINEAR_FWD_TASK_ID" - -[[values]] -name = "LINEAR_BWD_TASK_ID" - -[[values]] -name = "FLAT_INIT_TASK_ID" - -[[values]] -name = "FLAT_FWD_TASK_ID" - -[[values]] -name = "FLAT_BWD_TASK_ID" - -[[values]] -name = "SOFTMAX_INIT_TASK_ID" - -[[values]] -name = "SOFTMAX_FWD_TASK_ID" - -[[values]] -name = "SOFTMAX_BWD_TASK_ID" - -[[values]] -name = "CONCAT_INIT_TASK_ID" - -[[values]] -name = "CONCAT_FWD_TASK_ID" - -[[values]] -name = "CONCAT_BWD_TASK_ID" - -[[values]] -name = "SPLIT_INIT_TASK_ID" - -[[values]] -name = "SPLIT_FWD_TASK_ID" - -[[values]] -name = "SPLIT_BWD_TASK_ID" - -[[values]] -name = "REDUCE_INIT_TASK_ID" - -[[values]] -name = "REDUCE_FWD_TASK_ID" - -[[values]] -name = "REDUCE_BWD_TASK_ID" - -[[values]] -name = "RESHAPE_FWD_TASK_ID" - -[[values]] -name = "RESHAPE_BWD_TASK_ID" - -[[values]] -name = "REVERSE_INIT_TASK_ID" - -[[values]] -name = "REVERSE_FWD_TASK_ID" - -[[values]] -name = "REVERSE_BWD_TASK_ID" - -[[values]] -name = "TOPK_FWD_TASK_ID" - -[[values]] -name = "TOPK_BWD_TASK_ID" - -[[values]] -name = "TRANSPOSE_FWD_TASK_ID" - -[[values]] -name = "TRANSPOSE_BWD_TASK_ID" - -[[values]] -name = "ATTENTION_INIT_TASK_ID" - -[[values]] -name = "ATTENTION_FWD_TASK_ID" - -[[values]] -name = "ATTENTION_BWD_TASK_ID" - -[[values]] -name = "MSELOSS_BWD_TASK_ID" - -[[values]] -name = "FUSEDOP_INIT_TASK_ID" - -[[values]] -name = "FUSEDOP_FWD_TASK_ID" - -[[values]] -name = "FUSEDOP_BWD_TASK_ID" - -[[values]] -name = "NOOP_INIT_TASK_ID" - -[[values]] -name = "METRICS_COMP_TASK_ID" - -[[values]] -name = "UPDATE_METRICS_TASK_ID" - -[[values]] -name = "PS_PREFETCH_TASK_ID" - -[[values]] -name = "LOSS_BWD_TASK_ID" - -[[values]] -name = "SGD_UPD_PS_TASK_ID" - -[[values]] -name = "ADAM_UPD_PS_TASK_ID" - -[[values]] -name = "SGD_UPD_NCCL_TASK_ID" - -[[values]] -name = "ADAM_UPD_NCCL_TASK_ID" - -[[values]] -name = "GLOROT_INIT_TASK_ID" - -[[values]] -name = "ZERO_INIT_TASK_ID" - -[[values]] -name = "CONSTANT_INIT_TASK_ID" - -[[values]] -name = "UNIFORM_INIT_TASK_ID" - -[[values]] -name = "NORMAL_INIT_TASK_ID" - -[[values]] -name = "NCCL_GETUNIQUEID_TASK_ID" - -[[values]] -name = "NCCL_INIT_COMMS_TASK_ID" - -[[values]] -name = "STRATEGY_SEARCH_TASK_ID" - -[[values]] -name = "GRAPH_OPTIMIZE_TASK_ID" - -[[values]] -name = "PY_DL_FLOAT_LOAD_ENTIRE_CPU_TASK_ID" - -[[values]] -name = "PY_DL_INT32_LOAD_ENTIRE_CPU_TASK_ID" - -[[values]] -name = "PY_DL_INT64_LOAD_ENTIRE_CPU_TASK_ID" - -[[values]] -name = "PY_DL_FLOAT_INDEX_LOAD_ENTIRE_CPU_TASK_ID" - -[[values]] -name = "PY_DL_INT32_INDEX_LOAD_ENTIRE_CPU_TASK_ID" - -[[values]] -name = "PY_DL_INT64_INDEX_LOAD_ENTIRE_CPU_TASK_ID" - -[[values]] -name = "PY_DL_FLOAT_LOAD_BATCH_GPU_TASK_ID" - -[[values]] -name = "PY_DL_INT32_LOAD_BATCH_GPU_TASK_ID" - -[[values]] -name = "PY_DL_INT64_LOAD_BATCH_GPU_TASK_ID" - -[[values]] -name = "REPARTITION_INIT_TASK_ID" - -[[values]] -name = "REPARTITION_FWD_TASK_ID" - -[[values]] -name = "REPARTITION_BWD_TASK_ID" - -[[values]] -name = "COMBINE_INIT_TASK_ID" - -[[values]] -name = "COMBINE_FWD_TASK_ID" - -[[values]] -name = "COMBINE_BWD_TASK_ID" - -[[values]] -name = "REPLICATE_INIT_TASK_ID" - -[[values]] -name = "REPLICATE_FWD_TASK_ID" - -[[values]] -name = "REPLICATE_BWD_TASK_ID" - -[[values]] -name = "REDUCTION_INIT_TASK_ID" - -[[values]] -name = "REDUCTION_FWD_TASK_ID" - -[[values]] -name = "REDUCTION_BWD_TASK_ID" - -[[values]] -name = "PIPELINE_INIT_TASK_ID" - -[[values]] -name = "PIPELINE_FWD_TASK_ID" - -[[values]] -name = "PIPELINE_BWD_TASK_ID" - -[[values]] -name = "FUSED_PARALLELOP_INIT_TASK_ID" - -[[values]] -name = "FUSED_PARALLELOP_FWD_TASK_ID" - -[[values]] -name = "FUSED_PARALLELOP_BWD_TASK_ID" - -[[values]] -name = "CUSTOM_GPU_TASK_ID_FIRST" - -[[values]] -name = "CUSTOM_GPU_TASK_ID_1" - -[[values]] -name = "CUSTOM_GPU_TASK_ID_2" - -[[values]] -name = "CUSTOM_GPU_TASK_ID_3" - -[[values]] -name = "CUSTOM_GPU_TASK_ID_4" - -[[values]] -name = "CUSTOM_GPU_TASK_ID_5" - -[[values]] -name = "CUSTOM_GPU_TASK_ID_6" - -[[values]] -name = "CUSTOM_GPU_TASK_ID_7" - -[[values]] -name = "CUSTOM_GPU_TASK_ID_8" - -[[values]] -name = "CUSTOM_GPU_TASK_ID_LAST" - -[[values]] -name = "CUSTOM_CPU_TASK_ID_FIRST" - -[[values]] -name = "CUSTOM_CPU_TASK_ID_1" - -[[values]] -name = "CUSTOM_CPU_TASK_ID_2" - -[[values]] -name = "CUSTOM_CPU_TASK_ID_3" - -[[values]] -name = "CUSTOM_CPU_TASK_ID_4" - -[[values]] -name = "CUSTOM_CPU_TASK_ID_5" - -[[values]] -name = "CUSTOM_CPU_TASK_ID_6" - -[[values]] -name = "CUSTOM_CPU_TASK_ID_7" - -[[values]] -name = "CUSTOM_CPU_TASK_ID_LAST" - -[[values]] -name = "PYTHON_TOP_LEVEL_TASK_ID" diff --git a/lib/task-spec/include/task-spec/task_id_with_noop_default_t.dtg.toml b/lib/task-spec/include/task-spec/task_id_with_noop_default_t.dtg.toml new file mode 100644 index 0000000000..50349d5773 --- /dev/null +++ b/lib/task-spec/include/task-spec/task_id_with_noop_default_t.dtg.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "task_id_with_noop_default_t" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "fmt", + "rapidcheck", +] + +includes = [ + "task-spec/task_id_t.dtg.h", + "", +] + +src_includes = [ + "utils/rapidcheck/monostate.h", + "utils/fmt/monostate.h", +] + +[[values]] +type = "::FlexFlow::task_id_t" +key = "real_task" + +[[values]] +type = "std::monostate" +key = "noop_task" diff --git a/lib/task-spec/include/task-spec/task_id_with_noop_default_t.h b/lib/task-spec/include/task-spec/task_id_with_noop_default_t.h new file mode 100644 index 0000000000..054b73844e --- /dev/null +++ b/lib/task-spec/include/task-spec/task_id_with_noop_default_t.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_TASK_ID_WITH_NOOP_DEFAULT_T_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_TASK_ID_WITH_NOOP_DEFAULT_T_H + +#include "op-attrs/computation_graph_op_attrs.dtg.h" +#include "op-attrs/operator_type.dtg.h" +#include "task-spec/ops/op_task_id_t.dtg.h" +#include "task-spec/task_id_with_noop_default_t.dtg.h" + +namespace FlexFlow { + +task_id_with_noop_default_t lift_task_id_t(task_id_t); +task_id_with_noop_default_t default_noop_task(); + +task_id_with_noop_default_t lower_op_task_id_to_task_id_with_noop_default_t( + op_task_id_t, ComputationGraphOpAttrs const &); + +task_id_with_noop_default_t + get_init_task_id_for_op_attrs(ComputationGraphOpAttrs const &); + +task_id_with_noop_default_t + get_fwd_task_id_for_op_attrs(ComputationGraphOpAttrs const &); + +task_id_with_noop_default_t + get_bwd_task_id_for_op_attrs(ComputationGraphOpAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/task_impl_function.dtg.toml b/lib/task-spec/include/task-spec/task_impl_function.dtg.toml new file mode 100644 index 0000000000..50f15d7bee --- /dev/null +++ b/lib/task-spec/include/task-spec/task_impl_function.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "TaskImplFunction" +type = "variant" +features = [ + "eq", + "fmt", + "hash", + "ord" +] + +includes = [ + "task-spec/init_op_task_impl_function.h", + "task-spec/fwd_bwd_op_task_impl_function.h", + "task-spec/generic_task_impl_function.h", +] + +[[values]] +type = "::FlexFlow::InitOpTaskImplFunction" +key = "init_op_task_impl_function" + +[[values]] +type = "::FlexFlow::FwdBwdOpTaskImplFunction" +key = "fwd_bwd_op_task_impl_function" + +[[values]] +type = "::FlexFlow::GenericTaskImplFunction" +key = "generic_task_impl_function" diff --git a/lib/task-spec/include/task-spec/task_impl_function.variant.toml b/lib/task-spec/include/task-spec/task_impl_function.variant.toml deleted file mode 100644 index 74347a3290..0000000000 --- a/lib/task-spec/include/task-spec/task_impl_function.variant.toml +++ /dev/null @@ -1,26 +0,0 @@ -namespace = "FlexFlow" -name = "TaskImplFunction" -features = [ - "eq", - "fmt", - "hash", - "ord" -] - -includes = [ - "task-spec/init_op_task_impl_function.h", - "task-spec/fwd_bwd_op_task_impl_function.h", - "task-spec/generic_task_impl_function.h", -] - -[[values]] -type = "::FlexFlow::InitOpTaskImplFunction" -key = "init_op_task_impl_function" - -[[values]] -type = "::FlexFlow::FwdBwdOpTaskImplFunction" -key = "fwd_bwd_op_task_impl_function" - -[[values]] -type = "::FlexFlow::GenericTaskImplFunction" -key = "generic_task_impl_function" diff --git a/lib/task-spec/include/task-spec/task_invocation.h b/lib/task-spec/include/task-spec/task_invocation.h deleted file mode 100644 index 85940091a1..0000000000 --- a/lib/task-spec/include/task-spec/task_invocation.h +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef _FLEXFLOW_LOCAL_EXECUTION_TASK_INVOCATION_H -#define _FLEXFLOW_LOCAL_EXECUTION_TASK_INVOCATION_H - -#include "task-spec/task_invocation.dtg.h" - -namespace FlexFlow { - -bool is_invocation_valid(TaskSignature const &sig, TaskInvocation const &inv); - -} - -#endif diff --git a/lib/task-spec/include/task-spec/task_invocation.struct.toml b/lib/task-spec/include/task-spec/task_invocation.struct.toml deleted file mode 100644 index 38e02a1370..0000000000 --- a/lib/task-spec/include/task-spec/task_invocation.struct.toml +++ /dev/null @@ -1,21 +0,0 @@ -namespace = "FlexFlow" -name = "TaskInvocation" -features = [ - "eq", - "fmt", - "hash" -] - -includes = [ - "task-spec/task_binding.h", - "task-spec/task_id_t.dtg.h" -] - - -[[fields]] -name = "task_id" -type = "::FlexFlow::task_id_t" - -[[fields]] -name = "binding" -type = "::FlexFlow::TaskBinding" diff --git a/lib/task-spec/include/task-spec/task_signature.h b/lib/task-spec/include/task-spec/task_signature.h deleted file mode 100644 index 8214e7e1b5..0000000000 --- a/lib/task-spec/include/task-spec/task_signature.h +++ /dev/null @@ -1,60 +0,0 @@ -#ifndef _FLEXFLOW_LOCAL_EXECUTION_TASK_SIGNATURE_H -#define _FLEXFLOW_LOCAL_EXECUTION_TASK_SIGNATURE_H - -#include "task-spec/task_signature.dtg.h" -#include "utils/type_index.h" - -namespace FlexFlow { - -TaskSignature make_empty_task_signature(); - -void add_slot(TaskSignature &, - int name, - TensorType, - SlotType slot_type = SlotType::TENSOR); -void add_slot(TaskSignature &, - slot_id_t name, - TensorType, - SlotType slot_type = SlotType::TENSOR); - -template -void add_arg_slot(TaskSignature &task_signature, int name) { - add_arg_slot(task_signature, slot_id_t{name}); -} - -template -void add_arg_slot(TaskSignature &task_signature, slot_id_t name) { - // static_assert(is_serializable::value, "Type must be serializable"); - task_signature.task_arg_types.insert({name, get_type_index_for_type()}); -} - -template -void add_return_value(TaskSignature &task_signature) { - task_signature.return_value = get_type_index_for_type(); -} - -/** - * @brief Adds an argument slot without checking if it is serializable. - * - * This function is used for arguments that are device-specific. - */ - -template -void add_unchecked_arg_slot(TaskSignature &task_signature, int name) { - add_unchecked_arg_slot(task_signature, slot_id_t{name}); -} - -/** - * @brief Adds an argument slot without checking if it is serializable. - * - * This function is used for arguments that are device-specific. - */ - -template -void add_unchecked_arg_slot(TaskSignature &task_signature, slot_id_t name) { - task_signature.task_arg_types.insert({name, get_type_index_for_type()}); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/task_signature.struct.toml b/lib/task-spec/include/task-spec/task_signature.struct.toml deleted file mode 100644 index 3df0a8cfc7..0000000000 --- a/lib/task-spec/include/task-spec/task_signature.struct.toml +++ /dev/null @@ -1,33 +0,0 @@ -namespace = "FlexFlow" -name = "TaskSignature" -features = [ - "eq", - "fmt", - "hash" -] - -includes = [ - "task-spec/tensor_type_slot_spec.dtg.h", - "task-spec/slot_id_t.dtg.h", - "", - "" -] - -src_includes = [ - "utils/fmt/unordered_map.h", - "utils/hash/unordered_map.h", - "utils/fmt/optional.h", - "utils/type_index.h" -] - -[[fields]] -name = "return_value" -type = "std::optional" - -[[fields]] -name = "task_arg_types" -type = "std::unordered_map<::FlexFlow::slot_id_t, std::type_index>" - -[[fields]] -name = "tensor_guid_slots" -type = "std::unordered_map<::FlexFlow::slot_id_t, ::FlexFlow::TensorTypeSlotSpec>" diff --git a/lib/task-spec/include/task-spec/task_signature_impl.h b/lib/task-spec/include/task-spec/task_signature_impl.h deleted file mode 100644 index a781e53485..0000000000 --- a/lib/task-spec/include/task-spec/task_signature_impl.h +++ /dev/null @@ -1,22 +0,0 @@ -#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_TASK_SIGNATURE_IMPL_H -#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_TASK_SIGNATURE_IMPL_H - -#include "op-attrs/computation_graph_op_attrs.h" -#include "task-spec/op_task_invocation.h" -#include "task-spec/task_id_t.dtg.h" -#include "task-spec/task_signature_impl.dtg.h" - -namespace FlexFlow { - -TaskSignatureAndImpl get_task_signature_and_impl_for_task_id(task_id_t const &); -std::vector get_task_ids(ComputationGraphOpAttrs const &); - -OpTaskInvocation get_init_op_task_invocation(ComputationGraphOpAttrs const &); -OpTaskInvocation - get_forward_op_task_invocation(ComputationGraphOpAttrs const &); -OpTaskInvocation - get_backward_op_task_invocation(ComputationGraphOpAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/task_signature_impl.struct.toml b/lib/task-spec/include/task-spec/task_signature_impl.struct.toml deleted file mode 100644 index 574f11a084..0000000000 --- a/lib/task-spec/include/task-spec/task_signature_impl.struct.toml +++ /dev/null @@ -1,20 +0,0 @@ -namespace = "FlexFlow" -name = "TaskSignatureAndImpl" -features = [ - "eq", - "fmt", - "hash" -] - -includes = [ - "task-spec/task_impl_function.dtg.h", - "task-spec/op_task_signature.h", -] - -[[fields]] -name = "impl_function" -type = "::FlexFlow::TaskImplFunction" - -[[fields]] -name = "task_signature" -type = "::FlexFlow::OpTaskSignature" diff --git a/lib/task-spec/include/task-spec/tensor_sub_slot_id_t.struct.toml b/lib/task-spec/include/task-spec/tensor_sub_slot_id_t.struct.toml deleted file mode 100644 index a830725a27..0000000000 --- a/lib/task-spec/include/task-spec/tensor_sub_slot_id_t.struct.toml +++ /dev/null @@ -1,21 +0,0 @@ -namespace = "FlexFlow" -name = "tensor_sub_slot_id_t" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "task-spec/tensor_type.dtg.h", - "task-spec/slot_id_t.dtg.h", -] - -[[fields]] -name = "slot_id" -type = "::FlexFlow::slot_id_t" - -[[fields]] -name = "tensor_type" -type = "::FlexFlow::TensorType" diff --git a/lib/task-spec/include/task-spec/tensor_type.enum.toml b/lib/task-spec/include/task-spec/tensor_type.enum.toml deleted file mode 100644 index b1ae8fa667..0000000000 --- a/lib/task-spec/include/task-spec/tensor_type.enum.toml +++ /dev/null @@ -1,20 +0,0 @@ -namespace = "FlexFlow" -name = "TensorType" -features = [ - "hash", - "fmt", - "rapidcheck", - "json", -] - -[[values]] -name = "LOSS" - -[[values]] -name = "FORWARD" - -[[values]] -name = "GRADIENT" - -[[values]] -name = "OPTIMIZER" diff --git a/lib/task-spec/include/task-spec/tensor_type_slot_spec.dtg.toml b/lib/task-spec/include/task-spec/tensor_type_slot_spec.dtg.toml new file mode 100644 index 0000000000..0ade832c3c --- /dev/null +++ b/lib/task-spec/include/task-spec/tensor_type_slot_spec.dtg.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "TensorTypeSlotSpec" +type = "struct" +features = [ + "eq", + "fmt", + "hash", + "ord", +] + +includes = [ + "op-attrs/tensor_slot_name.dtg.h", + "task-spec/training_tensor_type.dtg.h", +] + +[[fields]] +name = "name" +type = "::FlexFlow::TensorSlotName" + +[[fields]] +name = "tensor_type" +type = "::FlexFlow::TrainingTensorType" diff --git a/lib/task-spec/include/task-spec/tensor_type_slot_spec.struct.toml b/lib/task-spec/include/task-spec/tensor_type_slot_spec.struct.toml deleted file mode 100644 index 26e70a5ef8..0000000000 --- a/lib/task-spec/include/task-spec/tensor_type_slot_spec.struct.toml +++ /dev/null @@ -1,26 +0,0 @@ -namespace = "FlexFlow" -name = "TensorTypeSlotSpec" -features = [ - "eq", - "fmt", - "hash", - "ord", -] - -includes = [ - "task-spec/slot_type.dtg.h", - "task-spec/slot_id_t.dtg.h", - "task-spec/tensor_type.dtg.h", -] - -[[fields]] -name = "slot_id" -type = "::FlexFlow::slot_id_t" - -[[fields]] -name = "tensor_type" -type = "::FlexFlow::TensorType" - -[[fields]] -name = "slot_type" -type = "::FlexFlow::SlotType" diff --git a/lib/task-spec/include/task-spec/training_computation_graph.h b/lib/task-spec/include/task-spec/training_computation_graph.h deleted file mode 100644 index 1cda57a49e..0000000000 --- a/lib/task-spec/include/task-spec/training_computation_graph.h +++ /dev/null @@ -1,68 +0,0 @@ -#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_TRAINING_COMPUTATION_GRAPH_H -#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_TRAINING_COMPUTATION_GRAPH_H - -#include "pcg/optimizer_attrs.dtg.h" -#include "task-spec/forward_tensor_source.h" -#include "task-spec/gradient_tensor_source.h" -#include "task-spec/loss_tensor_source.h" -#include "task-spec/optimizer_tensor_source.h" -#include "task-spec/training_computation_graph.dtg.h" -#include "task-spec/training_layer_plus_context.dtg.h" -#include "task-spec/training_tensor_guid_t.dtg.h" - -namespace FlexFlow { - -TrainingComputationGraph generate_training_computation_graph( - ComputationGraph const &computation_graph, - OptimizerAttrs const &optimizer_attrs, - tensor_guid_t const &logit_tensor, - ForwardTensorSource &forward_tensor_source, - GradientTensorSource &gradient_tensor_source, - OptimizerTensorSource &optimizer_tensor_source, - LossTensorSource &loss_tensor_source); - -TrainingTensorGroup - get_training_tensor_group_for_tensor_guid(TrainingComputationGraph const &, - tensor_guid_t); -TrainingTensorGroupWithAttrs - get_training_tensor_group_with_attrs_for_tensor_guid( - TrainingComputationGraph const &, tensor_guid_t); - -forward_tensor_guid_t - get_forward_tensor_guid_for_tensor_guid(TrainingComputationGraph const &, - tensor_guid_t); -gradient_tensor_guid_t - get_gradient_tensor_guid_for_tensor_guid(TrainingComputationGraph const &, - tensor_guid_t); -std::vector - get_optimizer_tensor_guids_for_tensor_guid(TrainingComputationGraph const &, - tensor_guid_t); - -tensor_guid_t - get_tensor_guid_for_forward_tensor_guid(TrainingComputationGraph const &, - forward_tensor_guid_t); -tensor_guid_t - get_tensor_guid_for_gradient_tensor_guid(TrainingComputationGraph const &, - gradient_tensor_guid_t); -tensor_guid_t - get_tensor_guid_for_optimizer_tensor_guid(TrainingComputationGraph const &, - optimizer_tensor_guid_t); - -tensor_guid_t - get_tensor_guid_for_training_tensor_guid(TrainingComputationGraph const &, - training_tensor_guid_t); - -std::unordered_set - get_all_training_tensors_in_training_computation_graph( - TrainingComputationGraph const &); - -TrainingLayerPlusContext - get_training_layer_plus_context(TrainingComputationGraph const &, - layer_guid_t); - -std::unordered_map - get_all_training_tensor_shapes(TrainingComputationGraph const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/training_computation_graph.struct.toml b/lib/task-spec/include/task-spec/training_computation_graph.struct.toml deleted file mode 100644 index 1e294df7eb..0000000000 --- a/lib/task-spec/include/task-spec/training_computation_graph.struct.toml +++ /dev/null @@ -1,27 +0,0 @@ -namespace = "FlexFlow" -name = "TrainingComputationGraph" -features = [] - -includes = [ - "pcg/computation_graph.h", - "", - "pcg/tensor_guid_t.dtg.h", - "task-spec/training_tensor_group.dtg.h", - "task-spec/loss_tensor_guid_t.dtg.h", -] - -[[fields]] -name = "computation_graph" -type = "::FlexFlow::ComputationGraph" - -[[fields]] -name = "training_tensor_group_for_tensor" -type = "std::unordered_map" - -[[fields]] -name = "logit_tensor" -type = "::FlexFlow::tensor_guid_t" - -[[fields]] -name = "label_tensor" -type = "::FlexFlow::loss_tensor_guid_t" diff --git a/lib/task-spec/include/task-spec/training_layer_plus_context.h b/lib/task-spec/include/task-spec/training_layer_plus_context.h deleted file mode 100644 index 4ce1ddf1a9..0000000000 --- a/lib/task-spec/include/task-spec/training_layer_plus_context.h +++ /dev/null @@ -1,50 +0,0 @@ -#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_TRAINING_LAYER_PLUS_CONTEXT_H -#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_TRAINING_LAYER_PLUS_CONTEXT_H - -#include "pcg/cg_operator_tensor_shape_signature.dtg.h" -#include "pcg/tensor_role.dtg.h" -#include "task-spec/training_layer_plus_context.dtg.h" -#include "task-spec/training_layer_tensor_group_signature.dtg.h" - -namespace FlexFlow { - -std::vector - get_training_tensor_groups_with_attrs_for_role( - TrainingLayerPlusContext const &training_layer_plus_context, - TensorRole tensor_role); - -TrainingTensorGroupWithAttrs - get_training_tensor_group_with_attrs_for_role_and_index( - TrainingLayerPlusContext const &training_layer_plus_context, - TensorRole tensor_role, - nonnegative_int index); - -std::vector - get_input_tensors(TrainingLayerPlusContext const &); -std::vector - get_input_grad_tensors(TrainingLayerPlusContext const &); -std::vector - get_input_tensor_shapes(TrainingLayerPlusContext const &); - -std::vector - get_weight_tensors(TrainingLayerPlusContext const &); -std::vector - get_weight_grad_tensors(TrainingLayerPlusContext const &); -std::vector - get_weight_tensor_shapes(TrainingLayerPlusContext const &); - -std::vector - get_output_tensors(TrainingLayerPlusContext const &); -std::vector - get_output_grad_tensors(TrainingLayerPlusContext const &); -std::vector - get_output_tensor_shapes(TrainingLayerPlusContext const &); - -TrainingLayerTensorGroupSignature - get_tensor_group_signature(TrainingLayerPlusContext const &); -CGOperatorTensorShapeSignature - get_cg_op_shape_signature(TrainingLayerPlusContext const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/training_layer_plus_context.struct.toml b/lib/task-spec/include/task-spec/training_layer_plus_context.struct.toml deleted file mode 100644 index 9090059351..0000000000 --- a/lib/task-spec/include/task-spec/training_layer_plus_context.struct.toml +++ /dev/null @@ -1,29 +0,0 @@ -namespace = "FlexFlow" -name = "TrainingLayerPlusContext" -features = [] - -includes = [ - "pcg/layer_guid_t.dtg.h", - "pcg/layer_attrs.dtg.h", - "task-spec/training_tensor_group_with_attrs.dtg.h", -] - -[[fields]] -name = "layer_guid" -type = "::FlexFlow::layer_guid_t" - -[[fields]] -name = "layer_attrs" -type = "::FlexFlow::LayerAttrs" - -[[fields]] -name = "input_tensor_groups" -type = "std::vector<::FlexFlow::TrainingTensorGroupWithAttrs>" - -[[fields]] -name = "weight_tensor_groups" -type = "std::vector<::FlexFlow::TrainingTensorGroupWithAttrs>" - -[[fields]] -name = "output_tensor_groups" -type = "std::vector<::FlexFlow::TrainingTensorGroupWithAttrs>" diff --git a/lib/task-spec/include/task-spec/training_layer_tensor_group_signature.h b/lib/task-spec/include/task-spec/training_layer_tensor_group_signature.h deleted file mode 100644 index 62b11e3af3..0000000000 --- a/lib/task-spec/include/task-spec/training_layer_tensor_group_signature.h +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_TRAINING_LAYER_TENSOR_GROUP_SIGNATURE_H -#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_TRAINING_LAYER_TENSOR_GROUP_SIGNATURE_H - -#include "pcg/tensor_role.dtg.h" -#include "task-spec/training_layer_tensor_group_signature.dtg.h" -#include "utils/nonnegative_int/nonnegative_int.h" - -namespace FlexFlow { - -std::vector get_training_tensor_groups_for_role( - TrainingLayerTensorGroupSignature const &signature, TensorRole tensor_role); - -TrainingTensorGroup get_training_tensor_group_for_role_and_index( - TrainingLayerTensorGroupSignature const &signature, - TensorRole tensor_role, - nonnegative_int index); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/training_layer_tensor_group_signature.struct.toml b/lib/task-spec/include/task-spec/training_layer_tensor_group_signature.struct.toml deleted file mode 100644 index d9859559a1..0000000000 --- a/lib/task-spec/include/task-spec/training_layer_tensor_group_signature.struct.toml +++ /dev/null @@ -1,19 +0,0 @@ -namespace = "FlexFlow" -name = "TrainingLayerTensorGroupSignature" -features = [] - -includes = [ - "task-spec/training_tensor_group.dtg.h", -] - -[[fields]] -name = "input_tensor_groups" -type = "std::vector<::FlexFlow::TrainingTensorGroup>" - -[[fields]] -name = "weight_tensor_groups" -type = "std::vector<::FlexFlow::TrainingTensorGroup>" - -[[fields]] -name = "output_tensor_groups" -type = "std::vector<::FlexFlow::TrainingTensorGroup>" diff --git a/lib/task-spec/include/task-spec/training_tensor_group.h b/lib/task-spec/include/task-spec/training_tensor_group.h deleted file mode 100644 index 40269ceab0..0000000000 --- a/lib/task-spec/include/task-spec/training_tensor_group.h +++ /dev/null @@ -1,28 +0,0 @@ -#ifndef _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_TRAINING_TENSOR_GROUP_H -#define _FLEXFLOW_LIB_LOCAL_EXECUTION_INCLUDE_LOCAL_EXECUTION_TRAINING_TENSOR_GROUP_H - -#include "pcg/optimizer_attrs.dtg.h" -#include "pcg/tensor_attrs.dtg.h" -#include "pcg/tensor_guid_t.dtg.h" -#include "task-spec/forward_tensor_source.h" -#include "task-spec/gradient_tensor_source.h" -#include "task-spec/optimizer_tensor_source.h" -#include "task-spec/training_tensor_group.dtg.h" -#include "task-spec/training_tensor_guid_t.dtg.h" - -namespace FlexFlow { - -TrainingTensorGroup make_training_tensor_group_for_tensor_guid_t( - tensor_guid_t tensor_guid, - TensorAttrs const &tensor_attrs, - OptimizerAttrs const &optimizer_attrs, - ForwardTensorSource &forward_tensor_source, - GradientTensorSource &gradient_tensor_source, - OptimizerTensorSource &optimizer_tensor_source); - -std::unordered_set - get_all_training_tensors_in_tensor_group(TrainingTensorGroup const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/training_tensor_group.struct.toml b/lib/task-spec/include/task-spec/training_tensor_group.struct.toml deleted file mode 100644 index eadaac08ad..0000000000 --- a/lib/task-spec/include/task-spec/training_tensor_group.struct.toml +++ /dev/null @@ -1,31 +0,0 @@ -namespace = "FlexFlow" -name = "TrainingTensorGroup" -features = [ - "eq", - "ord", - "fmt", - "hash", -] - -includes = [ - "task-spec/forward_tensor_guid_t.dtg.h", - "task-spec/gradient_tensor_guid_t.dtg.h", - "task-spec/optimizer_tensor_guid_t.dtg.h", -] - -src_includes = [ - "utils/hash/vector.h", - "utils/fmt/vector.h", -] - -[[fields]] -name = "forward_tensor" -type = "::FlexFlow::forward_tensor_guid_t" - -[[fields]] -name = "gradient_tensor" -type = "::FlexFlow::gradient_tensor_guid_t" - -[[fields]] -name = "optimizer_tensors" -type = "std::vector<::FlexFlow::optimizer_tensor_guid_t>" diff --git a/lib/task-spec/include/task-spec/training_tensor_group_with_attrs.h b/lib/task-spec/include/task-spec/training_tensor_group_with_attrs.h deleted file mode 100644 index 2560228b1c..0000000000 --- a/lib/task-spec/include/task-spec/training_tensor_group_with_attrs.h +++ /dev/null @@ -1,18 +0,0 @@ -#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_TRAINING_TENSOR_GROUP_WITH_ATTRS_H -#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_TRAINING_TENSOR_GROUP_WITH_ATTRS_H - -#include "task-spec/training_tensor_group.dtg.h" -#include "task-spec/training_tensor_group_with_attrs.dtg.h" - -namespace FlexFlow { - -TrainingTensorGroupWithAttrs - make_training_tensor_group_with_attrs_from_group_and_attrs( - TrainingTensorGroup const &group, TensorAttrs const &attrs); - -TrainingTensorGroup - tensor_group_without_attrs(TrainingTensorGroupWithAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/include/task-spec/training_tensor_group_with_attrs.struct.toml b/lib/task-spec/include/task-spec/training_tensor_group_with_attrs.struct.toml deleted file mode 100644 index 5816214fb3..0000000000 --- a/lib/task-spec/include/task-spec/training_tensor_group_with_attrs.struct.toml +++ /dev/null @@ -1,37 +0,0 @@ -namespace = "FlexFlow" -name = "TrainingTensorGroupWithAttrs" -features = [ - "eq", - "ord", - "fmt", - "hash", -] - -includes = [ - "pcg/tensor_attrs.dtg.h", - "task-spec/forward_tensor_guid_t.dtg.h", - "task-spec/gradient_tensor_guid_t.dtg.h", - "task-spec/optimizer_tensor_guid_t.dtg.h", -] - -src_includes = [ - "utils/hash/vector.h", - "utils/fmt/vector.h", -] - -[[fields]] -name = "tensor_attrs" -type = "::FlexFlow::TensorAttrs" - -[[fields]] -name = "forward_tensor" -type = "::FlexFlow::forward_tensor_guid_t" - -[[fields]] -name = "gradient_tensor" -type = "::FlexFlow::gradient_tensor_guid_t" - -[[fields]] -name = "optimizer_tensors" -type = "std::vector<::FlexFlow::optimizer_tensor_guid_t>" - diff --git a/lib/task-spec/include/task-spec/training_tensor_guid_t.variant.toml b/lib/task-spec/include/task-spec/training_tensor_guid_t.variant.toml deleted file mode 100644 index d2520dacbf..0000000000 --- a/lib/task-spec/include/task-spec/training_tensor_guid_t.variant.toml +++ /dev/null @@ -1,31 +0,0 @@ -namespace = "FlexFlow" -name = "training_tensor_guid_t" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "task-spec/forward_tensor_guid_t.dtg.h", - "task-spec/optimizer_tensor_guid_t.dtg.h", - "task-spec/gradient_tensor_guid_t.dtg.h", - "task-spec/loss_tensor_guid_t.dtg.h" -] - -[[values]] -type = "::FlexFlow::forward_tensor_guid_t" -key = "forward_tensor" - -[[values]] -type = "::FlexFlow::gradient_tensor_guid_t" -key = "gradient_tensor" - -[[values]] -type = "::FlexFlow::optimizer_tensor_guid_t" -key = "optimizer_tensor" - -[[values]] -type = "::FlexFlow::loss_tensor_guid_t" -key = "loss_tensor" diff --git a/lib/task-spec/include/task-spec/training_tensor_type.dtg.toml b/lib/task-spec/include/task-spec/training_tensor_type.dtg.toml new file mode 100644 index 0000000000..febd6a2713 --- /dev/null +++ b/lib/task-spec/include/task-spec/training_tensor_type.dtg.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "TrainingTensorType" +type = "enum" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "LOSS" + +[[values]] +name = "FORWARD" + +[[values]] +name = "GRADIENT" + +[[values]] +name = "OPTIMIZER" diff --git a/lib/task-spec/include/task-spec/variadic_tensor_ref.h b/lib/task-spec/include/task-spec/variadic_tensor_ref.h index e990fd5366..5dcae716c9 100644 --- a/lib/task-spec/include/task-spec/variadic_tensor_ref.h +++ b/lib/task-spec/include/task-spec/variadic_tensor_ref.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LOCAL_EXECUTION_VARIADIC_TENSOR_ARG_REF_H #include "task-spec/arg_ref.h" -#include "task-spec/op_tensor_spec.h" +#include "task-spec/ops/op_tensor_spec.h" namespace FlexFlow { diff --git a/lib/task-spec/src/task-spec/arg_ref.cc b/lib/task-spec/src/task-spec/arg_ref.cc new file mode 100644 index 0000000000..2221fe5932 --- /dev/null +++ b/lib/task-spec/src/task-spec/arg_ref.cc @@ -0,0 +1,11 @@ +#include "task-spec/arg_ref.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using LABEL_TYPE = value_type<0>; +using T = value_type<1>; + +template struct ArgRef; + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/arg_ref_spec.cc b/lib/task-spec/src/task-spec/arg_ref_spec.cc new file mode 100644 index 0000000000..4f53ae0711 --- /dev/null +++ b/lib/task-spec/src/task-spec/arg_ref_spec.cc @@ -0,0 +1,10 @@ +#include "task-spec/arg_ref_spec.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using LABEL_TYPE = value_type<0>; + +template struct ArgRefSpec; + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/device_specific.cc b/lib/task-spec/src/task-spec/device_specific.cc new file mode 100644 index 0000000000..c767714554 --- /dev/null +++ b/lib/task-spec/src/task-spec/device_specific.cc @@ -0,0 +1,16 @@ +#include "task-spec/device_specific.h" +#include "utils/archetypes/value_type.h" + +using T = ::FlexFlow::value_type<0>; + +namespace FlexFlow { + +template struct DeviceSpecific; + +} // namespace FlexFlow + +namespace std { + +template struct hash<::FlexFlow::DeviceSpecific>; + +} diff --git a/lib/task-spec/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc b/lib/task-spec/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc new file mode 100644 index 0000000000..8568b56b11 --- /dev/null +++ b/lib/task-spec/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc @@ -0,0 +1,246 @@ +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" +#include "utils/containers/all_of.h" +#include "utils/containers/contains_duplicates.h" +#include "utils/containers/flatmap.h" +#include "utils/containers/multiset_union.h" +#include "utils/containers/zip_strict.h" +#include "utils/containers/zip_values_strict.h" +#include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/find_isomorphism_between_labelled_open_kwarg_dataflow_graphs.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" +#include "utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_graph_input.dtg.h" +#include "utils/many_to_one/many_to_one.h" + +namespace FlexFlow { + +DynamicOpenDataflowGraph make_empty_dynamic_open_dataflow_graph() { + return DynamicOpenDataflowGraph{ + std::unordered_set{}, + }; +} + +nonnegative_int dynamic_graph_num_nodes(DynamicOpenDataflowGraph const &g) { + return num_elements(get_dynamic_nodes(g)); +} + +bool full_dynamic_graph_satisfies( + DynamicOpenDataflowGraph const &g, + std::function const &node_condition, + std::function const &value_condition, + std::function const &slot_condition) { + + return all_of(get_dynamic_nodes(g), node_condition) && + all_of(get_dynamic_values(g), value_condition) && + all_of(get_dynamic_tensor_slots(g), slot_condition); +} + +bool no_part_of_dynamic_graph_satisfies( + DynamicOpenDataflowGraph const &g, + std::function const &node_condition, + std::function const &value_condition, + std::function const &slot_condition) { + + return full_dynamic_graph_satisfies( + g, + [&](DynamicNodeAttrs const &n) -> bool { return !node_condition(n); }, + [&](DynamicValueAttrs const &v) -> bool { return !value_condition(v); }, + [&](DynamicTensorSlot const &s) -> bool { return !slot_condition(s); }); +} + +std::unordered_multiset + get_dynamic_nodes(DynamicOpenDataflowGraph const &g) { + return transform(unordered_multiset_of(g.invocations), + [&](DynamicNodeInvocation const &i) -> DynamicNodeAttrs { + return i.node_attrs; + }); +} + +std::unordered_multiset + get_dynamic_values(DynamicOpenDataflowGraph const &g) { + return flatmap(unordered_multiset_of(g.invocations), + [&](DynamicNodeInvocation const &i) + -> std::unordered_multiset { + return multiset_union(values(i.inputs), values(i.outputs)); + }); +} + +std::unordered_multiset + get_dynamic_tensor_slots(DynamicOpenDataflowGraph const &g) { + return flatmap(unordered_multiset_of(g.invocations), + [&](DynamicNodeInvocation const &i) + -> std::unordered_multiset { + return unordered_multiset_of( + set_union(keys(i.inputs), keys(i.outputs))); + }); +} + +std::unordered_set + get_dynamic_invocation_set(DynamicOpenDataflowGraph const &g) { + return g.invocations; +} + +DynamicOpenDataflowGraph transform_dynamic_invocation_set( + DynamicOpenDataflowGraph const &g, + std::function const + &f) { + std::unordered_set current_invocation_set = + get_dynamic_invocation_set(g); + std::unordered_set new_invocation_set = + transform(current_invocation_set, f); + + return dynamic_open_dataflow_graph_from_invocation_set(new_invocation_set); +} + +DynamicOpenDataflowGraph flatmap_dynamic_invocation_set( + DynamicOpenDataflowGraph const &g, + std::function( + DynamicNodeInvocation const &)> const &f) { + + std::unordered_set current_invocation_set = + get_dynamic_invocation_set(g); + std::vector new_invocation_set = + flatmap(vector_of(current_invocation_set), f); + + ASSERT(!contains_duplicates(new_invocation_set)); + + return dynamic_open_dataflow_graph_from_invocation_set( + unordered_set_of(new_invocation_set)); +} + +DynamicOpenDataflowGraph dynamic_open_dataflow_graph_from_invocation_set( + std::unordered_set const &invocation_set) { + + return DynamicOpenDataflowGraph{ + invocation_set, + }; +} + +LabelledOpenKwargDataflowGraph + labelled_open_kwarg_dataflow_graph_from_dynamic_open_dataflow_graph( + DynamicOpenDataflowGraph const &g) { + + std::unordered_set all_values = + unordered_set_of(get_dynamic_values(g)); + + ManyToOne value_to_producer; + for (DynamicNodeInvocation const &invocation : + get_dynamic_invocation_set(g)) { + for (DynamicValueAttrs const &output : values(invocation.outputs)) { + value_to_producer.insert({output, invocation}); + } + } + + std::unordered_set graph_inputs = + filter(all_values, [&](DynamicValueAttrs const &v) -> bool { + return !value_to_producer.contains_l(v); + }); + + LabelledOpenKwargDataflowGraph + result = LabelledOpenKwargDataflowGraph:: + create< + UnorderedSetLabelledOpenKwargDataflowGraph>(); + + bidict, DynamicValueAttrs> + value_map; + + for (auto const &kv : enumerate(graph_inputs)) { + int input_idx = kv.first.unwrap_nonnegative(); + DynamicValueAttrs graph_input = kv.second; + KwargDataflowGraphInput added = + result.add_input(input_idx, graph_input); + value_map.equate(OpenKwargDataflowValue{added}, + graph_input); + } + + auto inputs_have_been_added = + [&](DynamicNodeInvocation const &invocation) -> bool { + return all_of(values(invocation.inputs), + [&](DynamicValueAttrs const &input) -> bool { + return value_map.contains_r(input); + }); + }; + + std::unordered_set to_add = g.invocations; + + auto add_invocation_to_graph = + [&](DynamicNodeInvocation const &invocation) -> void { + KwargNodeAddedResult added = result.add_node( + invocation.node_attrs, + map_values(invocation.inputs, + [&](DynamicValueAttrs const &input) + -> OpenKwargDataflowValue { + return value_map.at_r(input); + }), + invocation.outputs); + + for (auto const &[k, v] : + zip_values_strict(invocation.outputs, added.outputs)) { + DynamicValueAttrs invocation_output = v.first; + KwargDataflowOutput graph_output = v.second; + value_map.equate( + OpenKwargDataflowValue{graph_output}, + invocation_output); + } + + to_add.erase(invocation); + }; + + auto add_next_invocation_to_graph = [&]() { + for (DynamicNodeInvocation const &invocation : to_add) { + if (inputs_have_been_added(invocation)) { + add_invocation_to_graph(invocation); + return; + } + } + + PANIC("Failed to add any invocations in to_add", to_add); + }; + + while (to_add.size() > 0) { + add_next_invocation_to_graph(); + } + + return result; +} + +bool dynamic_open_dataflow_graphs_are_isomorphic( + DynamicOpenDataflowGraph const &lhs, DynamicOpenDataflowGraph const &rhs) { + LabelledOpenKwargDataflowGraphView + lhs_dataflow_graph = + labelled_open_kwarg_dataflow_graph_from_dynamic_open_dataflow_graph( + lhs); + + LabelledOpenKwargDataflowGraphView + rhs_dataflow_graph = + labelled_open_kwarg_dataflow_graph_from_dynamic_open_dataflow_graph( + rhs); + + return find_isomorphism_between_labelled_open_kwarg_dataflow_graphs( + lhs_dataflow_graph, rhs_dataflow_graph) + .has_value(); +} + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/dynamic_graph/dynamic_tensor_role.cc b/lib/task-spec/src/task-spec/dynamic_graph/dynamic_tensor_role.cc new file mode 100644 index 0000000000..235436cdac --- /dev/null +++ b/lib/task-spec/src/task-spec/dynamic_graph/dynamic_tensor_role.cc @@ -0,0 +1,22 @@ +#include "task-spec/dynamic_graph/dynamic_tensor_role.h" + +namespace FlexFlow { + +DynamicTensorRole + dynamic_tensor_role_from_fwb_tensor_type(FwbTensorType tensor_type) { + return DynamicTensorRole{tensor_type}; +} + +DynamicTensorRole mk_dynamic_tensor_role_fwd() { + return DynamicTensorRole{FwbTensorType::FORWARD}; +} + +DynamicTensorRole mk_dynamic_tensor_role_bwd() { + return DynamicTensorRole{FwbTensorType::GRADIENT}; +} + +DynamicTensorRole mk_dynamic_tensor_role_opt(OptimizerSlotName s) { + return DynamicTensorRole{DynamicOptimizerTensorRole{s}}; +} + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/dynamic_graph/dynamic_tensor_slot.cc b/lib/task-spec/src/task-spec/dynamic_graph/dynamic_tensor_slot.cc new file mode 100644 index 0000000000..ef05974ccc --- /dev/null +++ b/lib/task-spec/src/task-spec/dynamic_graph/dynamic_tensor_slot.cc @@ -0,0 +1,15 @@ +#include "task-spec/dynamic_graph/dynamic_tensor_slot.h" + +namespace FlexFlow { + +DynamicTensorSlot decide_tensor_slot_role(DynamicTensorSlot const &slot, + DynamicTensorRole role) { + ASSERT(slot.slot_tensor_role == std::nullopt); + + DynamicTensorSlot result = slot; + result.slot_tensor_role = role; + + return result; +} + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/dynamic_graph/dynamic_value_attrs.cc b/lib/task-spec/src/task-spec/dynamic_graph/dynamic_value_attrs.cc new file mode 100644 index 0000000000..418f496450 --- /dev/null +++ b/lib/task-spec/src/task-spec/dynamic_graph/dynamic_value_attrs.cc @@ -0,0 +1,11 @@ +#include "task-spec/dynamic_graph/dynamic_value_attrs.h" + +namespace FlexFlow { + +DynamicValueAttrs decide_dynamic_value_attrs_role(DynamicValueAttrs const &, + DynamicTensorRole) { + + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/dynamic_graph/machine_slicing.cc b/lib/task-spec/src/task-spec/dynamic_graph/machine_slicing.cc new file mode 100644 index 0000000000..0a22015ddf --- /dev/null +++ b/lib/task-spec/src/task-spec/dynamic_graph/machine_slicing.cc @@ -0,0 +1,33 @@ +#include "task-spec/dynamic_graph/machine_slicing.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" + +namespace FlexFlow { + +std::unordered_set + perform_machine_slicing_for_invocation( + DynamicNodeInvocation const &invocation, + MachineSpaceCoordinate const &device_coord) { + + ASSERT(invocation.node_attrs.device_coord.has_value()); + + if (invocation.node_attrs.device_coord.value() == device_coord) { + return {invocation}; + } else { + return {}; + } +} + +DynamicOpenDataflowGraph + perform_machine_slicing(DynamicOpenDataflowGraph const &g, + MachineSpaceCoordinate const &device_coord) { + DynamicOpenDataflowGraph result = flatmap_dynamic_invocation_set( + g, + [&](DynamicNodeInvocation const &invocation) + -> std::unordered_set { + return perform_machine_slicing_for_invocation(invocation, device_coord); + }); + + return result; +} + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc new file mode 100644 index 0000000000..2a1ae071fa --- /dev/null +++ b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc @@ -0,0 +1,137 @@ +#include "task-spec/dynamic_graph/pass_expansion.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" +#include "task-spec/dynamic_graph/dynamic_tensor_role.h" +#include "utils/containers/are_all_same.h" +#include "utils/containers/merge_disjoint_maps.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +bool node_is_pass_expanded(DynamicNodeAttrs const &n) { + return n.task_type.has_value(); +} + +bool slot_is_pass_expanded(DynamicTensorSlot const &s) { + return s.slot_tensor_role.has_value(); +} + +bool value_is_pass_expanded(DynamicValueAttrs const &v) { + return v.role.has_value(); +} + +bool no_part_of_graph_is_pass_expanded(DynamicOpenDataflowGraph const &g) { + return no_part_of_dynamic_graph_satisfies( + g, node_is_pass_expanded, value_is_pass_expanded, slot_is_pass_expanded); +} + +bool graph_is_fully_pass_expanded(DynamicOpenDataflowGraph const &g) { + return full_dynamic_graph_satisfies( + g, node_is_pass_expanded, value_is_pass_expanded, slot_is_pass_expanded); +} + +DynamicTensorSlot pass_expand_slot(DynamicTensorSlot const &s, + FwbTensorType tensor_type) { + ASSERT(s.slot_tensor_role == std::nullopt); + + DynamicTensorSlot result = s; + result.slot_tensor_role = + dynamic_tensor_role_from_fwb_tensor_type(tensor_type); + return result; +} + +DynamicValueAttrs pass_expand_value(DynamicValueAttrs const &v, + FwbTensorType tensor_type) { + ASSERT(!value_is_pass_expanded(v)); + + DynamicValueAttrs result = v; + result.role = DynamicTensorRole{tensor_type}; + return result; +}; + +DynamicNodeAttrs pass_expand_node(DynamicNodeAttrs const &n, + DynamicTaskType task_type) { + ASSERT(!node_is_pass_expanded(n)); + ASSERT(task_type == DynamicTaskType::FWD || + task_type == DynamicTaskType::BWD); + + DynamicNodeAttrs result = n; + result.task_type = task_type; + return result; +} + +DynamicNodeInvocation perform_fwd_pass_expansion_for_invocation( + DynamicNodeInvocation const &task) { + + auto to_fwd = [](DynamicTensorSlot const &k, DynamicValueAttrs const &v) { + return std::pair{ + pass_expand_slot(k, FwbTensorType::FORWARD), + pass_expand_value(v, FwbTensorType::FORWARD), + }; + }; + + return DynamicNodeInvocation{ + /*inputs=*/ + transform(task.inputs, to_fwd), + /*node_attrs=*/ + pass_expand_node(task.node_attrs, DynamicTaskType::FWD), + /*outputs=*/ + transform(task.outputs, to_fwd), + }; +} + +DynamicNodeInvocation perform_bwd_pass_expansion_for_invocation( + DynamicNodeInvocation const &invocation) { + + auto to_fwd = [](DynamicTensorSlot const &k, DynamicValueAttrs const &v) { + return std::pair{ + pass_expand_slot(k, FwbTensorType::FORWARD), + pass_expand_value(v, FwbTensorType::FORWARD), + }; + }; + + auto to_grad = [](DynamicTensorSlot const &k, DynamicValueAttrs const &v) { + return std::pair{ + pass_expand_slot(k, FwbTensorType::GRADIENT), + pass_expand_value(v, FwbTensorType::GRADIENT), + }; + }; + + return DynamicNodeInvocation{ + /*inputs=*/ + merge_disjoint_maps(std::vector{ + transform(invocation.inputs, to_fwd), + transform(invocation.outputs, to_fwd), + transform(invocation.outputs, to_grad), + }), + /*node_attrs=*/ + pass_expand_node(invocation.node_attrs, DynamicTaskType::BWD), + /*outputs=*/ + transform(invocation.inputs, to_grad), + }; +} + +DynamicOpenDataflowGraph + perform_pass_expansion(DynamicOpenDataflowGraph const &g) { + + ASSERT(no_part_of_graph_is_pass_expanded(g)); + + DynamicOpenDataflowGraph result = flatmap_dynamic_invocation_set( + g, [](DynamicNodeInvocation const &invocation) { + if (invocation.inputs.empty()) { + return std::unordered_set{ + perform_fwd_pass_expansion_for_invocation(invocation), + }; + } else { + return std::unordered_set{ + perform_fwd_pass_expansion_for_invocation(invocation), + perform_bwd_pass_expansion_for_invocation(invocation), + }; + }; + }); + + ASSERT(graph_is_fully_pass_expanded(result)); + + return result; +} + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc new file mode 100644 index 0000000000..ea253b63f8 --- /dev/null +++ b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc @@ -0,0 +1,84 @@ +#include "task-spec/dynamic_graph/shard_expansion.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" +#include "utils/containers/map_values2.h" +#include "utils/optional.h" + +namespace FlexFlow { + +bool node_is_shard_expanded(DynamicNodeAttrs const &n) { + return n.device_coord.has_value(); +} + +bool value_is_shard_expanded(DynamicValueAttrs const &n) { + return n.shard_coord.has_value(); +} + +bool no_part_of_graph_is_shard_expanded(DynamicOpenDataflowGraph const &g) { + auto slot_is_shard_expanded = [](DynamicTensorSlot const &) -> bool { + return true; + }; + + return no_part_of_dynamic_graph_satisfies(g, + node_is_shard_expanded, + value_is_shard_expanded, + slot_is_shard_expanded); +} + +bool graph_is_fully_shard_expanded(DynamicOpenDataflowGraph const &g) { + auto slot_is_shard_expanded = [](DynamicTensorSlot const &) -> bool { + return true; + }; + + return full_dynamic_graph_satisfies(g, + node_is_shard_expanded, + value_is_shard_expanded, + slot_is_shard_expanded); +} + +static DynamicNodeInvocation shard_invocation_for_binding( + DynamicNodeInvocation const &i, + MachineSpaceCoordinate const &machine_coord, + OperatorAtomicTaskShardBinding const &binding) { + auto shard_expand_value_attrs = + [&](DynamicTensorSlot const &s, + DynamicValueAttrs const &v) -> DynamicValueAttrs { + ParallelTensorSpaceCoordinate parallel_tensor_coord = + binding.tensor_coords.at(s.slot_name); + + DynamicValueAttrs result = v; + result.shard_coord = parallel_tensor_coord; + return result; + }; + + DynamicNodeAttrs expanded_node_attrs = [&]() { + DynamicNodeAttrs result = i.node_attrs; + result.device_coord = machine_coord; + return result; + }(); + + return DynamicNodeInvocation{ + /*inputs=*/map_values2(i.inputs, shard_expand_value_attrs), + /*node_attrs=*/expanded_node_attrs, + /*outputs=*/map_values2(i.outputs, shard_expand_value_attrs), + }; +} + +std::unordered_set + perform_shard_expansion_for_invocation(DynamicNodeInvocation const &i) { + + MappedOperatorTaskGroup mapping = assert_unwrap(i.node_attrs.mapping); + + std::unordered_set shard_machine_coords = + mapping.get_shard_bindings().left_values(); + + return transform( + shard_machine_coords, + [&](MachineSpaceCoordinate const &c) -> DynamicNodeInvocation { + OperatorAtomicTaskShardBinding slot_bindings = + mapping.get_shard_bindings().at_l(c); + + return shard_invocation_for_binding(i, c, slot_bindings); + }); +} + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc b/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc new file mode 100644 index 0000000000..66e7115a83 --- /dev/null +++ b/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc @@ -0,0 +1,94 @@ +#include "task-spec/dynamic_graph/update_insertion.h" +#include "pcg/optimizer_attrs.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" +#include "task-spec/dynamic_graph/dynamic_tensor_role.h" +#include "task-spec/dynamic_graph/dynamic_tensor_slot.h" +#include "task-spec/dynamic_graph/dynamic_value_attrs.h" +#include "task-spec/optimizer.h" +#include "utils/containers/get_only.h" +#include "utils/containers/map_from_pairs.h" +#include "utils/containers/set_union.h" + +namespace FlexFlow { + +static std::pair + get_weight_output(DynamicNodeInvocation const &i) { + ASSERT(i.node_attrs.op_attrs.value().is_weight()); + ASSERT(i.inputs.size() == 0); + + auto [slot, value_attrs] = get_only(i.outputs); + + return std::pair{ + slot, + value_attrs, + }; +} + +static DynamicNodeInvocation get_update_invocation_for_invocation( + DynamicNodeInvocation const &i, OptimizerAttrs const &optimizer_attrs) { + + auto output = get_weight_output(i); + DynamicTensorSlot slot = output.first; + DynamicValueAttrs value_attrs = output.second; + + ASSERT(value_attrs.accessor == std::nullopt); + + DynamicNodeAttrs update_node_attrs = i.node_attrs; + update_node_attrs.task_type = DynamicTaskType::UPD; + + auto create_binding_for_role = [&](DynamicTensorRole const &role) + -> std::pair { + DynamicTensorSlot binding_slot = decide_tensor_slot_role(slot, role); + DynamicValueAttrs value_attrs = decide_dynamic_value_attrs_role( + value_attrs, mk_dynamic_tensor_role_fwd()); + + return std::pair{ + binding_slot, + value_attrs, + }; + }; + + std::unordered_set tensor_roles = set_union( + std::unordered_set{ + mk_dynamic_tensor_role_fwd(), + mk_dynamic_tensor_role_bwd(), + }, + transform(get_slot_names_for_optimizer(optimizer_attrs), + mk_dynamic_tensor_role_opt)); + + return DynamicNodeInvocation{ + /*inputs=*/map_from_pairs( + transform(tensor_roles, create_binding_for_role)), + /*node_attrs=*/update_node_attrs, + /*outputs=*/std::unordered_map{}, + }; +} + +std::unordered_set + perform_update_insertion_for_invocation( + DynamicNodeInvocation const &invocation, + OptimizerAttrs const &optimizer_attrs) { + + if (invocation.node_attrs.task_type.value() == DynamicTaskType::FWD && + invocation.node_attrs.op_attrs.value().is_weight()) { + return std::unordered_set{ + invocation, + get_update_invocation_for_invocation(invocation, optimizer_attrs), + }; + } else { + return std::unordered_set{ + invocation, + }; + }; +} + +DynamicOpenDataflowGraph + perform_update_insertion(DynamicOpenDataflowGraph const &g, + OptimizerAttrs const &optimizer_attrs) { + + return flatmap_dynamic_invocation_set(g, [&](DynamicNodeInvocation const &i) { + return perform_update_insertion_for_invocation(i, optimizer_attrs); + }); +} + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/forward_tensor_source.cc b/lib/task-spec/src/task-spec/forward_tensor_source.cc deleted file mode 100644 index 3d82452377..0000000000 --- a/lib/task-spec/src/task-spec/forward_tensor_source.cc +++ /dev/null @@ -1,18 +0,0 @@ -#include "task-spec/forward_tensor_source.h" - -namespace FlexFlow { - -int ForwardTensorSource::next_available_forward_tensor_id = 0; - -ForwardTensorSource::ForwardTensorSource() {} - -forward_tensor_guid_t ForwardTensorSource::new_forward_tensor() { - return forward_tensor_guid_t{ - ForwardTensorSource::next_available_forward_tensor_id++}; -} - -void ForwardTensorSource::reset() { - ForwardTensorSource::next_available_forward_tensor_id = 0; -} - -} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/gradient_tensor_source.cc b/lib/task-spec/src/task-spec/gradient_tensor_source.cc deleted file mode 100644 index 8bc5034634..0000000000 --- a/lib/task-spec/src/task-spec/gradient_tensor_source.cc +++ /dev/null @@ -1,18 +0,0 @@ -#include "task-spec/gradient_tensor_source.h" - -namespace FlexFlow { - -int GradientTensorSource::next_available_gradient_tensor_id = 0; - -GradientTensorSource::GradientTensorSource() {} - -gradient_tensor_guid_t GradientTensorSource::new_gradient_tensor() { - return gradient_tensor_guid_t{ - GradientTensorSource::next_available_gradient_tensor_id++}; -} - -void GradientTensorSource::reset() { - GradientTensorSource::next_available_gradient_tensor_id = 0; -} - -} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/itask_argument_accessor.cc b/lib/task-spec/src/task-spec/itask_argument_accessor.cc deleted file mode 100644 index c7878b1abc..0000000000 --- a/lib/task-spec/src/task-spec/itask_argument_accessor.cc +++ /dev/null @@ -1 +0,0 @@ -#include "task-spec/itask_argument_accessor.h" diff --git a/lib/task-spec/src/task-spec/loss_functions.cc b/lib/task-spec/src/task-spec/loss_functions.cc index 698ca941d3..53db6b9cc4 100644 --- a/lib/task-spec/src/task-spec/loss_functions.cc +++ b/lib/task-spec/src/task-spec/loss_functions.cc @@ -21,43 +21,14 @@ namespace FlexFlow { -enum Slots { LOGIT, LABEL, LOGIT_GRAD, ATTRS, PROFILING, KERNEL_DEVICE_TYPE }; - -TaskSignature get_loss_bwd_signature() { - TaskSignature sig = make_empty_task_signature(); - add_slot(sig, LOGIT, TensorType::FORWARD); - add_slot(sig, LABEL, TensorType::LOSS); - add_slot(sig, LOGIT_GRAD, TensorType::GRADIENT); - - add_arg_slot(sig, ATTRS); - add_arg_slot(sig, PROFILING); - add_arg_slot(sig, KERNEL_DEVICE_TYPE); - return sig; -} - -TaskInvocation backward(LossAttrs const &attrs, - forward_tensor_guid_t logit, - gradient_tensor_guid_t logit_grad, - loss_tensor_guid_t label) { - TaskBinding b; - b.bind(LOGIT, logit); - b.bind_loss(LABEL, label); - b.bind_grad(LOGIT_GRAD, logit_grad); - - b.bind_arg(ATTRS, attrs); - b.bind_arg(PROFILING, profiling_settings()); - b.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - - return TaskInvocation{task_id_t::LOSS_BWD_TASK_ID, b}; -} - static void backward_task_impl(TaskArgumentAccessor const &acc) { - auto attrs = acc.get_argument(ATTRS); - auto profiling = acc.get_argument(PROFILING); - auto kernel_device_type = acc.get_argument(KERNEL_DEVICE_TYPE); - auto logit_grad = acc.get_tensor_grad(LOGIT_GRAD); - auto logit = acc.get_tensor(LOGIT); - auto label = acc.get_loss_tensor(LABEL); + LossAttrs attrs = acc.get_loss_attrs(); + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + + auto logit_grad = acc.get_tensor_grad(TensorSlotName::LOGIT); + auto logit = acc.get_tensor(TensorSlotName::LOGIT); + auto label = acc.get_loss_tensor(); int batch_size = dim_at_idx(logit.shape.dims, legion_dim_t{1_n}).int_from_positive_int(); @@ -75,7 +46,7 @@ static void backward_task_impl(TaskArgumentAccessor const &acc) { if (loss_type == LossFunction::SPARSE_CATEGORICAL_CROSSENTROPY) { // label shape is [batch dim, 1] auto scce_attrs = attrs.get(); - size_t ndim = get_num_dims(logit.shape.dims).unwrap_nonnegative(); + size_t ndim = get_num_dims(logit.shape.dims).int_from_num_tensor_dims(); int num_classes = dim_at_idx(logit.shape.dims, legion_dim_t{0_n}).int_from_positive_int(); ASSERT(logit_grad.shape == logit.shape); diff --git a/lib/task-spec/src/task-spec/loss_tensor_source.cc b/lib/task-spec/src/task-spec/loss_tensor_source.cc deleted file mode 100644 index 13b97fd604..0000000000 --- a/lib/task-spec/src/task-spec/loss_tensor_source.cc +++ /dev/null @@ -1,13 +0,0 @@ -#include "task-spec/loss_tensor_source.h" - -namespace FlexFlow { - -nonnegative_int LossTensorSource::next_available_loss_tensor_id = 0_n; - -LossTensorSource::LossTensorSource() {} - -loss_tensor_guid_t LossTensorSource::new_loss_tensor() { - return loss_tensor_guid_t{LossTensorSource::next_available_loss_tensor_id++}; -} - -} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/op_arg_spec.cc b/lib/task-spec/src/task-spec/op_arg_spec.cc deleted file mode 100644 index 6e48a7c5f7..0000000000 --- a/lib/task-spec/src/task-spec/op_arg_spec.cc +++ /dev/null @@ -1,10 +0,0 @@ -#include "task-spec/op_arg_spec.h" - -namespace FlexFlow { - -std::type_index get_op_arg_spec_type_index(OpArgSpec const &s) { - return s.visit( - [](auto &&arg) { return arg.get_type_index(); }); -} - -} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/op_task_invocation.cc b/lib/task-spec/src/task-spec/op_task_invocation.cc deleted file mode 100644 index a55995920a..0000000000 --- a/lib/task-spec/src/task-spec/op_task_invocation.cc +++ /dev/null @@ -1,127 +0,0 @@ -#include "task-spec/op_task_invocation.h" -#include "task-spec/op_arg_spec.h" -#include "utils/containers/contains_key.h" - -namespace FlexFlow { - -void OpTaskBinding::bind( - int slot, VariadicTensorRef const &variadic_tensor_ref) { - this->bind(slot_id_t{slot}, variadic_tensor_ref); -} - -void OpTaskBinding::bind( - slot_id_t slot, - VariadicTensorRef const &variadic_tensor_ref) { - NOT_IMPLEMENTED(); -} - -void OpTaskBinding::bind(int slot, OpTensorSpec const &tensor_spec) { - this->bind(slot_id_t{slot}, tensor_spec); -} - -void OpTaskBinding::bind(slot_id_t slot, OpTensorSpec const &tensor_spec) { - this->tensor_bindings.insert({SlotGradId{slot, IsGrad::NO}, tensor_spec}); -} - -void OpTaskBinding::bind_grad(int slot, OpTensorSpec const &tensor_spec) { - this->bind_grad(slot_id_t{slot}, tensor_spec); -} - -void OpTaskBinding::bind_grad(slot_id_t slot, OpTensorSpec const &tensor_spec) { - this->tensor_bindings.insert({SlotGradId{slot, IsGrad::YES}, tensor_spec}); -} - -void OpTaskBinding::insert_arg_spec(slot_id_t name, OpArgSpec const &arg_spec) { - assert(!contains_key(this->arg_bindings, name)); - this->arg_bindings.insert({name, arg_spec}); -} - -bool OpTaskBinding::operator==(OpTaskBinding const &other) const { - return this->tie() == other.tie(); -} - -bool OpTaskBinding::operator!=(OpTaskBinding const &other) const { - return this->tie() != other.tie(); -} - -std::tuple const &, - std::unordered_map const &> - OpTaskBinding::tie() const { - return std::tie(this->tensor_bindings, this->arg_bindings); -} - -std::unordered_map const & - OpTaskBinding::get_tensor_bindings() const { - return this->tensor_bindings; -} - -std::unordered_map const & - OpTaskBinding::get_arg_bindings() const { - return this->arg_bindings; -} - -void OpTaskBinding::bind_from_forward(OpTaskBinding const &fwd) { - this->arg_bindings = fwd.get_arg_bindings(); - this->tensor_bindings = fwd.get_tensor_bindings(); -} - -OpTaskBinding infer_bwd_binding(OpTaskBinding const &fwd) { - OpTaskBinding bwd; - bwd.bind_from_forward(fwd); - for (auto const &[key, spec] : fwd.get_tensor_bindings()) { - OpSlotOptions slot_option = spec.slot_option; - if (slot_option != OpSlotOptions::UNTRAINABLE || - slot_option != OpSlotOptions::OPTIONAL_UNTRAINABLE) { - slot_id_t slot = key.slot_id; - bwd.bind_grad(slot, spec); - } - } - return bwd; -} - -bool is_tensor_invocation_valid(OpTaskSignature const &sig, - OpTaskInvocation const &inv) { - // TODO: fix for variadic inputs (need to implement .bind() for variadic - // first) - for (std::pair const &tensor_binding : - inv.binding.get_tensor_bindings()) { - OpTensorSlotSpec op_tensor_slot_spec = - OpTensorSlotSpec{tensor_binding.first.slot_id, - SlotType::TENSOR, - tensor_binding.second.role, - tensor_binding.first.is_grad, - tensor_binding.second.slot_option}; - - if (!sig.get_tensor_slots().count(op_tensor_slot_spec)) { - return false; - } - } - - return true; -} - -bool is_arg_invocation_valid(OpTaskSignature const &sig, - OpTaskInvocation const &inv) { - // TODO: fix for device specific args - // for (std::pair const & arg_binding : - // inv.binding.get_arg_bindings()) { - // if (sig.get_arg_types().count(arg_binding.first)) { - // if (get_op_arg_spec_type_index(arg_binding.second) != - // sig.get_arg_types().at(arg_binding.first)) { - // return false; - // } - // } else { - // return false; - // } - // } - - return true; -} - -bool is_invocation_valid(OpTaskSignature const &sig, - OpTaskInvocation const &inv) { - return is_tensor_invocation_valid(sig, inv) && - is_arg_invocation_valid(sig, inv); -} - -} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/op_task_signature.cc b/lib/task-spec/src/task-spec/op_task_signature.cc deleted file mode 100644 index 94ac16d092..0000000000 --- a/lib/task-spec/src/task-spec/op_task_signature.cc +++ /dev/null @@ -1,165 +0,0 @@ -#include "task-spec/op_task_signature.h" -#include "utils/fmt/optional.h" -#include "utils/fmt/unordered_map.h" -#include "utils/fmt/unordered_set.h" - -namespace FlexFlow { - -OpTaskSignature::OpTaskSignature(OpTaskType t) : type(t){}; - -void OpTaskSignature::add_input_slot(int name, SlotType slot_type) { - this->add_input_slot(slot_id_t{name}, slot_type); -} - -void OpTaskSignature::add_input_slot(slot_id_t name, SlotType slot_type) { - OpTensorSlotSpec op_tensor_slot_spec = OpTensorSlotSpec{ - name, slot_type, TensorRole::INPUT, IsGrad::NO, OpSlotOptions::NECESSARY}; - this->op_tensor_slots.insert(op_tensor_slot_spec); -} - -void OpTaskSignature::add_optional_input_slot(int name, SlotType slot_type) { - this->add_optional_input_slot(slot_id_t{name}, slot_type); -} - -void OpTaskSignature::add_optional_input_slot(slot_id_t name, - SlotType slot_type) { - OpTensorSlotSpec op_tensor_slot_spec = OpTensorSlotSpec{ - name, slot_type, TensorRole::INPUT, IsGrad::NO, OpSlotOptions::OPTIONAL}; - this->op_tensor_slots.insert(op_tensor_slot_spec); -} - -void OpTaskSignature::add_untrainable_input_slot(int name, SlotType slot_type) { - this->add_untrainable_input_slot(slot_id_t{name}, slot_type); -} - -void OpTaskSignature::add_untrainable_input_slot(slot_id_t name, - SlotType slot_type) { - OpTensorSlotSpec op_tensor_slot_spec = - OpTensorSlotSpec{name, - slot_type, - TensorRole::INPUT, - IsGrad::NO, - OpSlotOptions::UNTRAINABLE}; - this->op_tensor_slots.insert(op_tensor_slot_spec); -} - -void OpTaskSignature::add_optional_untrainable_input_slot(int name, - SlotType slot_type) { - this->add_optional_untrainable_input_slot(slot_id_t{name}, slot_type); -} - -void OpTaskSignature::add_optional_untrainable_input_slot(slot_id_t name, - SlotType slot_type) { - OpTensorSlotSpec op_tensor_slot_spec = - OpTensorSlotSpec{name, - slot_type, - TensorRole::INPUT, - IsGrad::NO, - OpSlotOptions::OPTIONAL_UNTRAINABLE}; - this->op_tensor_slots.insert(op_tensor_slot_spec); -} - -void OpTaskSignature::add_output_slot(int name, SlotType slot_type) { - this->add_output_slot(slot_id_t{name}, slot_type); -} - -void OpTaskSignature::add_output_slot(slot_id_t name, SlotType slot_type) { - OpTensorSlotSpec op_tensor_slot_spec = - OpTensorSlotSpec{name, - slot_type, - TensorRole::OUTPUT, - IsGrad::NO, - OpSlotOptions::NECESSARY}; - this->op_tensor_slots.insert(op_tensor_slot_spec); -} - -void OpTaskSignature::add_bwd_optional_output_slot(int name, - SlotType slot_type) { - this->add_bwd_optional_output_slot(slot_id_t{name}, slot_type); -} - -void OpTaskSignature::add_bwd_optional_output_slot(slot_id_t name, - SlotType slot_type) { - OpTensorSlotSpec op_tensor_slot_spec = OpTensorSlotSpec{ - name, slot_type, TensorRole::OUTPUT, IsGrad::NO, OpSlotOptions::OPTIONAL}; - this->op_tensor_slots.insert(op_tensor_slot_spec); -} - -void OpTaskSignature::add_weight_slot(int name, SlotType slot_type) { - this->add_weight_slot(slot_id_t{name}, slot_type); -} - -void OpTaskSignature::add_weight_slot(slot_id_t name, SlotType slot_type) { - OpTensorSlotSpec op_tensor_slot_spec = - OpTensorSlotSpec{name, - slot_type, - TensorRole::WEIGHT, - IsGrad::NO, - OpSlotOptions::NECESSARY}; - this->op_tensor_slots.insert(op_tensor_slot_spec); -} - -void OpTaskSignature::add_optional_weight_slot(int name, SlotType slot_type) { - this->add_optional_weight_slot(slot_id_t{name}, slot_type); -} - -void OpTaskSignature::add_optional_weight_slot(slot_id_t name, - SlotType slot_type) { - OpTensorSlotSpec op_tensor_slot_spec = OpTensorSlotSpec{ - name, slot_type, TensorRole::WEIGHT, IsGrad::NO, OpSlotOptions::OPTIONAL}; - this->op_tensor_slots.insert(op_tensor_slot_spec); -} - -void OpTaskSignature::set_arg_types( - std::unordered_map const &arg_type) { - this->task_arg_types = arg_type; -} - -void OpTaskSignature::add_from_slot_spec(OpTensorSlotSpec const &spec) { - this->op_tensor_slots.insert(spec); -} - -OpTaskSignature infer_bwd_signature(OpTaskSignature const &fwd) { - OpTaskSignature bwd = fwd; - bwd.type = OpTaskType::BWD; - for (auto const &op_tensor_slot_spec : fwd.get_tensor_slots()) { - OpSlotOptions slot_option = op_tensor_slot_spec.slot_option; - if (slot_option != OpSlotOptions::UNTRAINABLE || - slot_option != OpSlotOptions::OPTIONAL_UNTRAINABLE) { - OpTensorSlotSpec grad_spec = - OpTensorSlotSpec{op_tensor_slot_spec.name, - op_tensor_slot_spec.slot_type, - op_tensor_slot_spec.tensor_role, - IsGrad::YES, - op_tensor_slot_spec.slot_option}; - bwd.op_tensor_slots.insert(grad_spec); - } - } - - return bwd; -} - -std::unordered_set OpTaskSignature::get_tensor_slots() const { - return this->op_tensor_slots; -} - -std::unordered_map - OpTaskSignature::get_arg_types() const { - return this->task_arg_types; -} - -std::string format_as(OpTaskSignature const &x) { - std::ostringstream oss; - oss << ""; - return oss.str(); -} -std::ostream &operator<<(std::ostream &s, OpTaskSignature const &x) { - return s << fmt::to_string(x); -} - -} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/op_task_to_task_invocation.cc b/lib/task-spec/src/task-spec/op_task_to_task_invocation.cc deleted file mode 100644 index b33edc9a76..0000000000 --- a/lib/task-spec/src/task-spec/op_task_to_task_invocation.cc +++ /dev/null @@ -1,162 +0,0 @@ -#include "task-spec/op_task_to_task_invocation.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "pcg/cg_operator_tensor_shape_signature.h" -#include "pcg/computation_graph.h" -#include "task-spec/slot_grad_id.dtg.h" -#include "task-spec/training_layer_plus_context.h" -#include "task-spec/training_layer_tensor_group_signature.h" -#include "utils/containers/map_values.h" -#include "utils/containers/transform.h" -#include "utils/overload.h" - -namespace FlexFlow { - -TaskInvocation - lower_to_task_invocation(OpTaskInvocation const &op_task_invocation, - TrainingLayerPlusContext const &training_layer, - std::optional const - &device_specific_device_states) { - - std::unordered_map - tensor_bindings = - transform(op_task_invocation.binding.get_tensor_bindings(), - [&](SlotGradId const &slot_grad_id, - OpTensorSpec const &op_tensor_spec) { - return lower_tensor_binding( - get_tensor_group_signature(training_layer), - slot_grad_id, - op_tensor_spec); - }); - - std::unordered_map arg_bindings = map_values( - op_task_invocation.binding.get_arg_bindings(), - [&](OpArgSpec const &op_arg_spec) { - return lower_to_task_arg_spec(op_arg_spec, - get_cg_op_shape_signature(training_layer), - training_layer.layer_guid, - device_specific_device_states); - }); - - return TaskInvocation{ - op_task_invocation.task_id, - TaskBinding{ - tensor_bindings, - arg_bindings, - }, - }; -} - -std::pair - lower_tensor_binding(TrainingLayerTensorGroupSignature const &signature, - SlotGradId const &slot_grad_id, - OpTensorSpec const &op_tensor_spec) { - auto [tensor_to_bind, gradient_tensor_guid_to_bind] = [&] { - TrainingTensorGroup group = get_training_tensor_group_for_role_and_index( - signature, op_tensor_spec.role, op_tensor_spec.idx); - - return std::pair{ - group.forward_tensor, - group.gradient_tensor, - }; - }(); - - if (slot_grad_id.is_grad == IsGrad::NO) { - return std::pair{ - tensor_sub_slot_id_t{ - slot_grad_id.slot_id, - TensorType::FORWARD, - }, - training_tensor_guid_t{ - tensor_to_bind, - }, - }; - } else if (slot_grad_id.is_grad == IsGrad::YES) { - return std::pair{ - tensor_sub_slot_id_t{ - slot_grad_id.slot_id, - TensorType::GRADIENT, - }, - training_tensor_guid_t{ - gradient_tensor_guid_to_bind, - }, - }; - } else { - PANIC("Invalid value for IsGrad {}", slot_grad_id.is_grad); - } -} - -TaskArgSpec lower_to_task_arg_spec( - OpArgSpec const &op_arg_spec, - CGOperatorTensorShapeSignature const &op_shape_signature, - layer_guid_t const &layer_guid, - std::optional const - &device_specific_device_states) { - return op_arg_spec.visit(overload{ - [](ConcreteArgSpec const &concrete_arg_spec) { - return TaskArgSpec{concrete_arg_spec}; - }, - [](RuntimeArgRefSpec const &runtime_arg_ref_spec) { - return TaskArgSpec{runtime_arg_ref_spec}; - }, - [&](OpArgRefSpec const &op_arg_ref_spec) { - return TaskArgSpec{ - lower_to_concrete_arg_spec(op_arg_ref_spec, - op_shape_signature, - layer_guid, - device_specific_device_states), - }; - }, - }); -} - -ConcreteArgSpec lower_to_concrete_arg_spec( - OpArgRefSpec const &op_arg_ref_spec, - CGOperatorTensorShapeSignature const &op_signature, - layer_guid_t const &op_guid, - std::optional const &device_states) { - - OpArgRefType op_arg_ref_type = op_arg_ref_spec.get_ref_type(); - return op_arg_ref_type.visit(overload{ - [&](PerDeviceOpStateRefType const &) { - PerDeviceOpState per_device_op_state = - get_device_state_from_device_specific(device_states.value(), 0); - - return per_device_op_state.visit(overload{ - [&](auto const &x) { - ASSERT(matches(op_arg_ref_spec.get_type_index())); - return ConcreteArgSpec::create(x); - }, - }); - }, - [&](ParallelTensorShapeRefType const &ref_type) { - TensorShape tensor_shape = tensor_shape_for_role_and_index( - /*signature=*/op_signature, - /*tensor_role=*/ref_type.tensor_role, - /*index=*/ref_type.idx); - ParallelTensorShape shape = lift_to_parallel(tensor_shape); - return ConcreteArgSpec::create(shape); - }, - }); -} - -ConcreteArgSpec - lower_to_concrete_arg_spec(RuntimeArgRefSpec const &runtime_arg_ref_spec, - RuntimeArgConfig const &runtime_arg_config) { - switch (runtime_arg_ref_spec.get_ref_type()) { - case RuntimeArgRefType::FF_HANDLE: - return ConcreteArgSpec::create(*(runtime_arg_config.ff_handle.get(0))); - case RuntimeArgRefType::PROFILING_SETTINGS: - return ConcreteArgSpec::create(runtime_arg_config.profiling_settings); - case RuntimeArgRefType::FF_ITERATION_CONFIG: - PANIC("FF_ITERATION_CONFIG is currently not handled. Please create an " - "issue or contact the FlexFlow train developers if you need this " - "feature."); - case RuntimeArgRefType::KERNEL_DEVICE_TYPE: - return ConcreteArgSpec::create(runtime_arg_config.kernel_device_type); - default: - PANIC(fmt::format("Unhandled RuntimeArgRefType {}", - runtime_arg_ref_spec.get_ref_type())); - } -} - -} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/attention.cc b/lib/task-spec/src/task-spec/ops/attention.cc deleted file mode 100644 index ea2282792a..0000000000 --- a/lib/task-spec/src/task-spec/ops/attention.cc +++ /dev/null @@ -1,290 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "task-spec/ops/attention.h" -#include "kernels/attention_kernels.h" -#include "kernels/device_handle_t.dtg.h" -#include "op-attrs/ops/attention.h" -#include "op-attrs/ops/attention/multihead_attention_parallel_inputs.h" -#include "task-spec/op_task_signature.h" -#include "task-spec/profiling.h" - -namespace FlexFlow { - -using namespace FlexFlow::Kernels::MultiHeadAttention; - -enum Slots { - QUERY_PARALLEL_TENSOR_SHAPE, - KEY_PARALLEL_TENSOR_SHAPE, - VALUE_PARALLEL_TENSOR_SHAPE, - QPROJSIZE, - KPROJSIZE, - VPROJSIZE, - OPROJSIZE, - ATTRS, - PROFILING, - QUERY, - KEY, - VALUE, - WEIGHTS, - OUTPUT, - HANDLE, - PER_DEVICE_STATE, - KERNEL_DEVICE_TYPE, -}; - -OpTaskInvocation init(MultiHeadAttentionAttrs const &attrs) { - OpTaskBinding b; - - b.bind_arg(HANDLE, ff_handle()); - b.bind_arg(ATTRS, attrs); - - b.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - - b.bind_arg(QUERY_PARALLEL_TENSOR_SHAPE, input_parallel_tensor_shape(0_n)); - b.bind_arg(KEY_PARALLEL_TENSOR_SHAPE, input_parallel_tensor_shape(1_n)); - b.bind_arg(VALUE_PARALLEL_TENSOR_SHAPE, input_parallel_tensor_shape(2_n)); - - b.bind_arg(QPROJSIZE, get_qProjSize(attrs)); - b.bind_arg(KPROJSIZE, get_kProjSize(attrs)); - b.bind_arg(VPROJSIZE, get_vProjSize(attrs)); - b.bind_arg(OPROJSIZE, get_oProjSize(attrs)); - - return OpTaskInvocation{ - task_id_t::ATTENTION_INIT_TASK_ID, - b, - }; -} - -OpTaskInvocation forward(MultiHeadAttentionAttrs const &attrs) { - OpTaskBinding b; - - b.bind(QUERY, input_tensor(0_n)); - b.bind(KEY, input_tensor(1_n)); - b.bind(VALUE, input_tensor(2_n)); - b.bind(WEIGHTS, weight_tensor(0_n)); - b.bind(OUTPUT, output_tensor(0_n)); - - b.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - b.bind_arg(PROFILING, profiling_settings()); - b.bind_arg(PER_DEVICE_STATE, - per_device_op_state>()); - - return OpTaskInvocation{ - task_id_t::ATTENTION_FWD_TASK_ID, - b, - }; -} - -OpTaskInvocation backward(MultiHeadAttentionAttrs const &attrs) { - OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - - return OpTaskInvocation{ - task_id_t::ATTENTION_BWD_TASK_ID, - b, - }; -} - -static DeviceSpecificDeviceStates - init_task_impl(TaskArgumentAccessor const &acc) { - auto const &attrs = acc.get_argument(ATTRS); - Allocator allocator = acc.get_allocator(); - - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - - positive_int qProjSize = acc.get_argument(QPROJSIZE); - positive_int kProjSize = acc.get_argument(KPROJSIZE); - positive_int vProjSize = acc.get_argument(VPROJSIZE); - positive_int oProjSize = acc.get_argument(OPROJSIZE); - - device_handle_t handle = acc.get_argument(HANDLE); - ParallelTensorShape query_parallel_tensor_shape = - acc.get_argument(QUERY_PARALLEL_TENSOR_SHAPE); - ParallelTensorShape key_parallel_tensor_shape = - acc.get_argument(KEY_PARALLEL_TENSOR_SHAPE); - ParallelTensorShape value_parallel_tensor_shape = - acc.get_argument(VALUE_PARALLEL_TENSOR_SHAPE); - - MultiHeadAttentionParallelInputs parsed = throw_if_unexpected( - parse_attention_parallel_input_shape(query_parallel_tensor_shape, - key_parallel_tensor_shape, - value_parallel_tensor_shape)); - ParallelTensorShape weight_parallel_tensor_shape = - throw_if_unexpected(get_weights_shape(attrs, - query_parallel_tensor_shape, - key_parallel_tensor_shape, - value_parallel_tensor_shape)); - - positive_int kvSeqLength = get_kvSeqLength(parsed); - positive_int qSize = get_qSize(parsed); - positive_int kSize = get_kSize(parsed); - positive_int vSize = get_vSize(parsed); - - positive_int qoSeqLength = get_qoSeqLength(parsed); - positive_int num_samples = get_num_samples(parsed); - positive_int num_heads = attrs.num_heads; - - std::optional per_device_state = init_kernel( - /*device_type=*/kernel_device_type, - /*per_device_ff_handle=*/handle, - /*allocator=*/allocator, - /*num_samples=*/num_samples.int_from_positive_int(), - /*num_heads=*/num_heads.int_from_positive_int(), - /*qSize=*/qSize.int_from_positive_int(), - /*kSize=*/kSize.int_from_positive_int(), - /*vSize=*/vSize.int_from_positive_int(), - /*qProjSize=*/qProjSize.int_from_positive_int(), - /*kProjSize=*/kProjSize.int_from_positive_int(), - /*vProjSize=*/vProjSize.int_from_positive_int(), - /*oProjSize=*/oProjSize.int_from_positive_int(), - /*qoSeqLength=*/qoSeqLength.int_from_positive_int(), - /*kvSeqLength=*/kvSeqLength.int_from_positive_int(), - /*add_bias_kv=*/attrs.add_bias_kv); - - return DeviceSpecificDeviceStates{ - DeviceSpecific>::create( - per_device_state), - }; -} - -static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { - auto query = acc.get_tensor(QUERY); - auto key = acc.get_tensor(KEY); - auto value = acc.get_tensor(VALUE); - auto weight = acc.get_tensor(WEIGHTS); - auto output = acc.get_tensor(OUTPUT); - - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - std::optional per_device_state = - acc.get_argument>(PER_DEVICE_STATE); - - return profile(forward_kernel, - profiling, - kernel_device_type, - "[MultiHeadAttention] forward_time = {:.2lf}ms\n", - per_device_state, - query.get_float_ptr(), - key.get_float_ptr(), - value.get_float_ptr(), - weight.get_float_ptr(), - output.get_float_ptr()); -} - -static std::optional - backward_task_impl(TaskArgumentAccessor const &acc) { - auto query = acc.get_tensor(QUERY); - auto key = acc.get_tensor(KEY); - auto value = acc.get_tensor(VALUE); - auto weight = acc.get_tensor(WEIGHTS); - - auto output_grad = acc.get_tensor_grad(OUTPUT); - auto weight_grad = acc.get_tensor_grad(WEIGHTS); - auto query_grad = acc.get_tensor_grad(QUERY); - auto key_grad = acc.get_tensor_grad(KEY); - auto value_grad = acc.get_tensor_grad(VALUE); - - std::optional per_device_state = - acc.get_argument>(PER_DEVICE_STATE); - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - - float *key_grad_ptr = - (key_grad == query_grad) ? nullptr : key_grad.get_float_ptr(); - float *value_grad_ptr = (value_grad == query_grad || value_grad == key_grad) - ? nullptr - : value_grad.get_float_ptr(); - - ASSERT(value_grad.shape == value.shape); - ASSERT(key_grad.shape == key.shape); - - ASSERT(query_grad.shape == query.shape); - ASSERT(weight_grad.shape == weight.shape); - - return profile(backward_kernel, - profiling, - kernel_device_type, - "[MultiHeadAttention] backward_time = {:.2lf}ms\n", - per_device_state, - query.get_float_ptr(), - query_grad.get_float_ptr(), - key.get_float_ptr(), - key_grad_ptr, - value.get_float_ptr(), - value_grad_ptr, - weight.get_float_ptr(), - weight_grad.get_float_ptr(), - output_grad.get_float_ptr()); -} - -TaskImplFunction get_attention_init_task_impl() { - return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; -} -TaskImplFunction get_attention_fwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; -} -TaskImplFunction get_attention_bwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; -} - -OpTaskSignature get_attention_init_signature() { - OpTaskSignature init(OpTaskType::INIT); - init.add_arg_slot(QUERY_PARALLEL_TENSOR_SHAPE); - init.add_arg_slot(KEY_PARALLEL_TENSOR_SHAPE); - init.add_arg_slot(VALUE_PARALLEL_TENSOR_SHAPE); - init.add_arg_slot(QPROJSIZE); - init.add_arg_slot(KPROJSIZE); - init.add_arg_slot(VPROJSIZE); - init.add_arg_slot(OPROJSIZE); - init.add_arg_slot(ATTRS); - init.add_unchecked_arg_slot(HANDLE); - - init.add_return_value>(); - - return init; -} - -OpTaskSignature get_attention_fwd_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_input_slot(QUERY); - fwd.add_input_slot(KEY); - fwd.add_input_slot(VALUE); - fwd.add_weight_slot(WEIGHTS); - fwd.add_output_slot(OUTPUT); - - fwd.add_arg_slot(PROFILING); - fwd.add_unchecked_arg_slot>( - PER_DEVICE_STATE); - - return fwd; -} - -OpTaskSignature get_attention_bwd_signature() { - OpTaskSignature bwd = infer_bwd_signature(get_attention_fwd_signature()); - - return bwd; -} - -std::vector get_task_ids(MultiHeadAttentionAttrs const &) { - return {task_id_t::ATTENTION_INIT_TASK_ID, - task_id_t::ATTENTION_FWD_TASK_ID, - task_id_t::ATTENTION_BWD_TASK_ID}; -} - -} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/batch_matmul.cc b/lib/task-spec/src/task-spec/ops/batch_matmul.cc deleted file mode 100644 index f8d6955b41..0000000000 --- a/lib/task-spec/src/task-spec/ops/batch_matmul.cc +++ /dev/null @@ -1,216 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "task-spec/ops/batch_matmul.h" -#include "kernels/batch_matmul_kernels.h" -#include "op-attrs/ops/batch_matmul.h" -#include "task-spec/op_task_signature.h" -#include "task-spec/profiling.h" -#include "utils/containers/transform.h" -#include "utils/nonnegative_int/nonnegative_range.h" - -namespace FlexFlow { - -using namespace FlexFlow::Kernels::BatchMatmul; - -enum Slots { - A_INPUT, // tensor - B_INPUT, // tensor - ATTRS, - OUTPUT, // tensor - PROFILING, - HANDLE, - ITERATION_CONFIG, - KERNEL_DEVICE_TYPE, -}; - -OpTaskInvocation forward(BatchMatmulAttrs const &attrs) { - OpTaskBinding fwd; - - fwd.bind(A_INPUT, input_tensor(0_n)); - fwd.bind(B_INPUT, input_tensor(1_n)); - fwd.bind(OUTPUT, output_tensor(0_n)); - - fwd.bind_arg(ATTRS, attrs); - fwd.bind_arg(HANDLE, ff_handle()); - fwd.bind_arg(PROFILING, profiling_settings()); - fwd.bind_arg(ITERATION_CONFIG, iteration_config()); - fwd.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - - return OpTaskInvocation{ - task_id_t::BATCHMATMUL_FWD_TASK_ID, - fwd, - }; -} - -OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { - OpTaskBinding bwd = infer_bwd_binding(forward(attrs).binding); - - return OpTaskInvocation{ - task_id_t::BATCHMATMUL_BWD_TASK_ID, - bwd, - }; -} - -static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { - auto a_input = acc.get_tensor(A_INPUT); - auto b_input = acc.get_tensor(B_INPUT); - auto output = acc.get_tensor(OUTPUT); - auto attrs = acc.get_argument(ATTRS); - device_handle_t handle = acc.get_argument(HANDLE); - - ProfilingSettings profiling = acc.get_argument(PROFILING); - FFIterationConfig iter_config = - acc.get_argument(ITERATION_CONFIG); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - - positive_int m = dim_at_idx(b_input.shape.dims, legion_dim_t{0_n}); - ASSERT(m == dim_at_idx(output.shape.dims, legion_dim_t{0_n})); - positive_int n = dim_at_idx(a_input.shape.dims, legion_dim_t{1_n}); - ASSERT(n == dim_at_idx(output.shape.dims, legion_dim_t{1_n})); - positive_int k = dim_at_idx(a_input.shape.dims, legion_dim_t{0_n}); - ASSERT(k == dim_at_idx(b_input.shape.dims, legion_dim_t{1_n})); - - ASSERT(get_num_elements(a_input.shape.dims) == - get_num_elements(b_input.shape.dims)); - ASSERT(get_num_elements(a_input.shape.dims) == - get_num_elements(output.shape.dims)); - - positive_int batch = 1_p; - for (nonnegative_int i : - nonnegative_range(2_n, get_num_dims(a_input.shape.dims))) { - positive_int dim_size = dim_at_idx(a_input.shape.dims, legion_dim_t{i}); - ASSERT(dim_size == dim_at_idx(b_input.shape.dims, legion_dim_t{i})); - ASSERT(dim_size == dim_at_idx(output.shape.dims, legion_dim_t{i})); - batch *= dim_size; - } - - auto get_raw_seq_len = [](std::optional seq_len) -> int { - return transform(seq_len, - [](nonnegative_int x) { return x.unwrap_nonnegative(); }) - .value_or(-1); - }; - - return profile(forward_kernel, - profiling, - kernel_device_type, - "[BatchMatmul] forward_time = {:.2lf}ms\n", - handle, - output.get_float_ptr(), - a_input.get_float_ptr(), - b_input.get_float_ptr(), - m.int_from_positive_int(), - n.int_from_positive_int(), - k.int_from_positive_int(), - batch.int_from_positive_int(), - get_raw_seq_len(attrs.a_seq_length_dim), - get_raw_seq_len(attrs.b_seq_length_dim), - iter_config.seq_length); -} - -static std::optional - backward_task_impl(TaskArgumentAccessor const &acc) { - // BatchMatmul* bmm = (BatchMatmul*) task->args; - FFIterationConfig iter_config = - acc.get_argument(ITERATION_CONFIG); - ProfilingSettings profiling = acc.get_argument(PROFILING); - device_handle_t handle = acc.get_argument(HANDLE); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - - auto output = acc.get_tensor(OUTPUT); - auto output_grad = acc.get_tensor_grad(OUTPUT); - ASSERT(output.shape == output_grad.shape); - - auto a_input = acc.get_tensor(A_INPUT); - auto a_input_grad = acc.get_tensor_grad(A_INPUT); - ASSERT(a_input.shape == a_input_grad.shape); - - auto b_input = acc.get_tensor(B_INPUT); - auto b_input_grad = acc.get_tensor_grad(B_INPUT); - ASSERT(b_input.shape == b_input_grad.shape); - - // check dins - positive_int m = dim_at_idx(b_input.shape.dims, legion_dim_t{0_n}); - ASSERT(m == dim_at_idx(output.shape.dims, legion_dim_t{0_n})); - positive_int n = dim_at_idx(a_input.shape.dims, legion_dim_t{1_n}); - ASSERT(n == dim_at_idx(output.shape.dims, legion_dim_t{1_n})); - positive_int k = dim_at_idx(a_input.shape.dims, legion_dim_t{0_n}); - ASSERT(k == dim_at_idx(b_input.shape.dims, legion_dim_t{1_n})); - ASSERT(get_num_elements(a_input.shape.dims) == - get_num_elements(b_input.shape.dims)); - ASSERT(get_num_elements(a_input.shape.dims) == - get_num_elements(output.shape.dims)); - - positive_int batch = 1_p; - for (nonnegative_int i : - nonnegative_range(2_n, get_num_dims(a_input.shape.dims))) { - positive_int dim_size = dim_at_idx(a_input.shape.dims, legion_dim_t{i}); - ASSERT(dim_size == dim_at_idx(b_input.shape.dims, legion_dim_t{i})); - ASSERT(dim_size == dim_at_idx(output.shape.dims, legion_dim_t{i})); - batch *= dim_size; - } - - return profile(backward_kernel, - profiling, - kernel_device_type, - "[BatchMatmul] backward_time = {:.2lf}ms\n", - handle, - output.get_float_ptr(), - output_grad.get_float_ptr(), - a_input.get_float_ptr(), - a_input_grad.get_float_ptr(), - b_input.get_float_ptr(), - b_input_grad.get_float_ptr(), - m.int_from_positive_int(), - n.int_from_positive_int(), - k.int_from_positive_int(), - batch.int_from_positive_int()); -} - -TaskImplFunction get_batch_matmul_fwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; -} -TaskImplFunction get_batch_matmul_bwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; -} - -OpTaskSignature get_batch_matmul_fwd_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_input_slot(A_INPUT); - fwd.add_input_slot(B_INPUT); - fwd.add_output_slot(OUTPUT); - fwd.add_arg_slot(ATTRS); - fwd.add_arg_slot(PROFILING); - fwd.add_arg_slot(KERNEL_DEVICE_TYPE); - fwd.add_unchecked_arg_slot(HANDLE); - - return fwd; -} - -OpTaskSignature get_batch_matmul_bwd_signature() { - OpTaskSignature bwd = infer_bwd_signature(get_batch_matmul_fwd_signature()); - - return bwd; -} - -std::vector get_task_ids(BatchMatmulAttrs const &) { - return {task_id_t::BATCHMATMUL_FWD_TASK_ID, - task_id_t::BATCHMATMUL_BWD_TASK_ID}; -} - -}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/batch_norm.cc b/lib/task-spec/src/task-spec/ops/batch_norm.cc deleted file mode 100644 index 0599eec3f5..0000000000 --- a/lib/task-spec/src/task-spec/ops/batch_norm.cc +++ /dev/null @@ -1,223 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "task-spec/ops/batch_norm.h" -#include "kernels/batch_norm_kernels.h" -#include "task-spec/profiling.h" - -namespace FlexFlow { - -using namespace FlexFlow::Kernels::BatchNorm; - -enum Slots { - INPUT, - SCALE, - BIAS, - OUTPUT, - ATTRS, - PROFILING, - PER_DEVICE_STATE, - RELU, - HANDLE, - KERNEL_DEVICE_TYPE, -}; - -OpTaskInvocation init(BatchNormAttrs const &attrs) { - OpTaskBinding binding; - - binding.bind(INPUT, input_tensor(0_n)); - binding.bind(BIAS, weight_tensor(1_n)); - binding.bind(OUTPUT, output_tensor(0_n)); - - binding.bind_arg(ATTRS, attrs); - binding.bind_arg(PROFILING, profiling_settings()); - binding.bind_arg(HANDLE, ff_handle()); - binding.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - - return OpTaskInvocation{ - task_id_t::BATCHNORM_INIT_TASK_ID, - binding, - }; -} - -OpTaskInvocation forward(BatchNormAttrs const &attrs) { - OpTaskBinding binding; - binding.bind_arg(PROFILING, profiling_settings()); - binding.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - binding.bind_arg( - PER_DEVICE_STATE, - per_device_op_state>()); - - binding.bind(INPUT, input_tensor(0_n)); - binding.bind(SCALE, weight_tensor(0_n)); - binding.bind(BIAS, weight_tensor(1_n)); - binding.bind(OUTPUT, output_tensor(0_n)); - - return OpTaskInvocation{ - task_id_t::BATCHNORM_FWD_TASK_ID, - binding, - }; -} - -OpTaskInvocation backward(BatchNormAttrs const &attrs) { - OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - - return OpTaskInvocation{ - task_id_t::BATCHNORM_BWD_TASK_ID, - binding, - }; -} - -static DeviceSpecificDeviceStates - init_task_impl(TaskArgumentAccessor const &acc) { - Allocator allocator = acc.get_allocator(); - device_handle_t handle = acc.get_argument(HANDLE); - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - - auto output = acc.get_tensor(OUTPUT); - auto const &attrs = acc.get_argument(ATTRS); - - positive_int output_w = dim_at_idx(output.shape.dims, legion_dim_t{0_n}); - positive_int output_h = dim_at_idx(output.shape.dims, legion_dim_t{1_n}); - positive_int output_c = dim_at_idx(output.shape.dims, legion_dim_t{2_n}); - positive_int output_n = dim_at_idx(output.shape.dims, legion_dim_t{3_n}); - - float *runningMean; - - std::optional per_device_state = init_kernel( - /*device_type=*/kernel_device_type, - /*handle=*/handle, - /*allocator=*/allocator, - /*runningMean=*/runningMean, - /*output_n=*/output_n.int_from_positive_int(), - /*output_c=*/output_c.int_from_positive_int(), - /*output_h=*/output_h.int_from_positive_int(), - /*output_w=*/output_w.int_from_positive_int(), - /*relu=*/attrs.relu); - - return DeviceSpecificDeviceStates{ - DeviceSpecific>::create( - per_device_state), - }; -} - -static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { - auto per_device_state = - acc.get_argument(PER_DEVICE_STATE); - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - auto scale = acc.get_tensor(SCALE); - auto bias = acc.get_tensor(SCALE); - - return profile(forward_kernel, - profiling, - kernel_device_type, - "[BatchNorm] forward_time = {:.2lf}ms\n", - per_device_state, - input.get_float_ptr(), - output.get_float_ptr(), - scale.get_float_ptr(), - bias.get_float_ptr()); -} - -static std::optional - backward_task_impl(TaskArgumentAccessor const &acc) { - auto per_device_state = - acc.get_argument(PER_DEVICE_STATE); - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - - auto input = acc.get_tensor(INPUT); - auto input_grad = acc.get_tensor_grad(INPUT); - auto output = acc.get_tensor(OUTPUT); - auto output_grad = acc.get_tensor_grad(OUTPUT); - auto scale = acc.get_tensor(SCALE); - auto scale_grad = acc.get_tensor_grad(SCALE); - auto bias_grad = acc.get_tensor_grad(BIAS); - - return profile(backward_kernel, - profiling, - kernel_device_type, - "[BatchNorm] backward_time = {:.2lf}ms\n", - per_device_state, - output.get_float_ptr(), - output_grad.get_float_ptr(), - input.get_float_ptr(), - input_grad.get_float_ptr(), - scale.get_float_ptr(), - scale_grad.get_float_ptr(), - bias_grad.get_float_ptr(), - get_num_elements(output.shape.dims).int_from_positive_int()); -} - -TaskImplFunction get_batch_norm_init_task_impl() { - return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; -} -TaskImplFunction get_batch_norm_fwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; -} -TaskImplFunction get_batch_norm_bwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; -} - -OpTaskSignature get_batch_norm_init_signature() { - OpTaskSignature init(OpTaskType::INIT); - - init.add_input_slot(INPUT); - init.add_input_slot(BIAS); - init.add_output_slot(OUTPUT); - init.add_arg_slot(ATTRS); - init.add_arg_slot(PROFILING); - init.add_unchecked_arg_slot(HANDLE); - - return init; -} - -OpTaskSignature get_batch_norm_fwd_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_input_slot(INPUT); - fwd.add_input_slot(SCALE); - fwd.add_input_slot(BIAS); - fwd.add_output_slot(OUTPUT); - fwd.add_arg_slot(PROFILING); - fwd.add_arg_slot(KERNEL_DEVICE_TYPE); - fwd.add_unchecked_arg_slot>( - PER_DEVICE_STATE); - - return fwd; -} -OpTaskSignature get_batch_norm_bwd_signature() { - OpTaskSignature bwd = infer_bwd_signature(get_batch_norm_fwd_signature()); - - return bwd; -} - -std::vector get_task_ids(BatchNormAttrs const &) { - return { - task_id_t::BATCHNORM_INIT_TASK_ID, - task_id_t::BATCHNORM_FWD_TASK_ID, - task_id_t::BATCHNORM_BWD_TASK_ID, - }; -} - -}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/cast.cc b/lib/task-spec/src/task-spec/ops/cast.cc deleted file mode 100644 index 0c00f1be58..0000000000 --- a/lib/task-spec/src/task-spec/ops/cast.cc +++ /dev/null @@ -1,120 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "task-spec/ops/cast.h" -#include "kernels/cast_kernels.h" -#include "task-spec/op_task_signature.h" -#include "task-spec/profiling.h" -#include "utils/hash-utils.h" - -using namespace FlexFlow::Kernels::Cast; - -namespace FlexFlow { - -enum Slots { INPUT, OUTPUT, ATTRS, PROFILING, KERNEL_DEVICE_TYPE }; - -OpTaskInvocation forward(CastAttrs const &attrs) { - OpTaskBinding binding; - - binding.bind_arg(PROFILING, profiling_settings()); - binding.bind_arg(ATTRS, attrs); - binding.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - - binding.bind(INPUT, input_tensor(0_n)); - binding.bind(OUTPUT, output_tensor(0_n)); - - return OpTaskInvocation{ - task_id_t::CAST_FWD_TASK_ID, - binding, - }; -} - -OpTaskInvocation backward(CastAttrs const &attrs) { - OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - - return OpTaskInvocation{ - task_id_t::CAST_BWD_TASK_ID, - binding, - }; -} - -static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto const &attrs = acc.get_argument(ATTRS); - - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - - return profile(forward_kernel, - profiling, - kernel_device_type, - "[Cast] forward_time = {:.2lf}ms\n", - input, - output); -} - -static std::optional - backward_task_impl(TaskArgumentAccessor const &acc) { - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto const &attrs = acc.get_argument(ATTRS); - - auto input = acc.get_tensor(INPUT); - - auto input_grad = acc.get_tensor_grad(INPUT); - auto output_grad = acc.get_tensor_grad(OUTPUT); - - return profile(backward_kernel, - profiling, - kernel_device_type, - "[Cast] forward_time = {:.2lf}ms\n", - input_grad, - output_grad); -} - -TaskImplFunction get_cast_fwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; -} -TaskImplFunction get_cast_bwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; -} - -OpTaskSignature get_cast_fwd_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_arg_slot(ATTRS); - fwd.add_arg_slot(PROFILING); - fwd.add_arg_slot(KERNEL_DEVICE_TYPE); - - fwd.add_input_slot(INPUT); - fwd.add_output_slot(OUTPUT); - - return fwd; -} - -OpTaskSignature get_cast_bwd_signature() { - OpTaskSignature bwd = infer_bwd_signature(get_cast_fwd_signature()); - - return bwd; -} - -std::vector get_task_ids(CastAttrs const &) { - return {task_id_t::CAST_FWD_TASK_ID, task_id_t::CAST_BWD_TASK_ID}; -} - -}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/concat.cc b/lib/task-spec/src/task-spec/ops/concat.cc deleted file mode 100644 index 26aa64f6ec..0000000000 --- a/lib/task-spec/src/task-spec/ops/concat.cc +++ /dev/null @@ -1,130 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "task-spec/ops/concat.h" -#include "kernels/concat_kernels.h" -#include "task-spec/op_task_signature.h" -#include "task-spec/profiling.h" -#include "task-spec/variadic_tensor_ref.h" -#include "utils/hash-utils.h" - -namespace FlexFlow { - -using namespace FlexFlow::Kernels::Concat; - -enum Slots { - INPUTS, - OUTPUT, - ATTRS, - PROFILING, - HANDLE, - NUM_INPUTS, - KERNEL_DEVICE_TYPE -}; - -OpTaskInvocation forward(ConcatAttrs const &attrs) { - OpTaskBinding binding; - binding.bind(INPUTS, get_input_tensors()); - binding.bind(OUTPUT, output_tensor(0_n)); - binding.bind_arg(PROFILING, profiling_settings()); - binding.bind_arg(ATTRS, attrs); - binding.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - - return OpTaskInvocation{ - task_id_t::CONCAT_FWD_TASK_ID, - binding, - }; -} - -OpTaskInvocation backward(ConcatAttrs const &attrs) { - OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - - return OpTaskInvocation{ - task_id_t::CONCAT_BWD_TASK_ID, - b, - }; -} - -static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto const &attrs = acc.get_argument(ATTRS); - - auto output = acc.get_tensor(OUTPUT); - auto inputs = acc.get_variadic_tensor(INPUTS); - - assert(inputs.size() <= MAX_NUM_INPUTS); - - return profile(forward_kernel, - profiling, - kernel_device_type, - "[Concat] forward_time = {:.2lf}ms\n", - output, - inputs, - attrs.axis); -} - -static std::optional - backward_task_impl(TaskArgumentAccessor const &acc) { - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto const &attrs = acc.get_argument(ATTRS); - - auto input_grads = acc.get_variadic_tensor_grad(INPUTS); - auto output_grad = acc.get_tensor_grad(OUTPUT); - - assert(input_grads.size() <= MAX_NUM_INPUTS); - - return profile(backward_kernel, - profiling, - kernel_device_type, - "[Concat] backward_time = {:.2lf}ms\n", - output_grad, - input_grads, - attrs.axis); -} - -TaskImplFunction get_concat_fwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; -} -TaskImplFunction get_concat_bwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; -} - -OpTaskSignature get_concat_fwd_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_arg_slot(ATTRS); - fwd.add_arg_slot(PROFILING); - fwd.add_arg_slot(KERNEL_DEVICE_TYPE); - fwd.add_input_slot(INPUTS, SlotType::VARIADIC); - fwd.add_output_slot(OUTPUT); - - return fwd; -} - -OpTaskSignature get_concat_bwd_signature() { - OpTaskSignature bwd = infer_bwd_signature(get_concat_fwd_signature()); - - return bwd; -} - -std::vector get_task_ids(ConcatAttrs const &) { - return {task_id_t::CONCAT_FWD_TASK_ID, task_id_t::CONCAT_BWD_TASK_ID}; -} - -}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/conv_2d.cc b/lib/task-spec/src/task-spec/ops/conv_2d.cc deleted file mode 100644 index d7110eabfa..0000000000 --- a/lib/task-spec/src/task-spec/ops/conv_2d.cc +++ /dev/null @@ -1,211 +0,0 @@ -#include "task-spec/ops/conv_2d.h" -#include "kernels/conv_2d_kernels.h" -#include "task-spec/profiling.h" - -namespace FlexFlow { - -using namespace FlexFlow::Kernels::Conv2D; - -enum Slots { - INPUT, - OUTPUT, - FILTER, - BIAS, - ATTRS, - PROFILING, - PER_DEVICE_STATE, - HANDLE, - KERNEL_DEVICE_TYPE, -}; - -OpTaskInvocation init(Conv2DAttrs const &attrs) { - OpTaskBinding binding; - - binding.bind(INPUT, input_tensor(0_n)); - binding.bind(OUTPUT, output_tensor(0_n)); - binding.bind(FILTER, weight_tensor(0_n)); - binding.bind_arg(ATTRS, attrs); - binding.bind_arg(HANDLE, ff_handle()); - binding.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - - return OpTaskInvocation{ - task_id_t::CONV2D_INIT_TASK_ID, - binding, - }; -} - -OpTaskInvocation forward(Conv2DAttrs const &attrs) { - OpTaskBinding binding; - - binding.bind_arg(ATTRS, attrs); - binding.bind_arg(PROFILING, profiling_settings()); - binding.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - binding.bind_arg(PER_DEVICE_STATE, - per_device_op_state>()); - - binding.bind(INPUT, input_tensor(0_n)); - binding.bind(OUTPUT, output_tensor(0_n)); - binding.bind(FILTER, weight_tensor(0_n)); - binding.bind(BIAS, weight_tensor(1_n)); - - return OpTaskInvocation{ - task_id_t::CONV2D_FWD_TASK_ID, - binding, - }; -} - -OpTaskInvocation backward(Conv2DAttrs const &attrs) { - OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - - return OpTaskInvocation{ - task_id_t::CONV2D_BWD_TASK_ID, - binding, - }; -} - -static DeviceSpecificDeviceStates - init_task_impl(TaskArgumentAccessor const &acc) { - - device_handle_t handle = acc.get_argument(HANDLE); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto attrs = acc.get_argument(ATTRS); - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - auto filter = acc.get_tensor(FILTER); - auto filter_grad = acc.get_tensor_grad(FILTER); - - std::optional per_device_state = init_kernel( - /*device_type=*/kernel_device_type, - /*handle=*/handle, - /*activation=*/attrs.activation, - /*kernel_h=*/attrs.kernel_h.int_from_positive_int(), - /*kernel_w=*/attrs.kernel_w.int_from_positive_int(), - /*groups=*/attrs.groups.int_from_positive_int(), - /*padding_h=*/attrs.padding_h.unwrap_nonnegative(), - /*padding_w=*/attrs.padding_w.unwrap_nonnegative(), - /*stride_h=*/attrs.stride_h.int_from_positive_int(), - /*stride_w=*/attrs.stride_w.int_from_positive_int(), - /*input=*/input, - /*output=*/output, - /*filter_ptr=*/filter.get_float_ptr(), - /*filter_grad_ptr=*/filter_grad.get_float_ptr()); - - return DeviceSpecificDeviceStates{ - DeviceSpecific>::create( - per_device_state), - }; -} - -static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto per_device_state = - acc.get_argument(PER_DEVICE_STATE); - auto attrs = acc.get_argument(ATTRS); - - auto input = acc.get_tensor(INPUT); - auto filter = acc.get_tensor(FILTER); - auto bias = acc.get_tensor(BIAS); - auto output = acc.get_tensor(OUTPUT); - - return profile(forward_kernel, - profiling, - kernel_device_type, - "[Conv2d] forward_time = {:.2lf}ms\n", - per_device_state, - input.get_float_ptr(), - output.get_float_ptr(), - filter.get_float_ptr(), - bias.get_float_ptr(), - attrs.activation); -} - -static std::optional - backward_task_impl(TaskArgumentAccessor const &acc) { - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto per_device_state = - acc.get_argument(PER_DEVICE_STATE); - auto attrs = acc.get_argument(ATTRS); - - auto output = acc.get_tensor(OUTPUT); - auto input = acc.get_tensor(INPUT); - auto filter = acc.get_tensor(FILTER); - - auto input_grad = acc.get_tensor_grad(INPUT); - auto output_grad = acc.get_tensor_grad(OUTPUT); - auto filter_grad = acc.get_tensor_grad(FILTER); - auto bias_grad = acc.get_tensor_grad(BIAS); - - return profile(backward_kernel, - profiling, - kernel_device_type, - "[Conv2d] backward_time = {:.2lf}ms\n", - per_device_state, - output.get_float_ptr(), - output_grad.get_float_ptr(), - input.get_float_ptr(), - input_grad.get_float_ptr(), - filter.get_float_ptr(), - filter_grad.get_float_ptr(), - bias_grad.get_float_ptr(), - attrs.activation); -} - -TaskImplFunction get_conv_2d_init_task_impl() { - return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; -} -TaskImplFunction get_conv_2d_fwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; -} -TaskImplFunction get_conv_2d_bwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; -} - -OpTaskSignature get_conv_2d_init_signature() { - OpTaskSignature init(OpTaskType::INIT); - - init.add_input_slot(INPUT); - init.add_output_slot(OUTPUT); - init.add_weight_slot(FILTER); - init.add_arg_slot(ATTRS); - init.add_arg_slot(KERNEL_DEVICE_TYPE); - init.add_unchecked_arg_slot(HANDLE); - - init.add_return_value(); - - return init; -} - -OpTaskSignature get_conv_2d_fwd_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_arg_slot(PROFILING); - fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); - fwd.add_arg_slot(KERNEL_DEVICE_TYPE); - fwd.add_arg_slot(ATTRS); - - fwd.add_input_slot(INPUT); - fwd.add_output_slot(OUTPUT); - fwd.add_weight_slot(FILTER); - fwd.add_weight_slot(BIAS); - - return fwd; -} - -OpTaskSignature get_conv_2d_bwd_signature() { - OpTaskSignature bwd = infer_bwd_signature(get_conv_2d_fwd_signature()); - - return bwd; -} - -std::vector get_task_ids(Conv2DAttrs const &) { - return {task_id_t::CONV2D_INIT_TASK_ID, - task_id_t::CONV2D_FWD_TASK_ID, - task_id_t::CONV2D_BWD_TASK_ID}; -} - -} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/dropout.cc b/lib/task-spec/src/task-spec/ops/dropout.cc deleted file mode 100644 index a36506984e..0000000000 --- a/lib/task-spec/src/task-spec/ops/dropout.cc +++ /dev/null @@ -1,174 +0,0 @@ -#include "task-spec/ops/dropout.h" -#include "kernels/dropout_kernels.h" -#include "task-spec/op_task_invocation.h" -#include "task-spec/op_task_signature.h" -#include "task-spec/profiling.h" -#include "utils/hash-utils.h" - -namespace FlexFlow { - -using namespace FlexFlow::Kernels::Dropout; - -enum Slots { - INPUT, - OUTPUT, - ATTRS, - PER_DEVICE_STATE, - FF_HANDLE, - PROFILING, - KERNEL_DEVICE_TYPE -}; - -OpTaskInvocation init(DropoutAttrs const &attrs) { - OpTaskBinding binding; - - binding.bind_arg(ATTRS, attrs); - binding.bind_arg(FF_HANDLE, ff_handle()); - binding.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - - binding.bind(OUTPUT, output_tensor(0_n)); - - return OpTaskInvocation{ - task_id_t::DROPOUT_INIT_TASK_ID, - binding, - }; -} - -OpTaskInvocation forward(DropoutAttrs const &attrs) { - OpTaskBinding binding; - - binding.bind(INPUT, input_tensor(0_n)); - binding.bind(OUTPUT, output_tensor(0_n)); - - binding.bind_arg(PROFILING, profiling_settings()); - binding.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - binding.bind_arg(PER_DEVICE_STATE, - per_device_op_state>()); - - return OpTaskInvocation{ - task_id_t::DROPOUT_FWD_TASK_ID, - binding, - }; -} - -OpTaskInvocation backward(DropoutAttrs const &attrs) { - OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - - return OpTaskInvocation{ - task_id_t::DROPOUT_BWD_TASK_ID, - b, - }; -} - -static DeviceSpecificDeviceStates - init_task_impl(TaskArgumentAccessor const &acc) { - auto output = acc.get_tensor(OUTPUT); - Allocator allocator = acc.get_allocator(); - device_handle_t handle = acc.get_argument(FF_HANDLE); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto const &attrs = acc.get_argument(ATTRS); - - std::optional per_device_state = - init_kernel(kernel_device_type, - handle, - attrs.rate, - attrs.seed, - output.shape, - allocator); - - return DeviceSpecificDeviceStates{ - DeviceSpecific>::create( - per_device_state), - }; -} - -static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { - auto per_device_state = - acc.get_argument>(PER_DEVICE_STATE); - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - - return profile(forward_kernel, - profiling, - kernel_device_type, - "[Dropout] forward_time = {:.2lf}ms\n", - per_device_state, - input.get_float_ptr(), - output.get_float_ptr()); -} - -static std::optional - backward_task_impl(TaskArgumentAccessor const &acc) { - auto const &attrs = acc.get_argument(ATTRS); - auto per_device_state = - acc.get_argument(PER_DEVICE_STATE); - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - - auto input_grad = acc.get_tensor_grad(INPUT); - auto output_grad = acc.get_tensor_grad(OUTPUT); - - return profile(backward_kernel, - profiling, - kernel_device_type, - "[Dropout] backward_time = {:.2lf}ms\n", - per_device_state, - output_grad.get_float_ptr(), - input_grad.get_float_ptr()); -} - -TaskImplFunction get_dropout_init_task_impl() { - return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; -} -TaskImplFunction get_dropout_fwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; -} -TaskImplFunction get_dropout_bwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; -} - -OpTaskSignature get_dropout_init_signature() { - OpTaskSignature init(OpTaskType::INIT); - - init.add_arg_slot(ATTRS); - init.add_arg_slot(KERNEL_DEVICE_TYPE); - init.add_unchecked_arg_slot(FF_HANDLE); - init.add_output_slot(OUTPUT); - - init.add_return_value>(); - - return init; -} - -OpTaskSignature get_dropout_fwd_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_unchecked_arg_slot>( - PER_DEVICE_STATE); - fwd.add_arg_slot(PROFILING); - fwd.add_arg_slot(KERNEL_DEVICE_TYPE); - - fwd.add_input_slot(INPUT); - fwd.add_output_slot(OUTPUT); - - return fwd; -} - -OpTaskSignature get_dropout_bwd_signature() { - OpTaskSignature bwd = infer_bwd_signature(get_dropout_fwd_signature()); - - return bwd; -} - -std::vector get_task_ids(DropoutAttrs const &) { - return {task_id_t::DROPOUT_INIT_TASK_ID, - task_id_t::DROPOUT_FWD_TASK_ID, - task_id_t::DROPOUT_BWD_TASK_ID}; -} - -}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/element_binary.cc b/lib/task-spec/src/task-spec/ops/element_binary.cc deleted file mode 100644 index a5f9f012fe..0000000000 --- a/lib/task-spec/src/task-spec/ops/element_binary.cc +++ /dev/null @@ -1,211 +0,0 @@ -#include "task-spec/ops/element_binary.h" -#include "kernels/element_binary_kernels.h" -#include "task-spec/profiling.h" -#include "task-spec/task_signature_impl.h" -#include "utils/hash-utils.h" - -namespace FlexFlow { - -using namespace FlexFlow::Kernels::ElementBinary; - -enum Slots { - LHS_INPUT, - RHS_INPUT, - OUTPUT, - PROFILING, - PER_DEVICE_STATE, - HANDLE, - ATTRS, - KERNEL_DEVICE_TYPE, -}; - -OpTaskInvocation init(ElementBinaryAttrs const &attrs) { - OpTaskBinding binding; - - binding.bind(LHS_INPUT, input_tensor(0_n)); - binding.bind(RHS_INPUT, input_tensor(1_n)); - binding.bind(OUTPUT, output_tensor(0_n)); - - binding.bind_arg(ATTRS, attrs); - binding.bind_arg(HANDLE, ff_handle()); - binding.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - - return OpTaskInvocation{ - task_id_t::ELEMENTBINARY_INIT_TASK_ID, - binding, - }; -} - -OpTaskInvocation forward(ElementBinaryAttrs const &attrs) { - OpTaskBinding binding; - - binding.bind(LHS_INPUT, input_tensor(0_n)); - binding.bind(RHS_INPUT, input_tensor(1_n)); - binding.bind(OUTPUT, output_tensor(0_n)); - - binding.bind_arg(ATTRS, attrs); - binding.bind_arg(PROFILING, profiling_settings()); - binding.bind_arg( - PER_DEVICE_STATE, - per_device_op_state>()); - binding.bind_arg(HANDLE, ff_handle()); - binding.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - - return OpTaskInvocation{ - task_id_t::ELEMENTBINARY_FWD_TASK_ID, - binding, - }; -} - -OpTaskInvocation backward(ElementBinaryAttrs const &attrs) { - OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - - return OpTaskInvocation{ - task_id_t::ELEMENTBINARY_BWD_TASK_ID, - b, - }; -} - -static DeviceSpecificDeviceStates - init_task_impl(TaskArgumentAccessor const &acc) { - auto input_lhs = acc.get_tensor(LHS_INPUT); - auto input_rhs = acc.get_tensor(RHS_INPUT); - auto output = acc.get_tensor(OUTPUT); - - device_handle_t handle = acc.get_argument(HANDLE); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto const &attrs = acc.get_argument(ATTRS); - - std::optional per_device_state = - init_kernel(kernel_device_type, - handle, - attrs.type, - attrs.should_broadcast_lhs, - attrs.should_broadcast_rhs, - input_lhs.shape, - input_rhs.shape, - output.shape); - - return DeviceSpecificDeviceStates{ - DeviceSpecific>::create( - per_device_state), - }; -} - -static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto per_device_state = - acc.get_argument(PER_DEVICE_STATE); - auto const &attrs = acc.get_argument(ATTRS); - - auto input_lhs = acc.get_tensor(LHS_INPUT); - auto input_rhs = acc.get_tensor(RHS_INPUT); - auto output = acc.get_tensor(OUTPUT); - device_handle_t handle = acc.get_argument(HANDLE); - - return profile(forward_kernel, - profiling, - kernel_device_type, - "[ElementBinary] forward_time = {:.2lf}ms\n", - per_device_state, - input_lhs.get_float_ptr(), - input_rhs.get_float_ptr(), - output.get_float_ptr(), - attrs.type, - attrs.should_broadcast_lhs, - handle); -} - -static std::optional - backward_task_impl(TaskArgumentAccessor const &acc) { - auto per_device_state = - acc.get_argument(PER_DEVICE_STATE); - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto const &attrs = acc.get_argument(ATTRS); - device_handle_t handle = acc.get_argument(HANDLE); - - auto input_lhs = acc.get_tensor(LHS_INPUT); - auto input_rhs = acc.get_tensor(RHS_INPUT); - - auto output_grad = acc.get_tensor_grad(OUTPUT); - auto input_lhs_grad = acc.get_tensor_grad(LHS_INPUT); - auto input_rhs_grad = acc.get_tensor_grad(RHS_INPUT); - - return profile(backward_kernel, - profiling, - kernel_device_type, - "[ElementBinary] backward_time = {:.2lf}ms\n", - per_device_state, - output_grad.get_float_ptr(), - input_lhs.get_float_ptr(), - input_rhs.get_float_ptr(), - input_lhs_grad.get_float_ptr(), - input_rhs_grad.get_float_ptr(), - attrs.type, - attrs.should_broadcast_lhs, - attrs.should_broadcast_rhs, - handle); -} - -TaskImplFunction get_element_binary_init_task_impl() { - return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; -} - -TaskImplFunction get_element_binary_fwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; -} - -TaskImplFunction get_element_binary_bwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; -} - -OpTaskSignature get_element_binary_init_signature() { - OpTaskSignature init(OpTaskType::INIT); - - init.add_input_slot(LHS_INPUT); - init.add_input_slot(RHS_INPUT); - init.add_output_slot(OUTPUT); - - init.add_arg_slot(ATTRS); - init.add_arg_slot(KERNEL_DEVICE_TYPE); - init.add_unchecked_arg_slot(HANDLE); - - init.add_return_value(); - - return init; -} - -OpTaskSignature get_element_binary_fwd_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_arg_slot(PROFILING); - fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); - fwd.add_arg_slot(ATTRS); - fwd.add_arg_slot(KERNEL_DEVICE_TYPE); - fwd.add_unchecked_arg_slot(HANDLE); - - fwd.add_input_slot(LHS_INPUT); - fwd.add_input_slot(RHS_INPUT); - fwd.add_output_slot(OUTPUT); - - return fwd; -} - -OpTaskSignature get_element_binary_bwd_signature() { - OpTaskSignature bwd = infer_bwd_signature(get_element_binary_fwd_signature()); - - return bwd; -} - -std::vector get_task_ids(ElementBinaryAttrs const &) { - return {task_id_t::ELEMENTBINARY_INIT_TASK_ID, - task_id_t::ELEMENTBINARY_FWD_TASK_ID, - task_id_t::ELEMENTBINARY_BWD_TASK_ID}; -} - -}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/element_unary.cc b/lib/task-spec/src/task-spec/ops/element_unary.cc deleted file mode 100644 index f8df53b578..0000000000 --- a/lib/task-spec/src/task-spec/ops/element_unary.cc +++ /dev/null @@ -1,194 +0,0 @@ -#include "task-spec/ops/element_unary.h" -#include "kernels/element_unary_kernels.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "task-spec/profiling.h" -#include "utils/hash-utils.h" - -namespace FlexFlow { - -// declare Legion names - -using namespace FlexFlow::Kernels::ElementUnary; - -enum Slots { - INPUT, - INPUT_SHAPE, - OUTPUT, - OUTPUT_SHAPE, - ATTRS, - HANDLE, - PROFILING, - PER_DEVICE_STATE, - KERNEL_DEVICE_TYPE, -}; - -/* ElementUnary */ -OpTaskInvocation init(ElementUnaryAttrs const &attrs) { - OpTaskBinding b; - - b.bind_arg(ATTRS, attrs); - b.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - - b.bind_arg(INPUT_SHAPE, input_parallel_tensor_shape(0_n)); - b.bind_arg(OUTPUT_SHAPE, output_parallel_tensor_shape(0_n)); - - return OpTaskInvocation{ - task_id_t::ELEMENTUNARY_INIT_TASK_ID, - b, - }; -} - -OpTaskInvocation forward(ElementUnaryAttrs const &attrs) { - OpTaskBinding b; - - b.bind(INPUT, input_tensor(0_n)); - b.bind(OUTPUT, output_tensor(0_n)); - b.bind_arg(ATTRS, attrs); - - b.bind_arg(HANDLE, ff_handle()); - b.bind_arg(PROFILING, profiling_settings()); - b.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - b.bind_arg(PER_DEVICE_STATE, - per_device_op_state>()); - - return OpTaskInvocation{ - task_id_t::ELEMENTUNARY_FWD_TASK_ID, - b, - }; -} - -OpTaskInvocation backward(ElementUnaryAttrs const &attrs) { - OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - - return OpTaskInvocation{ - task_id_t::ELEMENTUNARY_BWD_TASK_ID, - b, - }; -} - -static DeviceSpecificDeviceStates - init_task_impl(TaskArgumentAccessor const &acc) { - - auto attrs = acc.get_argument(ATTRS); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - - ParallelTensorShape input_shape = - acc.get_argument(INPUT_SHAPE); - ParallelTensorShape output_shape = - acc.get_argument(OUTPUT_SHAPE); - - std::optional per_device_state = - init_kernel(kernel_device_type, - get_piece_shape(input_shape), - get_piece_shape(output_shape), - attrs); - - return DeviceSpecificDeviceStates{ - DeviceSpecific>::create( - per_device_state), - }; -} - -static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - auto attrs = acc.get_argument(ATTRS); - - auto handle = acc.get_argument(HANDLE); - - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto per_device_state = - acc.get_argument(PER_DEVICE_STATE); - - return profile(forward_kernel, - profiling, - kernel_device_type, - "[ElementUnary] forward_time = {:.2lf}ms\n", - per_device_state, - attrs, - handle, - input, - output); -} - -static std::optional - backward_task_impl(TaskArgumentAccessor const &acc) { - auto input = acc.get_tensor(INPUT); - auto input_grad = acc.get_tensor_grad(INPUT); - auto output = acc.get_tensor(OUTPUT); - auto output_grad = acc.get_tensor_grad(OUTPUT); - - auto const &attrs = acc.get_argument(ATTRS); - auto handle = acc.get_argument(HANDLE); - - auto per_device_state = - acc.get_argument(PER_DEVICE_STATE); - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - - return profile(backward_kernel, - profiling, - kernel_device_type, - "[ElementUnary] backward_time = {:.2lf}ms\n", - per_device_state, - attrs, - handle, - output, - output_grad, - input, - input_grad); -} - -TaskImplFunction get_element_unary_init_task_impl() { - return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; -} -TaskImplFunction get_element_unary_fwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; -} -TaskImplFunction get_element_unary_bwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; -} - -OpTaskSignature get_element_unary_init_signature() { - OpTaskSignature init(OpTaskType::INIT); - - init.add_arg_slot(INPUT_SHAPE); - init.add_arg_slot(ATTRS); - init.add_arg_slot(KERNEL_DEVICE_TYPE); - init.add_unchecked_arg_slot(HANDLE); - - init.add_return_value(); - - return init; -} - -OpTaskSignature get_element_unary_fwd_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_input_slot(INPUT); - fwd.add_output_slot(OUTPUT); - - fwd.add_arg_slot(PROFILING); - fwd.add_arg_slot(KERNEL_DEVICE_TYPE); - fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); - - return fwd; -} - -OpTaskSignature get_element_unary_bwd_signature() { - OpTaskSignature bwd = infer_bwd_signature(get_element_unary_fwd_signature()); - - return bwd; -} - -std::vector get_task_ids(ElementUnaryAttrs const &) { - return {task_id_t::ELEMENTUNARY_INIT_TASK_ID, - task_id_t::ELEMENTUNARY_FWD_TASK_ID, - task_id_t::ELEMENTUNARY_BWD_TASK_ID}; -} - -} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/embedding.cc b/lib/task-spec/src/task-spec/ops/embedding.cc deleted file mode 100644 index 4ba32c8483..0000000000 --- a/lib/task-spec/src/task-spec/ops/embedding.cc +++ /dev/null @@ -1,120 +0,0 @@ -#include "task-spec/ops/embedding.h" -#include "kernels/embedding_kernels.h" -#include "task-spec/profiling.h" - -namespace FlexFlow { - -using namespace FlexFlow::Kernels::Embedding; - -enum Slots { INPUT, WEIGHT, OUTPUT, ATTRS, PROFILING, KERNEL_DEVICE_TYPE }; - -OpTaskInvocation forward(EmbeddingAttrs const &attrs) { - OpTaskBinding b; - - b.bind(INPUT, input_tensor(0_n)); - b.bind(WEIGHT, weight_tensor(0_n)); - b.bind(OUTPUT, output_tensor(0_n)); - - b.bind_arg(ATTRS, attrs); - b.bind_arg(PROFILING, profiling_settings()); - b.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - - return OpTaskInvocation{ - task_id_t::EMBED_FWD_TASK_ID, - b, - }; -} - -OpTaskInvocation backward(EmbeddingAttrs const &attrs) { - OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - - return OpTaskInvocation{ - task_id_t::EMBED_BWD_TASK_ID, - b, - }; -} - -static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { - auto input = acc.get_tensor(INPUT); - auto weight = acc.get_tensor(WEIGHT); - auto output = acc.get_tensor(OUTPUT); - - ProfilingSettings profiling = acc.get_argument(PROFILING); - EmbeddingAttrs attrs = acc.get_argument(ATTRS); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - - return profile( - forward_kernel, - profiling, - kernel_device_type, - "[Embedding] forward_time = {:.2lf}ms\n", - input, - output, - weight, - input.shape.data_type, - output.shape.data_type, - attrs.aggr, - get_num_dims(input.shape.dims).unwrap_nonnegative(), - get_num_dims(output.shape.dims).unwrap_nonnegative(), - dim_at_idx(input.shape.dims, legion_dim_t{1_n}).int_from_positive_int()); -} - -static std::optional - backward_task_impl(TaskArgumentAccessor const &acc) { - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - auto weight_grad = acc.get_tensor_grad(WEIGHT); - - ProfilingSettings profiling = acc.get_argument(PROFILING); - EmbeddingAttrs attrs = acc.get_argument(ATTRS); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - - return profile( - backward_kernel, - profiling, - kernel_device_type, - "[Embedding] backward_time = {:.2lf}ms\n", - output, - input, - weight_grad, - output.shape.data_type, - input.shape.data_type, - attrs.aggr, - get_num_dims(input.shape.dims).unwrap_nonnegative(), - get_num_dims(output.shape.dims).unwrap_nonnegative(), - dim_at_idx(input.shape.dims, ff_dim_t{0_n}).int_from_positive_int()); -} - -TaskImplFunction get_embedding_fwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; -} -TaskImplFunction get_embedding_bwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; -} - -OpTaskSignature get_embedding_fwd_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_input_slot(INPUT); - fwd.add_input_slot(OUTPUT); - fwd.add_input_slot(WEIGHT); - - fwd.add_arg_slot(ATTRS); - fwd.add_arg_slot(PROFILING); - fwd.add_arg_slot(KERNEL_DEVICE_TYPE); - - return fwd; -} - -OpTaskSignature get_embedding_bwd_signature() { - OpTaskSignature bwd = infer_bwd_signature(get_embedding_fwd_signature()); - return bwd; -} - -std::vector get_task_ids(EmbeddingAttrs const &) { - return {task_id_t::EMBED_FWD_TASK_ID, task_id_t::EMBED_BWD_TASK_ID}; -} - -} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/flat.cc b/lib/task-spec/src/task-spec/ops/flat.cc deleted file mode 100644 index 6cec1b383f..0000000000 --- a/lib/task-spec/src/task-spec/ops/flat.cc +++ /dev/null @@ -1,97 +0,0 @@ -#include "task-spec/ops/flat.h" -#include "kernels/flat_kernels.h" -#include "task-spec/profiling.h" - -namespace FlexFlow { - -using namespace FlexFlow::Kernels::Flat; - -enum SLOTS { INPUT, OUTPUT, HANDLE, PROFILING, KERNEL_DEVICE_TYPE }; - -OpTaskInvocation forward(FlatAttrs const &attrs) { - OpTaskBinding binding; - - binding.bind(INPUT, input_tensor(0_n)); - binding.bind(OUTPUT, output_tensor(0_n)); - - binding.bind_arg(PROFILING, profiling_settings()); - binding.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - - return OpTaskInvocation{ - task_id_t::FLAT_FWD_TASK_ID, - binding, - }; -} - -OpTaskInvocation backward(FlatAttrs const &attrs) { - OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - - return OpTaskInvocation{ - task_id_t::FLAT_BWD_TASK_ID, - b, - }; -} - -static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - - return profile(forward_kernel, - profiling, - kernel_device_type, - "[Flat] forward_time = {:.2lf}ms\n", - input, - output.get_float_ptr()); -} - -static std::optional - backward_task_impl(TaskArgumentAccessor const &acc) { - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - - auto input = acc.get_tensor(INPUT); - auto output_grad = acc.get_tensor_grad(OUTPUT); - auto input_grad = acc.get_tensor_grad(INPUT); - - return profile(backward_kernel, - profiling, - kernel_device_type, - "[Flat] backward_time = {:.2lf}ms\n", - input, - output_grad.get_float_ptr(), - input_grad.get_float_ptr()); -} - -TaskImplFunction get_flat_fwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; -} -TaskImplFunction get_flat_bwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; -} - -OpTaskSignature get_flat_fwd_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_arg_slot(PROFILING); - fwd.add_arg_slot(KERNEL_DEVICE_TYPE); - fwd.add_input_slot(INPUT); - fwd.add_output_slot(OUTPUT); - - return fwd; -} - -OpTaskSignature get_flat_bwd_signature() { - OpTaskSignature bwd = infer_bwd_signature(get_flat_fwd_signature()); - - return bwd; -} - -std::vector get_task_ids(FlatAttrs const &) { - return {task_id_t::FLAT_FWD_TASK_ID, task_id_t::FLAT_BWD_TASK_ID}; -} - -}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/gather.cc b/lib/task-spec/src/task-spec/ops/gather.cc deleted file mode 100644 index 7f8aacf9d6..0000000000 --- a/lib/task-spec/src/task-spec/ops/gather.cc +++ /dev/null @@ -1,207 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "task-spec/ops/gather.h" -#include "kernels/gather_kernels.h" -#include "op-attrs/ff_ordered/get_idxs.h" -#include "task-spec/profiling.h" -#include "utils/nonnegative_int/nonnegative_range.h" -#include - -namespace FlexFlow { - -using namespace FlexFlow::Kernels::Gather; - -enum Slots { - INPUT, - OUTPUT, - INDEX, - ATTRS, - HANDLE, - PROFILING, - PER_DEVICE_STATE, - KERNEL_DEVICE_TYPE -}; - -OpTaskInvocation init(GatherAttrs const &attrs) { - OpTaskBinding binding; - - binding.bind(INPUT, input_tensor(0_n)); - binding.bind(INDEX, input_tensor(1_n)); - binding.bind(OUTPUT, output_tensor(0_n)); - binding.bind_arg(ATTRS, attrs); - binding.bind_arg(HANDLE, ff_handle()); - binding.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - - return OpTaskInvocation{ - task_id_t::GATHER_INIT_TASK_ID, - binding, - }; -} - -OpTaskInvocation forward(GatherAttrs const &attrs) { - OpTaskBinding binding; - - binding.bind_arg(ATTRS, attrs); - binding.bind_arg(PROFILING, profiling_settings()); - binding.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - binding.bind_arg(PER_DEVICE_STATE, - per_device_op_state>()); - - binding.bind(INPUT, input_tensor(0_n)); - binding.bind(OUTPUT, output_tensor(0_n)); - binding.bind(INDEX, weight_tensor(0_n)); - - return OpTaskInvocation{ - task_id_t::GATHER_FWD_TASK_ID, - binding, - }; -} - -OpTaskInvocation backward(GatherAttrs const &attrs) { - OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - - return OpTaskInvocation{ - task_id_t::GATHER_BWD_TASK_ID, - binding, - }; -} - -static DeviceSpecificDeviceStates - init_task_impl(TaskArgumentAccessor const &acc) { - auto input = acc.get_tensor(INPUT); - auto index = acc.get_tensor(INDEX); - auto output = acc.get_tensor(OUTPUT); - - device_handle_t handle = acc.get_argument(HANDLE); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto const &attrs = acc.get_argument(ATTRS); - - ASSERT(get_num_dims(input.shape.dims) == get_num_dims(index.shape.dims)); - ASSERT(get_num_dims(output.shape.dims) == get_num_dims(index.shape.dims)); - - for (ff_dim_t i : get_idxs(input.shape.dims.ff_ordered)) { - ASSERT(dim_at_idx(index.shape.dims, i) == dim_at_idx(output.shape.dims, i)); - if (i != attrs.dim) { - ASSERT(dim_at_idx(input.shape.dims, i) == - dim_at_idx(index.shape.dims, i)); - } - } - - std::optional per_device_state = - init_kernel(kernel_device_type, handle, attrs.dim); - return DeviceSpecificDeviceStates{ - DeviceSpecific>::create( - per_device_state), - }; -} - -static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto per_device_state = - acc.get_argument(PER_DEVICE_STATE); - - auto input = acc.get_tensor(INPUT); - auto index = acc.get_tensor(INDEX); - auto output = acc.get_tensor(OUTPUT); - - return profile(forward_kernel, - profiling, - kernel_device_type, - "[Gather] forward_time = {:.2lf}ms\n", - per_device_state, - input, - index, - output); -} - -static std::optional - backward_task_impl(TaskArgumentAccessor const &acc) { - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto per_device_state = - acc.get_argument(PER_DEVICE_STATE); - - auto output_grad = acc.get_tensor_grad(OUTPUT); - auto index = acc.get_tensor(INDEX); - auto input_grad = acc.get_tensor_grad(INPUT); - - return profile(backward_kernel, - profiling, - kernel_device_type, - "[Gather] backward_time = {:.2lf}ms\n", - per_device_state, - output_grad, - index, - input_grad); -} - -TaskImplFunction get_gather_init_task_impl() { - return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; -} -TaskImplFunction get_gather_fwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; -} -TaskImplFunction get_gather_bwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; -} - -OpTaskSignature get_gather_init_signature() { - OpTaskSignature init(OpTaskType::INIT); - - init.add_input_slot(INPUT); - init.add_input_slot(INDEX); - init.add_output_slot(OUTPUT); - - init.add_arg_slot(ATTRS); - init.add_arg_slot(KERNEL_DEVICE_TYPE); - init.add_unchecked_arg_slot(HANDLE); - - init.add_return_value(); - - return init; -} - -OpTaskSignature get_gather_fwd_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_arg_slot(PROFILING); - fwd.add_arg_slot(KERNEL_DEVICE_TYPE); - fwd.add_arg_slot(ATTRS); - - fwd.add_input_slot(INPUT); - fwd.add_output_slot(OUTPUT); - fwd.add_weight_slot(INDEX); - - return fwd; -} - -OpTaskSignature get_gather_bwd_signature() { - OpTaskSignature bwd = infer_bwd_signature(get_gather_fwd_signature()); - - return bwd; -} - -std::vector get_task_ids(GatherAttrs const &) { - return {task_id_t::GATHER_INIT_TASK_ID, - task_id_t::GATHER_FWD_TASK_ID, - task_id_t::GATHER_BWD_TASK_ID}; -} - -}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/impl/attention.cc b/lib/task-spec/src/task-spec/ops/impl/attention.cc new file mode 100644 index 0000000000..1f1c0a507a --- /dev/null +++ b/lib/task-spec/src/task-spec/ops/impl/attention.cc @@ -0,0 +1,169 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "task-spec/ops/impl/attention.h" +#include "kernels/attention_kernels.h" +#include "kernels/device_handle_t.dtg.h" +#include "op-attrs/ops/attention.h" +#include "op-attrs/ops/attention/multihead_attention_inputs.h" +#include "op-attrs/ops/attention/multihead_attention_parallel_inputs.h" +#include "task-spec/profiling.h" + +namespace FlexFlow { + +using namespace FlexFlow::Kernels::MultiHeadAttention; + +static DeviceSpecificPerDeviceOpState + init_task_impl(TaskArgumentAccessor const &acc) { + MultiHeadAttentionAttrs attrs = + acc.get_op_attrs().require_multi_head_attention(); + Allocator allocator = acc.get_allocator(); + + DeviceType kernel_device_type = acc.get_kernel_device_type(); + + positive_int qProjSize = get_qProjSize(attrs); + positive_int kProjSize = get_kProjSize(attrs); + positive_int vProjSize = get_vProjSize(attrs); + positive_int oProjSize = get_oProjSize(attrs); + + device_handle_t handle = acc.get_ff_handle(); + + TensorShape query_tensor_shape = acc.get_tensor_shape(TensorSlotName::QUERY); + TensorShape key_tensor_shape = acc.get_tensor_shape(TensorSlotName::KEY); + TensorShape value_tensor_shape = acc.get_tensor_shape(TensorSlotName::VALUE); + + MultiHeadAttentionInputs parsed = + throw_if_unexpected(parse_attention_input_shape( + query_tensor_shape, key_tensor_shape, value_tensor_shape)); + TensorShape weight_tensor_shape = throw_if_unexpected(get_weights_shape( + attrs, query_tensor_shape, key_tensor_shape, value_tensor_shape)); + + positive_int kvSeqLength = get_kvSeqLength(parsed); + positive_int qSize = get_qSize(parsed); + positive_int kSize = get_kSize(parsed); + positive_int vSize = get_vSize(parsed); + + positive_int qoSeqLength = get_qoSeqLength(parsed); + positive_int num_samples = get_num_samples(parsed); + positive_int num_heads = attrs.num_heads; + + std::optional per_device_state = init_kernel( + /*device_type=*/kernel_device_type, + /*per_device_ff_handle=*/handle, + /*allocator=*/allocator, + /*num_samples=*/num_samples.int_from_positive_int(), + /*num_heads=*/num_heads.int_from_positive_int(), + /*qSize=*/qSize.int_from_positive_int(), + /*kSize=*/kSize.int_from_positive_int(), + /*vSize=*/vSize.int_from_positive_int(), + /*qProjSize=*/qProjSize.int_from_positive_int(), + /*kProjSize=*/kProjSize.int_from_positive_int(), + /*vProjSize=*/vProjSize.int_from_positive_int(), + /*oProjSize=*/oProjSize.int_from_positive_int(), + /*qoSeqLength=*/qoSeqLength.int_from_positive_int(), + /*kvSeqLength=*/kvSeqLength.int_from_positive_int(), + /*add_bias_kv=*/attrs.add_bias_kv); + + return DeviceSpecificPerDeviceOpState{ + acc.make_device_specific(per_device_state), + }; +} + +static std::optional + forward_task_impl(TaskArgumentAccessor const &acc) { + auto query = acc.get_tensor(TensorSlotName::QUERY); + auto key = acc.get_tensor(TensorSlotName::KEY); + auto value = acc.get_tensor(TensorSlotName::VALUE); + auto weight = acc.get_tensor(TensorSlotName::WEIGHT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + std::optional per_device_state = + acc.get_per_device_op_state().require_mha(); + + return profile(forward_kernel, + profiling, + kernel_device_type, + "[MultiHeadAttention] forward_time = {:.2lf}ms\n", + per_device_state, + query.get_float_ptr(), + key.get_float_ptr(), + value.get_float_ptr(), + weight.get_float_ptr(), + output.get_float_ptr()); +} + +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { + auto query = acc.get_tensor(TensorSlotName::QUERY); + auto key = acc.get_tensor(TensorSlotName::KEY); + auto value = acc.get_tensor(TensorSlotName::VALUE); + auto weight = acc.get_tensor(TensorSlotName::WEIGHT); + + auto output_grad = + acc.get_tensor_grad(TensorSlotName::OUTPUT); + auto weight_grad = + acc.get_tensor_grad(TensorSlotName::WEIGHT); + auto query_grad = acc.get_tensor_grad(TensorSlotName::QUERY); + auto key_grad = acc.get_tensor_grad(TensorSlotName::KEY); + auto value_grad = acc.get_tensor_grad(TensorSlotName::VALUE); + + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + std::optional per_device_state = + acc.get_per_device_op_state().require_mha(); + + float *key_grad_ptr = + (key_grad == query_grad) ? nullptr : key_grad.get_float_ptr(); + float *value_grad_ptr = (value_grad == query_grad || value_grad == key_grad) + ? nullptr + : value_grad.get_float_ptr(); + + ASSERT(value_grad.shape == value.shape); + ASSERT(key_grad.shape == key.shape); + + ASSERT(query_grad.shape == query.shape); + ASSERT(weight_grad.shape == weight.shape); + + return profile(backward_kernel, + profiling, + kernel_device_type, + "[MultiHeadAttention] backward_time = {:.2lf}ms\n", + per_device_state, + query.get_float_ptr(), + query_grad.get_float_ptr(), + key.get_float_ptr(), + key_grad_ptr, + value.get_float_ptr(), + value_grad_ptr, + weight.get_float_ptr(), + weight_grad.get_float_ptr(), + output_grad.get_float_ptr()); +} + +TaskImplFunction get_attention_init_task_impl() { + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; +} + +TaskImplFunction get_attention_fwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; +} + +TaskImplFunction get_attention_bwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; +} + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/impl/batch_matmul.cc b/lib/task-spec/src/task-spec/ops/impl/batch_matmul.cc new file mode 100644 index 0000000000..43bc185b0d --- /dev/null +++ b/lib/task-spec/src/task-spec/ops/impl/batch_matmul.cc @@ -0,0 +1,95 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "task-spec/ops/impl/batch_matmul.h" +#include "kernels/batch_matmul_kernels.h" +#include "op-attrs/ops/batch_matmul.h" +#include "task-spec/profiling.h" +#include "utils/containers/transform.h" +#include "utils/nonnegative_int/nonnegative_range.h" + +namespace FlexFlow { + +using namespace FlexFlow::Kernels::BatchMatmul; + +static std::optional + forward_task_impl(TaskArgumentAccessor const &acc) { + auto a_input = acc.get_tensor(TensorSlotName::LHS_INPUT); + auto b_input = acc.get_tensor(TensorSlotName::RHS_INPUT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + BatchMatmulAttrs attrs = acc.get_op_attrs().require_batch_matmul(); + device_handle_t handle = acc.get_ff_handle(); + + ProfilingSettings profiling = acc.get_profiling_settings(); + FFIterationConfig iter_config = acc.get_iteration_config(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + + return profile(forward_kernel, + profiling, + kernel_device_type, + "[BatchMatmul] forward_time = {:.2lf}ms\n", + handle, + output, + a_input, + b_input, + iter_config.seq_length, + attrs.a_seq_length_dim, + attrs.b_seq_length_dim); +} + +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { + FFIterationConfig iter_config = acc.get_iteration_config(); + ProfilingSettings profiling = acc.get_profiling_settings(); + device_handle_t handle = acc.get_ff_handle(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + auto output_grad = + acc.get_tensor_grad(TensorSlotName::OUTPUT); + ASSERT(output.shape == output_grad.shape); + + auto a_input = acc.get_tensor(TensorSlotName::LHS_INPUT); + auto a_input_grad = + acc.get_tensor_grad(TensorSlotName::LHS_INPUT); + ASSERT(a_input.shape == a_input_grad.shape); + + auto b_input = acc.get_tensor(TensorSlotName::RHS_INPUT); + auto b_input_grad = + acc.get_tensor_grad(TensorSlotName::RHS_INPUT); + ASSERT(b_input.shape == b_input_grad.shape); + + return profile(backward_kernel, + profiling, + kernel_device_type, + "[BatchMatmul] backward_time = {:.2lf}ms\n", + handle, + output, + output_grad, + a_input, + a_input_grad, + b_input, + b_input_grad); +} + +TaskImplFunction get_batch_matmul_fwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; +} + +TaskImplFunction get_batch_matmul_bwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; +} + +}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/impl/batch_norm.cc b/lib/task-spec/src/task-spec/ops/impl/batch_norm.cc new file mode 100644 index 0000000000..e622f0bc4c --- /dev/null +++ b/lib/task-spec/src/task-spec/ops/impl/batch_norm.cc @@ -0,0 +1,123 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "task-spec/ops/impl/batch_norm.h" +#include "kernels/batch_norm_kernels.h" +#include "task-spec/profiling.h" + +namespace FlexFlow { + +using namespace FlexFlow::Kernels::BatchNorm; + +static DeviceSpecificPerDeviceOpState + init_task_impl(TaskArgumentAccessor const &acc) { + Allocator allocator = acc.get_allocator(); + device_handle_t handle = acc.get_ff_handle(); + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + BatchNormAttrs attrs = acc.get_op_attrs().require_batch_norm(); + + positive_int output_w = dim_at_idx(output.shape.dims, legion_dim_t{0_n}); + positive_int output_h = dim_at_idx(output.shape.dims, legion_dim_t{1_n}); + positive_int output_c = dim_at_idx(output.shape.dims, legion_dim_t{2_n}); + positive_int output_n = dim_at_idx(output.shape.dims, legion_dim_t{3_n}); + + float *runningMean; + + std::optional per_device_state = init_kernel( + /*device_type=*/kernel_device_type, + /*handle=*/handle, + /*allocator=*/allocator, + /*runningMean=*/runningMean, + /*output_n=*/output_n.int_from_positive_int(), + /*output_c=*/output_c.int_from_positive_int(), + /*output_h=*/output_h.int_from_positive_int(), + /*output_w=*/output_w.int_from_positive_int(), + /*relu=*/attrs.relu); + + return DeviceSpecificPerDeviceOpState{ + acc.make_device_specific(per_device_state), + }; +} + +static std::optional + forward_task_impl(TaskArgumentAccessor const &acc) { + auto per_device_state = + acc.get_per_device_op_state().require_batch_norm().value(); + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + auto scale = acc.get_tensor(TensorSlotName::SCALE); + auto bias = acc.get_tensor(TensorSlotName::BIAS); + + return profile(forward_kernel, + profiling, + kernel_device_type, + "[BatchNorm] forward_time = {:.2lf}ms\n", + per_device_state, + input.get_float_ptr(), + output.get_float_ptr(), + scale.get_float_ptr(), + bias.get_float_ptr()); +} + +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { + BatchNormPerDeviceState per_device_state = + acc.get_per_device_op_state().require_batch_norm().value(); + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto input_grad = acc.get_tensor_grad(TensorSlotName::INPUT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + auto output_grad = + acc.get_tensor_grad(TensorSlotName::OUTPUT); + auto scale = acc.get_tensor(TensorSlotName::SCALE); + auto scale_grad = acc.get_tensor_grad(TensorSlotName::SCALE); + auto bias_grad = acc.get_tensor_grad(TensorSlotName::BIAS); + + return profile(backward_kernel, + profiling, + kernel_device_type, + "[BatchNorm] backward_time = {:.2lf}ms\n", + per_device_state, + output.get_float_ptr(), + output_grad.get_float_ptr(), + input.get_float_ptr(), + input_grad.get_float_ptr(), + scale.get_float_ptr(), + scale_grad.get_float_ptr(), + bias_grad.get_float_ptr(), + get_num_elements(output.shape.dims).int_from_positive_int()); +} + +TaskImplFunction get_batch_norm_init_task_impl() { + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; +} + +TaskImplFunction get_batch_norm_fwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; +} + +TaskImplFunction get_batch_norm_bwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; +} + +}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/impl/broadcast.cc b/lib/task-spec/src/task-spec/ops/impl/broadcast.cc new file mode 100644 index 0000000000..e83132549d --- /dev/null +++ b/lib/task-spec/src/task-spec/ops/impl/broadcast.cc @@ -0,0 +1,13 @@ +#include "task-spec/ops/impl/broadcast.h" + +namespace FlexFlow { + +TaskImplFunction get_broadcast_fwd_task_impl() { + NOT_IMPLEMENTED(); +} + +TaskImplFunction get_broadcast_bwd_task_impl() { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/impl/cast.cc b/lib/task-spec/src/task-spec/ops/impl/cast.cc new file mode 100644 index 0000000000..fbde3b7a25 --- /dev/null +++ b/lib/task-spec/src/task-spec/ops/impl/cast.cc @@ -0,0 +1,70 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "task-spec/ops/impl/cast.h" +#include "kernels/cast_kernels.h" +#include "task-spec/profiling.h" +#include "utils/hash-utils.h" + +using namespace FlexFlow::Kernels::Cast; + +namespace FlexFlow { + +static std::optional + forward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + CastAttrs attrs = acc.get_op_attrs().require_cast(); + + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + + return profile(forward_kernel, + profiling, + kernel_device_type, + "[Cast] forward_time = {:.2lf}ms\n", + input, + output); +} + +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + CastAttrs attrs = acc.get_op_attrs().require_cast(); + + auto input = acc.get_tensor(TensorSlotName::INPUT); + + auto input_grad = acc.get_tensor_grad(TensorSlotName::INPUT); + auto output_grad = + acc.get_tensor_grad(TensorSlotName::OUTPUT); + + return profile(backward_kernel, + profiling, + kernel_device_type, + "[Cast] forward_time = {:.2lf}ms\n", + input_grad, + output_grad); +} + +TaskImplFunction get_cast_fwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; +} + +TaskImplFunction get_cast_bwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; +} + +}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/impl/concat.cc b/lib/task-spec/src/task-spec/ops/impl/concat.cc new file mode 100644 index 0000000000..39f9806226 --- /dev/null +++ b/lib/task-spec/src/task-spec/ops/impl/concat.cc @@ -0,0 +1,97 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "task-spec/ops/impl/concat.h" +#include "kernels/concat_kernels.h" +#include "op-attrs/tensor_slot_name.h" +#include "task-spec/profiling.h" +#include "task-spec/variadic_tensor_ref.h" +#include "utils/containers/slice.h" +#include "utils/hash-utils.h" + +namespace FlexFlow { + +using namespace FlexFlow::Kernels::Concat; + +static std::vector get_input_slots(ConcatAttrs const &attrs) { + return slice(get_variadic_inputs_slot_name_sequence(), + 0, + attrs.num_inputs.int_from_int_ge_two()); +} + +static std::optional + forward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + ConcatAttrs attrs = acc.get_op_attrs().require_concat(); + + std::vector input_slots = get_input_slots(attrs); + + std::vector inputs = + transform(input_slots, + [&](TensorSlotName input_slot_name) -> GenericTensorAccessorR { + return acc.get_tensor(input_slot_name); + }); + + ASSERT(inputs.size() <= MAX_NUM_INPUTS); + + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + + return profile(forward_kernel, + profiling, + kernel_device_type, + "[Concat] forward_time = {:.2lf}ms\n", + output, + inputs, + attrs.axis); +} + +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + ConcatAttrs attrs = acc.get_op_attrs().require_concat(); + + std::vector input_slots = get_input_slots(attrs); + + std::vector input_grads = + transform(input_slots, + [&](TensorSlotName input_slot_name) -> GenericTensorAccessorW { + return acc.get_tensor_grad(input_slot_name); + }); + + ASSERT(input_grads.size() <= MAX_NUM_INPUTS); + + auto output_grad = + acc.get_tensor_grad(TensorSlotName::OUTPUT); + + return profile(backward_kernel, + profiling, + kernel_device_type, + "[Concat] backward_time = {:.2lf}ms\n", + output_grad, + input_grads, + attrs.axis); +} + +TaskImplFunction get_concat_fwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; +} + +TaskImplFunction get_concat_bwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; +} + +}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/impl/conv_2d.cc b/lib/task-spec/src/task-spec/ops/impl/conv_2d.cc new file mode 100644 index 0000000000..28fe73c3fc --- /dev/null +++ b/lib/task-spec/src/task-spec/ops/impl/conv_2d.cc @@ -0,0 +1,113 @@ +#include "task-spec/ops/impl/conv_2d.h" +#include "kernels/conv_2d_kernels.h" +#include "task-spec/profiling.h" + +namespace FlexFlow { + +using namespace FlexFlow::Kernels::Conv2D; + +static DeviceSpecificPerDeviceOpState + init_task_impl(TaskArgumentAccessor const &acc) { + + device_handle_t handle = acc.get_ff_handle(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + Conv2DAttrs attrs = acc.get_op_attrs().require_conv2d(); + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + auto filter = acc.get_tensor(TensorSlotName::FILTER); + auto filter_grad = + acc.get_tensor_grad(TensorSlotName::FILTER); + + std::optional per_device_state = init_kernel( + /*device_type=*/kernel_device_type, + /*handle=*/handle, + /*activation=*/attrs.activation, + /*kernel_h=*/attrs.kernel_h.int_from_positive_int(), + /*kernel_w=*/attrs.kernel_w.int_from_positive_int(), + /*groups=*/attrs.groups.int_from_positive_int(), + /*padding_h=*/attrs.padding_h.unwrap_nonnegative(), + /*padding_w=*/attrs.padding_w.unwrap_nonnegative(), + /*stride_h=*/attrs.stride_h.int_from_positive_int(), + /*stride_w=*/attrs.stride_w.int_from_positive_int(), + /*input=*/input, + /*output=*/output, + /*filter_ptr=*/filter.get_float_ptr(), + /*filter_grad_ptr=*/filter_grad.get_float_ptr()); + + return DeviceSpecificPerDeviceOpState{ + acc.make_device_specific(per_device_state), + }; +} + +static std::optional + forward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + Conv2DPerDeviceState per_device_state = + acc.get_per_device_op_state().require_conv2d().value(); + Conv2DAttrs attrs = acc.get_op_attrs().require_conv2d(); + + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto filter = acc.get_tensor(TensorSlotName::FILTER); + auto bias = acc.get_tensor(TensorSlotName::BIAS); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + + return profile(forward_kernel, + profiling, + kernel_device_type, + "[Conv2d] forward_time = {:.2lf}ms\n", + per_device_state, + input.get_float_ptr(), + output.get_float_ptr(), + filter.get_float_ptr(), + bias.get_float_ptr(), + attrs.activation); +} + +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + Conv2DPerDeviceState per_device_state = + acc.get_per_device_op_state().require_conv2d().value(); + Conv2DAttrs attrs = acc.get_op_attrs().require_conv2d(); + + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto filter = acc.get_tensor(TensorSlotName::FILTER); + + auto input_grad = acc.get_tensor_grad(TensorSlotName::INPUT); + auto output_grad = + acc.get_tensor_grad(TensorSlotName::OUTPUT); + auto filter_grad = + acc.get_tensor_grad(TensorSlotName::FILTER); + auto bias_grad = acc.get_tensor_grad(TensorSlotName::BIAS); + + return profile(backward_kernel, + profiling, + kernel_device_type, + "[Conv2d] backward_time = {:.2lf}ms\n", + per_device_state, + output.get_float_ptr(), + output_grad.get_float_ptr(), + input.get_float_ptr(), + input_grad.get_float_ptr(), + filter.get_float_ptr(), + filter_grad.get_float_ptr(), + bias_grad.get_float_ptr(), + attrs.activation); +} + +TaskImplFunction get_conv_2d_init_task_impl() { + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; +} + +TaskImplFunction get_conv_2d_fwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; +} + +TaskImplFunction get_conv_2d_bwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; +} + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/impl/dropout.cc b/lib/task-spec/src/task-spec/ops/impl/dropout.cc new file mode 100644 index 0000000000..016e7cde75 --- /dev/null +++ b/lib/task-spec/src/task-spec/ops/impl/dropout.cc @@ -0,0 +1,84 @@ +#include "task-spec/ops/impl/dropout.h" +#include "kernels/dropout_kernels.h" +#include "task-spec/profiling.h" +#include "utils/hash-utils.h" + +namespace FlexFlow { + +using namespace FlexFlow::Kernels::Dropout; + +static DeviceSpecificPerDeviceOpState + init_task_impl(TaskArgumentAccessor const &acc) { + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + Allocator allocator = acc.get_allocator(); + device_handle_t handle = acc.get_ff_handle(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + + DropoutAttrs attrs = acc.get_op_attrs().require_dropout(); + + std::optional per_device_state = + init_kernel(kernel_device_type, + handle, + attrs.rate, + attrs.seed, + output.shape, + allocator); + + return DeviceSpecificPerDeviceOpState{ + acc.make_device_specific(per_device_state), + }; +} + +static std::optional + forward_task_impl(TaskArgumentAccessor const &acc) { + DropoutPerDeviceState per_device_state = + acc.get_per_device_op_state().require_dropout().value(); + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + + return profile(forward_kernel, + profiling, + kernel_device_type, + "[Dropout] forward_time = {:.2lf}ms\n", + per_device_state, + input.get_float_ptr(), + output.get_float_ptr()); +} + +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { + + DropoutPerDeviceState per_device_state = + acc.get_per_device_op_state().require_dropout().value(); + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + + auto input_grad = acc.get_tensor_grad(TensorSlotName::INPUT); + auto output_grad = + acc.get_tensor_grad(TensorSlotName::OUTPUT); + + return profile(backward_kernel, + profiling, + kernel_device_type, + "[Dropout] backward_time = {:.2lf}ms\n", + per_device_state, + output_grad.get_float_ptr(), + input_grad.get_float_ptr()); +} + +TaskImplFunction get_dropout_init_task_impl() { + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; +} + +TaskImplFunction get_dropout_fwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; +} + +TaskImplFunction get_dropout_bwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; +} + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/impl/element_binary.cc b/lib/task-spec/src/task-spec/ops/impl/element_binary.cc new file mode 100644 index 0000000000..13465d7a5f --- /dev/null +++ b/lib/task-spec/src/task-spec/ops/impl/element_binary.cc @@ -0,0 +1,108 @@ +#include "task-spec/ops/impl/element_binary.h" +#include "kernels/element_binary_kernels.h" +#include "task-spec/profiling.h" +#include "utils/hash-utils.h" + +namespace FlexFlow { + +using namespace FlexFlow::Kernels::ElementBinary; + +static DeviceSpecificPerDeviceOpState + init_task_impl(TaskArgumentAccessor const &acc) { + auto input_lhs = acc.get_tensor(TensorSlotName::LHS_INPUT); + auto input_rhs = acc.get_tensor(TensorSlotName::RHS_INPUT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + + device_handle_t handle = acc.get_ff_handle(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + ElementBinaryAttrs attrs = acc.get_op_attrs().require_element_binary(); + + std::optional per_device_state = + init_kernel(kernel_device_type, + handle, + attrs.type, + attrs.should_broadcast_lhs, + attrs.should_broadcast_rhs, + input_lhs.shape, + input_rhs.shape, + output.shape); + + return DeviceSpecificPerDeviceOpState{ + acc.make_device_specific(per_device_state), + }; +} + +static std::optional + forward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + ElementBinaryPerDeviceState per_device_state = + acc.get_per_device_op_state().require_element_binary().value(); + ElementBinaryAttrs attrs = acc.get_op_attrs().require_element_binary(); + device_handle_t handle = acc.get_ff_handle(); + + auto input_lhs = acc.get_tensor(TensorSlotName::LHS_INPUT); + auto input_rhs = acc.get_tensor(TensorSlotName::RHS_INPUT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + + return profile(forward_kernel, + profiling, + kernel_device_type, + "[ElementBinary] forward_time = {:.2lf}ms\n", + per_device_state, + input_lhs.get_float_ptr(), + input_rhs.get_float_ptr(), + output.get_float_ptr(), + attrs.type, + attrs.should_broadcast_lhs, + handle); +} + +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + ElementBinaryPerDeviceState per_device_state = + acc.get_per_device_op_state().require_element_binary().value(); + ElementBinaryAttrs attrs = acc.get_op_attrs().require_element_binary(); + device_handle_t handle = acc.get_ff_handle(); + + auto input_lhs = acc.get_tensor(TensorSlotName::LHS_INPUT); + auto input_rhs = acc.get_tensor(TensorSlotName::RHS_INPUT); + + auto output_grad = + acc.get_tensor_grad(TensorSlotName::OUTPUT); + auto input_lhs_grad = + acc.get_tensor_grad(TensorSlotName::LHS_INPUT); + auto input_rhs_grad = + acc.get_tensor_grad(TensorSlotName::RHS_INPUT); + + return profile(backward_kernel, + profiling, + kernel_device_type, + "[ElementBinary] backward_time = {:.2lf}ms\n", + per_device_state, + output_grad.get_float_ptr(), + input_lhs.get_float_ptr(), + input_rhs.get_float_ptr(), + input_lhs_grad.get_float_ptr(), + input_rhs_grad.get_float_ptr(), + attrs.type, + attrs.should_broadcast_lhs, + attrs.should_broadcast_rhs, + handle); +} + +TaskImplFunction get_element_binary_init_task_impl() { + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; +} + +TaskImplFunction get_element_binary_fwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; +} + +TaskImplFunction get_element_binary_bwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; +} + +}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/impl/element_unary.cc b/lib/task-spec/src/task-spec/ops/impl/element_unary.cc new file mode 100644 index 0000000000..d66ff9ab8d --- /dev/null +++ b/lib/task-spec/src/task-spec/ops/impl/element_unary.cc @@ -0,0 +1,93 @@ +#include "task-spec/ops/impl/element_unary.h" +#include "kernels/element_unary_kernels.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "task-spec/profiling.h" +#include "utils/hash-utils.h" + +namespace FlexFlow { + +using namespace FlexFlow::Kernels::ElementUnary; + +static DeviceSpecificPerDeviceOpState + init_task_impl(TaskArgumentAccessor const &acc) { + + ElementUnaryAttrs attrs = acc.get_op_attrs().require_element_unary(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + + TensorShape input_shape = acc.get_tensor_shape(TensorSlotName::INPUT); + TensorShape output_shape = acc.get_tensor_shape(TensorSlotName::OUTPUT); + + std::optional per_device_state = + init_kernel(kernel_device_type, input_shape, output_shape, attrs); + + return DeviceSpecificPerDeviceOpState{ + acc.make_device_specific(per_device_state), + }; +} + +static std::optional + forward_task_impl(TaskArgumentAccessor const &acc) { + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + ElementUnaryAttrs attrs = acc.get_op_attrs().require_element_unary(); + + device_handle_t handle = acc.get_ff_handle(); + + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + ElementUnaryPerDeviceState per_device_state = + acc.get_per_device_op_state().require_element_unary().value(); + + return profile(forward_kernel, + profiling, + kernel_device_type, + "[ElementUnary] forward_time = {:.2lf}ms\n", + per_device_state, + attrs, + handle, + input, + output); +} + +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto input_grad = acc.get_tensor_grad(TensorSlotName::INPUT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + auto output_grad = + acc.get_tensor_grad(TensorSlotName::OUTPUT); + + ElementUnaryAttrs attrs = acc.get_op_attrs().require_element_unary(); + device_handle_t handle = acc.get_ff_handle(); + + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + ElementUnaryPerDeviceState per_device_state = + acc.get_per_device_op_state().require_element_unary().value(); + + return profile(backward_kernel, + profiling, + kernel_device_type, + "[ElementUnary] backward_time = {:.2lf}ms\n", + per_device_state, + attrs, + handle, + output, + output_grad, + input, + input_grad); +} + +TaskImplFunction get_element_unary_init_task_impl() { + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; +} + +TaskImplFunction get_element_unary_fwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; +} + +TaskImplFunction get_element_unary_bwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; +} + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/impl/embedding.cc b/lib/task-spec/src/task-spec/ops/impl/embedding.cc new file mode 100644 index 0000000000..dfc55d020a --- /dev/null +++ b/lib/task-spec/src/task-spec/ops/impl/embedding.cc @@ -0,0 +1,70 @@ +#include "task-spec/ops/impl/embedding.h" +#include "kernels/embedding_kernels.h" +#include "task-spec/profiling.h" + +namespace FlexFlow { + +using namespace FlexFlow::Kernels::Embedding; + +static std::optional + forward_task_impl(TaskArgumentAccessor const &acc) { + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto weight = acc.get_tensor(TensorSlotName::WEIGHT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + + ProfilingSettings profiling = acc.get_profiling_settings(); + EmbeddingAttrs attrs = acc.get_op_attrs().require_embedding(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + + return profile( + forward_kernel, + profiling, + kernel_device_type, + "[Embedding] forward_time = {:.2lf}ms\n", + input, + output, + weight, + input.shape.data_type, + output.shape.data_type, + attrs.aggr, + get_num_dims(input.shape.dims), + get_num_dims(output.shape.dims), + dim_at_idx(input.shape.dims, legion_dim_t{1_n}).int_from_positive_int()); +} + +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + auto weight_grad = + acc.get_tensor_grad(TensorSlotName::WEIGHT); + + ProfilingSettings profiling = acc.get_profiling_settings(); + EmbeddingAttrs attrs = acc.get_op_attrs().require_embedding(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + + return profile( + backward_kernel, + profiling, + kernel_device_type, + "[Embedding] backward_time = {:.2lf}ms\n", + output, + input, + weight_grad, + output.shape.data_type, + input.shape.data_type, + attrs.aggr, + get_num_dims(input.shape.dims), + get_num_dims(output.shape.dims), + dim_at_idx(input.shape.dims, ff_dim_t{0_n}).int_from_positive_int()); +} + +TaskImplFunction get_embedding_fwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; +} + +TaskImplFunction get_embedding_bwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; +} + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/impl/flat.cc b/lib/task-spec/src/task-spec/ops/impl/flat.cc new file mode 100644 index 0000000000..321f7720b7 --- /dev/null +++ b/lib/task-spec/src/task-spec/ops/impl/flat.cc @@ -0,0 +1,51 @@ +#include "task-spec/ops/impl/flat.h" +#include "kernels/flat_kernels.h" +#include "task-spec/profiling.h" + +namespace FlexFlow { + +using namespace FlexFlow::Kernels::Flat; + +static std::optional + forward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + + return profile(forward_kernel, + profiling, + kernel_device_type, + "[Flat] forward_time = {:.2lf}ms\n", + input, + output.get_float_ptr()); +} + +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto output_grad = + acc.get_tensor_grad(TensorSlotName::OUTPUT); + auto input_grad = acc.get_tensor_grad(TensorSlotName::INPUT); + + return profile(backward_kernel, + profiling, + kernel_device_type, + "[Flat] backward_time = {:.2lf}ms\n", + input, + output_grad.get_float_ptr(), + input_grad.get_float_ptr()); +} + +TaskImplFunction get_flat_fwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; +} + +TaskImplFunction get_flat_bwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; +} + +}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/impl/gather.cc b/lib/task-spec/src/task-spec/ops/impl/gather.cc new file mode 100644 index 0000000000..6544e2f521 --- /dev/null +++ b/lib/task-spec/src/task-spec/ops/impl/gather.cc @@ -0,0 +1,110 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "task-spec/ops/impl/gather.h" +#include "kernels/gather_kernels.h" +#include "op-attrs/ff_ordered/get_idxs.h" +#include "task-spec/profiling.h" +#include "utils/nonnegative_int/nonnegative_range.h" +#include + +namespace FlexFlow { + +using namespace FlexFlow::Kernels::Gather; + +static DeviceSpecificPerDeviceOpState + init_task_impl(TaskArgumentAccessor const &acc) { + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto index = acc.get_tensor(TensorSlotName::INDEX); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + + device_handle_t handle = acc.get_ff_handle(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + GatherAttrs attrs = acc.get_op_attrs().require_gather(); + + ASSERT(get_num_dims(input.shape.dims) == get_num_dims(index.shape.dims)); + ASSERT(get_num_dims(output.shape.dims) == get_num_dims(index.shape.dims)); + + for (ff_dim_t i : get_idxs(input.shape.dims.ff_ordered)) { + ASSERT(dim_at_idx(index.shape.dims, i) == dim_at_idx(output.shape.dims, i)); + if (i != attrs.dim) { + ASSERT(dim_at_idx(input.shape.dims, i) == + dim_at_idx(index.shape.dims, i)); + } + } + + std::optional per_device_state = + init_kernel(kernel_device_type, handle, attrs.dim); + return DeviceSpecificPerDeviceOpState{ + acc.make_device_specific(per_device_state), + }; +} + +static std::optional + forward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + GatherPerDeviceState per_device_state = + acc.get_per_device_op_state().require_gather().value(); + + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto index = acc.get_tensor(TensorSlotName::INDEX); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + + return profile(forward_kernel, + profiling, + kernel_device_type, + "[Gather] forward_time = {:.2lf}ms\n", + per_device_state, + input, + index, + output); +} + +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + GatherPerDeviceState per_device_state = + acc.get_per_device_op_state().require_gather().value(); + + auto output_grad = + acc.get_tensor_grad(TensorSlotName::OUTPUT); + auto index = acc.get_tensor(TensorSlotName::INDEX); + auto input_grad = acc.get_tensor_grad(TensorSlotName::INPUT); + + return profile(backward_kernel, + profiling, + kernel_device_type, + "[Gather] backward_time = {:.2lf}ms\n", + per_device_state, + output_grad, + index, + input_grad); +} + +TaskImplFunction get_gather_init_task_impl() { + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; +} + +TaskImplFunction get_gather_fwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; +} + +TaskImplFunction get_gather_bwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; +} + +}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/impl/layer_norm.cc b/lib/task-spec/src/task-spec/ops/impl/layer_norm.cc new file mode 100644 index 0000000000..2f1952d769 --- /dev/null +++ b/lib/task-spec/src/task-spec/ops/impl/layer_norm.cc @@ -0,0 +1,128 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "task-spec/ops/impl/layer_norm.h" +#include "kernels/layer_norm_kernels.h" +#include "op-attrs/ff_ordered/transform.h" +#include "op-attrs/ops/layer_norm.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "task-spec/profiling.h" +#include "utils/containers/product.h" +#include "utils/exception.h" +#include "utils/hash-utils.h" +#include "utils/nonnegative_int/nonnegative_range.h" +#include + +namespace FlexFlow { + +using namespace FlexFlow::Kernels::LayerNorm; + +static std::optional + forward_task_impl(TaskArgumentAccessor const &acc) { + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + auto gamma = acc.get_tensor(TensorSlotName::GAMMA); + auto beta = acc.get_tensor(TensorSlotName::BETA); + + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + LayerNormPerDeviceState state = + acc.get_per_device_op_state().require_layer_norm().value(); + + return profile(forward_kernel, + profiling, + kernel_device_type, + "[LayerNorm] forward time = {:.2lf}ms\n", + state, + input, + output, + gamma, + beta); +} + +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto gamma = acc.get_tensor(TensorSlotName::GAMMA); + + auto input_grad = acc.get_tensor_grad(TensorSlotName::INPUT); + auto gamma_grad = acc.get_tensor_grad(TensorSlotName::GAMMA); + auto beta_grad = acc.get_tensor_grad(TensorSlotName::BETA); + auto output_grad = + acc.get_tensor_grad(TensorSlotName::OUTPUT); + + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + LayerNormPerDeviceState state = + acc.get_per_device_op_state().require_layer_norm().value(); + + return profile(backward_kernel, + profiling, + kernel_device_type, + "[LayerNorm] backward time = {:.2lf}ms\n", + state, + output_grad, + input, + input_grad, + gamma, + gamma_grad, + beta_grad); +} + +static DeviceSpecificPerDeviceOpState + init_task_impl(TaskArgumentAccessor const &acc) { + LayerNormAttrs attrs = acc.get_op_attrs().require_layer_norm(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + Allocator allocator = acc.get_allocator(); + auto input = acc.get_tensor(TensorSlotName::INPUT); + device_handle_t handle = acc.get_ff_handle(); + + positive_int M = product(transform(attrs.axes, [&](ff_dim_t dim) { + return dim_at_idx(input.shape.dims, dim); + })); + + positive_int num_replicas = get_num_elements(input.shape.dims); + + positive_int effective_num_elements = M; + positive_int effective_batch_size = + positive_int{get_num_elements(input.shape.dims) / M}; + + std::optional per_device_state = + init_kernel(kernel_device_type, + handle, + allocator, + attrs.elementwise_affine, + effective_batch_size.int_from_positive_int(), + effective_num_elements.int_from_positive_int(), + attrs.eps); + + return DeviceSpecificPerDeviceOpState{ + acc.make_device_specific(per_device_state), + }; +} + +TaskImplFunction get_layer_norm_init_task_impl() { + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; +} + +TaskImplFunction get_layer_norm_fwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; +} + +TaskImplFunction get_layer_norm_bwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; +} + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/impl/linear.cc b/lib/task-spec/src/task-spec/ops/impl/linear.cc new file mode 100644 index 0000000000..e90cbd2544 --- /dev/null +++ b/lib/task-spec/src/task-spec/ops/impl/linear.cc @@ -0,0 +1,124 @@ +#include "task-spec/ops/impl/linear.h" +#include "kernels/format_accessor_contents.h" +#include "kernels/linear_kernels.h" +#include "op-attrs/ff_dim_t.h" +#include "task-spec/profiling.h" +#include "task-spec/task_argument_accessor/task_argument_accessor.h" +#include "utils/exception.h" +#include "utils/hash-utils.h" + +namespace FlexFlow { + +static DeviceSpecificPerDeviceOpState + init_task_impl(TaskArgumentAccessor const &acc) { + LinearAttrs attrs = acc.get_op_attrs().require_linear(); + device_handle_t handle = acc.get_ff_handle(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto weight = acc.get_tensor(TensorSlotName::WEIGHT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + positive_int out_dim = dim_at_idx(output.shape.dims, ff_dim_t{0_n}); + positive_int batch_size = dim_at_idx(output.shape.dims, ff_dim_t{1_n}); + + std::optional per_device_state = + linear_init_kernel(kernel_device_type, + handle, + attrs.activation, + attrs.regularizer, + attrs.use_bias, + input.shape.data_type, + weight.shape.data_type, + output.shape.data_type, + batch_size.int_from_positive_int(), + attrs.out_channels.int_from_positive_int()); + + return DeviceSpecificPerDeviceOpState{ + acc.make_device_specific(per_device_state), + }; +} + +static std::optional + forward_task_impl(TaskArgumentAccessor const &acc) { + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto weight = acc.get_tensor(TensorSlotName::WEIGHT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + + LinearAttrs attrs = acc.get_op_attrs().require_linear(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + ProfilingSettings profiling = acc.get_profiling_settings(); + LinearPerDeviceState per_device_state = + acc.get_per_device_op_state().require_linear().value(); + + std::optional bias = std::nullopt; + if (attrs.use_bias) { + bias = acc.get_tensor(TensorSlotName::BIAS); + } + + auto result = profile(linear_forward_kernel, + profiling, + kernel_device_type, + "[Linear] forward_time = {:.2lf}ms\n", + per_device_state, + attrs, + input, + output, + weight, + bias); + + return result; +} + +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto weight = acc.get_tensor(TensorSlotName::WEIGHT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + + auto input_grad = acc.get_tensor_grad(TensorSlotName::INPUT); + auto weight_grad = + acc.get_tensor_grad(TensorSlotName::WEIGHT); + auto output_grad = + acc.get_tensor_grad(TensorSlotName::OUTPUT); + + LinearAttrs attrs = acc.get_op_attrs().require_linear(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + ProfilingSettings profiling = acc.get_profiling_settings(); + LinearPerDeviceState per_device_state = + acc.get_per_device_op_state().require_linear().value(); + + std::optional bias_grad = std::nullopt; + if (attrs.use_bias) { + bias_grad = acc.get_tensor(TensorSlotName::BIAS); + } + + auto result = profile(linear_backward_kernel, + profiling, + kernel_device_type, + "[Linear] backward_time = {:.2lf}ms\n", + per_device_state, + attrs, + output, + output_grad, + input, + input_grad, + weight, + weight_grad, + bias_grad); + + return result; +} + +TaskImplFunction get_linear_init_task_impl() { + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; +} + +TaskImplFunction get_linear_fwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; +} + +TaskImplFunction get_linear_bwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; +} + +}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/noop.cc b/lib/task-spec/src/task-spec/ops/impl/noop.cc similarity index 80% rename from lib/task-spec/src/task-spec/ops/noop.cc rename to lib/task-spec/src/task-spec/ops/impl/noop.cc index 4d69b8fd5f..17846c24e6 100644 --- a/lib/task-spec/src/task-spec/ops/noop.cc +++ b/lib/task-spec/src/task-spec/ops/impl/noop.cc @@ -13,12 +13,4 @@ * limitations under the License. */ -#include "task-spec/ops/noop.h" - -namespace FlexFlow { - -std::vector get_task_ids(NoopAttrs const &attrs) { - return {}; -} - -}; // namespace FlexFlow +#include "task-spec/ops/impl/noop.h" diff --git a/lib/task-spec/src/task-spec/ops/impl/pool_2d.cc b/lib/task-spec/src/task-spec/ops/impl/pool_2d.cc new file mode 100644 index 0000000000..ba2d984115 --- /dev/null +++ b/lib/task-spec/src/task-spec/ops/impl/pool_2d.cc @@ -0,0 +1,123 @@ +#include "task-spec/ops/impl/pool_2d.h" +#include "kernels/pool_2d_kernels.h" +#include "op-attrs/ops/pool_2d.h" +#include "task-spec/profiling.h" +#include "utils/exception.h" +#include "utils/hash-utils.h" + +using namespace FlexFlow::Kernels::Pool2D; + +namespace FlexFlow { + +static nonnegative_int calculate_padding(nonnegative_int output_size, + nonnegative_int stride, + nonnegative_int kernel_size, + nonnegative_int input_size) { + int o = output_size.unwrap_nonnegative(); + int s = stride.unwrap_nonnegative(); + int k = kernel_size.unwrap_nonnegative(); + int i = kernel_size.unwrap_nonnegative(); + + return nonnegative_int{ + ((o - 1) * s + k - i + 1) / 2, + }; +} + +static DeviceSpecificPerDeviceOpState + init_task_impl(TaskArgumentAccessor const &acc) { + Pool2DAttrs attrs = acc.get_op_attrs().require_pool2d(); + device_handle_t handle = acc.get_ff_handle(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + + positive_int input_w = dim_at_idx(input.shape.dims, ff_dim_t{0_n}); + positive_int input_h = dim_at_idx(input.shape.dims, ff_dim_t{1_n}); + positive_int input_c = dim_at_idx(input.shape.dims, ff_dim_t{2_n}); + positive_int input_n = dim_at_idx(input.shape.dims, ff_dim_t{3_n}); + positive_int output_w = dim_at_idx(output.shape.dims, ff_dim_t{0_n}); + positive_int output_h = dim_at_idx(output.shape.dims, ff_dim_t{1_n}); + positive_int output_c = dim_at_idx(output.shape.dims, ff_dim_t{2_n}); + positive_int output_n = dim_at_idx(output.shape.dims, ff_dim_t{3_n}); + + std::optional per_device_state = + init_kernel(kernel_device_type, + handle, + attrs.activation, + input_w.int_from_positive_int(), + input_h.int_from_positive_int(), + input_c.int_from_positive_int(), + input_n.int_from_positive_int(), + output_w.int_from_positive_int(), + output_h.int_from_positive_int(), + output_c.int_from_positive_int(), + output_n.int_from_positive_int(), + attrs.padding_h.unwrap_nonnegative(), + attrs.padding_w.unwrap_nonnegative(), + attrs.kernel_h.int_from_positive_int(), + attrs.kernel_w.int_from_positive_int(), + attrs.stride_h.int_from_positive_int(), + attrs.stride_w.int_from_positive_int(), + attrs.pool_type); + + return DeviceSpecificPerDeviceOpState{ + acc.make_device_specific(per_device_state), + }; +} + +static std::optional + forward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + Pool2DPerDeviceState state = + acc.get_per_device_op_state().require_pool_2d().value(); + + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + + return profile(forward_kernel, + profiling, + kernel_device_type, + "[Pool2D] forward_time = {:.2lf}ms\n", + state, + input.get_float_ptr(), + output.get_float_ptr()); +} + +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + Pool2DPerDeviceState state = + acc.get_per_device_op_state().require_pool_2d().value(); + + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + auto output_grad = acc.get_tensor(TensorSlotName::OUTPUT); + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto input_grad = acc.get_tensor(TensorSlotName::INPUT); + + return profile(backward_kernel, + profiling, + kernel_device_type, + "[Pool2D] backward_time = {:.2lf}ms\n", + state, + output.get_float_ptr(), + output_grad.get_float_ptr(), + input.get_float_ptr(), + input_grad.get_float_ptr()); +} + +TaskImplFunction get_pool_2d_init_task_impl() { + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; +} + +TaskImplFunction get_pool_2d_fwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; +} + +TaskImplFunction get_pool_2d_bwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; +} + +}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/impl/reduce.cc b/lib/task-spec/src/task-spec/ops/impl/reduce.cc new file mode 100644 index 0000000000..45034114a2 --- /dev/null +++ b/lib/task-spec/src/task-spec/ops/impl/reduce.cc @@ -0,0 +1,89 @@ +#include "task-spec/ops/impl/reduce.h" +#include "kernels/reduce_kernels.h" +#include "task-spec/profiling.h" +#include "utils/exception.h" +#include "utils/hash-utils.h" +#include "utils/type_traits_core.h" + +namespace FlexFlow { + +using namespace FlexFlow::Kernels::Reduce; + +static DeviceSpecificPerDeviceOpState + init_task_impl(TaskArgumentAccessor const &acc) { + device_handle_t handle = acc.get_ff_handle(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + ReduceAttrs attrs = acc.get_op_attrs().require_reduce(); + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + + OperatorType op_type = attrs.op_type; + + nonnegative_int reduction_size = + get_num_elements(input.shape.dims) / get_num_elements(output.shape.dims); + + std::optional per_device_state = + init_kernel(kernel_device_type, + handle, + op_type, + reduction_size.unwrap_nonnegative(), + input.shape, + output.shape); + + return DeviceSpecificPerDeviceOpState{ + acc.make_device_specific(per_device_state), + }; +} + +static std::optional + forward_task_impl(TaskArgumentAccessor const &acc) { + ReducePerDeviceState per_device_state = + acc.get_per_device_op_state().require_reduce().value(); + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + + return profile(forward_kernel, + profiling, + kernel_device_type, + "[Reduce] forward_time = {:.2lf}ms\n", + per_device_state, + input.get_float_ptr(), + output.get_float_ptr()); +} + +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { + ReducePerDeviceState per_device_state = + acc.get_per_device_op_state().require_reduce().value(); + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + + auto input_grad = acc.get_tensor_grad(TensorSlotName::INPUT); + auto output_grad = + acc.get_tensor_grad(TensorSlotName::OUTPUT); + + return profile(backward_kernel, + profiling, + kernel_device_type, + "[Reduce] backward_time = {:.2lf}ms\n", + per_device_state, + output_grad.get_float_ptr(), + input_grad.get_float_ptr()); +} + +TaskImplFunction get_reduce_init_task_impl() { + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; +} + +TaskImplFunction get_reduce_fwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; +} + +TaskImplFunction get_reduce_bwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; +} + +}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/impl/reshape.cc b/lib/task-spec/src/task-spec/ops/impl/reshape.cc new file mode 100644 index 0000000000..98470004c2 --- /dev/null +++ b/lib/task-spec/src/task-spec/ops/impl/reshape.cc @@ -0,0 +1,67 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "task-spec/ops/impl/reshape.h" +#include "kernels/reshape_kernels.h" +#include "task-spec/profiling.h" + +namespace FlexFlow { + +using namespace FlexFlow::Kernels::Reshape; + +static std::optional + forward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + ReshapeAttrs attrs = acc.get_op_attrs().require_reshape(); + + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + + return profile(forward_kernel, + profiling, + kernel_device_type, + "[Reshape] forward time = {:.2lf}ms\n", + input, + output); +} + +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + ReshapeAttrs attrs = acc.get_op_attrs().require_reshape(); + + auto input_grad = acc.get_tensor_grad(TensorSlotName::INPUT); + auto output_grad = + acc.get_tensor_grad(TensorSlotName::OUTPUT); + + return profile(backward_kernel, + profiling, + kernel_device_type, + "[Reshape] backward time = {:.2lf}ms\n", + output_grad, + input_grad); +} + +TaskImplFunction get_reshape_fwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; +} + +TaskImplFunction get_reshape_bwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; +} + +}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/impl/reverse.cc b/lib/task-spec/src/task-spec/ops/impl/reverse.cc new file mode 100644 index 0000000000..56baebd67d --- /dev/null +++ b/lib/task-spec/src/task-spec/ops/impl/reverse.cc @@ -0,0 +1,73 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "task-spec/ops/impl/reverse.h" +#include "kernels/accessor.h" +#include "kernels/reverse_kernels.h" +#include "task-spec/profiling.h" +#include "utils/nonnegative_int/nonnegative_range.h" + +namespace FlexFlow { + +using namespace FlexFlow::Kernels::Reverse; + +using coord_t = long long; + +static std::optional + forward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + ReverseAttrs attrs = acc.get_op_attrs().require_reverse(); + + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + + return profile(forward_kernel, + profiling, + kernel_device_type, + "[reverse] forward_time = {:.2lf}ms\n", + input, + output, + attrs); +} + +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + ReverseAttrs attrs = acc.get_op_attrs().require_reverse(); + + auto input_grad = acc.get_tensor_grad(TensorSlotName::INPUT); + auto output_grad = + acc.get_tensor_grad(TensorSlotName::OUTPUT); + + return profile(backward_kernel, + profiling, + kernel_device_type, + "[reverse] backward_time = {:.2lf}ms\n", + output_grad, + input_grad, + attrs); +} + +TaskImplFunction get_reverse_fwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; +} + +TaskImplFunction get_reverse_bwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; +} + +}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/impl/softmax.cc b/lib/task-spec/src/task-spec/ops/impl/softmax.cc new file mode 100644 index 0000000000..66693913e6 --- /dev/null +++ b/lib/task-spec/src/task-spec/ops/impl/softmax.cc @@ -0,0 +1,112 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "task-spec/ops/impl/softmax.h" +#include "kernels/softmax_kernels.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "task-spec/profiling.h" +#include "utils/exception.h" +#include "utils/hash-utils.h" + +namespace FlexFlow { + +using namespace FlexFlow::Kernels::Softmax; + +static DeviceSpecificPerDeviceOpState + init_task_impl(TaskArgumentAccessor const &acc) { + device_handle_t handle = acc.get_ff_handle(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + SoftmaxAttrs attrs = acc.get_op_attrs().require_softmax(); + + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + + positive_int output_w = dim_at_idx(output.shape.dims, legion_dim_t{0_n}); + positive_int output_h = dim_at_idx(output.shape.dims, legion_dim_t{1_n}); + positive_int output_c = dim_at_idx(output.shape.dims, legion_dim_t{2_n}); + positive_int output_n = dim_at_idx(output.shape.dims, legion_dim_t{3_n}); + + std::optional per_device_state = + init_kernel(kernel_device_type, + handle, + attrs.dim, + output_n.int_from_positive_int(), + output_c.int_from_positive_int(), + output_h.int_from_positive_int(), + output_w.int_from_positive_int()); + + return DeviceSpecificPerDeviceOpState{ + acc.make_device_specific(per_device_state), + }; +} + +static std::optional + forward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + SoftmaxPerDeviceState per_device_state = + acc.get_per_device_op_state().require_softmax().value(); + + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + + return profile(forward_kernel, + profiling, + kernel_device_type, + "[Softmax] forward_time = {:.2lf}ms\n", + per_device_state, + input.get_float_ptr(), + output.get_float_ptr()); +} + +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + SoftmaxPerDeviceState per_device_state = + acc.get_per_device_op_state().require_softmax().value(); + + auto input_grad = acc.get_tensor_grad(TensorSlotName::INPUT); + auto input = acc.get_tensor(TensorSlotName::INPUT); + assert(input_grad.shape == input.shape); + + auto output_grad = + acc.get_tensor_grad(TensorSlotName::OUTPUT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + + assert(output_grad.shape == output.shape); + + return profile( + backward_kernel, + profiling, + kernel_device_type, + "[Softmax] backward_time = {:.2lf}ms\n", + output_grad.get_float_ptr(), + input_grad.get_float_ptr(), + get_num_elements(output_grad.shape.dims).int_from_positive_int()); +} + +TaskImplFunction get_softmax_init_task_impl() { + return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; +} + +TaskImplFunction get_softmax_fwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; +} + +TaskImplFunction get_softmax_bwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; +} + +}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/impl/split.cc b/lib/task-spec/src/task-spec/ops/impl/split.cc new file mode 100644 index 0000000000..1e3d3dde92 --- /dev/null +++ b/lib/task-spec/src/task-spec/ops/impl/split.cc @@ -0,0 +1,112 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "task-spec/ops/impl/split.h" +#include "kernels/split_kernels.h" +#include "task-spec/profiling.h" +#include "utils/exception.h" +#include "utils/hash-utils.h" +#include "utils/nonnegative_int/nonnegative_range.h" + +namespace FlexFlow { + +using namespace FlexFlow::Kernels::Split; + +static std::pair + calc_block_size(TensorShape const &tensor_shape, ff_dim_t axis) { + positive_int num_blocks = 1_p; + positive_int block_size = 1_p; + for (nonnegative_int d : + nonnegative_range(get_num_elements(tensor_shape.dims) + .nonnegative_int_from_positive_int())) { + if (d <= axis.value) { + block_size *= dim_at_idx(tensor_shape.dims, legion_dim_t{d}); + } else { + num_blocks *= dim_at_idx(tensor_shape.dims, legion_dim_t{d}); + } + } + return {num_blocks, block_size}; +} + +static std::optional + forward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + SplitAttrs attrs = acc.get_op_attrs().require_split(); + + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + + int out_block_sizes[MAX_NUM_OUTPUTS]; + auto [num_blocks, in_block_size] = calc_block_size(input.shape, attrs.axis); + + for (int i = 0; i < attrs.splits.size(); i++) { + auto [_, out_block_size] = calc_block_size(output.shape, attrs.axis); + out_block_sizes[i] = out_block_size.int_from_positive_int(); + } + float *output_float_ptr = output.get_float_ptr(); + return profile(forward_kernel, + profiling, + kernel_device_type, + "[Split] forward_time = {:.2lf}ms\n", + &output_float_ptr, + input.get_float_ptr(), + out_block_sizes, + in_block_size.int_from_positive_int(), + num_blocks.int_from_positive_int(), + attrs.splits.size()); +} + +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + SplitAttrs attrs = acc.get_op_attrs().require_split(); + + auto input_grad = acc.get_tensor_grad(TensorSlotName::INPUT); + auto output_grad = + acc.get_tensor_grad(TensorSlotName::OUTPUT); + + int out_block_sizes[MAX_NUM_OUTPUTS]; + auto [num_blocks, in_block_size] = + calc_block_size(input_grad.shape, attrs.axis); + + for (int i = 0; i < attrs.splits.size(); i++) { + int out_num_blocks; + auto [_, out_block_size] = calc_block_size(output_grad.shape, attrs.axis); + out_block_sizes[i] = out_block_size.int_from_positive_int(); + } + float const *output_grad_ptr = output_grad.get_float_ptr(); + return profile(backward_kernel, + profiling, + kernel_device_type, + "[Split] backward_time = {:.2lf}ms\n", + input_grad.get_float_ptr(), + &output_grad_ptr, + out_block_sizes, + in_block_size.int_from_positive_int(), + num_blocks.int_from_positive_int(), + attrs.splits.size()); +} + +TaskImplFunction get_split_fwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; +} + +TaskImplFunction get_split_bwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; +} + +}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/impl/topk.cc b/lib/task-spec/src/task-spec/ops/impl/topk.cc new file mode 100644 index 0000000000..4a1a813c3e --- /dev/null +++ b/lib/task-spec/src/task-spec/ops/impl/topk.cc @@ -0,0 +1,88 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "task-spec/ops/impl/topk.h" +#include "kernels/topk_kernels.h" +#include "task-spec/profiling.h" +#include "utils/exception.h" + +namespace FlexFlow { + +using namespace FlexFlow::Kernels::TopK; + +static std::optional + forward_task_impl(TaskArgumentAccessor const &acc) { + TopKAttrs attrs = acc.get_op_attrs().require_topk(); + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + + positive_int length = dim_at_idx(input.shape.dims, legion_dim_t{0_n}); + positive_int batch_size = + positive_int{get_num_elements(input.shape.dims) / length}; + auto indices = acc.get_tensor(TensorSlotName::INDEX); + + return profile(forward_kernel, + profiling, + kernel_device_type, + "[TopK] forward_time = {:.2lf}ms\n", + input.get_float_ptr(), + output.get_float_ptr(), + indices.get_int32_ptr(), + batch_size.int_from_positive_int(), + length.int_from_positive_int(), + attrs.k.int_from_positive_int(), + attrs.sorted); +} + +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { + auto attrs = acc.get_op_attrs().require_topk(); + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + + auto input_grad = acc.get_tensor_grad(TensorSlotName::INPUT); + auto output_grad = + acc.get_tensor_grad(TensorSlotName::OUTPUT); + + auto indices = acc.get_tensor(TensorSlotName::INDEX); + + positive_int length = dim_at_idx(input_grad.shape.dims, legion_dim_t{0_n}); + positive_int batch_size = + positive_int{get_num_elements(input_grad.shape.dims) / length}; + + return profile(backward_kernel, + profiling, + kernel_device_type, + "[TopK] backward_time = {:.2lf}ms\n", + output_grad.get_float_ptr(), + indices.get_int32_ptr(), + input_grad.get_float_ptr(), + batch_size.int_from_positive_int(), + length.int_from_positive_int(), + attrs.k.int_from_positive_int()); +} + +TaskImplFunction get_topk_fwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; +} + +TaskImplFunction get_topk_bwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; +} + +}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/impl/transpose.cc b/lib/task-spec/src/task-spec/ops/impl/transpose.cc new file mode 100644 index 0000000000..6b0c1d6d44 --- /dev/null +++ b/lib/task-spec/src/task-spec/ops/impl/transpose.cc @@ -0,0 +1,71 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "task-spec/ops/impl/transpose.h" +#include "kernels/transpose_kernels.h" +#include "op-attrs/ops/transpose.h" +#include "task-spec/profiling.h" +#include "utils/integer_conversions.h" + +using namespace FlexFlow::Kernels::Transpose; + +namespace FlexFlow { + +static std::optional + forward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_profiling_settings(); + TransposeAttrs attrs = acc.get_op_attrs().require_transpose(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + + auto input = acc.get_tensor(TensorSlotName::INPUT); + auto output = acc.get_tensor(TensorSlotName::OUTPUT); + + return profile(forward_kernel, + profiling, + kernel_device_type, + "[Transpose] Forward_time = {:.2lf} [ms]", + attrs, + input, + output); +} + +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_profiling_settings(); + TransposeAttrs attrs = acc.get_op_attrs().require_transpose(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); + + auto input_grad = acc.get_tensor_grad(TensorSlotName::INPUT); + auto output_grad = + acc.get_tensor_grad(TensorSlotName::OUTPUT); + + return profile(backward_kernel, + profiling, + kernel_device_type, + "[Transpose] Backward_time = {:.2lf} [ms]", + attrs, + output_grad, + input_grad); +} + +TaskImplFunction get_transpose_fwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; +} + +TaskImplFunction get_transpose_bwd_task_impl() { + return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; +} + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/input.cc b/lib/task-spec/src/task-spec/ops/input.cc deleted file mode 100644 index 53caadfe68..0000000000 --- a/lib/task-spec/src/task-spec/ops/input.cc +++ /dev/null @@ -1,9 +0,0 @@ -#include "task-spec/ops/input.h" - -namespace FlexFlow { - -std::vector get_task_ids(InputAttrs const &attrs) { - return {}; -} - -}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/layer_norm.cc b/lib/task-spec/src/task-spec/ops/layer_norm.cc deleted file mode 100644 index b37e63c2d1..0000000000 --- a/lib/task-spec/src/task-spec/ops/layer_norm.cc +++ /dev/null @@ -1,217 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "task-spec/ops/layer_norm.h" -#include "kernels/layer_norm_kernels.h" -#include "op-attrs/ff_ordered/transform.h" -#include "op-attrs/ops/layer_norm.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "task-spec/profiling.h" -#include "utils/containers/product.h" -#include "utils/exception.h" -#include "utils/hash-utils.h" -#include "utils/nonnegative_int/nonnegative_range.h" -#include - -namespace FlexFlow { - -using namespace FlexFlow::Kernels::LayerNorm; - -enum Slots { - PROFILING, - INPUT, - OUTPUT, - GAMMA, - BETA, - PER_DEVICE_STATE, - ATTRS, - HANDLE, - KERNEL_DEVICE_TYPE, -}; - -OpTaskInvocation init(LayerNormAttrs const &attrs) { - OpTaskBinding b; - - b.bind(INPUT, input_tensor(0_n)); - - b.bind_arg(HANDLE, ff_handle()); - b.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - b.bind_arg(ATTRS, attrs); - - return OpTaskInvocation{ - task_id_t::LAYERNORM_INIT_TASK_ID, - b, - }; -} - -OpTaskInvocation forward(LayerNormAttrs const &attrs) { - OpTaskBinding b; - - b.bind(INPUT, input_tensor(0_n)); - b.bind(OUTPUT, output_tensor(0_n)); - b.bind(GAMMA, weight_tensor(0_n)); - b.bind(BETA, weight_tensor(1_n)); - b.bind_arg(PROFILING, profiling_settings()); - b.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - b.bind_arg(PER_DEVICE_STATE, - per_device_op_state>()); - - return OpTaskInvocation{ - task_id_t::LAYERNORM_FWD_TASK_ID, - b, - }; -} - -OpTaskInvocation backward(LayerNormAttrs const &attrs) { - OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - - return OpTaskInvocation{ - task_id_t::LAYERNORM_BWD_TASK_ID, - b, - }; -} - -static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - auto gamma = acc.get_tensor(GAMMA); - auto beta = acc.get_tensor(BETA); - - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto &state = acc.get_argument(PER_DEVICE_STATE); - - return profile(forward_kernel, - profiling, - kernel_device_type, - "[LayerNorm] forward time = {:.2lf}ms\n", - state, - input, - output, - gamma, - beta); -} - -static std::optional - backward_task_impl(TaskArgumentAccessor const &acc) { - auto input = acc.get_tensor(INPUT); - auto gamma = acc.get_tensor(GAMMA); - - auto input_grad = acc.get_tensor_grad(INPUT); - auto gamma_grad = acc.get_tensor_grad(GAMMA); - auto beta_grad = acc.get_tensor_grad(BETA); - auto output_grad = acc.get_tensor_grad(OUTPUT); - - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto &state = acc.get_argument(PER_DEVICE_STATE); - - return profile(backward_kernel, - profiling, - kernel_device_type, - "[LayerNorm] backward time = {:.2lf}ms\n", - state, - output_grad, - input, - input_grad, - gamma, - gamma_grad, - beta_grad); -} - -static DeviceSpecificDeviceStates - init_task_impl(TaskArgumentAccessor const &acc) { - auto const &attrs = acc.get_argument(ATTRS); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - Allocator allocator = acc.get_allocator(); - auto input = acc.get_tensor(INPUT); - auto handle = acc.get_argument(HANDLE); - - positive_int M = product(transform(attrs.axes, [&](ff_dim_t dim) { - return dim_at_idx(input.shape.dims, dim); - })); - - positive_int num_replicas = get_num_elements(input.shape.dims); - - positive_int effective_num_elements = M; - positive_int effective_batch_size = - positive_int{get_num_elements(input.shape.dims) / M}; - - std::optional per_device_state = - init_kernel(kernel_device_type, - handle, - allocator, - attrs.elementwise_affine, - effective_batch_size.int_from_positive_int(), - effective_num_elements.int_from_positive_int(), - attrs.eps); - - return DeviceSpecificDeviceStates{ - DeviceSpecific>::create( - per_device_state), - }; -} - -TaskImplFunction get_layer_norm_init_task_impl() { - return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; -} -TaskImplFunction get_layer_norm_fwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; -} -TaskImplFunction get_layer_norm_bwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; -} - -OpTaskSignature get_layer_norm_fwd_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_input_slot(INPUT); - fwd.add_output_slot(OUTPUT); - fwd.add_weight_slot(GAMMA); - fwd.add_weight_slot(BETA); - - fwd.add_arg_slot(PROFILING); - fwd.add_arg_slot(KERNEL_DEVICE_TYPE); - fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); - return fwd; -} - -OpTaskSignature get_layer_norm_bwd_signature() { - OpTaskSignature bwd = infer_bwd_signature(get_layer_norm_fwd_signature()); - return bwd; -} - -OpTaskSignature get_layer_norm_init_signature() { - OpTaskSignature init(OpTaskType::INIT); - - init.add_input_slot(INPUT); - init.add_arg_slot(ATTRS); - init.add_arg_slot(KERNEL_DEVICE_TYPE); - init.add_unchecked_arg_slot(HANDLE); - - init.add_return_value(); - return init; -} - -std::vector get_task_ids(LayerNormAttrs const &) { - return {task_id_t::LAYERNORM_INIT_TASK_ID, - task_id_t::LAYERNORM_FWD_TASK_ID, - task_id_t::LAYERNORM_BWD_TASK_ID}; -} - -} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/linear.cc b/lib/task-spec/src/task-spec/ops/linear.cc deleted file mode 100644 index 9ce02bc7fd..0000000000 --- a/lib/task-spec/src/task-spec/ops/linear.cc +++ /dev/null @@ -1,226 +0,0 @@ -#include "task-spec/ops/linear.h" -#include "kernels/format_accessor_contents.h" -#include "kernels/linear_kernels.h" -#include "op-attrs/ff_dim_t.h" -#include "task-spec/profiling.h" -#include "task-spec/task_argument_accessor.h" -#include "utils/exception.h" -#include "utils/hash-utils.h" - -namespace FlexFlow { - -enum slots { - INPUT, - OUTPUT, - WEIGHT, - BIAS, - ATTRS, - PROFILING, - HANDLE, - PER_DEVICE_STATE, - KERNEL_DEVICE_TYPE, -}; - -OpTaskInvocation init(LinearAttrs const &attrs) { - OpTaskBinding binding; - - binding.bind_arg(HANDLE, ff_handle()); - binding.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - binding.bind_arg(ATTRS, attrs); - - binding.bind(INPUT, input_tensor(0_n)); - binding.bind(WEIGHT, weight_tensor(0_n)); - binding.bind(OUTPUT, output_tensor(0_n)); - - return OpTaskInvocation{ - task_id_t::LINEAR_INIT_TASK_ID, - binding, - }; -} - -OpTaskInvocation forward(LinearAttrs const &attrs) { - OpTaskBinding binding; - - binding.bind(INPUT, input_tensor(0_n)); - binding.bind(WEIGHT, weight_tensor(0_n)); - binding.bind(OUTPUT, output_tensor(0_n)); - if (attrs.use_bias) { - binding.bind(BIAS, weight_tensor(1_n)); - } - - binding.bind_arg(PROFILING, profiling_settings()); - binding.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - binding.bind_arg(PER_DEVICE_STATE, - per_device_op_state>()); - binding.bind_arg(ATTRS, attrs); - - return OpTaskInvocation{ - task_id_t::LINEAR_FWD_TASK_ID, - binding, - }; -} - -OpTaskInvocation backward(LinearAttrs const &attrs) { - OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - - return OpTaskInvocation{ - task_id_t::LINEAR_BWD_TASK_ID, - b, - }; -} - -static DeviceSpecificDeviceStates - init_task_impl(TaskArgumentAccessor const &acc) { - auto const &attrs = acc.get_argument(ATTRS); - device_handle_t handle = acc.get_argument(HANDLE); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - - auto input = acc.get_tensor(INPUT); - auto weight = acc.get_tensor(WEIGHT); - auto output = acc.get_tensor(OUTPUT); - positive_int out_dim = dim_at_idx(output.shape.dims, ff_dim_t{0_n}); - positive_int batch_size = dim_at_idx(output.shape.dims, ff_dim_t{1_n}); - - std::optional per_device_state = - linear_init_kernel(kernel_device_type, - handle, - attrs.activation, - attrs.regularizer, - attrs.use_bias, - input.shape.data_type, - weight.shape.data_type, - output.shape.data_type, - batch_size.int_from_positive_int(), - attrs.out_channels.int_from_positive_int()); - - return DeviceSpecificDeviceStates{ - DeviceSpecific>::create( - per_device_state), - }; -} - -static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { - auto input = acc.get_tensor(INPUT); - auto weight = acc.get_tensor(WEIGHT); - auto output = acc.get_tensor(OUTPUT); - - auto per_device_state = - acc.get_argument>(PER_DEVICE_STATE); - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto attrs = acc.get_argument(ATTRS); - - std::optional bias = std::nullopt; - if (attrs.use_bias) { - bias = acc.get_tensor(BIAS); - } - - auto result = profile(linear_forward_kernel, - profiling, - kernel_device_type, - "[Linear] forward_time = {:.2lf}ms\n", - per_device_state, - attrs, - input, - output, - weight, - bias); - - return result; -} - -static std::optional - backward_task_impl(TaskArgumentAccessor const &acc) { - auto input = acc.get_tensor(INPUT); - auto weight = acc.get_tensor(WEIGHT); - auto output = acc.get_tensor(OUTPUT); - - auto input_grad = acc.get_tensor_grad(INPUT); - auto weight_grad = acc.get_tensor_grad(WEIGHT); - auto output_grad = acc.get_tensor_grad(OUTPUT); - - auto per_device_state = - acc.get_argument>(PER_DEVICE_STATE); - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto attrs = acc.get_argument(ATTRS); - - std::optional bias_grad = std::nullopt; - if (attrs.use_bias) { - bias_grad = acc.get_tensor(BIAS); - } - - auto result = profile(linear_backward_kernel, - profiling, - kernel_device_type, - "[Linear] backward_time = {:.2lf}ms\n", - per_device_state, - attrs, - output, - output_grad, - input, - input_grad, - weight, - weight_grad, - bias_grad); - - return result; -} - -TaskImplFunction get_linear_init_task_impl() { - return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; -} - -TaskImplFunction get_linear_fwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; -} - -TaskImplFunction get_linear_bwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; -} - -OpTaskSignature get_linear_init_signature() { - OpTaskSignature init(OpTaskType::INIT); - - init.add_input_slot(INPUT); - init.add_weight_slot(WEIGHT); - init.add_output_slot(OUTPUT); - - init.add_arg_slot(ATTRS); - init.add_arg_slot(KERNEL_DEVICE_TYPE); - init.add_unchecked_arg_slot(HANDLE); - - init.add_return_value(); - return init; -} - -OpTaskSignature get_linear_fwd_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_input_slot(INPUT); - fwd.add_weight_slot(WEIGHT); - fwd.add_optional_weight_slot(BIAS); - fwd.add_output_slot(OUTPUT); - - fwd.add_arg_slot(PROFILING); - fwd.add_arg_slot(KERNEL_DEVICE_TYPE); - fwd.add_arg_slot(ATTRS); - fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); - return fwd; -} - -OpTaskSignature get_linear_bwd_signature() { - OpTaskSignature bwd = infer_bwd_signature(get_linear_fwd_signature()); - return bwd; -} - -std::vector get_task_ids(LinearAttrs const &) { - return {task_id_t::LINEAR_INIT_TASK_ID, - task_id_t::LINEAR_FWD_TASK_ID, - task_id_t::LINEAR_BWD_TASK_ID}; -} - -}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/op_arg_ref.cc b/lib/task-spec/src/task-spec/ops/op_arg_ref.cc similarity index 96% rename from lib/task-spec/src/task-spec/op_arg_ref.cc rename to lib/task-spec/src/task-spec/ops/op_arg_ref.cc index 29c895f1c8..3b56780294 100644 --- a/lib/task-spec/src/task-spec/op_arg_ref.cc +++ b/lib/task-spec/src/task-spec/ops/op_arg_ref.cc @@ -1,4 +1,4 @@ -#include "task-spec/op_arg_ref.h" +#include "task-spec/ops/op_arg_ref.h" namespace FlexFlow { diff --git a/lib/task-spec/src/task-spec/ops/op_arg_ref_spec.cc b/lib/task-spec/src/task-spec/ops/op_arg_ref_spec.cc new file mode 100644 index 0000000000..d40ff26db4 --- /dev/null +++ b/lib/task-spec/src/task-spec/ops/op_arg_ref_spec.cc @@ -0,0 +1 @@ +#include "task-spec/ops/op_arg_ref_spec.h" diff --git a/lib/task-spec/src/task-spec/op_tensor_spec.cc b/lib/task-spec/src/task-spec/ops/op_tensor_spec.cc similarity index 91% rename from lib/task-spec/src/task-spec/op_tensor_spec.cc rename to lib/task-spec/src/task-spec/ops/op_tensor_spec.cc index ed312e47af..aa9befdbb3 100644 --- a/lib/task-spec/src/task-spec/op_tensor_spec.cc +++ b/lib/task-spec/src/task-spec/ops/op_tensor_spec.cc @@ -1,4 +1,4 @@ -#include "task-spec/op_tensor_spec.h" +#include "task-spec/ops/op_tensor_spec.h" namespace FlexFlow { diff --git a/lib/task-spec/src/task-spec/ops/pool_2d.cc b/lib/task-spec/src/task-spec/ops/pool_2d.cc deleted file mode 100644 index 20707acb2d..0000000000 --- a/lib/task-spec/src/task-spec/ops/pool_2d.cc +++ /dev/null @@ -1,212 +0,0 @@ -#include "task-spec/ops/pool_2d.h" -#include "kernels/pool_2d_kernels.h" -#include "op-attrs/ops/pool_2d.h" -#include "task-spec/profiling.h" -#include "utils/exception.h" -#include "utils/hash-utils.h" - -using namespace FlexFlow::Kernels::Pool2D; - -namespace FlexFlow { - -enum Slots { - INPUT, - OUTPUT, - ATTRS, - PROFILING, - PER_DEVICE_STATE, - HANDLE, - KERNEL_DEVICE_TYPE -}; - -OpTaskInvocation init(Pool2DAttrs const &attrs) { - OpTaskBinding binding; - binding.bind(INPUT, input_tensor(0_n)); - binding.bind(OUTPUT, output_tensor(0_n)); - binding.bind_arg(ATTRS, attrs); - binding.bind_arg(HANDLE, ff_handle()); - binding.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - - return OpTaskInvocation{ - task_id_t::POOL2D_INIT_TASK_ID, - binding, - }; -} - -static nonnegative_int calculate_padding(nonnegative_int output_size, - nonnegative_int stride, - nonnegative_int kernel_size, - nonnegative_int input_size) { - int o = output_size.unwrap_nonnegative(); - int s = stride.unwrap_nonnegative(); - int k = kernel_size.unwrap_nonnegative(); - int i = kernel_size.unwrap_nonnegative(); - - return nonnegative_int{ - ((o - 1) * s + k - i + 1) / 2, - }; -} - -static DeviceSpecificDeviceStates - init_task_impl(TaskArgumentAccessor const &acc) { - auto const &attrs = acc.get_argument(ATTRS); - device_handle_t handle = acc.get_argument(HANDLE); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - - positive_int input_w = dim_at_idx(input.shape.dims, ff_dim_t{0_n}); - positive_int input_h = dim_at_idx(input.shape.dims, ff_dim_t{1_n}); - positive_int input_c = dim_at_idx(input.shape.dims, ff_dim_t{2_n}); - positive_int input_n = dim_at_idx(input.shape.dims, ff_dim_t{3_n}); - positive_int output_w = dim_at_idx(output.shape.dims, ff_dim_t{0_n}); - positive_int output_h = dim_at_idx(output.shape.dims, ff_dim_t{1_n}); - positive_int output_c = dim_at_idx(output.shape.dims, ff_dim_t{2_n}); - positive_int output_n = dim_at_idx(output.shape.dims, ff_dim_t{3_n}); - - std::optional per_device_state = - init_kernel(kernel_device_type, - handle, - attrs.activation, - input_w.int_from_positive_int(), - input_h.int_from_positive_int(), - input_c.int_from_positive_int(), - input_n.int_from_positive_int(), - output_w.int_from_positive_int(), - output_h.int_from_positive_int(), - output_c.int_from_positive_int(), - output_n.int_from_positive_int(), - attrs.padding_h.unwrap_nonnegative(), - attrs.padding_w.unwrap_nonnegative(), - attrs.kernel_h.int_from_positive_int(), - attrs.kernel_w.int_from_positive_int(), - attrs.stride_h.int_from_positive_int(), - attrs.stride_w.int_from_positive_int(), - attrs.pool_type); - - return DeviceSpecificDeviceStates{ - DeviceSpecific>::create( - per_device_state), - }; -} - -OpTaskInvocation forward(Pool2DAttrs const &attrs) { - OpTaskBinding binding; - binding.bind(INPUT, input_tensor(0_n)); - binding.bind(OUTPUT, output_tensor(0_n)); - - binding.bind_arg(PROFILING, profiling_settings()); - binding.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - binding.bind_arg(PER_DEVICE_STATE, - per_device_op_state>()); - - return OpTaskInvocation{ - task_id_t::POOL2D_FWD_TASK_ID, - binding, - }; -} - -OpTaskInvocation backward(Pool2DAttrs const &attrs) { - OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - - return OpTaskInvocation{ - task_id_t::POOL2D_BWD_TASK_ID, - b, - }; -} - -static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - Pool2DPerDeviceState state = - acc.get_argument(PER_DEVICE_STATE); - - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - - return profile(forward_kernel, - profiling, - kernel_device_type, - "[Pool2D] forward_time = {:.2lf}ms\n", - state, - input.get_float_ptr(), - output.get_float_ptr()); -} - -static std::optional - backward_task_impl(TaskArgumentAccessor const &acc) { - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - Pool2DPerDeviceState state = - acc.get_argument(PER_DEVICE_STATE); - - auto output = acc.get_tensor(OUTPUT); - auto output_grad = acc.get_tensor(OUTPUT); - auto input = acc.get_tensor(INPUT); - auto input_grad = acc.get_tensor(INPUT); - - return profile(backward_kernel, - profiling, - kernel_device_type, - "[Pool2D] backward_time = {:.2lf}ms\n", - state, - output.get_float_ptr(), - output_grad.get_float_ptr(), - input.get_float_ptr(), - input_grad.get_float_ptr()); -} - -TaskImplFunction get_pool_2d_init_task_impl() { - return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; -} - -TaskImplFunction get_pool_2d_fwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; -} - -TaskImplFunction get_pool_2d_bwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; -} - -OpTaskSignature get_pool_2d_init_signature() { - OpTaskSignature init(OpTaskType::INIT); - - init.add_input_slot(INPUT); - init.add_output_slot(OUTPUT); - - init.add_arg_slot(ATTRS); - init.add_arg_slot(KERNEL_DEVICE_TYPE); - init.add_unchecked_arg_slot(HANDLE); - - init.add_return_value(); - return init; -} - -OpTaskSignature get_pool_2d_fwd_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_input_slot(INPUT); - fwd.add_output_slot(OUTPUT); - fwd.add_arg_slot(PROFILING); - fwd.add_arg_slot(KERNEL_DEVICE_TYPE); - - fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); - return fwd; -} - -OpTaskSignature get_pool_2d_bwd_signature() { - OpTaskSignature bwd = infer_bwd_signature(get_pool_2d_fwd_signature()); - return bwd; -} - -std::vector get_task_ids(Pool2DAttrs const &) { - return {task_id_t::POOL2D_INIT_TASK_ID, - task_id_t::POOL2D_FWD_TASK_ID, - task_id_t::POOL2D_BWD_TASK_ID}; -} - -}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/reduce.cc b/lib/task-spec/src/task-spec/ops/reduce.cc deleted file mode 100644 index d8818393ec..0000000000 --- a/lib/task-spec/src/task-spec/ops/reduce.cc +++ /dev/null @@ -1,179 +0,0 @@ -#include "task-spec/ops/reduce.h" -#include "kernels/reduce_kernels.h" -#include "task-spec/profiling.h" -#include "utils/exception.h" -#include "utils/hash-utils.h" -#include "utils/type_traits_core.h" - -namespace FlexFlow { - -using namespace FlexFlow::Kernels::Reduce; - -enum Slots { - INPUT, - OUTPUT, - ATTRS, - PROFILING, - REDUCE, - PER_DEVICE_STATE, - HANDLE, - KERNEL_DEVICE_TYPE, -}; - -OpTaskInvocation init(ReduceAttrs const &attrs) { - OpTaskBinding binding; - - binding.bind_arg(HANDLE, ff_handle()); - binding.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - binding.bind_arg(ATTRS, attrs); - - binding.bind(INPUT, input_tensor(0_n)); - binding.bind(OUTPUT, output_tensor(0_n)); - - return OpTaskInvocation{ - task_id_t::REDUCE_INIT_TASK_ID, - binding, - }; -} - -static DeviceSpecificDeviceStates - init_task_impl(TaskArgumentAccessor const &acc) { - device_handle_t handle = acc.get_argument(HANDLE); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto attrs = acc.get_argument(ATTRS); - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - - OperatorType op_type = attrs.op_type; - - nonnegative_int reduction_size = - get_num_elements(input.shape.dims) / get_num_elements(output.shape.dims); - - std::optional per_device_state = - init_kernel(kernel_device_type, - handle, - op_type, - reduction_size.unwrap_nonnegative(), - input.shape, - output.shape); - - return DeviceSpecificDeviceStates{ - DeviceSpecific>::create( - per_device_state), - }; -} - -// Note: forward_kernel only needs ReducePerDeviceState, input, output -OpTaskInvocation forward(ReduceAttrs const &attrs) { - OpTaskBinding binding; - - binding.bind_arg(PER_DEVICE_STATE, - per_device_op_state>()); - binding.bind_arg(PROFILING, profiling_settings()); - binding.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - - binding.bind(INPUT, input_tensor(0_n)); - binding.bind(OUTPUT, output_tensor(0_n)); - - return OpTaskInvocation{ - task_id_t::REDUCE_FWD_TASK_ID, - binding, - }; -} - -static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { - auto per_device_state = - acc.get_argument(PER_DEVICE_STATE); - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - - return profile(forward_kernel, - profiling, - kernel_device_type, - "[Reduce] forward_time = {:.2lf}ms\n", - per_device_state, - input.get_float_ptr(), - output.get_float_ptr()); -} - -OpTaskInvocation backward(ReduceAttrs const &attrs) { - OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - - return OpTaskInvocation{ - task_id_t::REDUCE_BWD_TASK_ID, - binding, - }; -} - -static std::optional - backward_task_impl(TaskArgumentAccessor const &acc) { - auto per_device_state = - acc.get_argument(PER_DEVICE_STATE); - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - - auto input_grad = acc.get_tensor_grad(INPUT); - auto output_grad = acc.get_tensor_grad(OUTPUT); - - return profile(backward_kernel, - profiling, - kernel_device_type, - "[Reduce] backward_time = {:.2lf}ms\n", - per_device_state, - output_grad.get_float_ptr(), - input_grad.get_float_ptr()); -} - -TaskImplFunction get_reduce_init_task_impl() { - return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; -} - -TaskImplFunction get_reduce_fwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; -} - -TaskImplFunction get_reduce_bwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; -} - -OpTaskSignature get_reduce_init_signature() { - OpTaskSignature init(OpTaskType::INIT); - - init.add_unchecked_arg_slot(HANDLE); - init.add_arg_slot(ATTRS); - init.add_arg_slot(KERNEL_DEVICE_TYPE); - - init.add_return_value(); - return init; -} - -OpTaskSignature get_reduce_fwd_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); - fwd.add_arg_slot(PROFILING); - fwd.add_arg_slot(KERNEL_DEVICE_TYPE); - - fwd.add_input_slot(INPUT); - fwd.add_output_slot(OUTPUT); - return fwd; -} - -OpTaskSignature get_reduce_bwd_signature() { - OpTaskSignature bwd = infer_bwd_signature(get_reduce_fwd_signature()); - return bwd; -} - -std::vector get_task_ids(ReduceAttrs const &) { - return {task_id_t::REDUCE_INIT_TASK_ID, - task_id_t::REDUCE_FWD_TASK_ID, - task_id_t::REDUCE_BWD_TASK_ID}; -} - -}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/reshape.cc b/lib/task-spec/src/task-spec/ops/reshape.cc deleted file mode 100644 index b6d8cabd82..0000000000 --- a/lib/task-spec/src/task-spec/ops/reshape.cc +++ /dev/null @@ -1,112 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "task-spec/ops/reshape.h" -#include "kernels/reshape_kernels.h" -#include "task-spec/profiling.h" - -namespace FlexFlow { - -using namespace FlexFlow::Kernels::Reshape; - -enum Slots { INPUT, OUTPUT, ATTRS, PROFILING, KERNEL_DEVICE_TYPE }; - -OpTaskInvocation forward(ReshapeAttrs const &attrs) { - OpTaskBinding binding; - - binding.bind_arg(PROFILING, profiling_settings()); - binding.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - binding.bind_arg(ATTRS, attrs); - - binding.bind(INPUT, input_tensor(0_n)); - binding.bind(OUTPUT, output_tensor(0_n)); - return OpTaskInvocation{ - task_id_t::RESHAPE_FWD_TASK_ID, - binding, - }; -} - -OpTaskInvocation backward(ReshapeAttrs const &attrs) { - OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - - return OpTaskInvocation{ - task_id_t::RESHAPE_BWD_TASK_ID, - binding, - }; -} - -static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - ReshapeAttrs attrs = acc.get_argument(ATTRS); - - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - - return profile(forward_kernel, - profiling, - kernel_device_type, - "[Reshape] forward time = {:.2lf}ms\n", - input, - output); -} - -static std::optional - backward_task_impl(TaskArgumentAccessor const &acc) { - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - ReshapeAttrs attrs = acc.get_argument(ATTRS); - - auto input_grad = acc.get_tensor_grad(INPUT); - auto output_grad = acc.get_tensor_grad(OUTPUT); - - return profile(backward_kernel, - profiling, - kernel_device_type, - "[Reshape] backward time = {:.2lf}ms\n", - output_grad, - input_grad); -} - -TaskImplFunction get_reshape_fwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; -} - -TaskImplFunction get_reshape_bwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; -} - -OpTaskSignature get_reshape_fwd_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_arg_slot(PROFILING); - - fwd.add_input_slot(INPUT); - fwd.add_output_slot(OUTPUT); - return fwd; -} - -OpTaskSignature get_reshape_bwd_signature() { - OpTaskSignature bwd = infer_bwd_signature(get_reshape_fwd_signature()); - return bwd; -} - -std::vector get_task_ids(ReshapeAttrs const &) { - return {task_id_t::RESHAPE_FWD_TASK_ID, task_id_t::RESHAPE_BWD_TASK_ID}; -} - -}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/reverse.cc b/lib/task-spec/src/task-spec/ops/reverse.cc deleted file mode 100644 index 9d1a8e1753..0000000000 --- a/lib/task-spec/src/task-spec/ops/reverse.cc +++ /dev/null @@ -1,114 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "task-spec/ops/reverse.h" -#include "kernels/accessor.h" -#include "kernels/reverse_kernels.h" -#include "task-spec/profiling.h" -#include "utils/nonnegative_int/nonnegative_range.h" - -namespace FlexFlow { - -using namespace FlexFlow::Kernels::Reverse; -using coord_t = long long; - -enum Slots { INPUT, OUTPUT, ATTRS, PROFILING, KERNEL_DEVICE_TYPE }; - -OpTaskInvocation forward(ReverseAttrs const &attrs) { - OpTaskBinding binding; - - binding.bind_arg(PROFILING, profiling_settings()); - binding.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - binding.bind_arg(ATTRS, attrs); - - binding.bind(INPUT, input_tensor(0_n)); - binding.bind(OUTPUT, output_tensor(0_n)); - - return OpTaskInvocation{ - task_id_t::REVERSE_FWD_TASK_ID, - binding, - }; -} -OpTaskInvocation backward(ReverseAttrs const &attrs) { - OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - - return OpTaskInvocation{ - task_id_t::REVERSE_BWD_TASK_ID, - binding, - }; -} - -static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - auto attrs = acc.get_argument(ATTRS); - - return profile(forward_kernel, - profiling, - kernel_device_type, - "[reverse] forward_time = {:.2lf}ms\n", - input, - output, - attrs); -} - -static std::optional - backward_task_impl(TaskArgumentAccessor const &acc) { - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto input_grad = acc.get_tensor_grad(INPUT); - auto output_grad = acc.get_tensor_grad(OUTPUT); - auto attrs = acc.get_argument(ATTRS); - - return profile(backward_kernel, - profiling, - kernel_device_type, - "[reverse] backward_time = {:.2lf}ms\n", - output_grad, - input_grad, - attrs); -} - -TaskImplFunction get_reverse_fwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; -} -TaskImplFunction get_reverse_bwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; -} - -OpTaskSignature get_reverse_fwd_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_arg_slot(PROFILING); - fwd.add_arg_slot(KERNEL_DEVICE_TYPE); - fwd.add_input_slot(INPUT); - fwd.add_output_slot(OUTPUT); - return fwd; -} - -OpTaskSignature get_reverse_bwd_signature() { - OpTaskSignature bwd = infer_bwd_signature(get_reverse_fwd_signature()); - return bwd; -} - -std::vector get_task_ids(ReverseAttrs const &) { - return {task_id_t::REVERSE_FWD_TASK_ID, task_id_t::REVERSE_BWD_TASK_ID}; -} - -}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/softmax.cc b/lib/task-spec/src/task-spec/ops/softmax.cc deleted file mode 100644 index 89ea42299f..0000000000 --- a/lib/task-spec/src/task-spec/ops/softmax.cc +++ /dev/null @@ -1,192 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "task-spec/ops/softmax.h" -#include "kernels/softmax_kernels.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "task-spec/profiling.h" -#include "utils/exception.h" -#include "utils/hash-utils.h" - -namespace FlexFlow { -using namespace FlexFlow::Kernels::Softmax; - -enum Slots { - INPUT, - OUTPUT, - ATTRS, - PROFILING, - PER_DEVICE_STATE, - HANDLE, - KERNEL_DEVICE_TYPE -}; - -OpTaskInvocation init(SoftmaxAttrs const &attrs) { - OpTaskBinding binding; - - binding.bind_arg(HANDLE, ff_handle()); - binding.bind_arg(ATTRS, attrs); - binding.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - - return OpTaskInvocation{ - task_id_t::SOFTMAX_INIT_TASK_ID, - binding, - }; -} - -OpTaskInvocation forward(SoftmaxAttrs const &attrs) { - OpTaskBinding binding; - - binding.bind_arg(PER_DEVICE_STATE, - per_device_op_state>()); - binding.bind_arg(PROFILING, profiling_settings()); - binding.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - - binding.bind(INPUT, input_tensor(0_n)); - binding.bind(OUTPUT, output_tensor(0_n)); - - return OpTaskInvocation{ - task_id_t::SOFTMAX_FWD_TASK_ID, - binding, - }; -} - -OpTaskInvocation backward(SoftmaxAttrs const &attrs) { - OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - - return OpTaskInvocation{ - task_id_t::SOFTMAX_BWD_TASK_ID, - binding, - }; -} - -static DeviceSpecificDeviceStates - init_task_impl(TaskArgumentAccessor const &acc) { - device_handle_t handle = acc.get_argument(HANDLE); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - - auto output = acc.get_tensor(OUTPUT); - auto const &attrs = acc.get_argument(ATTRS); - - positive_int output_w = dim_at_idx(output.shape.dims, legion_dim_t{0_n}); - positive_int output_h = dim_at_idx(output.shape.dims, legion_dim_t{1_n}); - positive_int output_c = dim_at_idx(output.shape.dims, legion_dim_t{2_n}); - positive_int output_n = dim_at_idx(output.shape.dims, legion_dim_t{3_n}); - - std::optional per_device_state = - init_kernel(kernel_device_type, - handle, - attrs.dim, - output_n.int_from_positive_int(), - output_c.int_from_positive_int(), - output_h.int_from_positive_int(), - output_w.int_from_positive_int()); - - return DeviceSpecificDeviceStates{ - DeviceSpecific>::create( - per_device_state), - }; -} - -static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto per_device_state = - acc.get_argument(PER_DEVICE_STATE); - - return profile(forward_kernel, - profiling, - kernel_device_type, - "[Softmax] forward_time = {:.2lf}ms\n", - per_device_state, - input.get_float_ptr(), - output.get_float_ptr()); -} - -static std::optional - backward_task_impl(TaskArgumentAccessor const &acc) { - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - - auto input_grad = acc.get_tensor_grad(INPUT); - auto input = acc.get_tensor(INPUT); - assert(input_grad.shape == input.shape); - - auto output_grad = acc.get_tensor_grad(OUTPUT); - auto output = acc.get_tensor(OUTPUT); - - assert(output_grad.shape == output.shape); - - return profile( - backward_kernel, - profiling, - kernel_device_type, - "[Softmax] backward_time = {:.2lf}ms\n", - output_grad.get_float_ptr(), - input_grad.get_float_ptr(), - get_num_elements(output_grad.shape.dims).int_from_positive_int()); -} - -TaskImplFunction get_softmax_init_task_impl() { - return TaskImplFunction{InitOpTaskImplFunction{init_task_impl}}; -} - -TaskImplFunction get_softmax_fwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; -} - -TaskImplFunction get_softmax_bwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; -} - -OpTaskSignature get_softmax_init_signature() { - OpTaskSignature init(OpTaskType::INIT); - - init.add_unchecked_arg_slot(HANDLE); - init.add_arg_slot(ATTRS); - init.add_arg_slot(KERNEL_DEVICE_TYPE); - init.add_return_value(); - return init; -} - -OpTaskSignature get_softmax_fwd_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_arg_slot(PROFILING); - fwd.add_arg_slot(KERNEL_DEVICE_TYPE); - fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); - - fwd.add_input_slot(INPUT); - fwd.add_output_slot(OUTPUT); - return fwd; -} - -OpTaskSignature get_softmax_bwd_signature() { - OpTaskSignature bwd = infer_bwd_signature(get_softmax_fwd_signature()); - return bwd; -} - -std::vector get_task_ids(SoftmaxAttrs const &) { - return {task_id_t::SOFTMAX_INIT_TASK_ID, - task_id_t::SOFTMAX_FWD_TASK_ID, - task_id_t::SOFTMAX_BWD_TASK_ID}; -} - -}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/split.cc b/lib/task-spec/src/task-spec/ops/split.cc deleted file mode 100644 index 88c16be57c..0000000000 --- a/lib/task-spec/src/task-spec/ops/split.cc +++ /dev/null @@ -1,155 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "task-spec/ops/split.h" -#include "kernels/split_kernels.h" -#include "task-spec/profiling.h" -#include "utils/exception.h" -#include "utils/hash-utils.h" -#include "utils/nonnegative_int/nonnegative_range.h" - -namespace FlexFlow { - -using namespace FlexFlow::Kernels::Split; - -enum Slots { INPUT, OUTPUT, ATTRS, PROFILING, KERNEL_DEVICE_TYPE }; - -OpTaskInvocation forward(SplitAttrs const &attrs) { - OpTaskBinding binding; - - binding.bind_arg(PROFILING, profiling_settings()); - binding.bind_arg(ATTRS, attrs); - binding.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - - binding.bind(INPUT, input_tensor(0_n)); - binding.bind(OUTPUT, output_tensor(0_n)); - - return OpTaskInvocation{ - task_id_t::SPLIT_FWD_TASK_ID, - binding, - }; -} - -OpTaskInvocation backward(SplitAttrs const &attrs) { - OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - - return OpTaskInvocation{ - task_id_t::SPLIT_BWD_TASK_ID, - binding, - }; -} - -static std::pair - calc_block_size(TensorShape const &tensor_shape, ff_dim_t axis) { - positive_int num_blocks = 1_p; - positive_int block_size = 1_p; - for (nonnegative_int d : - nonnegative_range(get_num_elements(tensor_shape.dims) - .nonnegative_int_from_positive_int())) { - if (d <= axis.value) { - block_size *= dim_at_idx(tensor_shape.dims, legion_dim_t{d}); - } else { - num_blocks *= dim_at_idx(tensor_shape.dims, legion_dim_t{d}); - } - } - return {num_blocks, block_size}; -} - -static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - auto attrs = acc.get_argument(ATTRS); - - int out_block_sizes[MAX_NUM_OUTPUTS]; - auto [num_blocks, in_block_size] = calc_block_size(input.shape, attrs.axis); - - for (int i = 0; i < attrs.splits.size(); i++) { - auto [_, out_block_size] = calc_block_size(output.shape, attrs.axis); - out_block_sizes[i] = out_block_size.int_from_positive_int(); - } - float *output_float_ptr = output.get_float_ptr(); - return profile(forward_kernel, - profiling, - kernel_device_type, - "[Split] forward_time = {:.2lf}ms\n", - &output_float_ptr, - input.get_float_ptr(), - out_block_sizes, - in_block_size.int_from_positive_int(), - num_blocks.int_from_positive_int(), - attrs.splits.size()); -} - -// maybe we should add assert like the original code -static std::optional - backward_task_impl(TaskArgumentAccessor const &acc) { - ProfilingSettings profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - auto input_grad = acc.get_tensor_grad(INPUT); - auto output_grad = acc.get_tensor_grad(OUTPUT); - auto attrs = acc.get_argument(ATTRS); - - int out_block_sizes[MAX_NUM_OUTPUTS]; - auto [num_blocks, in_block_size] = - calc_block_size(input_grad.shape, attrs.axis); - - for (int i = 0; i < attrs.splits.size(); i++) { - int out_num_blocks; - auto [_, out_block_size] = calc_block_size(output_grad.shape, attrs.axis); - out_block_sizes[i] = out_block_size.int_from_positive_int(); - } - float const *output_grad_ptr = output_grad.get_float_ptr(); - return profile(backward_kernel, - profiling, - kernel_device_type, - "[Split] backward_time = {:.2lf}ms\n", - input_grad.get_float_ptr(), - &output_grad_ptr, - out_block_sizes, - in_block_size.int_from_positive_int(), - num_blocks.int_from_positive_int(), - attrs.splits.size()); -} - -TaskImplFunction get_split_fwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; -} -TaskImplFunction get_split_bwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; -} - -OpTaskSignature get_split_fwd_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_arg_slot(PROFILING); - - fwd.add_input_slot(INPUT); - fwd.add_output_slot(OUTPUT); - return fwd; -} -OpTaskSignature get_split_bwd_signature() { - OpTaskSignature bwd = infer_bwd_signature(get_split_fwd_signature()); - return bwd; -} - -std::vector get_task_ids(SplitAttrs const &) { - return {task_id_t::SPLIT_FWD_TASK_ID, task_id_t::SPLIT_BWD_TASK_ID}; -} - -}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/topk.cc b/lib/task-spec/src/task-spec/ops/topk.cc deleted file mode 100644 index 8ff275dac3..0000000000 --- a/lib/task-spec/src/task-spec/ops/topk.cc +++ /dev/null @@ -1,142 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "task-spec/ops/topk.h" -#include "kernels/topk_kernels.h" -#include "task-spec/profiling.h" -#include "utils/exception.h" - -namespace FlexFlow { - -using namespace FlexFlow::Kernels::TopK; - -// For an input tensor, computes the top k entries in each row -// (resp. vector along the last dimension). Thus, -// values.shape = indices.shape = input.shape[:-1] + [k] - -enum Slots { INPUT, OUTPUT, INDICES, ATTRS, PROFILING, KERNEL_DEVICE_TYPE }; - -OpTaskInvocation forward(TopKAttrs const &attrs) { - OpTaskBinding binding; - - binding.bind_arg(PROFILING, profiling_settings()); - binding.bind_arg(ATTRS, attrs); - binding.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - - binding.bind(INPUT, input_tensor(0_n)); - binding.bind(OUTPUT, output_tensor(0_n)); - binding.bind(INDICES, output_tensor(1_n)); - - return OpTaskInvocation{ - task_id_t::TOPK_FWD_TASK_ID, - binding, - }; -} - -OpTaskInvocation backward(TopKAttrs const &attrs) { - OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - - return OpTaskInvocation{ - task_id_t::TOPK_BWD_TASK_ID, - binding, - }; -} - -static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { - auto attrs = acc.get_argument(ATTRS); - auto profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - - positive_int length = dim_at_idx(input.shape.dims, legion_dim_t{0_n}); - positive_int batch_size = - positive_int{get_num_elements(input.shape.dims) / length}; - auto indices = acc.get_tensor(INDICES); - - return profile(forward_kernel, - profiling, - kernel_device_type, - "[TopK] forward_time = {:.2lf}ms\n", - input.get_float_ptr(), - output.get_float_ptr(), - indices.get_int32_ptr(), - batch_size.int_from_positive_int(), - length.int_from_positive_int(), - attrs.k.int_from_positive_int(), - attrs.sorted); -} - -static std::optional - backward_task_impl(TaskArgumentAccessor const &acc) { - auto attrs = acc.get_argument(ATTRS); - auto profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - - auto input_grad = acc.get_tensor_grad(INPUT); - auto output_grad = acc.get_tensor_grad(OUTPUT); - - auto indices = acc.get_tensor(INDICES); - - positive_int length = dim_at_idx(input_grad.shape.dims, legion_dim_t{0_n}); - positive_int batch_size = - positive_int{get_num_elements(input_grad.shape.dims) / length}; - - return profile(backward_kernel, - profiling, - kernel_device_type, - "[TopK] backward_time = {:.2lf}ms\n", - output_grad.get_float_ptr(), - indices.get_int32_ptr(), - input_grad.get_float_ptr(), - batch_size.int_from_positive_int(), - length.int_from_positive_int(), - attrs.k.int_from_positive_int()); -} - -TaskImplFunction get_topk_fwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; -} - -TaskImplFunction get_topk_bwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; -} - -OpTaskSignature get_topk_fwd_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_arg_slot(PROFILING); - fwd.add_arg_slot(ATTRS); - fwd.add_arg_slot(KERNEL_DEVICE_TYPE); - - fwd.add_input_slot(INPUT); - fwd.add_output_slot(OUTPUT); - fwd.add_output_slot(INDICES); - return fwd; -} - -OpTaskSignature get_topk_bwd_signature() { - OpTaskSignature bwd = infer_bwd_signature(get_topk_fwd_signature()); - return bwd; -} - -std::vector get_task_ids(TopKAttrs const &) { - return {task_id_t::TOPK_FWD_TASK_ID, task_id_t::TOPK_BWD_TASK_ID}; -} - -}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/transpose.cc b/lib/task-spec/src/task-spec/ops/transpose.cc deleted file mode 100644 index b2f94b6484..0000000000 --- a/lib/task-spec/src/task-spec/ops/transpose.cc +++ /dev/null @@ -1,123 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "task-spec/ops/transpose.h" -#include "kernels/transpose_kernels.h" -#include "op-attrs/ops/transpose.h" -#include "task-spec/profiling.h" -#include "utils/integer_conversions.h" - -using namespace FlexFlow::Kernels::Transpose; - -namespace FlexFlow { - -enum Slots { - INPUT, - OUTPUT, - ATTRS, - PROFILING, - KERNEL_DEVICE_TYPE, -}; - -OpTaskInvocation forward(TransposeAttrs const &attrs) { - OpTaskBinding binding; - - binding.bind_arg(PROFILING, profiling_settings()); - binding.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - - binding.bind(INPUT, input_tensor(0_n)); - binding.bind(OUTPUT, output_tensor(0_n)); - - return OpTaskInvocation{ - task_id_t::TRANSPOSE_FWD_TASK_ID, - binding, - }; -} - -static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { - ProfilingSettings profiling = acc.get_argument(PROFILING); - auto attrs = acc.get_argument(ATTRS); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - - return profile(forward_kernel, - profiling, - kernel_device_type, - "[Transpose] Forward_time = {:.2lf} [ms]", - attrs, - input, - output); -} - -static std::optional - backward_task_impl(TaskArgumentAccessor const &acc) { - ProfilingSettings profiling = acc.get_argument(PROFILING); - auto attrs = acc.get_argument(ATTRS); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); - - auto input_grad = acc.get_tensor_grad(INPUT); - auto output_grad = acc.get_tensor_grad(OUTPUT); - - return profile(backward_kernel, - profiling, - kernel_device_type, - "[Transpose] Backward_time = {:.2lf} [ms]", - attrs, - output_grad, - input_grad); -} - -OpTaskInvocation backward(TransposeAttrs const &attrs) { - OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - - return OpTaskInvocation{ - task_id_t::TRANSPOSE_BWD_TASK_ID, - binding, - }; -} - -TaskImplFunction get_transpose_fwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{forward_task_impl}}; -} - -TaskImplFunction get_transpose_bwd_task_impl() { - return TaskImplFunction{FwdBwdOpTaskImplFunction{backward_task_impl}}; -} - -OpTaskSignature get_transpose_fwd_signature() { - OpTaskSignature fwd(OpTaskType::FWD); - - fwd.add_arg_slot(PROFILING); - fwd.add_arg_slot(KERNEL_DEVICE_TYPE); - - fwd.add_input_slot(INPUT); - fwd.add_output_slot(OUTPUT); - return fwd; -} - -OpTaskSignature get_transpose_bwd_signature() { - OpTaskSignature bwd = infer_bwd_signature(get_transpose_fwd_signature()); - return bwd; -} - -std::vector get_task_ids(TransposeAttrs const &) { - return {task_id_t::TRANSPOSE_FWD_TASK_ID, task_id_t::TRANSPOSE_BWD_TASK_ID}; -} - -} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/weight.cc b/lib/task-spec/src/task-spec/ops/weight.cc deleted file mode 100644 index 08c9be26e9..0000000000 --- a/lib/task-spec/src/task-spec/ops/weight.cc +++ /dev/null @@ -1,9 +0,0 @@ -#include "task-spec/ops/weight.h" - -namespace FlexFlow { - -std::vector get_task_ids(WeightAttrs const &attrs) { - return {}; -} - -}; // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/optimizer.cc b/lib/task-spec/src/task-spec/optimizer.cc index c8fa23c2af..447f6095d2 100644 --- a/lib/task-spec/src/task-spec/optimizer.cc +++ b/lib/task-spec/src/task-spec/optimizer.cc @@ -3,73 +3,17 @@ #include "task-spec/profiling.h" #include "utils/containers/get_only.h" #include "utils/overload.h" +#include "utils/units/milliseconds_t.h" namespace FlexFlow { -enum Slots { - ATTRS, - WEIGHT, - WEIGHT_GRAD, - SGD_V, - PROFILING, - ADAM_M, - ADAM_V, - HANDLE, - KERNEL_DEVICE_TYPE, -}; - -TaskSignature get_sgd_update_signature() { - TaskSignature sig = make_empty_task_signature(); - add_slot(sig, WEIGHT, TensorType::FORWARD); - add_slot(sig, WEIGHT_GRAD, TensorType::GRADIENT); - add_slot(sig, SGD_V, TensorType::OPTIMIZER); - - add_arg_slot(sig, ATTRS); - add_arg_slot(sig, PROFILING); - add_arg_slot(sig, KERNEL_DEVICE_TYPE); - add_unchecked_arg_slot( - sig, HANDLE); // how to deal with removal of ParamSync? - - // if (CHOSEN_SYNC_TYPE == ParamSync::NCCL) { - // add_unchecked_arg_slot(sig, HANDLE); - // } - return sig; -} - -TaskInvocation sgd_update(SGDOptimizerAttrs const &attrs, - forward_tensor_guid_t const &weight, - gradient_tensor_guid_t const &weight_grad, - optimizer_tensor_guid_t const &sgd_v) { - TaskBinding b; - b.bind(WEIGHT, weight); - b.bind_grad(WEIGHT_GRAD, weight_grad); - - if (attrs.momentum > 0.0f) { - b.bind_optimizer(SGD_V, sgd_v); - } - b.bind_arg(ATTRS, attrs); - b.bind_arg(PROFILING, profiling_settings()); - b.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - - b.bind_arg(HANDLE, ff_handle()); - return TaskInvocation{task_id_t::SGD_UPD_NCCL_TASK_ID, - b}; // how to deal with removal of ParamSync? - - // if (CHOSEN_SYNC_TYPE == ParamSync::NCCL) { - // b.bind_arg(HANDLE, ff_handle()); - // return TaskInvocation{task_id_t::SGD_UPD_NCCL_TASK_ID, b}; - // } else { - // return TaskInvocation{task_id_t::SGD_UPD_PS_TASK_ID, b}; - // } -} - static void sgd_update_task_impl(TaskArgumentAccessor const &acc) { - auto attrs = acc.get_argument(ATTRS); - auto weight_grad = acc.get_tensor_grad(WEIGHT_GRAD); - auto weight = acc.get_tensor(WEIGHT); - auto profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); + SGDOptimizerAttrs attrs = acc.get_optimizer_attrs().require_sgd_optimizer(); + auto weight_grad = + acc.get_tensor_grad(TensorSlotName::WEIGHT); + auto weight = acc.get_tensor(TensorSlotName::WEIGHT); + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); ASSERT(weight.shape == weight_grad.shape); @@ -82,11 +26,12 @@ static void sgd_update_task_impl(TaskArgumentAccessor const &acc) { std::optional sgd_v = std::nullopt; if (attrs.momentum > 0.0f) { - sgd_v = acc.get_optimizer_tensor(SGD_V); + sgd_v = acc.get_optimizer_tensor(TensorSlotName::WEIGHT, + OptimizerSlotName::SGD_V); ASSERT(sgd_v.value().shape == weight.shape); } - auto handle = acc.get_argument(HANDLE); + device_handle_t handle = acc.get_ff_handle(); profile(sgd_update_task, profiling, kernel_device_type, @@ -99,96 +44,25 @@ static void sgd_update_task_impl(TaskArgumentAccessor const &acc) { weight_grad, num_replicas, weight, - sgd_v); // how to deal with removal of ParamSync? - - // if (CHOSEN_SYNC_TYPE == ParamSync::NCCL) { - // auto handle = acc.get_argument(HANDLE); - // profile(sgd_nccl_update_task_gpu, - // profiling, - // "[SGD NCCL] update_time = %.2lfms\n", - // attrs.lr, - // attrs.momentum, - // attrs.nesterov, - // attrs.weight_decay, - // handle, - // weight_grad.get_float_ptr(), - // size, - // weight.get_float_ptr(), - // sgd_v_ptr); - - // } else { - // profile(sgd_ps_update_task_gpu, - // profiling, - // "[SGD PS] update_time = %.2lfms\n", - // attrs.lr, - // attrs.momentum, - // attrs.nesterov, - // attrs.weight_decay, - // weight_grad.get_float_ptr(), - // size, - // num_replicas, - // weight.get_float_ptr(), - // sgd_v_ptr); - // } + sgd_v); } TaskImplFunction get_sgd_update_task_impl() { return TaskImplFunction{GenericTaskImplFunction{sgd_update_task_impl}}; } -TaskSignature get_adam_update_signature() { - TaskSignature sig = make_empty_task_signature(); - add_slot(sig, WEIGHT, TensorType::FORWARD); - add_slot(sig, WEIGHT_GRAD, TensorType::GRADIENT); - add_slot(sig, ADAM_V, TensorType::OPTIMIZER); - add_slot(sig, ADAM_M, TensorType::OPTIMIZER); - - add_arg_slot(sig, ATTRS); - add_arg_slot(sig, PROFILING); - add_arg_slot(sig, KERNEL_DEVICE_TYPE); - add_unchecked_arg_slot( - sig, HANDLE); // how to deal with removal of ParamSync? - // if (CHOSEN_SYNC_TYPE == ParamSync::NCCL) { - // add_unchecked_arg_slot(sig, HANDLE); - // } - return sig; -} - -TaskInvocation adam_update(AdamOptimizerAttrs const &attrs, - forward_tensor_guid_t const &weight, - gradient_tensor_guid_t const &weight_grad, - optimizer_tensor_guid_t const &adam_v, - optimizer_tensor_guid_t const &adam_m) { - TaskBinding b; - b.bind(WEIGHT, weight); - b.bind_grad(WEIGHT_GRAD, weight_grad); - b.bind_optimizer(ADAM_M, adam_m); - b.bind_optimizer(ADAM_V, adam_v); - b.bind_arg(ATTRS, attrs); - b.bind_arg(PROFILING, profiling_settings()); - b.bind_arg(KERNEL_DEVICE_TYPE, kernel_device_type()); - b.bind_arg(HANDLE, ff_handle()); - return TaskInvocation{task_id_t::ADAM_UPD_NCCL_TASK_ID, - b}; // how to deal with removal of ParamSync? - - // if (CHOSEN_SYNC_TYPE == ParamSync::NCCL) { - // b.bind_arg(HANDLE, ff_handle()); - // return TaskInvocation{task_id_t::ADAM_UPD_NCCL_TASK_ID, b}; - // } else { - // return TaskInvocation{task_id_t::ADAM_UPD_PS_TASK_ID, b}; - // } -} - static void adam_update_task_impl(TaskArgumentAccessor const &acc) { - auto attrs = acc.get_argument(ATTRS); - auto weight_grad = acc.get_tensor_grad(WEIGHT_GRAD); - auto weight = acc.get_tensor(WEIGHT); - auto v_tensor = acc.get_optimizer_tensor(ADAM_V); - auto m_tensor = acc.get_optimizer_tensor(ADAM_M); - - auto profiling = acc.get_argument(PROFILING); - DeviceType kernel_device_type = - acc.get_argument(KERNEL_DEVICE_TYPE); + AdamOptimizerAttrs attrs = acc.get_optimizer_attrs().require_adam_optimizer(); + auto weight_grad = + acc.get_tensor_grad(TensorSlotName::WEIGHT); + auto weight = acc.get_tensor(TensorSlotName::WEIGHT); + auto v_tensor = acc.get_optimizer_tensor( + TensorSlotName::WEIGHT, OptimizerSlotName::ADAM_V); + auto m_tensor = acc.get_optimizer_tensor( + TensorSlotName::WEIGHT, OptimizerSlotName::ADAM_M); + + ProfilingSettings profiling = acc.get_profiling_settings(); + DeviceType kernel_device_type = acc.get_kernel_device_type(); ASSERT(weight.shape == weight_grad.shape); int size = get_num_elements(weight_grad.shape.dims).int_from_positive_int(); @@ -200,7 +74,7 @@ static void adam_update_task_impl(TaskArgumentAccessor const &acc) { get_num_elements(weight_grad.shape.dims).int_from_positive_int() / get_num_elements(weight.shape.dims).int_from_positive_int(); - auto handle = acc.get_argument(HANDLE); + device_handle_t handle = acc.get_ff_handle(); profile(adam_update_task, profiling, kernel_device_type, @@ -216,71 +90,13 @@ static void adam_update_task_impl(TaskArgumentAccessor const &acc) { num_replicas, m_tensor.get_float_ptr(), v_tensor.get_float_ptr(), - weight.get_float_ptr()); // how to deal with removal of ParamSync? - - // if (CHOSEN_SYNC_TYPE == ParamSync::NCCL) { - // auto handle = acc.get_argument(HANDLE); - // profile(adam_nccl_update_task_gpu, - // profiling, - // "[Adam NCCL] update_time = %.2lfms\n", - // attrs.alpha_t, - // attrs.beta1, - // attrs.beta2, - // attrs.weight_decay, - // attrs.epsilon, - // size, - // handle, - // weight_grad.get_float_ptr(), - // m_tensor.get_float_ptr(), - // v_tensor.get_float_ptr(), - // weight.get_float_ptr()); - // } else { - // profile(adam_ps_update_task_gpu, - // profiling, - // "[Adam NCCL] update_time = %.2lfms\n", - // attrs.alpha_t, - // attrs.beta1, - // attrs.beta2, - // attrs.weight_decay, - // attrs.epsilon, - // size, - // num_replicas, - // weight_grad.get_float_ptr(), - // m_tensor.get_float_ptr(), - // v_tensor.get_float_ptr(), - // weight.get_float_ptr()); - // } + weight.get_float_ptr()); } TaskImplFunction get_adam_update_task_impl() { return TaskImplFunction{GenericTaskImplFunction{adam_update_task_impl}}; } -TaskSignature get_update_signature(OptimizerAttrs const &attrs) { - return attrs.visit(overload{ - [&](SGDOptimizerAttrs const &) { return get_sgd_update_signature(); }, - [&](AdamOptimizerAttrs const &) { return get_adam_update_signature(); }}); -} - -TaskInvocation get_update_invocation( - OptimizerAttrs const &attrs, - forward_tensor_guid_t const &weight, - gradient_tensor_guid_t const &weight_grad, - std::vector const &grad_buffer_tensors) { - return attrs.visit( - overload{[&](SGDOptimizerAttrs const &s) { - return sgd_update( - s, weight, weight_grad, get_only(grad_buffer_tensors)); - }, - [&](AdamOptimizerAttrs const &s) { - return adam_update(s, - weight, - weight_grad, - grad_buffer_tensors.at(0), - grad_buffer_tensors.at(1)); - }}); -} - TaskImplFunction get_update_task_impl(OptimizerAttrs const &attrs) { return attrs.visit(overload{ [&](SGDOptimizerAttrs const &) { return get_sgd_update_task_impl(); }, diff --git a/lib/task-spec/src/task-spec/optimizer_tensor_source.cc b/lib/task-spec/src/task-spec/optimizer_tensor_source.cc deleted file mode 100644 index ad7bf9f489..0000000000 --- a/lib/task-spec/src/task-spec/optimizer_tensor_source.cc +++ /dev/null @@ -1,18 +0,0 @@ -#include "task-spec/optimizer_tensor_source.h" - -namespace FlexFlow { - -int OptimizerTensorSource::next_available_optimizer_tensor_id = 0; - -OptimizerTensorSource::OptimizerTensorSource() {} - -optimizer_tensor_guid_t OptimizerTensorSource::new_optimizer_tensor() { - return optimizer_tensor_guid_t{ - OptimizerTensorSource::next_available_optimizer_tensor_id++}; -} - -void OptimizerTensorSource::reset() { - OptimizerTensorSource::next_available_optimizer_tensor_id = 0; -} - -} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/per_device_op_state.cc b/lib/task-spec/src/task-spec/per_device_op_state.cc index a959f4a8c9..12b649e663 100644 --- a/lib/task-spec/src/task-spec/per_device_op_state.cc +++ b/lib/task-spec/src/task-spec/per_device_op_state.cc @@ -4,7 +4,8 @@ namespace FlexFlow { PerDeviceOpState get_device_state_from_device_specific( - DeviceSpecificDeviceStates const &device_specific, size_t device_idx) { + DeviceSpecificPerDeviceOpState const &device_specific, + device_id_t device_idx) { return device_specific.visit( [&](auto const &x) { return PerDeviceOpState{*(x.get(device_idx))}; }); } diff --git a/lib/task-spec/src/task-spec/runtime_arg_config.cc b/lib/task-spec/src/task-spec/runtime_arg_config.cc deleted file mode 100644 index 9f3dc61545..0000000000 --- a/lib/task-spec/src/task-spec/runtime_arg_config.cc +++ /dev/null @@ -1,30 +0,0 @@ -#include "task-spec/runtime_arg_config.h" -#include "kernels/device_handle_t.h" - -namespace FlexFlow { - -RuntimeArgConfig - cpu_make_runtime_arg_config(EnableProfiling enable_profiling, - ProfilingSettings profiling_settings) { - return RuntimeArgConfig{ - DeviceSpecific::create(cpu_make_device_handle_t()), - enable_profiling, - profiling_settings, - DeviceType::CPU, - }; -} - -RuntimeArgConfig - gpu_make_runtime_arg_config(PerDeviceFFHandle const &ff_handle, - EnableProfiling enable_profiling, - ProfilingSettings profiling_settings) { - return RuntimeArgConfig{ - DeviceSpecific::create( - gpu_make_device_handle_t(ff_handle)), - enable_profiling, - profiling_settings, - DeviceType::GPU, - }; -} - -} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/runtime_arg_ref.cc b/lib/task-spec/src/task-spec/runtime_arg_ref.cc deleted file mode 100644 index 3aa1b7f907..0000000000 --- a/lib/task-spec/src/task-spec/runtime_arg_ref.cc +++ /dev/null @@ -1,23 +0,0 @@ -#include "task-spec/runtime_arg_ref.h" -#include "kernels/device_handle_t.dtg.h" -#include "task-spec/device_specific.h" - -namespace FlexFlow { - -RuntimeArgRef profiling_settings() { - return {RuntimeArgRefType::PROFILING_SETTINGS}; -} - -RuntimeArgRef> ff_handle() { - return {RuntimeArgRefType::FF_HANDLE}; -} - -RuntimeArgRef iteration_config() { - return {RuntimeArgRefType::FF_ITERATION_CONFIG}; -} - -RuntimeArgRef kernel_device_type() { - return {RuntimeArgRefType::KERNEL_DEVICE_TYPE}; -} - -} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/task_arg_spec.cc b/lib/task-spec/src/task-spec/task_arg_spec.cc deleted file mode 100644 index 36fa2f71fd..0000000000 --- a/lib/task-spec/src/task-spec/task_arg_spec.cc +++ /dev/null @@ -1,11 +0,0 @@ -#include "task-spec/task_arg_spec.h" -#include "utils/overload.h" - -namespace FlexFlow { - -std::type_index get_type_index(TaskArgSpec const &task_arg_spec) { - return task_arg_spec.visit( - overload{[](auto const &e) { return e.get_type_index(); }}); -} - -} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/task_argument_accessor.cc b/lib/task-spec/src/task-spec/task_argument_accessor.cc deleted file mode 100644 index cee9fc0708..0000000000 --- a/lib/task-spec/src/task-spec/task_argument_accessor.cc +++ /dev/null @@ -1 +0,0 @@ -#include "task-spec/task_argument_accessor.h" diff --git a/lib/task-spec/src/task-spec/task_argument_accessor/itask_argument_accessor.cc b/lib/task-spec/src/task-spec/task_argument_accessor/itask_argument_accessor.cc new file mode 100644 index 0000000000..e1ccf9bd26 --- /dev/null +++ b/lib/task-spec/src/task-spec/task_argument_accessor/itask_argument_accessor.cc @@ -0,0 +1 @@ +#include "task-spec/task_argument_accessor/itask_argument_accessor.h" diff --git a/lib/task-spec/src/task-spec/task_argument_accessor/task_argument_accessor.cc b/lib/task-spec/src/task-spec/task_argument_accessor/task_argument_accessor.cc new file mode 100644 index 0000000000..97f6069d68 --- /dev/null +++ b/lib/task-spec/src/task-spec/task_argument_accessor/task_argument_accessor.cc @@ -0,0 +1,36 @@ +#include "task-spec/task_argument_accessor/task_argument_accessor.h" + +namespace FlexFlow { + +ProfilingSettings TaskArgumentAccessor::get_profiling_settings() const { + return this->ptr->get_profiling_settings(); +} + +device_handle_t TaskArgumentAccessor::get_ff_handle() const { + return this->ptr->get_ff_handle(); +} +DeviceType TaskArgumentAccessor::get_kernel_device_type() const { + return this->ptr->get_kernel_device_type(); +} + +PCGOperatorAttrs TaskArgumentAccessor::get_op_attrs() const { + return this->ptr->get_op_attrs(); +} + +LossAttrs TaskArgumentAccessor::get_loss_attrs() const { + return this->ptr->get_loss_attrs(); +} + +PerDeviceOpState TaskArgumentAccessor::get_per_device_op_state() const { + return this->ptr->get_per_device_op_state(); +} + +FFIterationConfig TaskArgumentAccessor::get_iteration_config() const { + return this->ptr->get_iteration_config(); +} + +OptimizerAttrs TaskArgumentAccessor::get_optimizer_attrs() const { + return this->ptr->get_optimizer_attrs(); +} + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/task_argument_accessor/task_tensor_parameter.cc b/lib/task-spec/src/task-spec/task_argument_accessor/task_tensor_parameter.cc new file mode 100644 index 0000000000..6e8f8ae104 --- /dev/null +++ b/lib/task-spec/src/task-spec/task_argument_accessor/task_tensor_parameter.cc @@ -0,0 +1,23 @@ +#include "task-spec/task_argument_accessor/task_tensor_parameter.h" +#include + +namespace FlexFlow { + +TaskTensorParameter make_task_tensor_parameter_fwd(TensorSlotName slot) { + return TaskTensorParameter{TaskForwardTensorParameter{slot}}; +} + +TaskTensorParameter make_task_tensor_parameter_grad(TensorSlotName slot) { + return TaskTensorParameter{TaskGradientTensorParameter{slot}}; +} + +TaskTensorParameter make_task_tensor_parameter_opt(TensorSlotName slot, + OptimizerSlotName opt_slot) { + return TaskTensorParameter{TaskOptimizerTensorParameter{slot, opt_slot}}; +} + +TaskTensorParameter make_task_tensor_parameter_loss() { + return TaskTensorParameter{TaskLossTensorParameter{}}; +} + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/task_id_with_noop_default_t.cc b/lib/task-spec/src/task-spec/task_id_with_noop_default_t.cc new file mode 100644 index 0000000000..998d73e9ff --- /dev/null +++ b/lib/task-spec/src/task-spec/task_id_with_noop_default_t.cc @@ -0,0 +1,243 @@ +#include "task-spec/task_id_with_noop_default_t.h" +#include "utils/overload.h" + +namespace FlexFlow { + +task_id_with_noop_default_t lift_task_id_t(task_id_t task_id) { + return task_id_with_noop_default_t{task_id}; +} + +task_id_with_noop_default_t default_noop_task() { + return task_id_with_noop_default_t{std::monostate{}}; +} + +task_id_with_noop_default_t lower_op_task_id_to_task_id_with_noop_default_t( + op_task_id_t op_task_id, ComputationGraphOpAttrs const &op_attrs) { + switch (op_task_id) { + case op_task_id_t::INIT: + return get_init_task_id_for_op_attrs(op_attrs); + case op_task_id_t::FWD: + return get_fwd_task_id_for_op_attrs(op_attrs); + case op_task_id_t::BWD: + return get_bwd_task_id_for_op_attrs(op_attrs); + default: + PANIC("Unhandled op_task_id_t", op_task_id); + } +} + +task_id_with_noop_default_t + get_init_task_id_for_op_attrs(ComputationGraphOpAttrs const &op_attrs) { + + return op_attrs.visit(overload{ + [](BatchMatmulAttrs const &) { return default_noop_task(); }, + [](BatchNormAttrs const &) { + return lift_task_id_t(task_id_t::BATCHNORM_INIT_TASK_ID); + }, + [](BroadcastAttrs const &) { return default_noop_task(); }, + [](CastAttrs const &) { return default_noop_task(); }, + [](ConcatAttrs const &) { return default_noop_task(); }, + [](Conv2DAttrs const &) { + return lift_task_id_t(task_id_t::CONV2D_INIT_TASK_ID); + }, + [](DropoutAttrs const &) { + return lift_task_id_t(task_id_t::DROPOUT_INIT_TASK_ID); + }, + [](ElementBinaryAttrs const &) { + return lift_task_id_t(task_id_t::ELEMENTBINARY_INIT_TASK_ID); + }, + [](ElementUnaryAttrs const &) { + return lift_task_id_t(task_id_t::ELEMENTBINARY_INIT_TASK_ID); + }, + [](EmbeddingAttrs const &) { return default_noop_task(); }, + [](FlatAttrs const &) { return default_noop_task(); }, + [](GatherAttrs const &) { + return lift_task_id_t(task_id_t::GATHER_INIT_TASK_ID); + }, + [](InputAttrs const &) { return default_noop_task(); }, + [](LayerNormAttrs const &) { + return lift_task_id_t(task_id_t::LAYERNORM_INIT_TASK_ID); + }, + [](LinearAttrs const &) { + return lift_task_id_t(task_id_t::LINEAR_INIT_TASK_ID); + }, + [](MultiHeadAttentionAttrs const &) { + return lift_task_id_t(task_id_t::ATTENTION_INIT_TASK_ID); + }, + [](NoopAttrs const &) { return default_noop_task(); }, + [](Pool2DAttrs const &) { + return lift_task_id_t(task_id_t::POOL2D_INIT_TASK_ID); + }, + [](ReduceAttrs const &) { + return lift_task_id_t(task_id_t::REDUCE_INIT_TASK_ID); + }, + [](ReshapeAttrs const &) { return default_noop_task(); }, + [](ReverseAttrs const &) { return default_noop_task(); }, + [](SoftmaxAttrs const &) { + return lift_task_id_t(task_id_t::SOFTMAX_INIT_TASK_ID); + }, + [](SplitAttrs const &) { return default_noop_task(); }, + [](TopKAttrs const &) { return default_noop_task(); }, + [](TransposeAttrs const &) { return default_noop_task(); }, + [](WeightAttrs const &) { return default_noop_task(); }, + }); +} + +task_id_with_noop_default_t + get_fwd_task_id_for_op_attrs(ComputationGraphOpAttrs const &op_attrs) { + + return op_attrs.visit(overload{ + [](BatchMatmulAttrs const &) { + return lift_task_id_t(task_id_t::BATCHMATMUL_FWD_TASK_ID); + }, + [](BatchNormAttrs const &) { + return lift_task_id_t(task_id_t::BATCHNORM_FWD_TASK_ID); + }, + [](BroadcastAttrs const &) { + return lift_task_id_t(task_id_t::BROADCAST_FWD_TASK_ID); + }, + [](CastAttrs const &) { + return lift_task_id_t(task_id_t::CAST_FWD_TASK_ID); + }, + [](ConcatAttrs const &) { + return lift_task_id_t(task_id_t::CONCAT_FWD_TASK_ID); + }, + [](Conv2DAttrs const &) { + return lift_task_id_t(task_id_t::CONV2D_FWD_TASK_ID); + }, + [](DropoutAttrs const &) { + return lift_task_id_t(task_id_t::DROPOUT_FWD_TASK_ID); + }, + [](ElementBinaryAttrs const &) { + return lift_task_id_t(task_id_t::ELEMENTBINARY_FWD_TASK_ID); + }, + [](ElementUnaryAttrs const &) { + return lift_task_id_t(task_id_t::ELEMENTBINARY_FWD_TASK_ID); + }, + [](EmbeddingAttrs const &) { + return lift_task_id_t(task_id_t::EMBED_FWD_TASK_ID); + }, + [](FlatAttrs const &) { + return lift_task_id_t(task_id_t::FLAT_FWD_TASK_ID); + }, + [](GatherAttrs const &) { + return lift_task_id_t(task_id_t::GATHER_FWD_TASK_ID); + }, + [](InputAttrs const &) { return default_noop_task(); }, + [](LayerNormAttrs const &) { + return lift_task_id_t(task_id_t::LAYERNORM_FWD_TASK_ID); + }, + [](LinearAttrs const &) { + return lift_task_id_t(task_id_t::LINEAR_FWD_TASK_ID); + }, + [](MultiHeadAttentionAttrs const &) { + return lift_task_id_t(task_id_t::ATTENTION_FWD_TASK_ID); + }, + [](NoopAttrs const &) { return default_noop_task(); }, + [](Pool2DAttrs const &) { + return lift_task_id_t(task_id_t::POOL2D_FWD_TASK_ID); + }, + [](ReduceAttrs const &) { + return lift_task_id_t(task_id_t::REDUCE_FWD_TASK_ID); + }, + [](ReshapeAttrs const &) { + return lift_task_id_t(task_id_t::RESHAPE_FWD_TASK_ID); + }, + [](ReverseAttrs const &) { + return lift_task_id_t(task_id_t::REVERSE_FWD_TASK_ID); + }, + [](SoftmaxAttrs const &) { + return lift_task_id_t(task_id_t::SOFTMAX_FWD_TASK_ID); + }, + [](SplitAttrs const &) { + return lift_task_id_t(task_id_t::SPLIT_FWD_TASK_ID); + }, + [](TopKAttrs const &) { + return lift_task_id_t(task_id_t::TOPK_FWD_TASK_ID); + }, + [](TransposeAttrs const &) { + return lift_task_id_t(task_id_t::TRANSPOSE_FWD_TASK_ID); + }, + [](WeightAttrs const &) { return default_noop_task(); }, + }); +} + +task_id_with_noop_default_t + get_bwd_task_id_for_op_attrs(ComputationGraphOpAttrs const &op_attrs) { + + return op_attrs.visit(overload{ + [](BatchMatmulAttrs const &) { + return lift_task_id_t(task_id_t::BATCHMATMUL_BWD_TASK_ID); + }, + [](BatchNormAttrs const &) { + return lift_task_id_t(task_id_t::BATCHNORM_BWD_TASK_ID); + }, + [](BroadcastAttrs const &) { + return lift_task_id_t(task_id_t::BROADCAST_BWD_TASK_ID); + }, + [](CastAttrs const &) { + return lift_task_id_t(task_id_t::CAST_BWD_TASK_ID); + }, + [](ConcatAttrs const &) { + return lift_task_id_t(task_id_t::CONCAT_BWD_TASK_ID); + }, + [](Conv2DAttrs const &) { + return lift_task_id_t(task_id_t::CONV2D_BWD_TASK_ID); + }, + [](DropoutAttrs const &) { + return lift_task_id_t(task_id_t::DROPOUT_BWD_TASK_ID); + }, + [](ElementBinaryAttrs const &) { + return lift_task_id_t(task_id_t::ELEMENTBINARY_BWD_TASK_ID); + }, + [](ElementUnaryAttrs const &) { + return lift_task_id_t(task_id_t::ELEMENTBINARY_BWD_TASK_ID); + }, + [](EmbeddingAttrs const &) { + return lift_task_id_t(task_id_t::EMBED_BWD_TASK_ID); + }, + [](FlatAttrs const &) { + return lift_task_id_t(task_id_t::FLAT_BWD_TASK_ID); + }, + [](GatherAttrs const &) { + return lift_task_id_t(task_id_t::GATHER_BWD_TASK_ID); + }, + [](InputAttrs const &) { return default_noop_task(); }, + [](LayerNormAttrs const &) { + return lift_task_id_t(task_id_t::LAYERNORM_BWD_TASK_ID); + }, + [](LinearAttrs const &) { + return lift_task_id_t(task_id_t::LINEAR_BWD_TASK_ID); + }, + [](MultiHeadAttentionAttrs const &) { + return lift_task_id_t(task_id_t::ATTENTION_BWD_TASK_ID); + }, + [](NoopAttrs const &) { return default_noop_task(); }, + [](Pool2DAttrs const &) { + return lift_task_id_t(task_id_t::POOL2D_BWD_TASK_ID); + }, + [](ReduceAttrs const &) { + return lift_task_id_t(task_id_t::REDUCE_BWD_TASK_ID); + }, + [](ReshapeAttrs const &) { + return lift_task_id_t(task_id_t::RESHAPE_BWD_TASK_ID); + }, + [](ReverseAttrs const &) { + return lift_task_id_t(task_id_t::REVERSE_BWD_TASK_ID); + }, + [](SoftmaxAttrs const &) { + return lift_task_id_t(task_id_t::SOFTMAX_BWD_TASK_ID); + }, + [](SplitAttrs const &) { + return lift_task_id_t(task_id_t::SPLIT_BWD_TASK_ID); + }, + [](TopKAttrs const &) { + return lift_task_id_t(task_id_t::TOPK_BWD_TASK_ID); + }, + [](TransposeAttrs const &) { + return lift_task_id_t(task_id_t::TRANSPOSE_BWD_TASK_ID); + }, + [](WeightAttrs const &) { return default_noop_task(); }, + }); +} + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/task_invocation.cc b/lib/task-spec/src/task-spec/task_invocation.cc deleted file mode 100644 index 0677ff6e60..0000000000 --- a/lib/task-spec/src/task-spec/task_invocation.cc +++ /dev/null @@ -1,38 +0,0 @@ -#include "task-spec/task_invocation.h" -#include "task-spec/task_arg_spec.h" -#include "utils/containers/keys.h" - -namespace FlexFlow { - -bool is_invocation_valid(TaskSignature const &sig, TaskInvocation const &inv) { - TaskBinding binding = inv.binding; - - for (std::pair const &arg_binding : - binding.get_arg_bindings()) { - if (sig.task_arg_types.count(arg_binding.first)) { - if (get_type_index(arg_binding.second) != - sig.task_arg_types.at(arg_binding.first)) { - return false; // incorrect arg type - } - } else { - return false; // slot doesn't exist in signature - } - } - - for (std::pair const - &tensor_binding : binding.get_tensor_bindings()) { - slot_id_t tensor_slot_id = tensor_binding.first.slot_id; - if (sig.tensor_guid_slots.count(tensor_slot_id)) { - if (tensor_binding.first.tensor_type == - sig.tensor_guid_slots.at(tensor_slot_id).tensor_type) { - return false; // incorrect tensor type - } - } else { - return false; // slot doesn't exist in signature - } - } - - return true; -} - -} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/task_signature.cc b/lib/task-spec/src/task-spec/task_signature.cc deleted file mode 100644 index 3ac038e8c5..0000000000 --- a/lib/task-spec/src/task-spec/task_signature.cc +++ /dev/null @@ -1,25 +0,0 @@ -#include "task-spec/task_signature.h" - -namespace FlexFlow { - -TaskSignature make_empty_task_signature() { - return TaskSignature(std::nullopt, {}, {}); -} - -void add_slot(TaskSignature &task_signature, - int name, - TensorType tensor_type, - SlotType slot_type) { - add_slot(task_signature, slot_id_t{name}, tensor_type, slot_type); -} - -void add_slot(TaskSignature &task_signature, - slot_id_t name, - TensorType tensor_type, - SlotType slot_type) { - TensorTypeSlotSpec tensor_guid_slot_spec = - TensorTypeSlotSpec{name, tensor_type, slot_type}; - task_signature.tensor_guid_slots.insert({name, tensor_guid_slot_spec}); -} - -} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/task_signature_impl.cc b/lib/task-spec/src/task-spec/task_signature_impl.cc deleted file mode 100644 index 8da38b5840..0000000000 --- a/lib/task-spec/src/task-spec/task_signature_impl.cc +++ /dev/null @@ -1,323 +0,0 @@ -#include "task-spec/task_signature_impl.h" -#include "task-spec/ops/attention.h" -#include "task-spec/ops/batch_matmul.h" -#include "task-spec/ops/batch_norm.h" -#include "task-spec/ops/cast.h" -#include "task-spec/ops/concat.h" -#include "task-spec/ops/conv_2d.h" -#include "task-spec/ops/dropout.h" -#include "task-spec/ops/element_binary.h" -#include "task-spec/ops/element_unary.h" -#include "task-spec/ops/embedding.h" -#include "task-spec/ops/flat.h" -#include "task-spec/ops/gather.h" -#include "task-spec/ops/input.h" -#include "task-spec/ops/layer_norm.h" -#include "task-spec/ops/linear.h" -#include "task-spec/ops/noop.h" -#include "task-spec/ops/pool_2d.h" -#include "task-spec/ops/reduce.h" -#include "task-spec/ops/reshape.h" -#include "task-spec/ops/reverse.h" -#include "task-spec/ops/softmax.h" -#include "task-spec/ops/split.h" -#include "task-spec/ops/topk.h" -#include "task-spec/ops/transpose.h" -#include "task-spec/ops/weight.h" -#include "utils/overload.h" - -namespace FlexFlow { - -TaskSignatureAndImpl - get_task_signature_and_impl_for_task_id(task_id_t const &task_id) { - switch (task_id) { - case task_id_t::ELEMENTBINARY_INIT_TASK_ID: - return TaskSignatureAndImpl{get_element_binary_init_task_impl(), - get_element_binary_init_signature()}; - case task_id_t::ELEMENTBINARY_FWD_TASK_ID: - return TaskSignatureAndImpl{get_element_binary_fwd_task_impl(), - get_element_binary_fwd_signature()}; - case task_id_t::ELEMENTBINARY_BWD_TASK_ID: - return TaskSignatureAndImpl{get_element_binary_bwd_task_impl(), - get_element_binary_bwd_signature()}; - case task_id_t::ELEMENTUNARY_INIT_TASK_ID: - return TaskSignatureAndImpl{get_element_unary_init_task_impl(), - get_element_unary_init_signature()}; - case task_id_t::ELEMENTUNARY_FWD_TASK_ID: - return TaskSignatureAndImpl{get_element_unary_fwd_task_impl(), - get_element_unary_fwd_signature()}; - case task_id_t::ELEMENTUNARY_BWD_TASK_ID: - return TaskSignatureAndImpl{get_element_unary_bwd_task_impl(), - get_element_unary_bwd_signature()}; - case task_id_t::CONV2D_INIT_TASK_ID: - return TaskSignatureAndImpl{get_conv_2d_init_task_impl(), - get_conv_2d_init_signature()}; - case task_id_t::CONV2D_FWD_TASK_ID: - return TaskSignatureAndImpl{get_conv_2d_fwd_task_impl(), - get_conv_2d_fwd_signature()}; - case task_id_t::CONV2D_BWD_TASK_ID: - return TaskSignatureAndImpl{get_conv_2d_bwd_task_impl(), - get_conv_2d_bwd_signature()}; - case task_id_t::DROPOUT_INIT_TASK_ID: - return TaskSignatureAndImpl{get_dropout_init_task_impl(), - get_dropout_init_signature()}; - case task_id_t::DROPOUT_FWD_TASK_ID: - return TaskSignatureAndImpl{get_dropout_fwd_task_impl(), - get_dropout_fwd_signature()}; - case task_id_t::DROPOUT_BWD_TASK_ID: - return TaskSignatureAndImpl{get_dropout_bwd_task_impl(), - get_dropout_bwd_signature()}; - case task_id_t::EMBED_FWD_TASK_ID: - return TaskSignatureAndImpl{get_embedding_fwd_task_impl(), - get_embedding_fwd_signature()}; - case task_id_t::EMBED_BWD_TASK_ID: - return TaskSignatureAndImpl{get_embedding_bwd_task_impl(), - get_embedding_bwd_signature()}; - case task_id_t::GATHER_INIT_TASK_ID: - return TaskSignatureAndImpl{get_gather_init_task_impl(), - get_gather_init_signature()}; - case task_id_t::GATHER_FWD_TASK_ID: - return TaskSignatureAndImpl{get_gather_fwd_task_impl(), - get_gather_fwd_signature()}; - case task_id_t::GATHER_BWD_TASK_ID: - return TaskSignatureAndImpl{get_gather_bwd_task_impl(), - get_gather_bwd_signature()}; - case task_id_t::CAST_FWD_TASK_ID: - return TaskSignatureAndImpl{get_cast_fwd_task_impl(), - get_cast_fwd_signature()}; - case task_id_t::CAST_BWD_TASK_ID: - return TaskSignatureAndImpl{get_cast_bwd_task_impl(), - get_cast_bwd_signature()}; - case task_id_t::POOL2D_INIT_TASK_ID: - return TaskSignatureAndImpl{get_pool_2d_init_task_impl(), - get_pool_2d_init_signature()}; - case task_id_t::POOL2D_FWD_TASK_ID: - return TaskSignatureAndImpl{get_pool_2d_fwd_task_impl(), - get_pool_2d_fwd_signature()}; - case task_id_t::POOL2D_BWD_TASK_ID: - return TaskSignatureAndImpl{get_pool_2d_bwd_task_impl(), - get_pool_2d_bwd_signature()}; - case task_id_t::BATCHNORM_INIT_TASK_ID: - return TaskSignatureAndImpl{get_batch_norm_init_task_impl(), - get_batch_norm_init_signature()}; - case task_id_t::BATCHNORM_FWD_TASK_ID: - return TaskSignatureAndImpl{get_batch_norm_fwd_task_impl(), - get_batch_norm_fwd_signature()}; - case task_id_t::BATCHNORM_BWD_TASK_ID: - return TaskSignatureAndImpl{get_batch_norm_bwd_task_impl(), - get_batch_norm_bwd_signature()}; - case task_id_t::BATCHMATMUL_FWD_TASK_ID: - return TaskSignatureAndImpl{get_batch_matmul_fwd_task_impl(), - get_batch_matmul_fwd_signature()}; - case task_id_t::BATCHMATMUL_BWD_TASK_ID: - return TaskSignatureAndImpl{get_batch_matmul_bwd_task_impl(), - get_batch_matmul_bwd_signature()}; - case task_id_t::LAYERNORM_INIT_TASK_ID: - return TaskSignatureAndImpl{get_layer_norm_init_task_impl(), - get_layer_norm_init_signature()}; - case task_id_t::LAYERNORM_FWD_TASK_ID: - return TaskSignatureAndImpl{get_layer_norm_fwd_task_impl(), - get_layer_norm_init_signature()}; - case task_id_t::LAYERNORM_BWD_TASK_ID: - return TaskSignatureAndImpl{get_layer_norm_bwd_task_impl(), - get_layer_norm_bwd_signature()}; - case task_id_t::LINEAR_INIT_TASK_ID: - return TaskSignatureAndImpl{get_linear_init_task_impl(), - get_linear_init_signature()}; - case task_id_t::LINEAR_FWD_TASK_ID: - return TaskSignatureAndImpl{get_linear_fwd_task_impl(), - get_linear_fwd_signature()}; - case task_id_t::LINEAR_BWD_TASK_ID: - return TaskSignatureAndImpl{get_linear_bwd_task_impl(), - get_linear_bwd_signature()}; - case task_id_t::FLAT_FWD_TASK_ID: - return TaskSignatureAndImpl{get_flat_fwd_task_impl(), - get_flat_fwd_signature()}; - case task_id_t::FLAT_BWD_TASK_ID: - return TaskSignatureAndImpl{get_flat_bwd_task_impl(), - get_flat_bwd_signature()}; - case task_id_t::SOFTMAX_INIT_TASK_ID: - return TaskSignatureAndImpl{get_softmax_init_task_impl(), - get_softmax_init_signature()}; - case task_id_t::SOFTMAX_FWD_TASK_ID: - return TaskSignatureAndImpl{get_softmax_fwd_task_impl(), - get_softmax_fwd_signature()}; - case task_id_t::SOFTMAX_BWD_TASK_ID: - return TaskSignatureAndImpl{get_softmax_bwd_task_impl(), - get_softmax_bwd_signature()}; - case task_id_t::CONCAT_FWD_TASK_ID: - return TaskSignatureAndImpl{get_concat_fwd_task_impl(), - get_concat_fwd_signature()}; - case task_id_t::CONCAT_BWD_TASK_ID: - return TaskSignatureAndImpl{get_concat_bwd_task_impl(), - get_concat_bwd_signature()}; - case task_id_t::SPLIT_FWD_TASK_ID: - return TaskSignatureAndImpl{get_split_fwd_task_impl(), - get_split_fwd_signature()}; - case task_id_t::SPLIT_BWD_TASK_ID: - return TaskSignatureAndImpl{get_split_bwd_task_impl(), - get_split_bwd_signature()}; - case task_id_t::REDUCE_INIT_TASK_ID: - return TaskSignatureAndImpl{get_reduce_init_task_impl(), - get_reduce_init_signature()}; - case task_id_t::REDUCE_FWD_TASK_ID: - return TaskSignatureAndImpl{get_reduce_fwd_task_impl(), - get_reduce_fwd_signature()}; - case task_id_t::REDUCE_BWD_TASK_ID: - return TaskSignatureAndImpl{get_reduce_bwd_task_impl(), - get_reduce_bwd_signature()}; - case task_id_t::RESHAPE_FWD_TASK_ID: - return TaskSignatureAndImpl{get_reshape_fwd_task_impl(), - get_reshape_fwd_signature()}; - case task_id_t::RESHAPE_BWD_TASK_ID: - return TaskSignatureAndImpl{get_reshape_bwd_task_impl(), - get_reshape_bwd_signature()}; - case task_id_t::REVERSE_FWD_TASK_ID: - return TaskSignatureAndImpl{get_reverse_fwd_task_impl(), - get_reverse_fwd_signature()}; - case task_id_t::REVERSE_BWD_TASK_ID: - return TaskSignatureAndImpl{get_reverse_bwd_task_impl(), - get_reverse_bwd_signature()}; - case task_id_t::TOPK_FWD_TASK_ID: - return TaskSignatureAndImpl{get_topk_fwd_task_impl(), - get_topk_fwd_signature()}; - case task_id_t::TOPK_BWD_TASK_ID: - return TaskSignatureAndImpl{get_topk_bwd_task_impl(), - get_topk_bwd_signature()}; - case task_id_t::TRANSPOSE_FWD_TASK_ID: - return TaskSignatureAndImpl{get_transpose_fwd_task_impl(), - get_transpose_fwd_signature()}; - case task_id_t::TRANSPOSE_BWD_TASK_ID: - return TaskSignatureAndImpl{get_transpose_bwd_task_impl(), - get_transpose_bwd_signature()}; - case task_id_t::ATTENTION_INIT_TASK_ID: - return TaskSignatureAndImpl{get_attention_init_task_impl(), - get_attention_init_signature()}; - case task_id_t::ATTENTION_FWD_TASK_ID: - return TaskSignatureAndImpl{get_attention_fwd_task_impl(), - get_attention_fwd_signature()}; - case task_id_t::ATTENTION_BWD_TASK_ID: - return TaskSignatureAndImpl{get_attention_bwd_task_impl(), - get_attention_bwd_signature()}; - default: - PANIC("Unhandled task ID", task_id); - } -} - -std::vector get_task_ids(ComputationGraphOpAttrs const &op) { - return op.visit>(overload{ - [](BatchMatmulAttrs const &attrs) { return get_task_ids(attrs); }, - [](BatchNormAttrs const &attrs) { return get_task_ids(attrs); }, - [](CastAttrs const &attrs) { return get_task_ids(attrs); }, - [](ConcatAttrs const &attrs) { return get_task_ids(attrs); }, - [](Conv2DAttrs const &attrs) { return get_task_ids(attrs); }, - [](DropoutAttrs const &attrs) { return get_task_ids(attrs); }, - [](ElementBinaryAttrs const &attrs) { return get_task_ids(attrs); }, - [](ElementUnaryAttrs const &attrs) { return get_task_ids(attrs); }, - [](EmbeddingAttrs const &attrs) { return get_task_ids(attrs); }, - [](FlatAttrs const &attrs) { return get_task_ids(attrs); }, - [](GatherAttrs const &attrs) { return get_task_ids(attrs); }, - [](InputAttrs const &attrs) { return get_task_ids(attrs); }, - [](LayerNormAttrs const &attrs) { return get_task_ids(attrs); }, - [](LinearAttrs const &attrs) { return get_task_ids(attrs); }, - [](MultiHeadAttentionAttrs const &attrs) { return get_task_ids(attrs); }, - [](NoopAttrs const &attrs) { return get_task_ids(attrs); }, - [](Pool2DAttrs const &attrs) { return get_task_ids(attrs); }, - [](ReduceAttrs const &attrs) { return get_task_ids(attrs); }, - [](ReverseAttrs const &attrs) { return get_task_ids(attrs); }, - [](ReshapeAttrs const &attrs) { return get_task_ids(attrs); }, - [](SplitAttrs const &attrs) { return get_task_ids(attrs); }, - [](SoftmaxAttrs const &attrs) { return get_task_ids(attrs); }, - [](TopKAttrs const &attrs) { return get_task_ids(attrs); }, - [](TransposeAttrs const &attrs) { return get_task_ids(attrs); }, - [](WeightAttrs const &attrs) { return get_task_ids(attrs); }, - [](auto const &attrs) -> std::vector { - throw mk_runtime_error(fmt::format("Unhandled attr type: {}", attrs)); - }, - }); -} - -OpTaskInvocation - get_init_op_task_invocation(ComputationGraphOpAttrs const &op) { - return op.visit(overload{ - [](BatchNormAttrs const &attrs) { return init(attrs); }, - [](Conv2DAttrs const &attrs) { return init(attrs); }, - [](DropoutAttrs const &attrs) { return init(attrs); }, - [](ElementBinaryAttrs const &attrs) { return init(attrs); }, - [](ElementUnaryAttrs const &attrs) { return init(attrs); }, - [](GatherAttrs const &attrs) { return init(attrs); }, - [](LayerNormAttrs const &attrs) { return init(attrs); }, - [](LinearAttrs const &attrs) { return init(attrs); }, - [](MultiHeadAttentionAttrs const &attrs) { return init(attrs); }, - [](Pool2DAttrs const &attrs) { return init(attrs); }, - [](ReduceAttrs const &attrs) { return init(attrs); }, - [](SoftmaxAttrs const &attrs) { return init(attrs); }, - [](auto const &attrs) -> OpTaskInvocation { - PANIC("Unhandled attr type", attrs); - }, - }); -} - -OpTaskInvocation - get_forward_op_task_invocation(ComputationGraphOpAttrs const &op) { - return op.visit(overload{ - [](BatchMatmulAttrs const &attrs) { return forward(attrs); }, - [](BatchNormAttrs const &attrs) { return forward(attrs); }, - [](CastAttrs const &attrs) { return forward(attrs); }, - [](ConcatAttrs const &attrs) { return forward(attrs); }, - [](Conv2DAttrs const &attrs) { return forward(attrs); }, - [](DropoutAttrs const &attrs) { return forward(attrs); }, - [](ElementBinaryAttrs const &attrs) { return forward(attrs); }, - [](ElementUnaryAttrs const &attrs) { return forward(attrs); }, - [](EmbeddingAttrs const &attrs) { return forward(attrs); }, - [](FlatAttrs const &attrs) { return forward(attrs); }, - [](GatherAttrs const &attrs) { return forward(attrs); }, - [](LayerNormAttrs const &attrs) { return forward(attrs); }, - [](LinearAttrs const &attrs) { return forward(attrs); }, - [](MultiHeadAttentionAttrs const &attrs) { return forward(attrs); }, - [](Pool2DAttrs const &attrs) { return forward(attrs); }, - [](ReduceAttrs const &attrs) { return forward(attrs); }, - [](ReverseAttrs const &attrs) { return forward(attrs); }, - [](ReshapeAttrs const &attrs) { return forward(attrs); }, - [](SplitAttrs const &attrs) { return forward(attrs); }, - [](SoftmaxAttrs const &attrs) { return forward(attrs); }, - [](TopKAttrs const &attrs) { return forward(attrs); }, - [](TransposeAttrs const &attrs) { return forward(attrs); }, - [](auto const &attrs) -> OpTaskInvocation { - throw mk_runtime_error(fmt::format("Unhandled attr type {}", attrs)); - }, - }); -} - -OpTaskInvocation - get_backward_op_task_invocation(ComputationGraphOpAttrs const &op) { - return op.visit(overload{ - [](BatchMatmulAttrs const &attrs) { return backward(attrs); }, - [](BatchNormAttrs const &attrs) { return backward(attrs); }, - [](CastAttrs const &attrs) { return backward(attrs); }, - [](ConcatAttrs const &attrs) { return backward(attrs); }, - [](Conv2DAttrs const &attrs) { return backward(attrs); }, - [](DropoutAttrs const &attrs) { return backward(attrs); }, - [](ElementBinaryAttrs const &attrs) { return backward(attrs); }, - [](ElementUnaryAttrs const &attrs) { return backward(attrs); }, - [](EmbeddingAttrs const &attrs) { return backward(attrs); }, - [](FlatAttrs const &attrs) { return backward(attrs); }, - [](GatherAttrs const &attrs) { return backward(attrs); }, - [](LayerNormAttrs const &attrs) { return backward(attrs); }, - [](LinearAttrs const &attrs) { return backward(attrs); }, - [](MultiHeadAttentionAttrs const &attrs) { return backward(attrs); }, - [](Pool2DAttrs const &attrs) { return backward(attrs); }, - [](ReduceAttrs const &attrs) { return backward(attrs); }, - [](ReverseAttrs const &attrs) { return backward(attrs); }, - [](ReshapeAttrs const &attrs) { return backward(attrs); }, - [](SplitAttrs const &attrs) { return backward(attrs); }, - [](SoftmaxAttrs const &attrs) { return backward(attrs); }, - [](TopKAttrs const &attrs) { return backward(attrs); }, - [](TransposeAttrs const &attrs) { return backward(attrs); }, - [](auto const &attrs) -> OpTaskInvocation { - PANIC("Unhandled attr type", attrs); - }, - }); -} - -} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/training_computation_graph.cc b/lib/task-spec/src/task-spec/training_computation_graph.cc deleted file mode 100644 index f50930d684..0000000000 --- a/lib/task-spec/src/task-spec/training_computation_graph.cc +++ /dev/null @@ -1,183 +0,0 @@ -#include "task-spec/training_computation_graph.h" -#include "task-spec/loss_tensor_source.h" -#include "task-spec/training_tensor_group.h" -#include "task-spec/training_tensor_group_with_attrs.h" -#include "utils/containers/contains.h" -#include "utils/containers/filter_values.h" -#include "utils/containers/flatmap.h" -#include "utils/containers/generate_map.h" -#include "utils/containers/get_only.h" -#include "utils/containers/keys.h" -#include "utils/containers/set_of.h" -#include "utils/containers/transform.h" -#include "utils/overload.h" - -namespace FlexFlow { - -TrainingComputationGraph generate_training_computation_graph( - ComputationGraph const &computation_graph, - OptimizerAttrs const &optimizer_attrs, - tensor_guid_t const &logit_tensor, - ForwardTensorSource &forward_tensor_source, - GradientTensorSource &gradient_tensor_source, - OptimizerTensorSource &optimizer_tensor_source, - LossTensorSource &loss_tensor_source) { - - loss_tensor_guid_t label_tensor = loss_tensor_source.new_loss_tensor(); - - return TrainingComputationGraph{ - /*computation_graph=*/computation_graph, - /*training_tensor_group_for_tensor=*/ - transform( - get_all_tensor_attrs(computation_graph), - [&](tensor_guid_t tensor_guid, TensorAttrs const &tensor_attrs) { - return std::pair{ - tensor_guid, - make_training_tensor_group_for_tensor_guid_t( - /*tensor_guid=*/tensor_guid, - /*tensor_attrs=*/tensor_attrs, - /*optimizer_attrs=*/optimizer_attrs, - /*forward_tensor_source=*/forward_tensor_source, - /*gradient_tensor_source=*/gradient_tensor_source, - /*optimizer_tensor_source=*/optimizer_tensor_source), - }; - }), - /*logit_tensor=*/logit_tensor, - /*label_tensor=*/label_tensor, - }; -} - -TrainingTensorGroup get_training_tensor_group_for_tensor_guid( - TrainingComputationGraph const &training_cg, tensor_guid_t tensor_guid) { - - return training_cg.training_tensor_group_for_tensor.at(tensor_guid); -} - -TrainingTensorGroupWithAttrs - get_training_tensor_group_with_attrs_for_tensor_guid( - TrainingComputationGraph const &training_cg, - tensor_guid_t tensor_guid) { - return make_training_tensor_group_with_attrs_from_group_and_attrs( - /*group=*/get_training_tensor_group_for_tensor_guid(training_cg, - tensor_guid), - /*attrs=*/get_tensor_attrs(training_cg.computation_graph, tensor_guid)); -} - -forward_tensor_guid_t get_forward_tensor_guid_for_tensor_guid( - TrainingComputationGraph const &training_cg, tensor_guid_t t) { - return training_cg.training_tensor_group_for_tensor.at(t).forward_tensor; -} - -gradient_tensor_guid_t get_gradient_tensor_guid_for_tensor_guid( - TrainingComputationGraph const &training_cg, tensor_guid_t t) { - return training_cg.training_tensor_group_for_tensor.at(t).gradient_tensor; -} - -std::vector get_optimizer_tensor_guids_for_tensor_guid( - TrainingComputationGraph const &training_cg, tensor_guid_t t) { - return training_cg.training_tensor_group_for_tensor.at(t).optimizer_tensors; -} - -tensor_guid_t get_tensor_guid_for_forward_tensor_guid( - TrainingComputationGraph const &training_cg, forward_tensor_guid_t t) { - return get_only(keys(filter_values( - training_cg.training_tensor_group_for_tensor, - [&](TrainingTensorGroup const &g) { return g.forward_tensor == t; }))); -} - -tensor_guid_t get_tensor_guid_for_gradient_tensor_guid( - TrainingComputationGraph const &training_cg, gradient_tensor_guid_t t) { - return get_only(keys(filter_values( - training_cg.training_tensor_group_for_tensor, - [&](TrainingTensorGroup const &g) { return g.gradient_tensor == t; }))); -} - -tensor_guid_t get_tensor_guid_for_optimizer_tensor_guid( - TrainingComputationGraph const &training_cg, optimizer_tensor_guid_t t) { - return get_only( - keys(filter_values(training_cg.training_tensor_group_for_tensor, - [&](TrainingTensorGroup const &g) { - return contains(g.optimizer_tensors, t); - }))); -} - -tensor_guid_t get_tensor_guid_for_training_tensor_guid( - TrainingComputationGraph const &training_cg, training_tensor_guid_t t) { - return t.visit(overload{ - [&](forward_tensor_guid_t forward_tensor) { - return get_tensor_guid_for_forward_tensor_guid(training_cg, - forward_tensor); - }, - [&](gradient_tensor_guid_t gradient_tensor) { - return get_tensor_guid_for_gradient_tensor_guid(training_cg, - gradient_tensor); - }, - [&](optimizer_tensor_guid_t optimizer_tensor) { - return get_tensor_guid_for_optimizer_tensor_guid(training_cg, - optimizer_tensor); - }, - [&](loss_tensor_guid_t loss_tensor) -> tensor_guid_t { - PANIC("no tensor_guid_t can exist for a loss_tensor_guid_t"); - }, - }); -} - -std::unordered_set - get_all_training_tensors_in_training_computation_graph( - TrainingComputationGraph const &training_cg) { - std::unordered_set result = flatmap( - unordered_set_of(keys(training_cg.training_tensor_group_for_tensor)), - [&](tensor_guid_t t) { - return get_all_training_tensors_in_tensor_group( - training_cg.training_tensor_group_for_tensor.at(t)); - }); - - result.insert(training_tensor_guid_t{training_cg.label_tensor}); - return result; -} - -TrainingLayerPlusContext - get_training_layer_plus_context(TrainingComputationGraph const &training_cg, - layer_guid_t layer_guid) { - auto get_tensor_group_with_attrs = - [&](tensor_guid_t t) -> TrainingTensorGroupWithAttrs { - return get_training_tensor_group_with_attrs_for_tensor_guid(training_cg, t); - }; - - return TrainingLayerPlusContext{ - /*layer_guid=*/layer_guid, - /*layer_attrs=*/ - get_layer_attrs(training_cg.computation_graph, layer_guid), - /*input_tensor_groups=*/ - transform(get_incoming_inputs(training_cg.computation_graph, layer_guid), - get_tensor_group_with_attrs), - /*weight_tensor_groups=*/ - transform(get_incoming_weights(training_cg.computation_graph, layer_guid), - get_tensor_group_with_attrs), - /*output_tensor_groups=*/ - transform(get_outgoing_tensors(training_cg.computation_graph, layer_guid), - get_tensor_group_with_attrs), - }; -} - -std::unordered_map - get_all_training_tensor_shapes( - TrainingComputationGraph const &training_cg) { - return generate_map( - get_all_training_tensors_in_training_computation_graph(training_cg), - [&](training_tensor_guid_t t) { - if (t.is_loss_tensor()) { - ASSERT(t == training_tensor_guid_t{training_cg.label_tensor}); - return get_tensor_attrs(training_cg.computation_graph, - training_cg.logit_tensor) - .shape; - } - - return get_tensor_attrs( - training_cg.computation_graph, - get_tensor_guid_for_training_tensor_guid(training_cg, t)) - .shape; - }); -} - -} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/training_layer_plus_context.cc b/lib/task-spec/src/task-spec/training_layer_plus_context.cc deleted file mode 100644 index 9adbc6b2a1..0000000000 --- a/lib/task-spec/src/task-spec/training_layer_plus_context.cc +++ /dev/null @@ -1,122 +0,0 @@ -#include "task-spec/training_layer_plus_context.h" -#include "task-spec/training_tensor_group_with_attrs.h" -#include "utils/containers/transform.h" - -namespace FlexFlow { - -std::vector - get_training_tensor_groups_with_attrs_for_role( - TrainingLayerPlusContext const &training_layer_plus_context, - TensorRole tensor_role) { - - switch (tensor_role) { - case TensorRole::INPUT: - return training_layer_plus_context.input_tensor_groups; - case TensorRole::WEIGHT: - return training_layer_plus_context.weight_tensor_groups; - case TensorRole::OUTPUT: - return training_layer_plus_context.output_tensor_groups; - default: - PANIC("Unhandled TensorRole {}", tensor_role); - } -} - -TrainingTensorGroupWithAttrs - get_training_tensor_group_with_attrs_for_role_and_index( - TrainingLayerPlusContext const &training_layer_plus_context, - TensorRole tensor_role, - nonnegative_int index) { - - return get_training_tensor_groups_with_attrs_for_role( - training_layer_plus_context, tensor_role) - .at(index.unwrap_nonnegative()); -} - -std::vector - get_input_tensors(TrainingLayerPlusContext const &l) { - return transform( - l.input_tensor_groups, - [](TrainingTensorGroupWithAttrs const &g) { return g.forward_tensor; }); -} - -std::vector - get_input_grad_tensors(TrainingLayerPlusContext const &l) { - return transform( - l.input_tensor_groups, - [](TrainingTensorGroupWithAttrs const &g) { return g.gradient_tensor; }); -} - -std::vector - get_input_tensor_shapes(TrainingLayerPlusContext const &l) { - return transform(l.input_tensor_groups, - [](TrainingTensorGroupWithAttrs const &g) { - return g.tensor_attrs.shape; - }); -} - -std::vector - get_weight_tensors(TrainingLayerPlusContext const &l) { - return transform( - l.weight_tensor_groups, - [](TrainingTensorGroupWithAttrs const &g) { return g.forward_tensor; }); -} - -std::vector - get_weight_grad_tensors(TrainingLayerPlusContext const &l) { - return transform( - l.weight_tensor_groups, - [](TrainingTensorGroupWithAttrs const &g) { return g.gradient_tensor; }); -} - -std::vector - get_weight_tensor_shapes(TrainingLayerPlusContext const &l) { - return transform(l.weight_tensor_groups, - [](TrainingTensorGroupWithAttrs const &g) { - return g.tensor_attrs.shape; - }); -} - -std::vector - get_output_tensors(TrainingLayerPlusContext const &l) { - return transform( - l.output_tensor_groups, - [](TrainingTensorGroupWithAttrs const &g) { return g.forward_tensor; }); -} - -std::vector - get_output_grad_tensors(TrainingLayerPlusContext const &l) { - return transform( - l.output_tensor_groups, - [](TrainingTensorGroupWithAttrs const &g) { return g.gradient_tensor; }); -} - -std::vector - get_output_tensor_shapes(TrainingLayerPlusContext const &l) { - return transform(l.output_tensor_groups, - [](TrainingTensorGroupWithAttrs const &g) { - return g.tensor_attrs.shape; - }); -} - -TrainingLayerTensorGroupSignature - get_tensor_group_signature(TrainingLayerPlusContext const &l) { - return TrainingLayerTensorGroupSignature{ - /*input_tensor_groups=*/transform(l.input_tensor_groups, - tensor_group_without_attrs), - /*weight_tensor_groups=*/ - transform(l.weight_tensor_groups, tensor_group_without_attrs), - /*output_tensor_groups=*/ - transform(l.output_tensor_groups, tensor_group_without_attrs), - }; -} - -CGOperatorTensorShapeSignature - get_cg_op_shape_signature(TrainingLayerPlusContext const &l) { - return CGOperatorTensorShapeSignature{ - /*input_shapes=*/get_input_tensor_shapes(l), - /*weight_shapes=*/get_weight_tensor_shapes(l), - /*output_shapes=*/get_output_tensor_shapes(l), - }; -} - -} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/training_layer_tensor_group_signature.cc b/lib/task-spec/src/task-spec/training_layer_tensor_group_signature.cc deleted file mode 100644 index db8b8015ec..0000000000 --- a/lib/task-spec/src/task-spec/training_layer_tensor_group_signature.cc +++ /dev/null @@ -1,31 +0,0 @@ -#include "task-spec/training_layer_tensor_group_signature.h" -#include - -namespace FlexFlow { - -std::vector get_training_tensor_groups_for_role( - TrainingLayerTensorGroupSignature const &signature, - TensorRole tensor_role) { - - switch (tensor_role) { - case TensorRole::INPUT: - return signature.input_tensor_groups; - case TensorRole::WEIGHT: - return signature.weight_tensor_groups; - case TensorRole::OUTPUT: - return signature.output_tensor_groups; - default: - PANIC("Unhandled TensorRole {}", tensor_role); - } -} - -TrainingTensorGroup get_training_tensor_group_for_role_and_index( - TrainingLayerTensorGroupSignature const &signature, - TensorRole tensor_role, - nonnegative_int index) { - - return get_training_tensor_groups_for_role(signature, tensor_role) - .at(index.unwrap_nonnegative()); -} - -} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/training_tensor_group.cc b/lib/task-spec/src/task-spec/training_tensor_group.cc deleted file mode 100644 index 0f6710b80f..0000000000 --- a/lib/task-spec/src/task-spec/training_tensor_group.cc +++ /dev/null @@ -1,48 +0,0 @@ -#include "task-spec/training_tensor_group.h" -#include "pcg/optimizer_attrs.h" -#include "utils/containers/repeat.h" -#include "utils/containers/set_union.h" -#include "utils/containers/transform.h" -#include "utils/containers/unordered_set_of.h" - -namespace FlexFlow { - -TrainingTensorGroup make_training_tensor_group_for_tensor_guid_t( - tensor_guid_t tensor_guid, - TensorAttrs const &tensor_attrs, - OptimizerAttrs const &optimizer_attrs, - ForwardTensorSource &forward_tensor_source, - GradientTensorSource &gradient_tensor_source, - OptimizerTensorSource &optimizer_tensor_source) { - - nonnegative_int num_optimizer_tensors = [&]() { - if (tensor_attrs.create_grad == CreateGrad::YES) { - return get_num_optimizer_tensors(optimizer_attrs); - } else { - return 0_n; - } - }(); - - return TrainingTensorGroup{ - /*forward_tensor=*/forward_tensor_source.new_forward_tensor(), - /*gradient_tensor=*/gradient_tensor_source.new_gradient_tensor(), - /*optimizer_tensors=*/ - repeat(num_optimizer_tensors, - [&]() { return optimizer_tensor_source.new_optimizer_tensor(); }), - }; -} - -std::unordered_set - get_all_training_tensors_in_tensor_group(TrainingTensorGroup const &group) { - return set_union( - std::unordered_set{ - training_tensor_guid_t{group.forward_tensor}, - training_tensor_guid_t{group.gradient_tensor}, - }, - transform(unordered_set_of(group.optimizer_tensors), - [](optimizer_tensor_guid_t optimizer_tensor) { - return training_tensor_guid_t{optimizer_tensor}; - })); -} - -} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/training_tensor_group_with_attrs.cc b/lib/task-spec/src/task-spec/training_tensor_group_with_attrs.cc deleted file mode 100644 index 6014b46446..0000000000 --- a/lib/task-spec/src/task-spec/training_tensor_group_with_attrs.cc +++ /dev/null @@ -1,26 +0,0 @@ -#include "task-spec/training_tensor_group_with_attrs.h" - -namespace FlexFlow { - -TrainingTensorGroupWithAttrs - make_training_tensor_group_with_attrs_from_group_and_attrs( - TrainingTensorGroup const &group, TensorAttrs const &attrs) { - - return TrainingTensorGroupWithAttrs{ - /*tensor_attrs=*/attrs, - /*forward_tensor=*/group.forward_tensor, - /*gradient_tensor=*/group.gradient_tensor, - /*optimizer_tensors=*/group.optimizer_tensors, - }; -} - -TrainingTensorGroup - tensor_group_without_attrs(TrainingTensorGroupWithAttrs const &with_attrs) { - return TrainingTensorGroup{ - /*forward_tensor=*/with_attrs.forward_tensor, - /*gradient_tensor=*/with_attrs.gradient_tensor, - /*optimizer_tensors=*/with_attrs.optimizer_tensors, - }; -} - -} // namespace FlexFlow diff --git a/lib/task-spec/test/CMakeLists.txt b/lib/task-spec/test/CMakeLists.txt index 87abf10401..354d9358a5 100644 --- a/lib/task-spec/test/CMakeLists.txt +++ b/lib/task-spec/test/CMakeLists.txt @@ -2,13 +2,14 @@ ff_add_test_executable( NAME task-spec-tests SRC_PATTERNS - src/*.cc + src/task-spec/dynamic_graph/*.cc PRIVATE_INCLUDE src/ DEPS doctest utils-test-common - local-execution + # local-execution kernels + task-spec op-attrs ) diff --git a/lib/task-spec/test/src/task-spec/arg_ref.cc b/lib/task-spec/test/src/task-spec/arg_ref.cc deleted file mode 100644 index 5c331a1d71..0000000000 --- a/lib/task-spec/test/src/task-spec/arg_ref.cc +++ /dev/null @@ -1,31 +0,0 @@ -#include "task-spec/arg_ref.h" -#include -#include - -using namespace ::FlexFlow; - -enum class ExampleLabelType { - STRING, -}; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("ArgRefSpec::holds") { - ArgRefSpec arg_ref_spec = - ArgRefSpec::create( - ArgRef{ExampleLabelType::STRING}); - - SUBCASE("returns true if the type matches the ArgRef type") { - bool result = arg_ref_spec.holds(); - bool correct = true; - - CHECK(result == correct); - } - - SUBCASE("returns false otherwise") { - bool result = arg_ref_spec.holds(); - bool correct = false; - - CHECK(result == correct); - } - } -} diff --git a/lib/task-spec/test/src/task-spec/arg_ref_spec.cc b/lib/task-spec/test/src/task-spec/arg_ref_spec.cc new file mode 100644 index 0000000000..7dae8ee9cb --- /dev/null +++ b/lib/task-spec/test/src/task-spec/arg_ref_spec.cc @@ -0,0 +1,31 @@ +#include "task-spec/arg_ref_spec.h" +#include +#include + +using namespace ::FlexFlow; + +enum class ExampleLabelType { + STRING, +}; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("ArgRefSpec::holds") { + ArgRefSpec arg_ref_spec = + ArgRefSpec::create( + ArgRef{ExampleLabelType::STRING}); + + SUBCASE("returns true if the type matches the ArgRef type") { + bool result = arg_ref_spec.holds(); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("returns false otherwise") { + bool result = arg_ref_spec.holds(); + bool correct = false; + + CHECK(result == correct); + } + } +} diff --git a/lib/task-spec/test/src/task-spec/device_specific.cc b/lib/task-spec/test/src/task-spec/device_specific.cc new file mode 100644 index 0000000000..b5ee11d109 --- /dev/null +++ b/lib/task-spec/test/src/task-spec/device_specific.cc @@ -0,0 +1,17 @@ +#include "task-spec/device_specific.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("DeviceSpecific") { + DeviceSpecific device_specific = + DeviceSpecific::create(device_id_t{gpu_id_t{1_n}}, + "hello world"); + + std::string result = fmt::to_string(device_specific); + std::string correct = "hi"; + + ASSERT(result == correct); + } +} diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc new file mode 100644 index 0000000000..58cd8c67ec --- /dev/null +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc @@ -0,0 +1,130 @@ +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" +#include "utils/graph/node/algorithms.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("dynamic_op_dataflow_graph_from_invocation_set") { + DynamicValueAttrs value_1 = DynamicValueAttrs{ + /*pcg_tensor_guid=*/parallel_tensor_guid_t{ + KwargDataflowOutput{ + Node{1}, + TensorSlotName::OUTPUT, + }, + }, + /*parallel_tensor_shape=*/std::nullopt, + /*shard_coord=*/std::nullopt, + /*accessor=*/std::nullopt, + /*tensor_type=*/std::nullopt, + }; + + DynamicValueAttrs value_2 = DynamicValueAttrs{ + /*pcg_tensor_guid=*/parallel_tensor_guid_t{ + KwargDataflowOutput{ + Node{2}, + TensorSlotName::OUTPUT, + }, + }, + /*parallel_tensor_shape=*/std::nullopt, + /*shard_coord=*/std::nullopt, + /*accessor=*/std::nullopt, + /*tensor_type=*/std::nullopt, + }; + + DynamicValueAttrs value_3 = DynamicValueAttrs{ + /*pcg_tensor_guid=*/parallel_tensor_guid_t{ + KwargDataflowOutput{ + Node{3}, + TensorSlotName::OUTPUT, + }, + }, + /*parallel_tensor_shape=*/std::nullopt, + /*shard_coord=*/std::nullopt, + /*accessor=*/std::nullopt, + /*tensor_type=*/std::nullopt, + }; + + DynamicNodeAttrs node_attrs = DynamicNodeAttrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/std::nullopt, + /*pcg_layer_guid=*/parallel_layer_guid_t{Node{4}}, + /*per_device_op_state=*/std::nullopt, + }; + + DynamicNodeInvocation invocation_1 = DynamicNodeInvocation{ + /*inputs=*/std::unordered_map{ + {DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::INPUT, + /*slot_tensor_role=*/std::nullopt, + }, + value_1}, + }, + /*node_attrs=*/node_attrs, + /*outputs=*/ + std::unordered_map{ + {DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/std::nullopt, + }, + value_2}, + }, + }; + + DynamicNodeInvocation invocation_2 = DynamicNodeInvocation{ + /*inputs=*/std::unordered_map{}, + /*node_attrs=*/node_attrs, + /*outputs=*/ + std::unordered_map{ + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/std::nullopt, + }, + value_3, + }, + }, + }; + + DynamicNodeInvocation invocation_3 = DynamicNodeInvocation{ + /*inputs=*/std::unordered_map{ + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::INPUT, + /*slot_tensor_role=*/std::nullopt, + }, + value_1, + }, + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::WEIGHT, + /*slot_tensor_role=*/std::nullopt, + }, + value_2, + }, + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::BIAS, + /*slot_tensor_role=*/std::nullopt, + }, + value_1, + }, + }, + /*node_attrs=*/node_attrs, + /*outputs=*/std::unordered_map{}, + }; + + std::unordered_set invocation_set = { + invocation_1, + invocation_2, + invocation_3, + }; + + DynamicOpenDataflowGraph result = + dynamic_open_dataflow_graph_from_invocation_set(invocation_set); + + ASSERT(dynamic_graph_num_nodes(result) == 3); + } +} diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc new file mode 100644 index 0000000000..1e7162f741 --- /dev/null +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc @@ -0,0 +1,233 @@ +#include "task-spec/dynamic_graph/machine_slicing.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("perform_machine_slicing_for_invocation") { + auto mk_machine_coord = + [](nonnegative_int node_idx, + nonnegative_int device_idx) -> MachineSpaceCoordinate { + return MachineSpaceCoordinate{ + /*node_idx=*/node_idx, + /*device_idx=*/device_idx, + /*device_type=*/DeviceType::GPU, + }; + }; + + auto mk_pt_coord = + [](nonnegative_int idx1, + nonnegative_int idx2, + nonnegative_int idx3, + nonnegative_int idx4) -> ParallelTensorSpaceCoordinate { + return ParallelTensorSpaceCoordinate{ + /*sum_component=*/idx1, + /*discard_copy_component=*/idx2, + /*shard_components=*/ + FFOrdered{ + idx3, + idx4, + }, + }; + }; + + MachineSpaceCoordinate mc1 = mk_machine_coord(0_n, 0_n); + MachineSpaceCoordinate mc2 = mk_machine_coord(2_n, 0_n); + MachineSpaceCoordinate mc3 = mk_machine_coord(4_n, 0_n); + + ParallelTensorSpaceCoordinate mc1_input_coord = + mk_pt_coord(0_n, 0_n, 0_n, 0_n); + ParallelTensorSpaceCoordinate mc1_weight_coord = + mk_pt_coord(0_n, 1_n, 2_n, 0_n); + ParallelTensorSpaceCoordinate mc1_output_1_coord = + mk_pt_coord(1_n, 0_n, 0_n, 1_n); + ParallelTensorSpaceCoordinate mc1_output_2_coord = + mk_pt_coord(3_n, 0_n, 0_n, 0_n); + + ParallelTensorSpaceCoordinate mc2_input_coord = + mk_pt_coord(0_n, 1_n, 0_n, 0_n); + ParallelTensorSpaceCoordinate mc2_weight_coord = + mk_pt_coord(0_n, 4_n, 2_n, 0_n); + ParallelTensorSpaceCoordinate mc2_output_1_coord = + mk_pt_coord(1_n, 2_n, 0_n, 1_n); + ParallelTensorSpaceCoordinate mc2_output_2_coord = + mk_pt_coord(0_n, 0_n, 0_n, 0_n); + + auto mk_slot = [](TensorSlotName const &slot_name) -> DynamicTensorSlot { + return DynamicTensorSlot{ + /*slot_name=*/slot_name, + /*slot_tensor_role=*/std::nullopt, + }; + }; + + auto mk_value = + [](size_t src_node_id, + TensorSlotName src_slot_name, + std::optional const &shard_coord) + -> DynamicValueAttrs { + return DynamicValueAttrs{ + /*pcg_tensor_guid=*/parallel_tensor_guid_t{ + KwargDataflowOutput{ + Node{src_node_id}, + src_slot_name, + }, + }, + /*parallel_tensor_shape=*/std::nullopt, + /*shard_coord=*/shard_coord, + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, + }; + }; + + size_t invocation1_id = 20; + size_t invocation2_id = 21; + size_t invocation3_id = 22; + + DynamicValueAttrs graph_input1 = + mk_value(0, TensorSlotName::OUTPUT, std::nullopt); + DynamicValueAttrs graph_input2 = + mk_value(1, TensorSlotName::OUTPUT, std::nullopt); + DynamicValueAttrs invocation1_output1 = + mk_value(invocation1_id, TensorSlotName::OUTPUT_1, std::nullopt); + DynamicValueAttrs invocation1_output2 = + mk_value(invocation1_id, TensorSlotName::OUTPUT_2, std::nullopt); + DynamicValueAttrs invocation2_output1 = + mk_value(invocation2_id, TensorSlotName::OUTPUT_4, std::nullopt); + DynamicValueAttrs invocation3_output1 = + mk_value(invocation3_id, TensorSlotName::OUTPUT_1, std::nullopt); + + DynamicNodeInvocation invocation1 = DynamicNodeInvocation{ + /*inputs=*/{ + { + mk_slot(TensorSlotName::INPUT), + graph_input1, + }, + { + mk_slot(TensorSlotName::WEIGHT), + graph_input2, + }, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/mc2, + /*mapping=*/std::nullopt, + /*op_attrs=*/std::nullopt, + /*pcg_layer_guid=*/parallel_layer_guid_t{Node{invocation1_id}}, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + { + mk_slot(TensorSlotName::OUTPUT_1), + invocation1_output1, + }, + { + mk_slot(TensorSlotName::OUTPUT_2), + invocation1_output2, + }, + }, + }; + + DynamicNodeInvocation invocation2 = DynamicNodeInvocation{ + /*inputs=*/{ + {mk_slot(TensorSlotName::INPUT), invocation1_output2}, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/mc1, + /*mapping=*/std::nullopt, + /*op_attrs=*/std::nullopt, + /*pcg_layer_guid=*/parallel_layer_guid_t{Node{invocation2_id}}, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + { + mk_slot(TensorSlotName::OUTPUT_4), + invocation2_output1, + }, + }, + }; + + DynamicNodeInvocation invocation3 = DynamicNodeInvocation{ + /*inputs=*/{ + { + mk_slot(TensorSlotName::KEY), + invocation2_output1, + }, + { + mk_slot(TensorSlotName::QUERY), + graph_input2, + }, + { + mk_slot(TensorSlotName::VALUE), + invocation1_output1, + }, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/mc2, + /*mapping=*/std::nullopt, + /*op_attrs=*/std::nullopt, + /*pcg_layer_guid=*/parallel_layer_guid_t{Node{invocation3_id}}, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + { + mk_slot(TensorSlotName::OUTPUT_1), + invocation3_output1, + }, + }, + }; + + DynamicOpenDataflowGraph unsliced = + dynamic_open_dataflow_graph_from_invocation_set( + /*invocations=*/{ + invocation1, + invocation2, + invocation3, + }); + + SUBCASE("task exists on MachineCoord") { + SUBCASE("mc1") { + DynamicOpenDataflowGraph result = + perform_machine_slicing(unsliced, mc1); + + DynamicOpenDataflowGraph correct = + dynamic_open_dataflow_graph_from_invocation_set({ + invocation2, + }); + + CHECK(dynamic_open_dataflow_graphs_are_isomorphic(result, correct)); + } + + SUBCASE("mc2") { + DynamicOpenDataflowGraph result = + perform_machine_slicing(unsliced, mc2); + + DynamicOpenDataflowGraph correct = + dynamic_open_dataflow_graph_from_invocation_set({ + invocation1, + invocation3, + }); + + CHECK(dynamic_open_dataflow_graphs_are_isomorphic(result, correct)); + } + } + + SUBCASE("task does not exist on MachineCoord") { + DynamicOpenDataflowGraph result = perform_machine_slicing(unsliced, mc3); + + DynamicOpenDataflowGraph correct = + dynamic_open_dataflow_graph_from_invocation_set( + std::unordered_set{}); + + CHECK(dynamic_open_dataflow_graphs_are_isomorphic(result, correct)); + } + } +} diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc new file mode 100644 index 0000000000..9039f16cd9 --- /dev/null +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc @@ -0,0 +1,383 @@ +#include "task-spec/dynamic_graph/pass_expansion.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" +#include "task-spec/dynamic_graph/dynamic_tensor_role.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("perform_fwd_pass_expansion_for_invocation") { + auto mk_value_attrs = + [](size_t node_id, std::optional const &tensor_role) + -> DynamicValueAttrs { + return DynamicValueAttrs{ + /*pcg_tensor_guid=*/parallel_tensor_guid_t{ + KwargDataflowOutput{ + Node{node_id}, + TensorSlotName::OUTPUT, + }, + }, + /*parallel_tensor_shape=*/std::nullopt, + /*shard_coord=*/std::nullopt, + /*accessor=*/std::nullopt, + /*role=*/tensor_role, + }; + }; + + auto mk_slot = + [](TensorSlotName const &slot_name, + std::optional role) -> DynamicTensorSlot { + return DynamicTensorSlot{ + /*slot_name=*/slot_name, + /*slot_tensor_role=*/role, + }; + }; + + parallel_layer_guid_t pcg_layer_guid = parallel_layer_guid_t{Node{20}}; + + DynamicNodeInvocation invocation = [&]() -> DynamicNodeInvocation { + DynamicValueAttrs v1 = mk_value_attrs(0, std::nullopt); + DynamicValueAttrs v2 = mk_value_attrs(1, std::nullopt); + DynamicValueAttrs v3 = mk_value_attrs(2, std::nullopt); + + return DynamicNodeInvocation{ + /*inputs=*/{ + {mk_slot(TensorSlotName::INPUT, std::nullopt), v1}, + {mk_slot(TensorSlotName::WEIGHT, std::nullopt), v2}, + {mk_slot(TensorSlotName::BIAS, std::nullopt), v1}, + {mk_slot(TensorSlotName::SCALE, std::nullopt), v1}, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/std::nullopt, + /*pcg_layer_guid=*/pcg_layer_guid, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + {mk_slot(TensorSlotName::OUTPUT, std::nullopt), v3}, + }, + }; + }(); + + DynamicNodeInvocation result = + perform_fwd_pass_expansion_for_invocation(invocation); + + DynamicNodeInvocation correct = [&]() -> DynamicNodeInvocation { + DynamicTensorRole fwd_role = DynamicTensorRole{FwbTensorType::FORWARD}; + + DynamicValueAttrs v1_fwd = mk_value_attrs(0, fwd_role); + DynamicValueAttrs v2_fwd = mk_value_attrs(1, fwd_role); + DynamicValueAttrs v3_fwd = mk_value_attrs(2, fwd_role); + + return DynamicNodeInvocation{ + /*inputs=*/{ + {mk_slot(TensorSlotName::INPUT, fwd_role), v1_fwd}, + {mk_slot(TensorSlotName::WEIGHT, fwd_role), v2_fwd}, + {mk_slot(TensorSlotName::BIAS, fwd_role), v1_fwd}, + {mk_slot(TensorSlotName::SCALE, fwd_role), v1_fwd}, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/DynamicTaskType::FWD, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/std::nullopt, + /*pcg_layer_guid=*/pcg_layer_guid, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + {mk_slot(TensorSlotName::OUTPUT, fwd_role), v3_fwd}, + }, + }; + }(); + + ASSERT(result == correct); + } + + TEST_CASE("perform_bwd_pass_expansion_for_invocation") { + auto mk_value_attrs = + [](size_t node_id, std::optional const &tensor_role) + -> DynamicValueAttrs { + return DynamicValueAttrs{ + /*pcg_tensor_guid=*/parallel_tensor_guid_t{ + KwargDataflowOutput{ + Node{node_id}, + TensorSlotName::OUTPUT, + }, + }, + /*parallel_tensor_shape=*/std::nullopt, + /*shard_coord=*/std::nullopt, + /*accessor=*/std::nullopt, + /*role=*/tensor_role, + }; + }; + + auto mk_slot = + [](TensorSlotName const &slot_name, + std::optional role) -> DynamicTensorSlot { + return DynamicTensorSlot{ + /*slot_name=*/slot_name, + /*slot_tensor_role=*/role, + }; + }; + + parallel_layer_guid_t pcg_layer_guid = parallel_layer_guid_t{Node{20}}; + + DynamicNodeInvocation invocation = [&]() -> DynamicNodeInvocation { + DynamicValueAttrs v1 = mk_value_attrs(0, std::nullopt); + DynamicValueAttrs v2 = mk_value_attrs(1, std::nullopt); + DynamicValueAttrs v3 = mk_value_attrs(2, std::nullopt); + + return DynamicNodeInvocation{ + /*inputs=*/{ + {mk_slot(TensorSlotName::INPUT, std::nullopt), v1}, + {mk_slot(TensorSlotName::WEIGHT, std::nullopt), v2}, + {mk_slot(TensorSlotName::BIAS, std::nullopt), v1}, + {mk_slot(TensorSlotName::SCALE, std::nullopt), v1}, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/std::nullopt, + /*pcg_layer_guid=*/pcg_layer_guid, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + {mk_slot(TensorSlotName::OUTPUT, std::nullopt), v3}, + }, + }; + }(); + + DynamicNodeInvocation result = + perform_bwd_pass_expansion_for_invocation(invocation); + + DynamicNodeInvocation correct = [&]() -> DynamicNodeInvocation { + DynamicTensorRole fwd_role = DynamicTensorRole{FwbTensorType::FORWARD}; + DynamicTensorRole grad_role = DynamicTensorRole{FwbTensorType::GRADIENT}; + + DynamicValueAttrs v1_fwd = mk_value_attrs(0, fwd_role); + DynamicValueAttrs v2_fwd = mk_value_attrs(1, fwd_role); + DynamicValueAttrs v3_fwd = mk_value_attrs(2, fwd_role); + DynamicValueAttrs v1_grad = mk_value_attrs(0, grad_role); + DynamicValueAttrs v2_grad = mk_value_attrs(1, grad_role); + DynamicValueAttrs v3_grad = mk_value_attrs(2, grad_role); + + return DynamicNodeInvocation{ + /*inputs=*/{ + {mk_slot(TensorSlotName::INPUT, fwd_role), v1_fwd}, + {mk_slot(TensorSlotName::WEIGHT, fwd_role), v2_fwd}, + {mk_slot(TensorSlotName::BIAS, fwd_role), v1_fwd}, + {mk_slot(TensorSlotName::SCALE, fwd_role), v1_fwd}, + {mk_slot(TensorSlotName::OUTPUT, fwd_role), v3_fwd}, + {mk_slot(TensorSlotName::OUTPUT, grad_role), v3_grad}, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*pass_type=*/DynamicTaskType::BWD, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/std::nullopt, + /*pcg_layer_guid=*/pcg_layer_guid, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + {mk_slot(TensorSlotName::INPUT, grad_role), v1_grad}, + {mk_slot(TensorSlotName::WEIGHT, grad_role), v2_grad}, + {mk_slot(TensorSlotName::BIAS, grad_role), v1_grad}, + {mk_slot(TensorSlotName::SCALE, grad_role), v1_grad}, + }, + }; + }(); + + ASSERT(result == correct); + } + + TEST_CASE("perform_pass_expansion(DynamicOpenDataflowGraph)") { + auto mk_node_attrs = [](size_t layer_id, + std::optional const &pass_type) + -> DynamicNodeAttrs { + return DynamicNodeAttrs{ + /*pass_type=*/pass_type, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/std::nullopt, + /*pcg_layer_guid=*/parallel_layer_guid_t{Node{layer_id}}, + /*per_device_op_state=*/std::nullopt, + }; + }; + + auto mk_value_attrs = + [](size_t node_id, std::optional const &tensor_type) + -> DynamicValueAttrs { + return DynamicValueAttrs{ + /*pcg_tensor_guid=*/parallel_tensor_guid_t{ + KwargDataflowOutput{ + Node{node_id}, + TensorSlotName::OUTPUT, + }, + }, + /*parallel_tensor_shape=*/std::nullopt, + /*shard_coord=*/std::nullopt, + /*accessor=*/std::nullopt, + /*role=*/tensor_type, + }; + }; + + DynamicOpenDataflowGraph input = [&]() -> DynamicOpenDataflowGraph { + DynamicNodeAttrs n1 = mk_node_attrs(10, std::nullopt); + DynamicNodeAttrs n2 = mk_node_attrs(11, std::nullopt); + + DynamicValueAttrs v1 = mk_value_attrs(0, std::nullopt); + DynamicValueAttrs v2 = mk_value_attrs(1, std::nullopt); + + std::unordered_set invocation_set = { + DynamicNodeInvocation{ + /*inputs=*/std::unordered_map{}, + /*node_attrs=*/n1, + /*outputs=*/ + std::unordered_map{ + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/std::nullopt, + }, + v1, + }, + }, + }, + DynamicNodeInvocation{ + /*inputs=*/std::unordered_map{ + {DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::INPUT, + /*slot_tensor_role=*/std::nullopt, + }, + v1}, + }, + /*node_attrs=*/n2, + /*outputs=*/ + std::unordered_map{ + {DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/std::nullopt, + }, + v2}, + }, + }, + }; + + return dynamic_open_dataflow_graph_from_invocation_set(invocation_set); + }(); + + DynamicOpenDataflowGraph result = perform_pass_expansion(input); + + DynamicOpenDataflowGraph correct = [&]() -> DynamicOpenDataflowGraph { + DynamicNodeAttrs n1_fwd = mk_node_attrs(10, DynamicTaskType::FWD); + DynamicNodeAttrs n2_fwd = mk_node_attrs(11, DynamicTaskType::FWD); + DynamicNodeAttrs n1_bwd = mk_node_attrs(10, DynamicTaskType::BWD); + DynamicNodeAttrs n2_bwd = mk_node_attrs(11, DynamicTaskType::BWD); + + DynamicValueAttrs v1_activation = + mk_value_attrs(0, mk_dynamic_tensor_role_fwd()); + DynamicValueAttrs v1_gradient = + mk_value_attrs(0, mk_dynamic_tensor_role_bwd()); + DynamicValueAttrs v2_activation = + mk_value_attrs(1, mk_dynamic_tensor_role_fwd()); + DynamicValueAttrs v2_gradient = + mk_value_attrs(1, mk_dynamic_tensor_role_bwd()); + + std::unordered_set invocation_set = { + DynamicNodeInvocation{ + /*inputs=*/std::unordered_map{}, + /*node_attrs=*/n1_fwd, + /*outputs=*/ + std::unordered_map{ + std::pair{ + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), + }, + v1_activation, + }, + }, + }, + DynamicNodeInvocation{ + /*inputs=*/std::unordered_map{ + std::pair{ + DynamicTensorSlot{ + TensorSlotName::INPUT, + mk_dynamic_tensor_role_fwd(), + }, + v1_activation, + }, + }, + /*node_attrs=*/n2_fwd, + /*outputs=*/ + std::unordered_map{ + std::pair{ + DynamicTensorSlot{ + TensorSlotName::OUTPUT, + mk_dynamic_tensor_role_fwd(), + }, + v2_activation, + }, + }, + }, + DynamicNodeInvocation{ + /*inputs=*/std::unordered_map{ + std::pair{ + DynamicTensorSlot{ + TensorSlotName::INPUT, + mk_dynamic_tensor_role_fwd(), + }, + v1_activation, + }, + std::pair{ + DynamicTensorSlot{ + TensorSlotName::OUTPUT, + mk_dynamic_tensor_role_fwd(), + }, + v2_activation, + }, + std::pair{ + DynamicTensorSlot{ + TensorSlotName::OUTPUT, + mk_dynamic_tensor_role_bwd(), + }, + v2_gradient, + }, + }, + /*node_attrs=*/n2_bwd, + /*outputs=*/ + std::unordered_map{ + std::pair{ + DynamicTensorSlot{ + TensorSlotName::INPUT, + mk_dynamic_tensor_role_bwd(), + }, + v1_gradient, + }, + }, + }, + }; + + return dynamic_open_dataflow_graph_from_invocation_set(invocation_set); + }(); + + ASSERT(get_dynamic_invocation_set(result).size() == 3); + ASSERT(get_dynamic_invocation_set(result) == + get_dynamic_invocation_set(correct)); + ASSERT(dynamic_open_dataflow_graphs_are_isomorphic(result, correct)); + } +} diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc new file mode 100644 index 0000000000..f49a496647 --- /dev/null +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc @@ -0,0 +1,221 @@ +#include "task-spec/dynamic_graph/shard_expansion.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("perform_shard_expansion_for_invocation") { + auto mk_machine_coord = + [](nonnegative_int node_idx, + nonnegative_int device_idx) -> MachineSpaceCoordinate { + return MachineSpaceCoordinate{ + /*node_idx=*/node_idx, + /*device_idx=*/device_idx, + /*device_type=*/DeviceType::GPU, + }; + }; + + auto mk_pt_coord = + [](nonnegative_int idx1, + nonnegative_int idx2, + nonnegative_int idx3, + nonnegative_int idx4) -> ParallelTensorSpaceCoordinate { + return ParallelTensorSpaceCoordinate{ + /*sum_component=*/idx1, + /*discard_copy_component=*/idx2, + /*shard_components=*/ + FFOrdered{ + idx3, + idx4, + }, + }; + }; + + auto mk_shard_binding = [&](ParallelTensorSpaceCoordinate const &c1, + ParallelTensorSpaceCoordinate const &c2, + ParallelTensorSpaceCoordinate const &c3, + ParallelTensorSpaceCoordinate const &c4) + -> OperatorAtomicTaskShardBinding { + return OperatorAtomicTaskShardBinding{ + /*tensor_coords=*/{ + { + TensorSlotName::INPUT, + c1, + }, + { + TensorSlotName::WEIGHT, + c2, + }, + { + TensorSlotName::OUTPUT_1, + c3, + }, + { + TensorSlotName::OUTPUT_2, + c4, + }, + }, + }; + }; + + MachineSpaceCoordinate mc1 = mk_machine_coord(0_n, 0_n); + MachineSpaceCoordinate mc2 = mk_machine_coord(2_n, 0_n); + + ParallelTensorSpaceCoordinate mc1_input_coord = + mk_pt_coord(0_n, 0_n, 0_n, 0_n); + ParallelTensorSpaceCoordinate mc1_weight_coord = + mk_pt_coord(0_n, 1_n, 2_n, 0_n); + ParallelTensorSpaceCoordinate mc1_output_1_coord = + mk_pt_coord(1_n, 0_n, 0_n, 1_n); + ParallelTensorSpaceCoordinate mc1_output_2_coord = + mk_pt_coord(3_n, 0_n, 0_n, 0_n); + + ParallelTensorSpaceCoordinate mc2_input_coord = + mk_pt_coord(0_n, 1_n, 0_n, 0_n); + ParallelTensorSpaceCoordinate mc2_weight_coord = + mk_pt_coord(0_n, 4_n, 2_n, 0_n); + ParallelTensorSpaceCoordinate mc2_output_1_coord = + mk_pt_coord(1_n, 2_n, 0_n, 1_n); + ParallelTensorSpaceCoordinate mc2_output_2_coord = + mk_pt_coord(0_n, 0_n, 0_n, 0_n); + + MappedOperatorTaskGroup mapped_task_group = MappedOperatorTaskGroup{ + bidict{ + { + mc1, + mk_shard_binding(mc1_input_coord, + mc1_weight_coord, + mc1_output_1_coord, + mc1_output_2_coord), + }, + { + mc2, + mk_shard_binding(mc2_input_coord, + mc2_weight_coord, + mc2_output_1_coord, + mc2_output_2_coord), + }, + }, + }; + + auto mk_slot = [](TensorSlotName const &slot_name) -> DynamicTensorSlot { + return DynamicTensorSlot{ + /*slot_name=*/slot_name, + /*slot_tensor_role=*/std::nullopt, + }; + }; + + auto mk_value = + [](size_t src_node_id, + TensorSlotName src_slot_name, + std::optional const &shard_coord) + -> DynamicValueAttrs { + return DynamicValueAttrs{ + /*pcg_tensor_guid=*/parallel_tensor_guid_t{ + KwargDataflowOutput{ + Node{src_node_id}, + src_slot_name, + }, + }, + /*parallel_tensor_shape=*/std::nullopt, + /*shard_coord=*/shard_coord, + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, + }; + }; + + DynamicNodeInvocation input = DynamicNodeInvocation{ + /*inputs=*/{ + { + mk_slot(TensorSlotName::INPUT), + mk_value(0, TensorSlotName::OUTPUT, std::nullopt), + }, + { + mk_slot(TensorSlotName::WEIGHT), + mk_value(1, TensorSlotName::OUTPUT, std::nullopt), + }, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/std::nullopt, + /*mapping=*/mapped_task_group, + /*op_attrs=*/std::nullopt, + /*pcg_layer_guid=*/parallel_layer_guid_t{Node{20}}, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + { + mk_slot(TensorSlotName::OUTPUT_1), + mk_value(20, TensorSlotName::OUTPUT_1, std::nullopt), + }, + { + mk_slot(TensorSlotName::OUTPUT_2), + mk_value(20, TensorSlotName::OUTPUT_2, std::nullopt), + }, + }, + }; + + std::unordered_set result = + perform_shard_expansion_for_invocation(input); + + auto mk_invocation_shard = + [&](MachineSpaceCoordinate const &device_coord, + ParallelTensorSpaceCoordinate const &input_shard_coord, + ParallelTensorSpaceCoordinate const &weight_shard_coord, + ParallelTensorSpaceCoordinate const &output_1_shard_coord, + ParallelTensorSpaceCoordinate const &output_2_shard_coord) + -> DynamicNodeInvocation { + return DynamicNodeInvocation{ + /*inputs=*/{ + { + mk_slot(TensorSlotName::INPUT), + mk_value(0, TensorSlotName::OUTPUT, input_shard_coord), + }, + { + mk_slot(TensorSlotName::WEIGHT), + mk_value(1, TensorSlotName::OUTPUT, weight_shard_coord), + }, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/device_coord, + /*mapping=*/mapped_task_group, + /*op_attrs=*/std::nullopt, + /*pcg_layer_guid=*/parallel_layer_guid_t{Node{20}}, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + { + mk_slot(TensorSlotName::OUTPUT_1), + mk_value(20, TensorSlotName::OUTPUT_1, output_1_shard_coord), + }, + { + mk_slot(TensorSlotName::OUTPUT_2), + mk_value(20, TensorSlotName::OUTPUT_2, output_2_shard_coord), + }, + }, + }; + }; + + std::unordered_set correct = { + mk_invocation_shard(mc1, + mc1_input_coord, + mc1_weight_coord, + mc1_output_1_coord, + mc1_output_2_coord), + mk_invocation_shard(mc2, + mc2_input_coord, + mc2_weight_coord, + mc2_output_1_coord, + mc2_output_2_coord), + }; + + CHECK(result.size() == correct.size()); + CHECK(result == correct); + } +} diff --git a/lib/task-spec/test/src/task-spec/op_ordered_slot_signature.cc b/lib/task-spec/test/src/task-spec/op_ordered_slot_signature.cc new file mode 100644 index 0000000000..c9da5953da --- /dev/null +++ b/lib/task-spec/test/src/task-spec/op_ordered_slot_signature.cc @@ -0,0 +1,10 @@ +#include "task-spec/op_ordered_slot_signature.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_op_ordered_slot_signature_for_binding") { + CHECK_MESSAGE(false, "TODO: get_op_ordered_slot_signature_for_binding"); + } +} diff --git a/lib/utils/CMakeLists.txt b/lib/utils/CMakeLists.txt index 6d8f22bc29..e2f7c433d6 100644 --- a/lib/utils/CMakeLists.txt +++ b/lib/utils/CMakeLists.txt @@ -9,7 +9,6 @@ ff_add_library( src/ DEPS expected - visit_struct fmt json cuda diff --git a/lib/utils/include/utils/any.h b/lib/utils/include/utils/any.h deleted file mode 100644 index 0e1e3c7b06..0000000000 --- a/lib/utils/include/utils/any.h +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_ANY_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_ANY_H - -#include "any.hpp" - -namespace FlexFlow { - -using namespace ::linb; - -} - -#endif diff --git a/lib/utils/include/utils/archetypes/jsonable_value_type.h b/lib/utils/include/utils/archetypes/jsonable_value_type.h new file mode 100644 index 0000000000..85f381d9a9 --- /dev/null +++ b/lib/utils/include/utils/archetypes/jsonable_value_type.h @@ -0,0 +1,77 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ARCHETYPES_JSONABLE_VALUE_TYPE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ARCHETYPES_JSONABLE_VALUE_TYPE_H + +#include +#include +#include +#include +#include +#include + +namespace FlexFlow { + +template +struct jsonable_value_type { + jsonable_value_type() = delete; + + jsonable_value_type(jsonable_value_type const &) { + PANIC(); + } + jsonable_value_type &operator=(jsonable_value_type const &) { + PANIC(); + } + + jsonable_value_type(jsonable_value_type &&) { + PANIC(); + } + jsonable_value_type &operator=(jsonable_value_type &&) { + PANIC(); + } + + bool operator==(jsonable_value_type const &) const { + PANIC(); + } + bool operator!=(jsonable_value_type const &) const { + PANIC(); + } +}; + +template +std::string format_as(jsonable_value_type const &) { + PANIC(); +} + +template +std::ostream &operator<<(std::ostream &s, jsonable_value_type const &x) { + PANIC(); +} + +} // namespace FlexFlow + +namespace nlohmann { + +template +struct adl_serializer<::FlexFlow::jsonable_value_type> { + static ::FlexFlow::jsonable_value_type from_json(json const &) { + PANIC(); + } + + static void to_json(json &, ::FlexFlow::jsonable_value_type const &) { + PANIC(); + } +}; + +} // namespace nlohmann + +namespace std { + +template +struct hash<::FlexFlow::jsonable_value_type> { + size_t operator()(::FlexFlow::jsonable_value_type const &) const { + PANIC(); + }; +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/archetypes/ordered_value_type.h b/lib/utils/include/utils/archetypes/ordered_value_type.h index b14f378667..3666d19d3f 100644 --- a/lib/utils/include/utils/archetypes/ordered_value_type.h +++ b/lib/utils/include/utils/archetypes/ordered_value_type.h @@ -37,6 +37,12 @@ struct ordered_value_type { bool operator>(ordered_value_type const &) const { PANIC(); } + bool operator<=(ordered_value_type const &) const { + PANIC(); + } + bool operator>=(ordered_value_type const &) const { + PANIC(); + } }; template diff --git a/lib/utils/include/utils/archetypes/rapidcheckable_value_type.h b/lib/utils/include/utils/archetypes/rapidcheckable_value_type.h new file mode 100644 index 0000000000..596bd5afa5 --- /dev/null +++ b/lib/utils/include/utils/archetypes/rapidcheckable_value_type.h @@ -0,0 +1,77 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ARCHETYPES_RAPIDCHECKABLE_VALUE_TYPE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ARCHETYPES_RAPIDCHECKABLE_VALUE_TYPE_H + +#include +#include +#include +#include +#include +#include + +namespace FlexFlow { + +template +struct rapidcheckable_value_type { + rapidcheckable_value_type() = delete; + + rapidcheckable_value_type(rapidcheckable_value_type const &) { + PANIC(); + } + rapidcheckable_value_type &operator=(rapidcheckable_value_type const &) { + PANIC(); + } + + rapidcheckable_value_type(rapidcheckable_value_type &&) { + PANIC(); + } + rapidcheckable_value_type &operator=(rapidcheckable_value_type &&) { + PANIC(); + } + + bool operator==(rapidcheckable_value_type const &) const { + PANIC(); + } + bool operator!=(rapidcheckable_value_type const &) const { + PANIC(); + } + bool operator<(rapidcheckable_value_type const &) const { + PANIC(); + } +}; + +template +std::string format_as(rapidcheckable_value_type const &) { + PANIC(); +} + +template +std::ostream &operator<<(std::ostream &s, + rapidcheckable_value_type const &x) { + PANIC(); +} + +} // namespace FlexFlow + +namespace rc { + +template +struct Arbitrary<::FlexFlow::rapidcheckable_value_type> { + static Gen<::FlexFlow::rapidcheckable_value_type> arbitrary() { + PANIC(); + } +}; + +} // namespace rc + +namespace std { + +template +struct hash<::FlexFlow::rapidcheckable_value_type> { + size_t operator()(::FlexFlow::rapidcheckable_value_type const &) const { + PANIC(); + }; +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/bidict_from_enumerating.h b/lib/utils/include/utils/bidict/algorithms/bidict_from_enumerating.h index 83afc32e0c..495cfcc667 100644 --- a/lib/utils/include/utils/bidict/algorithms/bidict_from_enumerating.h +++ b/lib/utils/include/utils/bidict/algorithms/bidict_from_enumerating.h @@ -2,11 +2,27 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FROM_ENUMERATING_H #include "utils/bidict/bidict.h" +#include "utils/containers/contains_duplicates.h" #include "utils/nonnegative_int/nonnegative_int.h" +#include #include namespace FlexFlow { +template +bidict bidict_from_enumerating(std::vector const &s) { + ASSERT(!contains_duplicates(s)); + + bidict result; + nonnegative_int idx = 0_n; + for (T const &t : s) { + result.equate(idx, t); + idx++; + } + + return result; +} + template bidict bidict_from_enumerating(std::unordered_set const &s) { diff --git a/lib/utils/include/utils/bidict/algorithms/bidict_from_keys_and_values.h b/lib/utils/include/utils/bidict/algorithms/bidict_from_keys_and_values.h index 47af03591a..ddf1c92c75 100644 --- a/lib/utils/include/utils/bidict/algorithms/bidict_from_keys_and_values.h +++ b/lib/utils/include/utils/bidict/algorithms/bidict_from_keys_and_values.h @@ -4,21 +4,14 @@ #include "utils/bidict/algorithms/bidict_from_pairs.h" #include "utils/bidict/bidict.h" #include "utils/containers/zip.h" -#include "utils/exception.h" +#include namespace FlexFlow { template bidict bidict_from_keys_and_values(std::vector const &ls, std::vector const &rs) { - size_t l_size = ls.size(); - size_t r_size = rs.size(); - if (l_size != r_size) { - throw mk_runtime_error(fmt::format( - "recieved keys (of size {}) not matching values (of size {})", - l_size, - r_size)); - } + ASSERT(ls.size() == rs.size()); return bidict_from_pairs(zip(ls, rs)); } diff --git a/lib/utils/include/utils/bidict/algorithms/bidict_from_map.h b/lib/utils/include/utils/bidict/algorithms/bidict_from_map.h new file mode 100644 index 0000000000..b4d74df97c --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/bidict_from_map.h @@ -0,0 +1,30 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FROM_MAP_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FROM_MAP_H + +#include "utils/bidict/bidict.h" + +namespace FlexFlow { + +template +bidict bidict_from_map(std::unordered_map const &m) { + bidict result; + for (auto const &[k, v] : m) { + ASSERT(!result.contains_r(v)); + result.equate({k, v}); + } + return result; +} + +template +bidict bidict_from_map(std::map const &m) { + bidict result; + for (auto const &[k, v] : m) { + ASSERT(!result.contains_r(v)); + result.equate({k, v}); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/bidict_from_unstructured_relation.h b/lib/utils/include/utils/bidict/algorithms/bidict_from_unstructured_relation.h new file mode 100644 index 0000000000..1f6d7893b6 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/bidict_from_unstructured_relation.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FROM_UNSTRUCTURED_RELATION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FROM_UNSTRUCTURED_RELATION_H + +#include "utils/bidict/bidict.h" + +namespace FlexFlow { + +template +bidict bidict_from_unstructured_relation( + std::unordered_set> const &relation) { + bidict result; + for (auto const &lr : relation) { + result.equate_strict(lr); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/exhaustive_relational_join.h b/lib/utils/include/utils/bidict/algorithms/exhaustive_relational_join.h new file mode 100644 index 0000000000..65cd23bdb0 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/exhaustive_relational_join.h @@ -0,0 +1,27 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_EXHAUSTIVE_RELATIONAL_JOIN_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_EXHAUSTIVE_RELATIONAL_JOIN_H + +#include "utils/bidict/algorithms/left_entries.h" +#include "utils/bidict/algorithms/right_entries.h" +#include "utils/bidict/bidict.h" +#include "utils/exception.h" + +namespace FlexFlow { + +template +bidict exhaustive_relational_join(bidict const &fst, + bidict const &snd) { + ASSERT(right_entries(fst) == left_entries(snd)); + + bidict result; + + for (auto const &[v1, v2] : fst) { + result.equate({v1, snd.at_l(v2)}); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/filter_keys.h b/lib/utils/include/utils/bidict/algorithms/filter_keys.h new file mode 100644 index 0000000000..2734dfaeb5 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/filter_keys.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTER_KEYS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTER_KEYS_H + +#include "utils/bidict/bidict.h" + +namespace FlexFlow { + +template +bidict filter_keys(bidict const &m, F &&f) { + bidict result; + for (auto const &kv : m) { + if (f(kv.first)) { + result.equate(kv); + } + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/filter_values.h b/lib/utils/include/utils/bidict/algorithms/filter_values.h new file mode 100644 index 0000000000..5817578e79 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/filter_values.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTER_VALUES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTER_VALUES_H + +#include "utils/bidict/bidict.h" + +namespace FlexFlow { + +template +bidict filter_values(bidict const &m, F &&f) { + bidict result; + for (auto const &kv : m) { + if (f(kv.second)) { + result.equate(kv); + } + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/filtrans_keys.h b/lib/utils/include/utils/bidict/algorithms/filtrans_keys.h new file mode 100644 index 0000000000..df6495b400 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/filtrans_keys.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTRANS_KEYS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTRANS_KEYS_H + +#include "utils/bidict/bidict.h" + +namespace FlexFlow { + +template ::value_type> +bidict filtrans_keys(bidict const &m, F &&f) { + bidict result; + for (auto const &[k, v] : m) { + std::optional new_k = f(k); + if (new_k.has_value()) { + result.equate(new_k.value(), v); + } + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/filtrans_values.h b/lib/utils/include/utils/bidict/algorithms/filtrans_values.h new file mode 100644 index 0000000000..11180938b8 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/filtrans_values.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTRANS_VALUES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTRANS_VALUES_H + +#include "utils/bidict/bidict.h" + +namespace FlexFlow { + +template ::value_type> +bidict filtrans_values(bidict const &m, F &&f) { + bidict result; + for (auto const &[k, v] : m) { + std::optional new_v = f(v); + if (new_v.has_value()) { + result.equate(k, new_v.value()); + } + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/transform.h b/lib/utils/include/utils/bidict/algorithms/transform.h new file mode 100644 index 0000000000..7fbdd07db7 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/transform.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_TRANSFORM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_TRANSFORM_H + +#include "utils/bidict/bidict.h" + +namespace FlexFlow { + +template ::first_type, + typename V2 = typename std::invoke_result_t::second_type> +bidict transform(bidict const &m, F &&f) { + bidict result; + for (auto const &[k, v] : m) { + result.equate(f(k, v)); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/transform_keys.h b/lib/utils/include/utils/bidict/algorithms/transform_keys.h new file mode 100644 index 0000000000..8ecb10c401 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/transform_keys.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_TRANSFORM_KEYS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_TRANSFORM_KEYS_H + +#include "utils/bidict/bidict.h" + +namespace FlexFlow { + +template > +bidict transform_keys(bidict const &m, F &&f) { + bidict result; + for (auto const &kv : m) { + result.equate(f(kv.first), kv.second); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/transform_values.h b/lib/utils/include/utils/bidict/algorithms/transform_values.h new file mode 100644 index 0000000000..ef5b34ebe9 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/transform_values.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_TRANSFORM_VALUES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_TRANSFORM_VALUES_H + +#include "utils/bidict/bidict.h" + +namespace FlexFlow { + +template > +bidict transform_values(bidict const &m, F &&f) { + bidict result; + for (auto const &kv : m) { + result.equate({kv.first, f(kv.second)}); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/unordered_set_of.h b/lib/utils/include/utils/bidict/algorithms/unordered_set_of.h new file mode 100644 index 0000000000..b3df2514cf --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/unordered_set_of.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_UNORDERED_SET_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_UNORDERED_SET_OF_H + +#include "utils/bidict/bidict.h" +#include "utils/hash/pair.h" + +namespace FlexFlow { + +template +std::unordered_set> unordered_set_of(bidict const &c) { + std::unordered_set> result; + + for (auto const &lr : c) { + result.insert(lr); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/unstructured_relation_from_bidict.h b/lib/utils/include/utils/bidict/algorithms/unstructured_relation_from_bidict.h new file mode 100644 index 0000000000..2ceb527b96 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/unstructured_relation_from_bidict.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_UNSTRUCTURED_RELATION_FROM_BIDICT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_UNSTRUCTURED_RELATION_FROM_BIDICT_H + +#include "utils/bidict/algorithms/unordered_set_of.h" +#include "utils/bidict/bidict.h" + +namespace FlexFlow { + +template +std::unordered_set> + unstructured_relation_from_bidict(bidict const &b) { + return unordered_set_of(b); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/bidict.h b/lib/utils/include/utils/bidict/bidict.h index 8b19313002..5dbd1c603d 100644 --- a/lib/utils/include/utils/bidict/bidict.h +++ b/lib/utils/include/utils/bidict/bidict.h @@ -1,10 +1,17 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_BIDICT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_BIDICT_H +#include "utils/containers/keys.h" +#include "utils/containers/map_from_keys_and_values.h" #include "utils/fmt/unordered_map.h" #include "utils/hash/unordered_map.h" +#include "utils/json/check_is_json_deserializable.h" +#include "utils/json/check_is_json_serializable.h" +#include "utils/ord/unordered_map.h" #include +#include #include +#include #include namespace FlexFlow { @@ -65,6 +72,20 @@ struct bidict { bwd_map.insert({lr.second, lr.first}); } + void equate_strict(L const &l, R const &r) { + ASSERT(this->contains_l(l) == this->contains_r(r)); + + if (this->contains_l(l)) { + ASSERT(this->at_l(l) == r); + } else { + this->equate(l, r); + } + } + + void equate_strict(std::pair const &lr) { + this->equate_strict(lr.first, lr.second); + } + bool operator==(bidict const &other) const { bool result = this->fwd_map == other.fwd_map; assert(result == (this->bwd_map == other.bwd_map)); @@ -85,6 +106,14 @@ struct bidict { return bwd_map.at(r); } + std::unordered_set left_values() const { + return keys(this->fwd_map); + } + + std::unordered_set right_values() const { + return keys(this->bwd_map); + } + std::size_t size() const { assert(fwd_map.size() == bwd_map.size()); return fwd_map.size(); @@ -195,6 +224,12 @@ struct bidict { std::unordered_map bwd_map; }; +template +std::enable_if_t && is_lt_comparable_v, bool> + operator<(bidict const &lhs, bidict const &rhs) { + return lhs.as_unordered_map() < rhs.as_unordered_map(); +} + template std::unordered_map format_as(bidict const &b) { return b; @@ -208,96 +243,54 @@ std::ostream &operator<<(std::ostream &s, bidict const &b) { return s << fmt::to_string(b); } -template ()(std::declval()))> -bidict map_keys(bidict const &m, F const &f) { - bidict result; - for (auto const &kv : m) { - result.equate(f(kv.first), kv.second); - } - return result; -} +} // namespace FlexFlow -template ()(std::declval()))> -bidict map_values(bidict const &m, F const &f) { - bidict result; - for (auto const &kv : m) { - result.equate({kv.first, f(kv.second)}); - } - return result; -} +namespace nlohmann { -template -bidict filter_keys(bidict const &m, F const &f) { - bidict result; - for (auto const &kv : m) { - if (f(kv.first)) { - result.equate(kv); - } - } - return result; -} +template +struct adl_serializer<::FlexFlow::bidict> { + static ::FlexFlow::bidict from_json(json const &j) { + CHECK_IS_JSON_DESERIALIZABLE(L); + CHECK_IS_JSON_DESERIALIZABLE(R); -template -bidict filter_values(bidict const &m, F const &f) { - bidict result; - for (auto const &kv : m) { - if (f(kv.second)) { - result.equate(kv); - } - } - return result; -} + std::unordered_map m = j; -template ::value_type> -bidict filtermap_keys(bidict const &m, F const &f) { - bidict result; - for (auto const &[k, v] : m) { - std::optional new_k = f(k); - if (new_k.has_value()) { - result.equate(new_k.value(), v); - } + ::FlexFlow::bidict b{m.cbegin(), m.cend()}; + + return b; } - return result; -} + static void to_json(json &j, ::FlexFlow::bidict const &b) { + CHECK_IS_JSON_SERIALIZABLE(L); + CHECK_IS_JSON_SERIALIZABLE(R); -template ::value_type> -bidict filtermap_values(bidict const &m, F const &f) { - bidict result; - for (auto const &[k, v] : m) { - std::optional new_v = f(v); - if (new_v.has_value()) { - result.equate(k, new_v.value()); - } + j = b.as_unordered_map(); } - return result; -} +}; + +} // namespace nlohmann + +namespace rc { -template ::first_type, - typename V2 = typename std::invoke_result_t::second_type> -bidict transform(bidict const &m, F const &f) { - bidict result; - for (auto const &[k, v] : m) { - result.equate(f(k, v)); +template +struct Arbitrary<::FlexFlow::bidict> { + static Gen<::FlexFlow::bidict> arbitrary() { + return gen::map( + gen::withSize([](int size) -> Gen> { + return gen::apply( + [](std::vector const &keys, + std::vector const &values) -> std::unordered_map { + return ::FlexFlow::map_from_keys_and_values(keys, values); + }, + gen::unique>(size, gen::arbitrary()), + gen::unique>(size, gen::arbitrary())); + }), + [](std::unordered_map const &m) { + return ::FlexFlow::bidict{m.cbegin(), m.cend()}; + }); } - return result; -} +}; -} // namespace FlexFlow +} // namespace rc namespace std { diff --git a/lib/utils/include/utils/bijection/bijection.dtg.toml b/lib/utils/include/utils/bijection/bijection.dtg.toml new file mode 100644 index 0000000000..23606e8336 --- /dev/null +++ b/lib/utils/include/utils/bijection/bijection.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "Bijection" +type = "struct" +features = [] + +template_params = [ "L", "R" ] + +includes = [ + "" +] + +[[fields]] +name = "l_to_r" +type = "std::function" + +[[fields]] +name = "r_to_l" +type = "std::function" diff --git a/lib/utils/include/utils/bijection/bijection.h b/lib/utils/include/utils/bijection/bijection.h new file mode 100644 index 0000000000..a8fdac22ff --- /dev/null +++ b/lib/utils/include/utils/bijection/bijection.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIJECTION_BIJECTION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIJECTION_BIJECTION_H + +#include "utils/bijection/bijection.dtg.h" + +namespace FlexFlow { + +template +Bijection flip_bijection(Bijection const &b) { + return Bijection{ + /*l_to_r=*/b.r_to_l, + /*r_to_l=*/b.l_to_r, + }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bijection/to.dtg.toml b/lib/utils/include/utils/bijection/to.dtg.toml new file mode 100644 index 0000000000..c8a39c4bbf --- /dev/null +++ b/lib/utils/include/utils/bijection/to.dtg.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "To" +type = "struct" +features = [] + +template_params = [ "L", "R" ] + +includes = [ + "" +] + +[[fields]] +name = "func" +type = "std::function" diff --git a/lib/utils/include/utils/cli/cli_argument_key.dtg.toml b/lib/utils/include/utils/cli/cli_argument_key.dtg.toml new file mode 100644 index 0000000000..bea9ed3eb8 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_argument_key.dtg.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "CLIArgumentKey" +type = "variant" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "utils/cli/cli_positional_argument_key.dtg.h", + "utils/cli/cli_flag_key.dtg.h", +] + +[[values]] +type = "::FlexFlow::CLIPositionalArgumentKey" + +[[values]] +type = "::FlexFlow::CLIFlagKey" diff --git a/lib/utils/include/utils/cli/cli_argument_key.variant.toml b/lib/utils/include/utils/cli/cli_argument_key.variant.toml deleted file mode 100644 index be118160ce..0000000000 --- a/lib/utils/include/utils/cli/cli_argument_key.variant.toml +++ /dev/null @@ -1,19 +0,0 @@ -namespace = "FlexFlow" -name = "CLIArgumentKey" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "", - "utils/cli/cli_positional_argument_key.dtg.h", - "utils/cli/cli_flag_key.dtg.h", -] - -[[values]] -type = "::FlexFlow::CLIPositionalArgumentKey" - -[[values]] -type = "::FlexFlow::CLIFlagKey" diff --git a/lib/utils/include/utils/cli/cli_flag_key.dtg.toml b/lib/utils/include/utils/cli/cli_flag_key.dtg.toml new file mode 100644 index 0000000000..9fdc33e91b --- /dev/null +++ b/lib/utils/include/utils/cli/cli_flag_key.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "CLIFlagKey" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/nonnegative_int/nonnegative_int.h", +] + +[[fields]] +name = "raw_idx" +type = "::FlexFlow::nonnegative_int" diff --git a/lib/utils/include/utils/cli/cli_flag_key.struct.toml b/lib/utils/include/utils/cli/cli_flag_key.struct.toml deleted file mode 100644 index 9c02fddc3e..0000000000 --- a/lib/utils/include/utils/cli/cli_flag_key.struct.toml +++ /dev/null @@ -1,15 +0,0 @@ -namespace = "FlexFlow" -name = "CLIFlagKey" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "utils/nonnegative_int/nonnegative_int.h", -] - -[[fields]] -name = "raw_idx" -type = "::FlexFlow::nonnegative_int" diff --git a/lib/utils/include/utils/cli/cli_flag_spec.dtg.toml b/lib/utils/include/utils/cli/cli_flag_spec.dtg.toml new file mode 100644 index 0000000000..1bb29953e6 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_flag_spec.dtg.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "CLIFlagSpec" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "", +] + +src_includes = [ + "utils/fmt/optional.h", +] + +[[fields]] +name = "long_flag" +type = "std::string" + +[[fields]] +name = "short_flag" +type = "std::optional" + +[[fields]] +name = "description" +type = "std::optional" diff --git a/lib/utils/include/utils/cli/cli_flag_spec.struct.toml b/lib/utils/include/utils/cli/cli_flag_spec.struct.toml deleted file mode 100644 index 66a47de067..0000000000 --- a/lib/utils/include/utils/cli/cli_flag_spec.struct.toml +++ /dev/null @@ -1,28 +0,0 @@ -namespace = "FlexFlow" -name = "CLIFlagSpec" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "", - "", -] - -src_includes = [ - "utils/fmt/optional.h", -] - -[[fields]] -name = "long_flag" -type = "std::string" - -[[fields]] -name = "short_flag" -type = "std::optional" - -[[fields]] -name = "description" -type = "std::optional" diff --git a/lib/utils/include/utils/cli/cli_parse_result.dtg.toml b/lib/utils/include/utils/cli/cli_parse_result.dtg.toml new file mode 100644 index 0000000000..fc52613e51 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_parse_result.dtg.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "CLIParseResult" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "", + "utils/cli/cli_flag_key.dtg.h", + "utils/cli/cli_positional_argument_key.dtg.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "flags" +type = "std::unordered_map<::FlexFlow::CLIFlagKey, bool>" + +[[fields]] +name = "positional_arguments" +type = "std::unordered_map<::FlexFlow::CLIPositionalArgumentKey, std::string>" diff --git a/lib/utils/include/utils/cli/cli_parse_result.struct.toml b/lib/utils/include/utils/cli/cli_parse_result.struct.toml deleted file mode 100644 index b63da7be14..0000000000 --- a/lib/utils/include/utils/cli/cli_parse_result.struct.toml +++ /dev/null @@ -1,27 +0,0 @@ -namespace = "FlexFlow" -name = "CLIParseResult" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "", - "", - "utils/cli/cli_flag_key.dtg.h", - "utils/cli/cli_positional_argument_key.dtg.h", -] - -src_includes = [ - "utils/fmt/unordered_map.h", - "utils/hash/unordered_map.h", -] - -[[fields]] -name = "flags" -type = "std::unordered_map<::FlexFlow::CLIFlagKey, bool>" - -[[fields]] -name = "positional_arguments" -type = "std::unordered_map<::FlexFlow::CLIPositionalArgumentKey, std::string>" diff --git a/lib/utils/include/utils/cli/cli_positional_argument_key.dtg.toml b/lib/utils/include/utils/cli/cli_positional_argument_key.dtg.toml new file mode 100644 index 0000000000..2ed6eed5b4 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_positional_argument_key.dtg.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "CLIPositionalArgumentKey" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/nonnegative_int/nonnegative_int.h", +] + +[[fields]] +name = "raw_idx" +type = "::FlexFlow::nonnegative_int" diff --git a/lib/utils/include/utils/cli/cli_positional_argument_key.struct.toml b/lib/utils/include/utils/cli/cli_positional_argument_key.struct.toml deleted file mode 100644 index 4c50c277c0..0000000000 --- a/lib/utils/include/utils/cli/cli_positional_argument_key.struct.toml +++ /dev/null @@ -1,15 +0,0 @@ -namespace = "FlexFlow" -name = "CLIPositionalArgumentKey" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "utils/nonnegative_int/nonnegative_int.h", -] - -[[fields]] -name = "raw_idx" -type = "::FlexFlow::nonnegative_int" diff --git a/lib/utils/include/utils/cli/cli_positional_argument_spec.dtg.toml b/lib/utils/include/utils/cli/cli_positional_argument_spec.dtg.toml new file mode 100644 index 0000000000..34312f6a04 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_positional_argument_spec.dtg.toml @@ -0,0 +1,32 @@ +namespace = "FlexFlow" +name = "CLIPositionalArgumentSpec" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "", + "", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "name" +type = "std::string" + +[[fields]] +name = "choices" +type = "std::optional>" + +[[fields]] +name = "description" +type = "std::optional" diff --git a/lib/utils/include/utils/cli/cli_positional_argument_spec.struct.toml b/lib/utils/include/utils/cli/cli_positional_argument_spec.struct.toml deleted file mode 100644 index b1e74701ee..0000000000 --- a/lib/utils/include/utils/cli/cli_positional_argument_spec.struct.toml +++ /dev/null @@ -1,31 +0,0 @@ -namespace = "FlexFlow" -name = "CLIPositionalArgumentSpec" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "", - "", - "", -] - -src_includes = [ - "utils/fmt/optional.h", - "utils/fmt/vector.h", - "utils/hash/vector.h", -] - -[[fields]] -name = "name" -type = "std::string" - -[[fields]] -name = "choices" -type = "std::optional>" - -[[fields]] -name = "description" -type = "std::optional" diff --git a/lib/utils/include/utils/cli/cli_spec.dtg.toml b/lib/utils/include/utils/cli/cli_spec.dtg.toml new file mode 100644 index 0000000000..39ff325eac --- /dev/null +++ b/lib/utils/include/utils/cli/cli_spec.dtg.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "CLISpec" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "utils/cli/cli_flag_spec.dtg.h", + "utils/cli/cli_positional_argument_spec.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/unordered_set.h", + "utils/hash/unordered_set.h", + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "flags" +type = "std::vector<::FlexFlow::CLIFlagSpec>" + +[[fields]] +name = "positional_arguments" +type = "std::vector<::FlexFlow::CLIPositionalArgumentSpec>" diff --git a/lib/utils/include/utils/cli/cli_spec.struct.toml b/lib/utils/include/utils/cli/cli_spec.struct.toml deleted file mode 100644 index 9f64f62c15..0000000000 --- a/lib/utils/include/utils/cli/cli_spec.struct.toml +++ /dev/null @@ -1,29 +0,0 @@ -namespace = "FlexFlow" -name = "CLISpec" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "", - "utils/cli/cli_flag_spec.dtg.h", - "utils/cli/cli_positional_argument_spec.dtg.h", - "", -] - -src_includes = [ - "utils/fmt/unordered_set.h", - "utils/hash/unordered_set.h", - "utils/fmt/vector.h", - "utils/hash/vector.h", -] - -[[fields]] -name = "flags" -type = "std::vector<::FlexFlow::CLIFlagSpec>" - -[[fields]] -name = "positional_arguments" -type = "std::vector<::FlexFlow::CLIPositionalArgumentSpec>" diff --git a/lib/utils/include/utils/containers/all_of.h b/lib/utils/include/utils/containers/all_of.h index fb44aeaed8..ef5aac1c41 100644 --- a/lib/utils/include/utils/containers/all_of.h +++ b/lib/utils/include/utils/containers/all_of.h @@ -1,10 +1,14 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ALL_OF_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ALL_OF_H +#include +#include +#include + namespace FlexFlow { template -bool all_of(C const &c, F const &f) { +bool all_of(C const &c, F &&f) { for (auto const &v : c) { if (!f(v)) { return false; @@ -13,6 +17,30 @@ bool all_of(C const &c, F const &f) { return true; } +template +bool all_of(std::unordered_map const &m, F &&f) { + for (auto const &[k, v] : m) { + if (!f(k, v)) { + return false; + } + } + + return true; +} + +template +bool all_of(std::map const &m, F &&f) { + for (auto const &[k, v] : m) { + if (!f(k, v)) { + return false; + } + } + + return true; +} + +bool all_of(std::vector const &); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/at_idx.h b/lib/utils/include/utils/containers/at_idx.h index fdc13a0231..2442a759ac 100644 --- a/lib/utils/include/utils/containers/at_idx.h +++ b/lib/utils/include/utils/containers/at_idx.h @@ -2,18 +2,17 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_AT_IDX_H #include "utils/nonnegative_int/nonnegative_int.h" +#include #include #include namespace FlexFlow { template -std::optional at_idx(std::vector const &v, nonnegative_int idx) { - if (idx >= v.size()) { - return std::nullopt; - } else { - return v.at(idx.unwrap_nonnegative()); - } +E at_idx(std::vector const &v, nonnegative_int idx) { + ASSERT(idx < v.size()); + + return v.at(idx.unwrap_nonnegative()); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/containers/binary_cartesian_product.h b/lib/utils/include/utils/containers/binary_cartesian_product.h new file mode 100644 index 0000000000..1e9f5febbf --- /dev/null +++ b/lib/utils/include/utils/containers/binary_cartesian_product.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_CARTESIAN_PRODUCT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_CARTESIAN_PRODUCT_H + +#include "utils/hash/pair.h" +#include + +namespace FlexFlow { + +template +std::unordered_set> + binary_cartesian_product(std::unordered_set const &lhs, + std::unordered_set const &rhs) { + std::unordered_set> result; + + for (A const &a : lhs) { + for (B const &b : rhs) { + result.insert({a, b}); + } + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/binary_merge_disjoint_maps.h b/lib/utils/include/utils/containers/binary_merge_disjoint_maps.h new file mode 100644 index 0000000000..06a42327e1 --- /dev/null +++ b/lib/utils/include/utils/containers/binary_merge_disjoint_maps.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_DISJOINT_MAPS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_DISJOINT_MAPS_H + +#include "utils/containers/binary_merge_maps_with.h" +#include + +namespace FlexFlow { + +template +std::unordered_map + binary_merge_disjoint_maps(std::unordered_map const &lhs, + std::unordered_map const &rhs) { + + std::unordered_set lhs_keys = keys(lhs); + std::unordered_set rhs_keys = keys(rhs); + + std::unordered_set shared_keys = intersection(lhs_keys, rhs_keys); + ASSERT(shared_keys.empty()); + + return binary_merge_maps_with( + lhs, rhs, [](V const &, V const &) -> V { PANIC(); }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/binary_merge_maps_with.h b/lib/utils/include/utils/containers/binary_merge_maps_with.h new file mode 100644 index 0000000000..a7c196d061 --- /dev/null +++ b/lib/utils/include/utils/containers/binary_merge_maps_with.h @@ -0,0 +1,42 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_MAPS_WITH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_MAPS_WITH_H + +#include "utils/containers/generate_map.h" +#include "utils/containers/intersection.h" +#include "utils/containers/keys.h" +#include "utils/containers/merge_maps_with_right_dominating.h" +#include "utils/containers/restrict_keys.h" +#include "utils/containers/set_minus.h" +#include + +namespace FlexFlow { + +template +std::unordered_map + binary_merge_maps_with(std::unordered_map const &lhs, + std::unordered_map const &rhs, + F &&f) { + + std::unordered_set l_keys = keys(lhs); + std::unordered_set r_keys = keys(rhs); + + std::unordered_set l_only_keys = set_minus(l_keys, r_keys); + std::unordered_set r_only_keys = set_minus(r_keys, l_keys); + std::unordered_set both_keys = intersection(r_keys, l_keys); + + std::unordered_map l_only = restrict_keys(lhs, l_only_keys); + std::unordered_map r_only = restrict_keys(rhs, r_only_keys); + + std::unordered_map merged = generate_map( + both_keys, [&](K const &k) { return f(lhs.at(k), rhs.at(k)); }); + + return merge_maps_with_right_dominating(std::vector{ + l_only, + r_only, + merged, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/binary_merge_maps_with_left_dominating.h b/lib/utils/include/utils/containers/binary_merge_maps_with_left_dominating.h new file mode 100644 index 0000000000..f6e23af11c --- /dev/null +++ b/lib/utils/include/utils/containers/binary_merge_maps_with_left_dominating.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_MAPS_WITH_LEFT_DOMINATING_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_MAPS_WITH_LEFT_DOMINATING_H + +#include "utils/containers/merge_in_map.h" + +namespace FlexFlow { + +template +std::unordered_map binary_merge_maps_with_left_dominating( + std::unordered_map const &lhs, std::unordered_map const &rhs) { + std::unordered_map result; + merge_in_map(rhs, result); + merge_in_map(lhs, result); + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/binary_merge_maps_with_right_dominating.h b/lib/utils/include/utils/containers/binary_merge_maps_with_right_dominating.h new file mode 100644 index 0000000000..e5e29dfcb9 --- /dev/null +++ b/lib/utils/include/utils/containers/binary_merge_maps_with_right_dominating.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_MAPS_WITH_RIGHT_DOMINATING_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_MAPS_WITH_RIGHT_DOMINATING_H + +#include "utils/containers/merge_in_map.h" + +namespace FlexFlow { + +template +std::unordered_map binary_merge_maps_with_right_dominating( + std::unordered_map const &lhs, std::unordered_map const &rhs) { + std::unordered_map result; + merge_in_map(lhs, result); + merge_in_map(rhs, result); + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/contains_duplicates.h b/lib/utils/include/utils/containers/contains_duplicates.h new file mode 100644 index 0000000000..8203c5e882 --- /dev/null +++ b/lib/utils/include/utils/containers/contains_duplicates.h @@ -0,0 +1,27 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_CONTAINS_DUPLICATES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_CONTAINS_DUPLICATES_H + +#include "utils/containers/unordered_set_of.h" +#include +#include + +namespace FlexFlow { + +template +bool contains_duplicates(std::vector const &s) { + return unordered_set_of(s).size() != s.size(); +} + +template +bool contains_duplicates(std::unordered_multiset const &s) { + return unordered_set_of(s).size() != s.size(); +} + +template +bool contains_duplicates(std::multiset const &s) { + return unordered_set_of(s).size() != s.size(); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/count.h b/lib/utils/include/utils/containers/count.h index bae4ba104c..60955b3268 100644 --- a/lib/utils/include/utils/containers/count.h +++ b/lib/utils/include/utils/containers/count.h @@ -1,14 +1,15 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_COUNT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_COUNT_H +#include "utils/nonnegative_int/nonnegative_int.h" #include #include namespace FlexFlow { template -int count(C const &c, F const &f) { - int result = 0; +nonnegative_int count(C const &c, F const &f) { + nonnegative_int result = 0_n; for (auto const &v : c) { if (f(v)) { result++; @@ -17,8 +18,6 @@ int count(C const &c, F const &f) { return result; } -std::vector count(size_t n); - } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/extend.h b/lib/utils/include/utils/containers/extend.h index fa4e2d24a8..4a07d07110 100644 --- a/lib/utils/include/utils/containers/extend.h +++ b/lib/utils/include/utils/containers/extend.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_EXTEND_H #include "utils/containers/extend_vector.h" -#include +#include #include namespace FlexFlow { @@ -12,24 +12,26 @@ void extend(std::vector &lhs, C const &rhs) { extend_vector(lhs, rhs); } -template -void extend(std::vector &lhs, std::optional const &rhs) { - if (rhs.has_value()) { - extend(lhs, std::vector{rhs.value()}); - } +template +void extend(std::unordered_set &lhs, C const &rhs) { + lhs.reserve(lhs.size() + std::distance(rhs.begin(), rhs.end())); + lhs.insert(rhs.cbegin(), rhs.cend()); } template -void extend(std::unordered_set &lhs, C const &rhs) { +void extend(std::unordered_multiset &lhs, C const &rhs) { lhs.reserve(lhs.size() + std::distance(rhs.begin(), rhs.end())); lhs.insert(rhs.cbegin(), rhs.cend()); } -template -void extend(std::unordered_set &lhs, std::optional const &rhs) { - if (rhs.has_value()) { - extend(lhs, std::vector{rhs.value()}); - } +template +void extend(std::set &lhs, C const &rhs) { + lhs.insert(rhs.cbegin(), rhs.cend()); +} + +template +void extend(std::multiset &lhs, C const &rhs) { + lhs.insert(rhs.cbegin(), rhs.cend()); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/containers/filter_idxs.h b/lib/utils/include/utils/containers/filter_idxs.h new file mode 100644 index 0000000000..0de9a05130 --- /dev/null +++ b/lib/utils/include/utils/containers/filter_idxs.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FILTER_IDXS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FILTER_IDXS_H + +#include "utils/nonnegative_int/nonnegative_int.h" +#include "utils/nonnegative_int/num_elements.h" +#include "utils/nonnegative_int/range.h" +#include +#include + +namespace FlexFlow { + +template +std::vector filter_idxs(std::vector const &input, + std::function const &f) { + std::vector result; + + for (nonnegative_int idx : range(num_elements(input))) { + if (f(idx)) { + result.push_back(input.at(idx.unwrap_nonnegative())); + } + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/flatmap.h b/lib/utils/include/utils/containers/flatmap.h index eaa8d1dbef..70de6b5020 100644 --- a/lib/utils/include/utils/containers/flatmap.h +++ b/lib/utils/include/utils/containers/flatmap.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FLATMAP_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FLATMAP_H +#include "utils/containers/binary_merge_disjoint_maps.h" #include "utils/containers/extend.h" #include "utils/containers/get_element_type.h" -#include "utils/containers/merge_maps.h" #include #include #include @@ -13,7 +13,7 @@ namespace FlexFlow { template ::value_type> -std::vector flatmap(std::vector const &v, F const &f) { +std::vector flatmap(std::vector const &v, F &&f) { std::vector result; for (auto const &elem : v) { extend(result, f(elem)); @@ -24,7 +24,7 @@ std::vector flatmap(std::vector const &v, F const &f) { template >> -std::unordered_set flatmap(std::unordered_set const &v, F const &f) { +std::unordered_set flatmap(std::unordered_set const &v, F &&f) { std::unordered_set result; for (auto const &elem : v) { extend(result, f(elem)); @@ -32,10 +32,12 @@ std::unordered_set flatmap(std::unordered_set const &v, F const &f) { return result; } -template -std::unordered_set flatmap_v2(std::unordered_set const &v, - std::unordered_set (*f)(In const &)) { - std::unordered_set result; +template >> +std::unordered_multiset flatmap(std::unordered_multiset const &v, + F &&f) { + std::unordered_multiset result; for (auto const &elem : v) { extend(result, f(elem)); } @@ -45,7 +47,7 @@ std::unordered_set flatmap_v2(std::unordered_set const &v, template >> -std::set flatmap(std::set const &v, F const &f) { +std::set flatmap(std::set const &v, F &&f) { std::set result; for (auto const &elem : v) { extend(result, f(elem)); @@ -53,6 +55,17 @@ std::set flatmap(std::set const &v, F const &f) { return result; } +template >> +std::multiset flatmap(std::multiset const &v, F &&f) { + std::multiset result; + for (auto const &elem : v) { + extend(result, f(elem)); + } + return result; +} + template < typename InK, typename InV, @@ -64,14 +77,26 @@ std::unordered_map flatmap(std::unordered_map const &m, std::unordered_map result; for (auto const &[k, v] : m) { - result = merge_disjoint_maps(result, f(k, v)); + result = binary_merge_disjoint_maps(result, f(k, v)); } return result; } +template ::value_type> +std::optional flatmap(std::optional const &o, F &&f) { + if (o.has_value()) { + std::optional r = f(o.value()); + return r; + } else { + return std::nullopt; + } +} + template -std::string flatmap(std::string const &input, F const &f) { +std::string flatmap(std::string const &input, F &&f) { std::string result = ""; for (char c : input) { diff --git a/lib/utils/include/utils/containers/foldl.h b/lib/utils/include/utils/containers/foldl.h index 16851d7d9b..5b99b23a7c 100644 --- a/lib/utils/include/utils/containers/foldl.h +++ b/lib/utils/include/utils/containers/foldl.h @@ -25,7 +25,7 @@ namespace FlexFlow { * https://hackage.haskell.org/package/base-4.20.0.1/docs/Prelude.html#v:foldl */ template -T foldl(C const &c, T init, F func) { +T foldl(C const &c, T const &init, F func) { T result = init; for (auto const &elem : c) { result = func(result, elem); @@ -33,40 +33,6 @@ T foldl(C const &c, T init, F func) { return result; } -/** - * @brief - * Applies `func` to the elements of `c` from left to right, accumulating the - * result. The first element of `c` is used as the starting point for the - * accumulation. - * - * @example - * std::vector nums = {1, 2, 3, 4}; - * int result = foldl1(nums, [](int a, int b) { return a + b; }); - * result -> (((1+2)+3)+4) = 10 - * - * @note - * For more information, see - * https://hackage.haskell.org/package/base-4.20.0.1/docs/Prelude.html#v:foldl1 - * @throws std::runtime_error if the container is empty. - */ -template -E foldl1(C const &c, F func) { - if (c.empty()) { - throw mk_runtime_error( - fmt::format("foldl1 received empty container: {}", c)); - } - std::optional result = std::nullopt; - - for (E const &e : c) { - if (!result.has_value()) { - result = e; - } else { - result = func(result.value(), e); - } - } - return result.value(); -} - } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/foldl1.h b/lib/utils/include/utils/containers/foldl1.h index f542f8cf00..9125da1eb8 100644 --- a/lib/utils/include/utils/containers/foldl1.h +++ b/lib/utils/include/utils/containers/foldl1.h @@ -1,27 +1,41 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FOLDL1_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FOLDL1_H -#include "utils/exception.h" +#include #include namespace FlexFlow { -template -T foldl1(std::vector const &vec, F f) { - if (vec.empty()) { - throw mk_runtime_error(fmt::format( - "foldl1 expected non-empty vector, but receieved empty vector")); +/** + * @brief + * Applies `func` to the elements of `c` from left to right, accumulating the + * result. The first element of `c` is used as the starting point for the + * accumulation. + * + * @example + * std::vector nums = {1, 2, 3, 4}; + * int result = foldl1(nums, [](int a, int b) { return a + b; }); + * result -> (((1+2)+3)+4) = 10 + * + * @note + * For more information, see + * https://hackage.haskell.org/package/base-4.20.0.1/docs/Prelude.html#v:foldl1 + * @throws std::runtime_error if the container is empty. + */ +template +E foldl1(C const &c, F func) { + ASSERT(!c.empty(), + "foldl1 expected non-empty vector, but received empty vector"); + std::optional result = std::nullopt; + + for (E const &e : c) { + if (!result.has_value()) { + result = e; + } else { + result = func(result.value(), e); + } } - - auto it = vec.cbegin(); - T result = *it; - it++; - - for (; it != vec.cend(); it++) { - result = f(result, *it); - } - - return result; + return result.value(); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/containers/foldr.h b/lib/utils/include/utils/containers/foldr.h new file mode 100644 index 0000000000..08601b1fa2 --- /dev/null +++ b/lib/utils/include/utils/containers/foldr.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FOLDR_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FOLDR_H + +#include "utils/exception.h" +#include + +namespace FlexFlow { + +/** + * @brief + * Iteratively applies `func` to the elements of `c` from right to left. + * `init` is used as the starting value. + * + * @example + * std::vector nums = {1, 2, 3, 4}; + * int result = foldl(nums, 0, [](int a, int b) { return a + b; }); + * result -> (0+(1+(2+(3+4)))) = 10 + * + * @note + * For more information, see + * https://hackage.haskell.org/package/base-4.20.0.1/docs/Prelude.html#v:foldr + */ +template +T foldr(C const &c, T const &init, F func) { + T result = init; + for (auto const &elem : c) { + result = func(result, elem); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/generate_vector.h b/lib/utils/include/utils/containers/generate_vector.h new file mode 100644 index 0000000000..71e8cc5acb --- /dev/null +++ b/lib/utils/include/utils/containers/generate_vector.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GENERATE_VECTOR_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GENERATE_VECTOR_H + +#include "utils/nonnegative_int/nonnegative_int.h" +#include "utils/nonnegative_int/range.h" +#include + +namespace FlexFlow { + +template > +std::vector generate_vector(nonnegative_int length, F &&f) { + std::vector result; + for (nonnegative_int idx : range(length)) { + result.push_back(f(idx)); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/get_all_permutations_with_repetition.h b/lib/utils/include/utils/containers/get_all_permutations_with_repetition.h index 0a7e9d16c2..6201845a64 100644 --- a/lib/utils/include/utils/containers/get_all_permutations_with_repetition.h +++ b/lib/utils/include/utils/containers/get_all_permutations_with_repetition.h @@ -1,6 +1,8 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ALL_PERMUTATIONS_WITH_REPETITION_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ALL_PERMUTATIONS_WITH_REPETITION_H +#include "utils/containers/transform.h" +#include "utils/hash/vector.h" #include "utils/nonnegative_int/nonnegative_int.h" #include #include @@ -27,11 +29,13 @@ std::unordered_multiset> std::vector indices(n.unwrap_nonnegative(), 0); while (true) { - std::vector perm(n.unwrap_nonnegative()); + std::vector> perm(n.unwrap_nonnegative()); for (int i = 0; i < n; ++i) { perm[i] = elements[indices[i]]; } - result.insert(perm); + std::vector unwrapped_perm = + transform(perm, [](std::optional const &t) { return t.value(); }); + result.insert(unwrapped_perm); int i = n.unwrap_nonnegative() - 1; while (i != -1 && ++indices[i] == elements.size()) { diff --git a/lib/utils/include/utils/containers/get_only.h b/lib/utils/include/utils/containers/get_only.h index 201095c47d..ed44b26c36 100644 --- a/lib/utils/include/utils/containers/get_only.h +++ b/lib/utils/include/utils/containers/get_only.h @@ -15,6 +15,13 @@ typename C::value_type get_only(C const &c) { }); } +template +std::pair get_only(std::unordered_map const &m) { + ASSERT(m.size() == 1); + + return *m.cbegin(); +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/group_by.h b/lib/utils/include/utils/containers/group_by.h index 6abffbfed0..e5269e6958 100644 --- a/lib/utils/include/utils/containers/group_by.h +++ b/lib/utils/include/utils/containers/group_by.h @@ -1,18 +1,39 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GROUP_BY_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GROUP_BY_H +#include "utils/one_to_many/one_to_many.h" +#include #include #include #include +#include namespace FlexFlow { template > -std::unordered_map> - group_by(std::unordered_set const &vs, F f) { - std::unordered_map> result; +OneToMany group_by(std::unordered_set const &vs, F &&f) { + OneToMany result; for (V const &v : vs) { - result[f(v)].insert(v); + result.insert({f(v), v}); + } + return result; +} + +template > +OneToMany group_by(std::set const &vs, F &&f) { + OneToMany result; + for (V const &v : vs) { + result.insert({f(v), v}); + } + return result; +} + +template > +std::unordered_map> group_by(std::vector const &vs, + F &&f) { + std::unordered_map> result; + for (V const &v : vs) { + result[f(v)].push_back(v); } return result; } diff --git a/lib/utils/include/utils/containers/intersection.h b/lib/utils/include/utils/containers/intersection.h index 938ebd68c9..55e6c7a5f8 100644 --- a/lib/utils/include/utils/containers/intersection.h +++ b/lib/utils/include/utils/containers/intersection.h @@ -3,6 +3,7 @@ #include "utils/containers/contains.h" #include +#include #include namespace FlexFlow { @@ -19,6 +20,17 @@ std::unordered_set intersection(std::unordered_set const &l, return result; } +template +std::set intersection(std::set const &l, std::set const &r) { + std::set result; + for (T const &ll : l) { + if (contains(r, ll)) { + result.insert(ll); + } + } + return result; +} + template std::optional intersection(C const &c) { std::optional result; diff --git a/lib/utils/include/utils/containers/is_subseteq_of.h b/lib/utils/include/utils/containers/is_subseteq_of.h index 705c092962..e435aa24dd 100644 --- a/lib/utils/include/utils/containers/is_subseteq_of.h +++ b/lib/utils/include/utils/containers/is_subseteq_of.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_IS_SUBSETEQ_OF_H #include "utils/containers/contains.h" +#include #include namespace FlexFlow { @@ -21,6 +22,20 @@ bool is_subseteq_of(std::unordered_set const &sub, return true; } +template +bool is_subseteq_of(std::set const &l, std::set const &r) { + if (l.size() > r.size()) { + return false; + } + + for (auto const &ll : l) { + if (!contains(r, ll)) { + return false; + } + } + return true; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/lift_optional_through_map.h b/lib/utils/include/utils/containers/lift_optional_through_map.h new file mode 100644 index 0000000000..189d1b0519 --- /dev/null +++ b/lib/utils/include/utils/containers/lift_optional_through_map.h @@ -0,0 +1,39 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_LIFT_OPTIONAL_THROUGH_MAP_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_LIFT_OPTIONAL_THROUGH_MAP_H + +#include "utils/containers/all_of.h" +#include "utils/containers/map_values.h" +#include "utils/containers/values.h" +#include +#include +#include + +namespace FlexFlow { + +template +static std::optional> lift_optional_through_map( + std::unordered_map> const &m) { + ASSERT(!m.empty()); + + std::unordered_multiset> m_values = values(m); + + bool has_all_values = all_of(m_values, [](std::optional const &t) -> bool { + return t.has_value(); + }); + + bool has_no_values = all_of(m_values, [](std::optional const &t) -> bool { + return !t.has_value(); + }); + + ASSERT(has_all_values || has_no_values); + if (has_no_values) { + return std::nullopt; + } else { + return map_values(m, + [](std::optional const &t) -> V { return t.value(); }); + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/make_counter_func.h b/lib/utils/include/utils/containers/make_counter_func.h new file mode 100644 index 0000000000..eddc81adeb --- /dev/null +++ b/lib/utils/include/utils/containers/make_counter_func.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAKE_COUNTER_FUNC_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAKE_COUNTER_FUNC_H + +#include + +namespace FlexFlow { + +std::function make_counter_func(int start = 0); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/map_from_keys_and_values.h b/lib/utils/include/utils/containers/map_from_keys_and_values.h index 499965dc5e..8a9e36ff4e 100644 --- a/lib/utils/include/utils/containers/map_from_keys_and_values.h +++ b/lib/utils/include/utils/containers/map_from_keys_and_values.h @@ -2,7 +2,8 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_FROM_KEYS_AND_VALUES_H #include "utils/containers/zip.h" -#include "utils/exception.h" +#include +#include #include namespace FlexFlow { @@ -11,12 +12,8 @@ template std::unordered_map map_from_keys_and_values(std::vector const &keys, std::vector const &values) { - if (keys.size() != values.size()) { - throw mk_runtime_error(fmt::format( - "recieved keys (of size {}) not matching values (of size {})", - keys.size(), - values.size())); - } + ASSERT(keys.size() == values.size()); + std::unordered_map result; for (auto const &[k, v] : zip(keys, values)) { result.insert({k, v}); diff --git a/lib/utils/include/utils/containers/map_from_pairs.h b/lib/utils/include/utils/containers/map_from_pairs.h new file mode 100644 index 0000000000..7c470d4d3e --- /dev/null +++ b/lib/utils/include/utils/containers/map_from_pairs.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_FROM_PAIRS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_FROM_PAIRS_H + +#include +#include + +namespace FlexFlow { + +template +std::unordered_map + map_from_pairs(std::unordered_set> const &pairs) { + + std::unordered_map result(pairs.cbegin(), pairs.cend()); + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/map_keys.h b/lib/utils/include/utils/containers/map_keys.h index 4e5352748d..5cd44d8a5d 100644 --- a/lib/utils/include/utils/containers/map_keys.h +++ b/lib/utils/include/utils/containers/map_keys.h @@ -25,10 +25,9 @@ std::unordered_map map_keys(std::unordered_map const &m, for (auto const &kv : m) { result.insert({f(kv.first), kv.second}); } - if (keys(m).size() != keys(result).size()) { - throw mk_runtime_error( - "keys passed to map_keys must be transformed into distinct keys"); - } + + ASSERT(keys(m).size() == keys(result).size(), + "keys passed to map_keys must be transformed into distinct keys"); return result; } diff --git a/lib/utils/include/utils/containers/map_keys2.h b/lib/utils/include/utils/containers/map_keys2.h new file mode 100644 index 0000000000..fd848f18d8 --- /dev/null +++ b/lib/utils/include/utils/containers/map_keys2.h @@ -0,0 +1,30 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_KEYS2_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_KEYS2_H + +#include "utils/containers/keys.h" +#include +#include + +namespace FlexFlow { + +template > +std::unordered_map map_keys2(std::unordered_map const &m, + F const &f) { + + std::unordered_map result; + for (auto const &kv : m) { + result.insert({f(kv.first, kv.second), kv.second}); + } + + ASSERT(keys(m).size() == keys(result).size(), + "keys passed to map_keys must be transformed into distinct keys"); + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/map_keys_with_value_merging.h b/lib/utils/include/utils/containers/map_keys_with_value_merging.h new file mode 100644 index 0000000000..93c046f017 --- /dev/null +++ b/lib/utils/include/utils/containers/map_keys_with_value_merging.h @@ -0,0 +1,37 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_KEYS_WITH_VALUE_MERGING_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_KEYS_WITH_VALUE_MERGING_H + +#include "utils/containers/contains_key.h" +#include + +namespace FlexFlow { + +template > +std::unordered_map map_keys_with_value_merging( + std::unordered_map const &m, F &&key_func, MergeF &&merge_values) { + + std::unordered_map result; + + for (auto const &kv : m) { + K k = kv.first; + V v = kv.second; + + K2 k2 = key_func(k); + + if (contains_key(result, k2)) { + result.at(k2) = merge_values(result.at(k2), v); + } else { + result.insert({k2, v}); + } + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/map_values.h b/lib/utils/include/utils/containers/map_values.h index 9f7a4f4add..bf377b2c93 100644 --- a/lib/utils/include/utils/containers/map_values.h +++ b/lib/utils/include/utils/containers/map_values.h @@ -10,11 +10,10 @@ template > -std::unordered_map map_values(std::unordered_map const &m, - F const &f) { +std::unordered_map map_values(std::unordered_map const &m, F &&f) { std::unordered_map result; - for (auto const &kv : m) { - result.insert({kv.first, f(kv.second)}); + for (std::pair const &kv : m) { + result.insert(std::pair{kv.first, f(kv.second)}); } return result; } diff --git a/lib/utils/include/utils/containers/map_values2.h b/lib/utils/include/utils/containers/map_values2.h new file mode 100644 index 0000000000..752a8babd3 --- /dev/null +++ b/lib/utils/include/utils/containers/map_values2.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_VALUES2_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_VALUES2_H + +#include +#include + +namespace FlexFlow { + +template > +std::unordered_map map_values2(std::unordered_map const &m, + F &&f) { + std::unordered_map result; + for (std::pair const &kv : m) { + result.insert(std::pair{kv.first, f(kv.first, kv.second)}); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/maximum.h b/lib/utils/include/utils/containers/maximum.h index b3d6d0c6d7..6b3cbb6ebf 100644 --- a/lib/utils/include/utils/containers/maximum.h +++ b/lib/utils/include/utils/containers/maximum.h @@ -1,15 +1,16 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAXIMUM_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAXIMUM_H -#include "utils/exception.h" #include +#include +#include namespace FlexFlow { template typename C::value_type maximum(C const &c) { if (c.empty()) { - throw mk_runtime_error( + PANIC( fmt::format("maximum expected non-empty container but received {}", c)); } diff --git a/lib/utils/include/utils/containers/merge_disjoint_maps.h b/lib/utils/include/utils/containers/merge_disjoint_maps.h new file mode 100644 index 0000000000..eccb06180a --- /dev/null +++ b/lib/utils/include/utils/containers/merge_disjoint_maps.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_DISJOINT_MAPS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_DISJOINT_MAPS_H + +#include "utils/containers/binary_merge_disjoint_maps.h" +#include "utils/containers/foldl.h" + +namespace FlexFlow { + +template +std::unordered_map merge_disjoint_maps(C const &c) { + std::unordered_map empty = {}; + return foldl(c, + /*init=*/empty, + [](std::unordered_map const &lhs, + std::unordered_map const &rhs) { + return binary_merge_disjoint_maps(lhs, rhs); + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/merge_in_map.h b/lib/utils/include/utils/containers/merge_in_map.h new file mode 100644 index 0000000000..edae4b8a6a --- /dev/null +++ b/lib/utils/include/utils/containers/merge_in_map.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_IN_MAP_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_IN_MAP_H + +#include + +namespace FlexFlow { + +template +void merge_in_map(std::unordered_map const &m, + std::unordered_map &result) { + for (auto const &[k, v] : m) { + auto it = result.find(k); + if (it != result.end()) { + it->second = v; + } else { + result.insert({k, v}); + } + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/merge_maps.h b/lib/utils/include/utils/containers/merge_maps.h deleted file mode 100644 index bfc2446d99..0000000000 --- a/lib/utils/include/utils/containers/merge_maps.h +++ /dev/null @@ -1,69 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_MAPS_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_MAPS_H - -#include "utils/containers/are_disjoint.h" -#include "utils/containers/keys.h" -#include "utils/containers/merge_method.dtg.h" -#include "utils/exception.h" -#include "utils/fmt/unordered_map.h" -#include "utils/fmt/unordered_set.h" -#include - -namespace FlexFlow { - -template -void merge_in_map(std::unordered_map const &m, - std::unordered_map &result) { - for (auto const &[k, v] : m) { - auto it = result.find(k); - if (it != result.end()) { - it->second = v; - } else { - result.insert({k, v}); - } - } -} - -template -std::unordered_map - merge_disjoint_maps(std::unordered_map const &lhs, - std::unordered_map const &rhs) { - - std::unordered_set lhs_keys = keys(lhs); - std::unordered_set rhs_keys = keys(rhs); - std::unordered_set shared_keys = intersection(lhs_keys, rhs_keys); - if (!shared_keys.empty()) { - throw mk_runtime_error( - fmt::format("merge_maps expected disjoint maps, but maps share keys {}", - shared_keys)); - } - - std::unordered_map result; - merge_in_map(lhs, result); - merge_in_map(rhs, result); - return result; -} - -template -std::unordered_map - merge_map_left_dominates(std::unordered_map const &lhs, - std::unordered_map const &rhs) { - std::unordered_map result; - merge_in_map(rhs, result); - merge_in_map(lhs, result); - return result; -} - -template -std::unordered_map - merge_map_right_dominates(std::unordered_map const &lhs, - std::unordered_map const &rhs) { - std::unordered_map result; - merge_in_map(lhs, result); - merge_in_map(rhs, result); - return result; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/containers/merge_maps_with.h b/lib/utils/include/utils/containers/merge_maps_with.h new file mode 100644 index 0000000000..2f5a09e26e --- /dev/null +++ b/lib/utils/include/utils/containers/merge_maps_with.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_MAPS_WITH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_MAPS_WITH_H + +#include "utils/containers/binary_merge_maps_with.h" +#include "utils/containers/foldl.h" +#include +#include + +namespace FlexFlow { + +template +std::unordered_map + merge_maps_with(std::vector> const &to_merge, + F &&f) { + return foldl(to_merge, + std::unordered_map{}, + [&](std::unordered_map const &accum, + std::unordered_map const &m) { + return binary_merge_maps_with(accum, m, f); + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/merge_maps_with_right_dominating.h b/lib/utils/include/utils/containers/merge_maps_with_right_dominating.h new file mode 100644 index 0000000000..1d4f8536d8 --- /dev/null +++ b/lib/utils/include/utils/containers/merge_maps_with_right_dominating.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_MAPS_WITH_RIGHT_DOMINATING_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_MAPS_WITH_RIGHT_DOMINATING_H + +#include "utils/containers/merge_in_map.h" + +namespace FlexFlow { + +template +std::unordered_map merge_maps_with_right_dominating(C const &c) { + std::unordered_map result; + + for (std::unordered_map const &m : c) { + merge_in_map(m, result); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/merge_method.dtg.toml b/lib/utils/include/utils/containers/merge_method.dtg.toml new file mode 100644 index 0000000000..73c14a41c9 --- /dev/null +++ b/lib/utils/include/utils/containers/merge_method.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "MergeMethod" +type = "enum" +features = [ + "json", + "hash", + "fmt", + "rapidcheck", +] + +[[values]] +name = "REQUIRE_DISJOINT" + +[[values]] +name = "LEFT_DOMINATES" + +[[values]] +name = "RIGHT_DOMINATES" diff --git a/lib/utils/include/utils/containers/merge_method.enum.toml b/lib/utils/include/utils/containers/merge_method.enum.toml deleted file mode 100644 index ec0ed067dd..0000000000 --- a/lib/utils/include/utils/containers/merge_method.enum.toml +++ /dev/null @@ -1,17 +0,0 @@ -namespace = "FlexFlow" -name = "MergeMethod" -features = [ - "json", - "hash", - "fmt", - "rapidcheck", -] - -[[values]] -name = "REQUIRE_DISJOINT" - -[[values]] -name = "LEFT_DOMINATES" - -[[values]] -name = "RIGHT_DOMINATES" diff --git a/lib/utils/include/utils/containers/multiset_of.h b/lib/utils/include/utils/containers/multiset_of.h new file mode 100644 index 0000000000..79bfbc40a3 --- /dev/null +++ b/lib/utils/include/utils/containers/multiset_of.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MULTISET_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MULTISET_OF_H + +#include + +namespace FlexFlow { + +template +std::multiset multiset_of(C const &c) { + return std::multiset{std::cbegin(c), std::cend(c)}; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/permute_with_key.h b/lib/utils/include/utils/containers/permute_with_key.h new file mode 100644 index 0000000000..6579d6d50a --- /dev/null +++ b/lib/utils/include/utils/containers/permute_with_key.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_PERMUTE_WITH_KEY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_PERMUTE_WITH_KEY_H + +#include "utils/containers/product.h" +#include "utils/containers/range.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +template +std::vector permute_with_key(int key, std::vector const &v) { + int max_permutations = 10000; + int reduced_key = key % max_permutations; + + std::vector permutation = range(v.size()); + + for (int i = 0; i < reduced_key; i++) { + std::next_permutation(permutation.begin(), permutation.end()); + } + + return transform(permutation, [&](int permutation_entry) { + return v.at(permutation_entry); + }); +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/require_all_same.h b/lib/utils/include/utils/containers/require_all_same.h new file mode 100644 index 0000000000..9bfebd5494 --- /dev/null +++ b/lib/utils/include/utils/containers/require_all_same.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_ALL_SAME_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_ALL_SAME_H + +#include "utils/containers/require_all_same1.h" +#include +#include + +namespace FlexFlow { + +template +std::optional require_all_same(C const &c) { + if (c.empty()) { + return std::nullopt; + } else { + return require_all_same1(c); + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/require_all_same1.h b/lib/utils/include/utils/containers/require_all_same1.h index 2f42243857..3e210e1c11 100644 --- a/lib/utils/include/utils/containers/require_all_same1.h +++ b/lib/utils/include/utils/containers/require_all_same1.h @@ -2,26 +2,17 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_ALL_SAME1_H #include -#include +#include namespace FlexFlow { template -tl::expected require_all_same1(C const &c) { - if (c.empty()) { - return tl::unexpected(fmt::format( - "require_all_same1 expected non-empty container, but received {}", c)); - } +T require_all_same1(C const &c) { + ASSERT(!c.empty()); T const &first = *c.cbegin(); for (T const &v : c) { - if (v != first) { - return tl::unexpected(fmt::format("require_all_same1 found non-same " - "elements {} and {} in containers {}", - first, - v, - c)); - } + ASSERT(v == first); } return first; } diff --git a/lib/utils/include/utils/containers/require_only_key.h b/lib/utils/include/utils/containers/require_only_key.h new file mode 100644 index 0000000000..c63ff4d440 --- /dev/null +++ b/lib/utils/include/utils/containers/require_only_key.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_ONLY_KEY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_ONLY_KEY_H + +#include "utils/containers/contains_key.h" +#include +#include + +namespace FlexFlow { + +template +V require_only_key(std::unordered_map const &m, K const &k) { + ASSERT(m.size() == 1); + ASSERT(contains_key(m, k)); + + return m.at(k); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/require_same.h b/lib/utils/include/utils/containers/require_same.h index f638e1da1a..2f3439db32 100644 --- a/lib/utils/include/utils/containers/require_same.h +++ b/lib/utils/include/utils/containers/require_same.h @@ -8,14 +8,21 @@ namespace FlexFlow { template T const &require_same(T const &l, T const &r) { - if (l != r) { - throw mk_runtime_error( - fmt::format("require_same received non-equal inputs: {} != {}", l, r)); - } + ASSERT(l == r, "require_same received non-equal inputs"); return l; } +template +T const &require_same(T const &t1, T const &t2, T const &t3) { + return require_same(require_same(t1, t2), t3); +} + +template +T const &require_same(T const &t1, T const &t2, T const &t3, T const &t4) { + return require_same(require_same(require_same(t1, t2), t3), t4); +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/require_two_keys.h b/lib/utils/include/utils/containers/require_two_keys.h new file mode 100644 index 0000000000..8da683c2f0 --- /dev/null +++ b/lib/utils/include/utils/containers/require_two_keys.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_TWO_KEYS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_TWO_KEYS_H + +#include +#include + +namespace FlexFlow { + +template +std::pair require_two_keys(std::unordered_map const &m, + K const &k1, + K const &k2) { + ASSERT(k1 != k2); + ASSERT(m.size() == 2); + + return {m.at(k1), m.at(k2)}; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/scanl.h b/lib/utils/include/utils/containers/scanl.h index a30a9e1576..0d5a4fd7c4 100644 --- a/lib/utils/include/utils/containers/scanl.h +++ b/lib/utils/include/utils/containers/scanl.h @@ -1,14 +1,13 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SCANL_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SCANL_H -#include #include namespace FlexFlow { /** * @brief - * Applies `op` to the elements of `c` from left to right, accumulating + * Applies `f` to the elements of `c` from left to right, accumulating * the intermediate results in a vector. `init` is used as the starting point * for the accumulation. * @@ -23,55 +22,18 @@ namespace FlexFlow { * https://hackage.haskell.org/package/base-4.20.0.1/docs/Prelude.html#v:scanl */ template -std::vector scanl(C const &c, T init, F const &op) { +std::vector scanl(C const &c, T init, F &&f) { std::vector result; result.push_back(init); for (auto const &elem : c) { - init = op(init, elem); + init = f(init, elem); result.push_back(init); } return result; } -/** - * @brief - * Applies `op` to the elements of `c` from left to right, accumulating - * the intermediate results in a vector. The first item of `c` is used as the - * starting point for the accumulation. - * - * @example - * std::vector nums = {1, 2, 3, 4}; - * auto result = scanl1(nums, [](int a, int b) {return a+b;}); - * result -> {1,3,6,10} - * - * @note - * Essentially a foldl1 which stores the intermediate results. - * For more information, see - * https://hackage.haskell.org/package/base-4.20.0.1/docs/Prelude.html#v:scanl1 - */ -template -std::vector scanl1(C const &c, F op) { - - if (c.empty()) { - return std::vector(); - } - - std::optional init = std::nullopt; - std::vector result; - - for (T const &elem : c) { - if (!init.has_value()) { - init = elem; - } else { - init = op(init.value(), elem); - } - result.push_back(init.value()); - } - return result; -} - } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/scanl1.h b/lib/utils/include/utils/containers/scanl1.h new file mode 100644 index 0000000000..a8411cfcba --- /dev/null +++ b/lib/utils/include/utils/containers/scanl1.h @@ -0,0 +1,48 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SCANL1_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SCANL1_H + +#include +#include + +namespace FlexFlow { + +/** + * @brief + * Applies `f` to the elements of `c` from left to right, accumulating + * the intermediate results in a vector. The first item of `c` is used as the + * starting point for the accumulation. + * + * @example + * std::vector nums = {1, 2, 3, 4}; + * auto result = scanl1(nums, [](int a, int b) {return a+b;}); + * result -> {1,3,6,10} + * + * @note + * Essentially a foldl1 which stores the intermediate results. + * For more information, see + * https://hackage.haskell.org/package/base-4.20.0.1/docs/Prelude.html#v:scanl1 + */ +template +std::vector scanl1(C const &c, F &&f) { + + if (c.empty()) { + return std::vector(); + } + + std::optional init = std::nullopt; + std::vector result; + + for (T const &elem : c) { + if (!init.has_value()) { + init = elem; + } else { + init = f(init.value(), elem); + } + result.push_back(init.value()); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/scanr.h b/lib/utils/include/utils/containers/scanr.h new file mode 100644 index 0000000000..03fc94d8c6 --- /dev/null +++ b/lib/utils/include/utils/containers/scanr.h @@ -0,0 +1,40 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SCANR_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SCANR_H + +#include "utils/containers/reversed.h" +#include + +namespace FlexFlow { + +/** + * @brief + * Applies `f` to the elements of `c` from right to left, accumulating + * the intermediate results in a vector. `init` is used as the starting point + * for the accumulation. + * + * @example + * std::vector nums = {1, 2, 3, 4}; + * auto result = scanl(nums, 0, [](int a, int b) {return a+b;}); + * result -> {10, 9, 7, 4, 0} + * + * @note + * Essentially a foldl which stores the intermediate results + * For more information, see + * https://hackage.haskell.org/package/base-4.20.0.1/docs/Prelude.html#v:scan4 + */ +template +std::vector scanr(C const &c, T init, F &&f) { + std::vector result; + + result.push_back(init); + for (auto it = c.crbegin(); it != c.crend(); it++) { + init = f(*it, init); + result.push_back(init); + } + + return reversed(result); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/scanr1.h b/lib/utils/include/utils/containers/scanr1.h new file mode 100644 index 0000000000..7197d2c4ec --- /dev/null +++ b/lib/utils/include/utils/containers/scanr1.h @@ -0,0 +1,49 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SCANR1_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SCANR1_H + +#include "utils/containers/reversed.h" +#include +#include + +namespace FlexFlow { + +/** + * @brief + * Applies `op` to the elements of `c` from right to left, accumulating + * the intermediate results in a vector. The first item of `c` is used as the + * starting point for the accumulation. + * + * @example + * std::vector nums = {1, 2, 3, 4}; + * auto result = scanl1(nums, [](int a, int b) {return a+b;}); + * result -> {10, 9, 7, 4} + * + * @note + * Essentially a foldr1 which stores the intermediate results. + * For more information, see + * https://hackage.haskell.org/package/base-4.20.0.1/docs/Prelude.html#v:scanl1 + */ +template +std::vector scanr1(C const &c, F &&f) { + + if (c.empty()) { + return std::vector(); + } + + std::optional init = std::nullopt; + std::vector result; + + for (auto it = c.crbegin(); it != c.crend(); it++) { + if (!init.has_value()) { + init = *it; + } else { + init = f(*it, init.value()); + } + result.push_back(init.value()); + } + return reversed(result); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/set_union.h b/lib/utils/include/utils/containers/set_union.h index 0f7b895f7a..cd29b1e02e 100644 --- a/lib/utils/include/utils/containers/set_union.h +++ b/lib/utils/include/utils/containers/set_union.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SET_UNION_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SET_UNION_H +#include #include namespace FlexFlow { @@ -13,6 +14,13 @@ std::unordered_set set_union(std::unordered_set const &l, return result; } +template +std::set set_union(std::set const &l, std::set const &r) { + std::set result = l; + result.insert(r.cbegin(), r.cend()); + return result; +} + template std::unordered_set set_union(C const &sets) { std::unordered_set result; diff --git a/lib/utils/include/utils/containers/slice.h b/lib/utils/include/utils/containers/slice.h index a82fb383b5..29a4858a0d 100644 --- a/lib/utils/include/utils/containers/slice.h +++ b/lib/utils/include/utils/containers/slice.h @@ -10,7 +10,7 @@ namespace FlexFlow { template std::vector slice(std::vector const &v, - int const &maybe_start, + int maybe_start, std::optional const &maybe_end) { auto begin_iter = v.cbegin(); auto end_iter = v.cend(); diff --git a/lib/utils/include/utils/containers/transform.h b/lib/utils/include/utils/containers/transform.h index a8a6a749cd..14ef782690 100644 --- a/lib/utils/include/utils/containers/transform.h +++ b/lib/utils/include/utils/containers/transform.h @@ -4,6 +4,7 @@ #include "utils/containers/vector_transform.h" #include "utils/required_core.h" #include +#include #include #include #include @@ -80,6 +81,19 @@ std::unordered_map transform(std::unordered_map const &m, return result; } +template ::first_type, + typename V2 = typename std::invoke_result_t::second_type> +std::unordered_map transform(std::map const &m, F const &f) { + std::unordered_map result; + for (auto const &[k, v] : m) { + result.insert(f(k, v)); + } + return result; +} + template std::optional> transform(std::optional const &o, F &&f) { diff --git a/lib/utils/include/utils/containers/try_at_idx.h b/lib/utils/include/utils/containers/try_at_idx.h new file mode 100644 index 0000000000..7c16efe218 --- /dev/null +++ b/lib/utils/include/utils/containers/try_at_idx.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TRY_AT_IDX_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TRY_AT_IDX_H + +#include "utils/nonnegative_int/nonnegative_int.h" +#include +#include + +namespace FlexFlow { + +template +std::optional try_at_idx(std::vector const &v, nonnegative_int idx) { + if (idx >= v.size()) { + return std::nullopt; + } else { + return v.at(idx.unwrap_nonnegative()); + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/try_get_one_of.h b/lib/utils/include/utils/containers/try_get_one_of.h new file mode 100644 index 0000000000..5749164574 --- /dev/null +++ b/lib/utils/include/utils/containers/try_get_one_of.h @@ -0,0 +1,30 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TRY_GET_ONE_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TRY_GET_ONE_OF_H + +#include +#include +#include + +namespace FlexFlow { + +template +std::optional try_get_one_of(std::unordered_set const &s) { + if (s.empty()) { + return std::nullopt; + } else { + return *s.cbegin(); + } +} + +template +std::optional try_get_one_of(std::set const &s) { + if (s.empty()) { + return std::nullopt; + } else { + return *s.cbegin(); + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/uncurry.h b/lib/utils/include/utils/containers/uncurry.h new file mode 100644 index 0000000000..bd76931ca8 --- /dev/null +++ b/lib/utils/include/utils/containers/uncurry.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNCURRY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNCURRY_H + +#include +#include + +namespace FlexFlow { + +template > +std::function const &)> uncurry(F &&f) { + return [f](std::pair const &p) -> Result { + return f(p.first, p.second); + }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/unordered_set_of.h b/lib/utils/include/utils/containers/unordered_set_of.h index 722ae66d43..74c7683460 100644 --- a/lib/utils/include/utils/containers/unordered_set_of.h +++ b/lib/utils/include/utils/containers/unordered_set_of.h @@ -1,6 +1,8 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNIQUE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNIQUE_H +#include "utils/hash/pair.h" +#include #include namespace FlexFlow { @@ -10,6 +12,16 @@ std::unordered_set unordered_set_of(C const &c) { return std::unordered_set{c.cbegin(), c.cend()}; } +template +std::unordered_set> + unordered_set_of(std::unordered_map const &m) { + std::unordered_set> result; + for (auto const &[k, v] : m) { + result.insert({k, v}); + } + return result; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/unstructured_exhaustive_relational_join.h b/lib/utils/include/utils/containers/unstructured_exhaustive_relational_join.h new file mode 100644 index 0000000000..3894f4099d --- /dev/null +++ b/lib/utils/include/utils/containers/unstructured_exhaustive_relational_join.h @@ -0,0 +1,49 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNSTRUCTURED_EXHAUSTIVE_RELATIONAL_JOIN_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNSTRUCTURED_EXHAUSTIVE_RELATIONAL_JOIN_H + +#include "utils/containers/transform.h" +#include "utils/hash/pair.h" +#include +#include + +namespace FlexFlow { + +template +std::unordered_set> unstructured_exhaustive_relational_join( + std::unordered_set> const &lhs, + std::unordered_set> const &rhs) { + std::unordered_set> result; + + std::unordered_set lhs_ls = + transform(lhs, [](std::pair const &lc) { return lc.first; }); + std::unordered_set lhs_cs = + transform(lhs, [](std::pair const &lc) { return lc.second; }); + std::unordered_set rhs_cs = + transform(rhs, [](std::pair const &cr) { return cr.first; }); + std::unordered_set rhs_rs = + transform(rhs, [](std::pair const &cr) { return cr.second; }); + + ASSERT(lhs_cs == rhs_cs); + + std::unordered_set result_ls; + std::unordered_set result_rs; + + for (auto const &[l, c1] : lhs) { + for (auto const &[c2, r] : rhs) { + if (c1 == c2) { + result.insert({l, r}); + result_ls.insert(l); + result_rs.insert(r); + } + } + } + + ASSERT(result_ls == lhs_ls); + ASSERT(result_rs == rhs_rs); + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/vector_from_idx_map.h b/lib/utils/include/utils/containers/vector_from_idx_map.h new file mode 100644 index 0000000000..cd32385d1f --- /dev/null +++ b/lib/utils/include/utils/containers/vector_from_idx_map.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_FROM_IDX_MAP_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_FROM_IDX_MAP_H + +#include "utils/containers/contains_key.h" +#include "utils/nonnegative_int/nonnegative_int.h" +#include +#include +#include + +namespace FlexFlow { + +template +std::optional> + vector_from_idx_map(std::unordered_map const &m) { + std::vector result; + + for (nonnegative_int i = 0_n; i < m.size(); i++) { + if (!contains_key(m, i)) { + return std::nullopt; + } + result.push_back(m.at(i)); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/zip3_with.h b/lib/utils/include/utils/containers/zip3_with.h new file mode 100644 index 0000000000..fd79c02591 --- /dev/null +++ b/lib/utils/include/utils/containers/zip3_with.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_WITH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_WITH_H + +#include + +namespace FlexFlow { + +template > +std::vector zip3_with(std::vector const &v_a, + std::vector const &v_b, + std::vector const &v_c, + F &&f) { + std::vector result; + for (int i = 0; i < std::min(v_a.size(), std::min(v_b.size(), v_c.size())); + i++) { + result.push_back(f(v_a.at(i), v_b.at(i), v_c.at(i))); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/zip3_with_strict.h b/lib/utils/include/utils/containers/zip3_with_strict.h new file mode 100644 index 0000000000..055ae5a7fe --- /dev/null +++ b/lib/utils/include/utils/containers/zip3_with_strict.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_WITH_STRICT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_WITH_STRICT_H + +#include "utils/containers/zip3_with.h" +#include +#include + +namespace FlexFlow { + +template > +std::vector zip3_with_strict(std::vector const &v_a, + std::vector const &v_b, + std::vector const &v_c, + F &&f) { + ASSERT(v_a.size() == v_b.size() && v_b.size() == v_c.size(), + "zip3_with_strict requires inputs to have the same length, but " + "received mismatched lengths", + v_a.size(), + v_b.size(), + v_c.size()); + + return zip3_with(v_a, v_b, v_c, f); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/zip_values_strict.h b/lib/utils/include/utils/containers/zip_values_strict.h new file mode 100644 index 0000000000..60a7985bc5 --- /dev/null +++ b/lib/utils/include/utils/containers/zip_values_strict.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_VALUES_STRICT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_VALUES_STRICT_H + +#include "utils/containers/generate_map.h" +#include "utils/containers/keys.h" +#include "utils/containers/require_same.h" +#include +#include + +namespace FlexFlow { + +template +std::unordered_map> + zip_values_strict(std::unordered_map const &m1, + std::unordered_map const &m2) { + + ASSERT(keys(m1) == keys(m2)); + + return generate_map(require_same(keys(m1), keys(m2)), [&](K const &k) { + return std::pair{ + m1.at(k), + m2.at(k), + }; + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/zip_values_strict_with.h b/lib/utils/include/utils/containers/zip_values_strict_with.h new file mode 100644 index 0000000000..3b0530db8a --- /dev/null +++ b/lib/utils/include/utils/containers/zip_values_strict_with.h @@ -0,0 +1,30 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_VALUES_STRICT_WITH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_VALUES_STRICT_WITH_H + +#include "utils/containers/generate_map.h" +#include "utils/containers/keys.h" +#include "utils/containers/require_same.h" +#include +#include + +namespace FlexFlow { + +template > +std::unordered_map + zip_values_strict_with(std::unordered_map const &m1, + std::unordered_map const &m2, + F &&f) { + + ASSERT(keys(m1) == keys(m2)); + + return generate_map(require_same(keys(m1), keys(m2)), + [&](K const &k) -> Out { return f(m1.at(k), m2.at(k)); }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/zip_with.h b/lib/utils/include/utils/containers/zip_with.h index 7ae91a7336..2fb54d85a7 100644 --- a/lib/utils/include/utils/containers/zip_with.h +++ b/lib/utils/include/utils/containers/zip_with.h @@ -13,7 +13,8 @@ std::vector zip_with(std::vector const &l, std::vector const &r, F &&f) { std::vector result; for (int i = 0; i < l.size() && i < r.size(); i++) { - result.push_back(f(l.at(i), r.at(i))); + Result elem = f(l.at(i), r.at(i)); + result.push_back(elem); } return result; diff --git a/lib/utils/include/utils/containers/zip_with_strict.h b/lib/utils/include/utils/containers/zip_with_strict.h index fd1e2fa7fd..b9b2e47c84 100644 --- a/lib/utils/include/utils/containers/zip_with_strict.h +++ b/lib/utils/include/utils/containers/zip_with_strict.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_WITH_STRICT_H #include "utils/containers/zip_with.h" -#include "utils/exception.h" #include "utils/fmt/vector.h" +#include #include namespace FlexFlow { @@ -15,15 +15,9 @@ template zip_with_strict(std::vector const &lhs, std::vector const &rhs, F &&f) { - if (lhs.size() != rhs.size()) { - throw mk_runtime_error(fmt::format( - "zip_with_strict requires inputs to have the same length, but received " - "lhs = {} (length {}) and rhs = {} (length {})", - lhs, - lhs.size(), - rhs, - rhs.size())); - } + ASSERT(lhs.size() == rhs.size(), + "zip_with_strict requires inputs to have the same length." + "For a similar function without this requirement, see zip_with."); return zip_with(lhs, rhs, f); } diff --git a/lib/utils/include/utils/exception.h b/lib/utils/include/utils/exception.h index 959edcff8a..610d560c7b 100644 --- a/lib/utils/include/utils/exception.h +++ b/lib/utils/include/utils/exception.h @@ -15,8 +15,7 @@ namespace FlexFlow { "Function " __FUNC__ " not yet implemented " __FILE__ \ ":" __LINE__); #else -#define NOT_IMPLEMENTED() \ - throw not_implemented(__PRETTY_FUNCTION__, __FILE__, __LINE__); +#define NOT_IMPLEMENTED() PANIC("Not implemented"); #endif class not_implemented : public std::logic_error { diff --git a/lib/utils/include/utils/ffi/opaque.h b/lib/utils/include/utils/ffi/opaque.h index bf4f62cca8..87367f9460 100644 --- a/lib/utils/include/utils/ffi/opaque.h +++ b/lib/utils/include/utils/ffi/opaque.h @@ -13,7 +13,7 @@ template struct LibraryUtils { template - using err = expected; + using err = tl::expected; template static err allocate_opaque(T const &t) { diff --git a/lib/utils/include/utils/fmt.h b/lib/utils/include/utils/fmt.h index ee008f7bfe..378b2d07b9 100644 --- a/lib/utils/include/utils/fmt.h +++ b/lib/utils/include/utils/fmt.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_UTILS_INCLUDE_FMT_H #include "utils/check_fmtable.h" -#include "utils/test_types.h" #include "utils/type_traits_core.h" #include #include diff --git a/lib/utils/include/utils/fmt/unordered_set.h b/lib/utils/include/utils/fmt/unordered_set.h index 20a08916fc..e42f528f8c 100644 --- a/lib/utils/include/utils/fmt/unordered_set.h +++ b/lib/utils/include/utils/fmt/unordered_set.h @@ -2,10 +2,12 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_UNORDERED_SET_H #include "utils/check_fmtable.h" +#include "utils/containers/sorted.h" #include "utils/join_strings.h" #include "utils/type_traits_core.h" #include #include +#include namespace fmt { @@ -13,7 +15,30 @@ template struct formatter< ::std::unordered_set, Char, - std::enable_if_t>::value>> + std::enable_if_t<(!detail::has_format_as>::value) && + ::FlexFlow::is_lt_comparable_v>> + : formatter<::std::string> { + template + auto format(::std::unordered_set const &m, FormatContext &ctx) const + -> decltype(ctx.out()) { + CHECK_FMTABLE(T); + + std::vector v = ::FlexFlow::sorted(m); + + std::string result = + ::FlexFlow::join_strings(v.cbegin(), v.cend(), ", ", [](T const &t) { + return fmt::to_string(t); + }); + return formatter::format("{" + result + "}", ctx); + } +}; + +template +struct formatter< + ::std::unordered_set, + Char, + std::enable_if_t<(!detail::has_format_as>::value) && + (!::FlexFlow::is_lt_comparable_v)>> : formatter<::std::string> { template auto format(::std::unordered_set const &m, FormatContext &ctx) const diff --git a/lib/utils/include/utils/for_internal_use_only.h b/lib/utils/include/utils/for_internal_use_only.h deleted file mode 100644 index 2464406459..0000000000 --- a/lib/utils/include/utils/for_internal_use_only.h +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_FOR_INTERNAL_USER_ONLY_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_FOR_INTERNAL_USER_ONLY_H - -namespace FlexFlow { - -struct for_internal_use_only { - explicit for_internal_use_only(); -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/full_binary_tree/binary_tree_path.dtg.toml b/lib/utils/include/utils/full_binary_tree/binary_tree_path.dtg.toml new file mode 100644 index 0000000000..543622c5fe --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/binary_tree_path.dtg.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "BinaryTreePath" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", + "rapidcheck", +] + +includes = [ + "utils/full_binary_tree/binary_tree_path_entry.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "entries" +type = "std::vector<::FlexFlow::BinaryTreePathEntry>" diff --git a/lib/utils/include/utils/full_binary_tree/binary_tree_path.struct.toml b/lib/utils/include/utils/full_binary_tree/binary_tree_path.struct.toml deleted file mode 100644 index 08955c2d75..0000000000 --- a/lib/utils/include/utils/full_binary_tree/binary_tree_path.struct.toml +++ /dev/null @@ -1,24 +0,0 @@ -namespace = "FlexFlow" -name = "BinaryTreePath" -features = [ - "eq", - "ord", - "hash", - "fmt", - "json", - "rapidcheck", -] - -includes = [ - "utils/full_binary_tree/binary_tree_path_entry.dtg.h", - "", -] - -src_includes = [ - "utils/fmt/vector.h", - "utils/hash/vector.h", -] - -[[fields]] -name = "entries" -type = "std::vector<::FlexFlow::BinaryTreePathEntry>" diff --git a/lib/utils/include/utils/full_binary_tree/binary_tree_path_entry.dtg.toml b/lib/utils/include/utils/full_binary_tree/binary_tree_path_entry.dtg.toml new file mode 100644 index 0000000000..c4567a0e87 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/binary_tree_path_entry.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "BinaryTreePathEntry" +type = "enum" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "LEFT_CHILD" +key = "left" + +[[values]] +name = "RIGHT_CHILD" +key = "right" diff --git a/lib/utils/include/utils/full_binary_tree/binary_tree_path_entry.enum.toml b/lib/utils/include/utils/full_binary_tree/binary_tree_path_entry.enum.toml deleted file mode 100644 index 6c81123dcf..0000000000 --- a/lib/utils/include/utils/full_binary_tree/binary_tree_path_entry.enum.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "BinaryTreePathEntry" -features = [ - "hash", - "fmt", - "rapidcheck", - "json", -] - -[[values]] -name = "LEFT_CHILD" -key = "left" - -[[values]] -name = "RIGHT_CHILD" -key = "right" diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_implementation.dtg.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_implementation.dtg.toml new file mode 100644 index 0000000000..c47f358845 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree_implementation.dtg.toml @@ -0,0 +1,34 @@ +namespace = "FlexFlow" +name = "FullBinaryTreeImplementation" +type = "struct" +features = [] + +template_params = [ + "Tree", + "Parent", + "Leaf", +] + +includes = [ + "", +] + +[[fields]] +name = "get_left_child" +type = "std::function" + +[[fields]] +name = "get_right_child" +type = "std::function" + +[[fields]] +name = "is_leaf" +type = "std::function" + +[[fields]] +name = "require_leaf" +type = "std::function" + +[[fields]] +name = "require_parent" +type = "std::function" diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_implementation.struct.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_implementation.struct.toml deleted file mode 100644 index bf08701840..0000000000 --- a/lib/utils/include/utils/full_binary_tree/full_binary_tree_implementation.struct.toml +++ /dev/null @@ -1,33 +0,0 @@ -namespace = "FlexFlow" -name = "FullBinaryTreeImplementation" -features = [] - -template_params = [ - "Tree", - "Parent", - "Leaf", -] - -includes = [ - "", -] - -[[fields]] -name = "get_left_child" -type = "std::function" - -[[fields]] -name = "get_right_child" -type = "std::function" - -[[fields]] -name = "is_leaf" -type = "std::function" - -[[fields]] -name = "require_leaf" -type = "std::function" - -[[fields]] -name = "require_parent" -type = "std::function" diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_node_type.dtg.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_node_type.dtg.toml new file mode 100644 index 0000000000..dc49c0b696 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree_node_type.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "FullBinaryTreeNodeType" +type = "enum" +features = [ + "hash", + "fmt", + "json", + "rapidcheck", +] + +[[values]] +name = "PARENT" +key = "parent" + +[[values]] +name = "LEAF" +key = "leaf" diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_node_type.enum.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_node_type.enum.toml deleted file mode 100644 index 1f8af17cf3..0000000000 --- a/lib/utils/include/utils/full_binary_tree/full_binary_tree_node_type.enum.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "FullBinaryTreeNodeType" -features = [ - "hash", - "fmt", - "json", - "rapidcheck", -] - -[[values]] -name = "PARENT" -key = "parent" - -[[values]] -name = "LEAF" -key = "leaf" diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.dtg.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.dtg.toml new file mode 100644 index 0000000000..96bb067c4a --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.dtg.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "FullBinaryTreeVisitor" +type = "struct" +features = [] + +template_params = [ + "Result", + "Tree", + "Parent", + "Leaf", +] + +includes = [ + "", +] + +[[fields]] +name = "parent_func" +type = "std::function" + +[[fields]] +name = "leaf_func" +type = "std::function" diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml deleted file mode 100644 index 7418d7a016..0000000000 --- a/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml +++ /dev/null @@ -1,22 +0,0 @@ -namespace = "FlexFlow" -name = "FullBinaryTreeVisitor" -features = [] - -template_params = [ - "Result", - "Tree", - "Parent", - "Leaf", -] - -includes = [ - "", -] - -[[fields]] -name = "parent_func" -type = "std::function" - -[[fields]] -name = "leaf_func" -type = "std::function" diff --git a/lib/utils/include/utils/full_binary_tree/get_path_to_leaf_map.h b/lib/utils/include/utils/full_binary_tree/get_path_to_leaf_map.h new file mode 100644 index 0000000000..fd77509e4d --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_path_to_leaf_map.h @@ -0,0 +1,47 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_PATH_TO_LEAF_MAP_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_PATH_TO_LEAF_MAP_H + +#include "utils/containers/binary_merge_disjoint_maps.h" +#include "utils/containers/map_keys.h" +#include "utils/containers/multiset_union.h" +#include "utils/full_binary_tree/binary_tree_path.dtg.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/full_binary_tree/full_binary_tree_visitor.dtg.h" +#include "utils/full_binary_tree/visit.h" +#include + +namespace FlexFlow { + +template +std::unordered_map get_path_to_leaf_map( + Tree const &tree, + FullBinaryTreeImplementation const &impl) { + + auto visitor = FullBinaryTreeVisitor, + Tree, + Parent, + Leaf>{ + [&](Parent const &parent) -> std::unordered_map { + std::unordered_map left_map = map_keys( + get_path_to_leaf_map(impl.get_left_child(parent), impl), + [](BinaryTreePath const &p) { return nest_inside_left_child(p); }); + + std::unordered_map right_map = map_keys( + get_path_to_leaf_map(impl.get_right_child(parent), impl), + [](BinaryTreePath const &p) { return nest_inside_right_child(p); }); + + return binary_merge_disjoint_maps(left_map, right_map); + }, + [](Leaf const &leaf) -> std::unordered_map { + return std::unordered_map{ + {binary_tree_root_path(), leaf}, + }; + }, + }; + + return visit(tree, impl, visitor); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index ca59f997c7..4fdec740f8 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -1,150 +1,54 @@ #ifndef _FLEXFLOW_UTILS_GRAPH_ALGORITHMS_H #define _FLEXFLOW_UTILS_GRAPH_ALGORITHMS_H +#include "utils/dot_file.h" #include "utils/graph/digraph/digraph.h" +#include "utils/graph/graph_split.dtg.h" #include "utils/graph/node/graph.h" #include "utils/graph/undirected/undirected_graph.h" -// #include "utils/graph/open_multidigraph/open_multidigraph.h" -// #include -// "utils/graph/upward_open_multidigraph/upward_open_multidigraph_view.h" -// #include -// "utils/graph/downward_open_multidigraph/downward_open_multidigraph_view.h" -#include "utils/dot_file.h" -#include "utils/graph/graph_split.dtg.h" namespace FlexFlow { std::vector add_nodes(Graph &, int); std::vector add_nodes(UndirectedGraph &, int); std::vector add_nodes(DiGraph &, int); -// std::vector add_nodes(MultiDiGraph &, int); -// std::vector add_nodes(OpenMultiDiGraph &g, int num_nodes); - -// std::unordered_set get_nodes(OpenMultiDiEdge const &); std::unordered_set query_nodes(GraphView const &, std::unordered_set const &); -// void remove_node(MultiDiGraph &, Node const &); void remove_node(DiGraph &, Node const &); void remove_node(UndirectedGraph &, Node const &); -// void remove_node_if_unused(MultiDiGraph &, Node const &); void remove_node_if_unused(DiGraph &, Node const &); void remove_node_if_unused(UndirectedGraph &, Node const &); -// void contract_node_inplace(MultiDiGraph &, Node const &from, Node const -// &into); -void contract_node_inplace(DiGraph &, Node const &from, Node const &into); -void contract_node_inplace(UndirectedGraph &, - Node const &from, - Node const &into); - -// void contract_out_node_inplace(MultiDiGraph &, Node const &); -void contract_out_node_inplace(DiGraph &, Node const &); -void contract_out_node_inplace(UndirectedGraph &, Node const &); - -// MultiDiGraphView contract_out_node(MultiDiGraphView const &, Node const &); -DiGraphView contract_out_node(DiGraphView const &, Node const &); -UndirectedGraphView contract_out_node(UndirectedGraphView const &, - Node const &); - -// MultiDiGraphView -// contract_node(MultiDiGraphView const &, Node const &from, Node const -// &into); -UndirectedGraphView contract_node(UndirectedGraphView const &, - Node const &from, - Node const &into); - -// MultiDiGraphView apply_contraction(MultiDiGraphView const &, -// std::unordered_map const &); -DiGraphView apply_contraction(DiGraphView const &, - std::unordered_map const &); -UndirectedGraphView apply_contraction(UndirectedGraphView const &, - std::unordered_map const &); - bool empty(GraphView const &); -// void add_edges(MultiDiGraph &, std::vector const &); void add_edges(DiGraph &, std::vector const &); void add_edges(UndirectedGraph &, std::vector const &); bool contains_node(GraphView const &, Node const &); -// bool contains_edge(MultiDiGraphView const &, MultiDiEdge const &); bool contains_edge(DiGraphView const &, DirectedEdge const &); bool contains_edge(UndirectedGraphView const &, UndirectedEdge const &); -// void remove_edges(MultiDiGraph &, std::unordered_set const &); void remove_edges(DiGraph &, std::unordered_set const &); void remove_edges(UndirectedGraph &, std::vector const &); -std::unordered_set get_endpoints(UndirectedEdge const &); - -// std::unordered_set get_edges(MultiDiGraphView const &); -std::unordered_set get_edges(DiGraphView const &); std::unordered_set get_edges(UndirectedGraphView const &); -// std::unordered_set -// get_edges(UpwardOpenMultiDiGraphView const &); -// std::unordered_set -// get_edges(DownwardOpenMultiDiGraphView const &); -// std::unordered_set get_edges(OpenMultiDiGraphView const &); std::unordered_set get_node_edges(UndirectedGraphView const &, Node const &); -// std::unordered_set -// get_open_outputs(OpenMultiDiGraphView const &); -// std::unordered_set -// get_open_inputs(OpenMultiDiGraphView const &); - -// std::unordered_set -// get_incoming_edges(UpwardOpenMultiDiGraphView const &, Node const &); -// std::unordered_set -// get_incoming_edges(DownwardOpenMultiDiGraphView const &, Node const &); -// std::unordered_set -// get_incoming_edges(OpenMultiDiGraphView const &, Node const &); - -// std::unordered_set get_incoming_edges(MultiDiGraphView const &, -// std::unordered_set); - -// std::unordered_set get_outgoing_edges(MultiDiGraphView const &, -// Node const &); -// std::unordered_set -// get_outgoing_edges(UpwardOpenMultiDiGraphView const &, Node const &); -// std::unordered_set -// get_outgoing_edges(DownwardOpenMultiDiGraphView const &, Node const &); -// std::unordered_set -// get_outgoing_edges(OpenMultiDiGraphView const &, Node const &); - -// std::unordered_set -// get_outgoing_edges(MultiDiGraphView const &, -// std::unordered_set const &); - std::unordered_set get_node_edges(UndirectedGraphView const &, Node const &); std::unordered_set get_node_edges(UndirectedGraphView const &, std::unordered_set const &); -// Node get_src_node(MultiDiEdge const &); -// Node get_dst_node(MultiDiEdge const &); -// Node get_dst_node(InputMultiDiEdge const &); -// Node get_src_node(OutputMultiDiEdge const &); - std::unordered_set get_neighbors(UndirectedGraphView const &, Node const &); std::unordered_set get_neighbors(DiGraphView const &, Node const &); -// std::unordered_set get_neighbors(MultiDiGraphView const &, Node const -// &); - -// std::unordered_set get_closed_sources(OpenMultiDiGraphView const &g); -// std::unordered_set get_closed_sinks(OpenMultiDiGraphView const &g); -// std::unordered_set get_open_sources(OpenMultiDiGraphView const &g); -// std::unordered_set get_open_sinks(OpenMultiDiGraphView const &g); - -// std::optional get_imm_post_dominator(MultiDiGraphView const &, -// Node const &); std::vector get_dfs_ordering(DiGraphView const &, @@ -155,47 +59,15 @@ std::vector std::vector get_bfs_ordering(DiGraphView const &, std::unordered_set const &starting_points); -std::vector get_topological_ordering(DiGraphView const &); std::vector get_unchecked_topological_ordering(DiGraphView const &); -std::vector get_edge_topological_ordering(DiGraphView const &); -// std::vector -// get_edge_topological_ordering(MultiDiGraphView const &); - -// std::unordered_set> -// get_weakly_connected_components(MultiDiGraphView const &); - std::unordered_set get_transitive_reduction_delta(DiGraphView const &); -// std::pair split_edge(MultiDiEdge const -// &e); MultiDiEdge unsplit_edge(OutputMultiDiEdge const &, InputMultiDiEdge -// const &); - -// std::unordered_set get_cut_set(MultiDiGraphView const &, -// GraphSplit const &); - -// std::unordered_set get_cut_set(MultiDiGraphView const &, -// std::unordered_set const -// &); - -// bidict> -// get_edge_splits(MultiDiGraphView const &, GraphSplit const &); - UndirectedGraphView get_subgraph(UndirectedGraphView const &, std::unordered_set const &); DiGraphView get_subgraph(DiGraphView const &, std::unordered_set const &); -// MultiDiGraphView get_subgraph(MultiDiGraphView const &, -// std::unordered_set const &); - -// template -// OpenMultiDiGraphView get_subgraph(OpenMultiDiGraphView const &g, -// std::unordered_set const &nodes) { -// return OpenMultiDiGraphView::create(g, nodes); -// } -// MultiDiGraphView join(MultiDiGraphView const &lhs, MultiDiGraphView const -// &rhs); DiGraphView join(DiGraphView const &lhs, DiGraphView const &rhs); UndirectedGraphView join(UndirectedGraphView const &lhs, UndirectedGraphView const &rhs); @@ -206,9 +78,7 @@ DiGraphView with_added_edges(DiGraphView const &, std::unordered_set const &); UndirectedGraphView as_undirected(DiGraphView const &); -// MultiDiGraphView as_multidigraph(DiGraphView const &); DiGraphView as_digraph(UndirectedGraphView const &); -// OpenMultiDiGraphView as_openmultidigraph(MultiDiGraphView const &); void export_as_dot( DotFile &, diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graph_isomorphism.dtg.toml b/lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graph_isomorphism.dtg.toml new file mode 100644 index 0000000000..78efd9fba2 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graph_isomorphism.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "DataflowGraphIsomorphism" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/bidict/bidict.h", + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "node_mapping" +type = "::FlexFlow::bidict<::FlexFlow::Node, ::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graph_isomorphism.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graph_isomorphism.struct.toml deleted file mode 100644 index 082c25f6ea..0000000000 --- a/lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graph_isomorphism.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "DataflowGraphIsomorphism" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "utils/bidict/bidict.h", - "utils/graph/node/node.dtg.h", -] - -[[fields]] -name = "node_mapping" -type = "::FlexFlow::bidict<::FlexFlow::Node, ::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h index be0e57435a..0296bf545e 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_BOUNDARY_NODES_FOR_SPLIT_H #include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.dtg.h" -#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph_view.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h index e53bb876a1..09135acb51 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_EDGES_ACROSS_SPLIT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_EDGES_ACROSS_SPLIT_H -#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph_view.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h index ad8eadda0e..00b213845d 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_OUTPUTS_ACROSS_SPLIT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_OUTPUTS_ACROSS_SPLIT_H -#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph_view.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.dtg.toml b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.dtg.toml new file mode 100644 index 0000000000..6b23df3f4f --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "SplitBoundaryNodes" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "pre_split_boundary" +type = "std::unordered_set<::FlexFlow::Node>" + +[[fields]] +name = "post_split_boundary" +type = "std::unordered_set<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.struct.toml deleted file mode 100644 index 32582a6b74..0000000000 --- a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.struct.toml +++ /dev/null @@ -1,25 +0,0 @@ -namespace = "FlexFlow" -name = "SplitBoundaryNodes" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "utils/graph/node/node.dtg.h", - "", -] - -src_includes = [ - "utils/hash/unordered_set.h", - "utils/fmt/unordered_set.h", -] - -[[fields]] -name = "pre_split_boundary" -type = "std::unordered_set<::FlexFlow::Node>" - -[[fields]] -name = "post_split_boundary" -type = "std::unordered_set<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h deleted file mode 100644 index 916e8f7896..0000000000 --- a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_H - -#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" - -namespace FlexFlow { - -TransitiveReducedDataflowGraphView - get_dataflow_graph_transitive_reduction(DataflowGraphView const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.struct.toml deleted file mode 100644 index 54c710b26e..0000000000 --- a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.struct.toml +++ /dev/null @@ -1,17 +0,0 @@ -namespace = "FlexFlow" -name = "TransitiveReducedDataflowGraphView" -features = [] - -includes = [ - "utils/graph/dataflow_graph/dataflow_graph_view.h", - "utils/graph/digraph/digraph_view.h", -] - -[[fields]] -name = "full_dataflow_graph" -type = "::FlexFlow::DataflowGraphView" - -[[fields]] -name = "transitive_reduction" -type = "::FlexFlow::DiGraphView" - diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph_view.dtg.toml b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph_view.dtg.toml new file mode 100644 index 0000000000..1734d07112 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph_view.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "TransitiveReducedDataflowGraphView" +type = "struct" +features = [] + +includes = [ + "utils/graph/dataflow_graph/dataflow_graph_view.h", + "utils/graph/digraph/digraph_view.h", +] + +[[fields]] +name = "full_dataflow_graph" +type = "::FlexFlow::DataflowGraphView" + +[[fields]] +name = "transitive_reduction" +type = "::FlexFlow::DiGraphView" + diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph_view.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph_view.h new file mode 100644 index 0000000000..6c2c8927a7 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph_view.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_VIEW_H + +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph_view.dtg.h" + +namespace FlexFlow { + +TransitiveReducedDataflowGraphView + get_dataflow_graph_transitive_reduction(DataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge.dtg.toml b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge.dtg.toml new file mode 100644 index 0000000000..7a73c1a8aa --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge.dtg.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "DataflowEdge" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/dataflow_graph/dataflow_output.dtg.h", + "utils/graph/dataflow_graph/dataflow_input.dtg.h", +] + +[[fields]] +name = "src" +type = "::FlexFlow::DataflowOutput" + +[[fields]] +name = "dst" +type = "::FlexFlow::DataflowInput" diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge.struct.toml deleted file mode 100644 index a3237dde09..0000000000 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge.struct.toml +++ /dev/null @@ -1,21 +0,0 @@ -namespace = "FlexFlow" -name = "DataflowEdge" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/dataflow_graph/dataflow_output.dtg.h", - "utils/graph/dataflow_graph/dataflow_input.dtg.h", -] - -[[fields]] -name = "src" -type = "::FlexFlow::DataflowOutput" - -[[fields]] -name = "dst" -type = "::FlexFlow::DataflowInput" diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.dtg.toml b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.dtg.toml new file mode 100644 index 0000000000..bf4f4e80e8 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.dtg.toml @@ -0,0 +1,31 @@ +namespace = "FlexFlow" +name = "DataflowEdgeQuery" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/query_set.h", + "utils/graph/node/node.dtg.h", + "utils/nonnegative_int/nonnegative_int.h", +] + +[[fields]] +name = "src_nodes" +type = "::FlexFlow::query_set<::FlexFlow::Node>" + +[[fields]] +name = "src_idxs" +type = "::FlexFlow::query_set<::FlexFlow::nonnegative_int>" + +[[fields]] +name = "dst_nodes" +type = "::FlexFlow::query_set<::FlexFlow::Node>" + +[[fields]] +name = "dst_idxs" +type = "::FlexFlow::query_set<::FlexFlow::nonnegative_int>" diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.struct.toml deleted file mode 100644 index aed0c28aeb..0000000000 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.struct.toml +++ /dev/null @@ -1,30 +0,0 @@ -namespace = "FlexFlow" -name = "DataflowEdgeQuery" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/query_set.h", - "utils/graph/node/node.dtg.h", - "utils/nonnegative_int/nonnegative_int.h", -] - -[[fields]] -name = "src_nodes" -type = "::FlexFlow::query_set<::FlexFlow::Node>" - -[[fields]] -name = "src_idxs" -type = "::FlexFlow::query_set<::FlexFlow::nonnegative_int>" - -[[fields]] -name = "dst_nodes" -type = "::FlexFlow::query_set<::FlexFlow::Node>" - -[[fields]] -name = "dst_idxs" -type = "::FlexFlow::query_set<::FlexFlow::nonnegative_int>" diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_input.dtg.toml b/lib/utils/include/utils/graph/dataflow_graph/dataflow_input.dtg.toml new file mode 100644 index 0000000000..8169d1f736 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_input.dtg.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "DataflowInput" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", + "utils/nonnegative_int/nonnegative_int.h", +] + +[[fields]] +name = "node" +type = "::FlexFlow::Node" + +[[fields]] +name = "idx" +type = "::FlexFlow::nonnegative_int" diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_input.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/dataflow_input.struct.toml deleted file mode 100644 index eb9c30d558..0000000000 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_input.struct.toml +++ /dev/null @@ -1,21 +0,0 @@ -namespace = "FlexFlow" -name = "DataflowInput" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/node/node.dtg.h", - "utils/nonnegative_int/nonnegative_int.h", -] - -[[fields]] -name = "node" -type = "::FlexFlow::Node" - -[[fields]] -name = "idx" -type = "::FlexFlow::nonnegative_int" diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output.dtg.toml b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output.dtg.toml new file mode 100644 index 0000000000..dee7152aa2 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output.dtg.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "DataflowOutput" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", + "utils/nonnegative_int/nonnegative_int.h", +] + +[[fields]] +name = "node" +type = "::FlexFlow::Node" + +[[fields]] +name = "idx" +type = "::FlexFlow::nonnegative_int" diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output.struct.toml deleted file mode 100644 index 19d92a3d4c..0000000000 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output.struct.toml +++ /dev/null @@ -1,21 +0,0 @@ -namespace = "FlexFlow" -name = "DataflowOutput" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/node/node.dtg.h", - "utils/nonnegative_int/nonnegative_int.h", -] - -[[fields]] -name = "node" -type = "::FlexFlow::Node" - -[[fields]] -name = "idx" -type = "::FlexFlow::nonnegative_int" diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.dtg.toml b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.dtg.toml new file mode 100644 index 0000000000..2bd1068f4f --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "DataflowOutputQuery" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/query_set.h", + "utils/graph/node/node.dtg.h", + "utils/nonnegative_int/nonnegative_int.h", +] + +src_includes = [ + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "nodes" +type = "::FlexFlow::query_set<::FlexFlow::Node>" + +[[fields]] +name = "output_idxs" +type = "::FlexFlow::query_set<::FlexFlow::nonnegative_int>" diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.struct.toml deleted file mode 100644 index d1af6d5c0d..0000000000 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.struct.toml +++ /dev/null @@ -1,26 +0,0 @@ -namespace = "FlexFlow" -name = "DataflowOutputQuery" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/query_set.h", - "utils/graph/node/node.dtg.h", - "utils/nonnegative_int/nonnegative_int.h", -] - -src_includes = [ - "utils/fmt/unordered_set.h", -] - -[[fields]] -name = "nodes" -type = "::FlexFlow::query_set<::FlexFlow::Node>" - -[[fields]] -name = "output_idxs" -type = "::FlexFlow::query_set<::FlexFlow::nonnegative_int>" diff --git a/lib/utils/include/utils/graph/dataflow_graph/node_added_result.dtg.toml b/lib/utils/include/utils/graph/dataflow_graph/node_added_result.dtg.toml new file mode 100644 index 0000000000..4093006906 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/node_added_result.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "NodeAddedResult" +type = "struct" +features = [ + "eq", + "ord", + "fmt", +] + +includes = [ + "", + "utils/graph/node/node.dtg.h", + "utils/graph/dataflow_graph/dataflow_output.dtg.h", +] + +src_includes = [ + "utils/fmt/vector.h", +] + +[[fields]] +name = "node" +type = "::FlexFlow::Node" + +[[fields]] +name = "outputs" +type = "std::vector<::FlexFlow::DataflowOutput>" diff --git a/lib/utils/include/utils/graph/dataflow_graph/node_added_result.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/node_added_result.struct.toml deleted file mode 100644 index df0d601530..0000000000 --- a/lib/utils/include/utils/graph/dataflow_graph/node_added_result.struct.toml +++ /dev/null @@ -1,23 +0,0 @@ -namespace = "FlexFlow" -name = "NodeAddedResult" - -features = [ - "eq", - "ord", - "fmt", -] - -includes = [ - "", - "utils/graph/node/node.dtg.h", - "utils/fmt/vector.h", - "utils/graph/dataflow_graph/dataflow_output.dtg.h", -] - -[[fields]] -name = "node" -type = "::FlexFlow::Node" - -[[fields]] -name = "outputs" -type = "std::vector<::FlexFlow::DataflowOutput>" diff --git a/lib/utils/include/utils/graph/digraph/algorithms.h b/lib/utils/include/utils/graph/digraph/algorithms.h deleted file mode 100644 index 67cfba13ff..0000000000 --- a/lib/utils/include/utils/graph/digraph/algorithms.h +++ /dev/null @@ -1,22 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_H - -#include "utils/graph/digraph/digraph.h" - -namespace FlexFlow { - -std::unordered_set get_edges(DiGraphView const &); - -/** - * @brief Returns the set of nodes in the graph with no incoming edges. - */ -std::unordered_set get_initial_nodes(DiGraphView const &graph); - -/** - * @brief Returns the set of nodes in the graph with no outgoing edges. - */ -std::unordered_set get_terminal_nodes(DiGraphView const &graph); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/bipartite_component.dtg.toml b/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/bipartite_component.dtg.toml new file mode 100644 index 0000000000..be7c58a9f6 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/bipartite_component.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "BipartiteComponent" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "utils/graph/node/node.dtg.h", +] + +src_includes = [ + "utils/fmt/unordered_set.h", + "utils/hash/unordered_set.h", +] + +[[fields]] +name = "head_nodes" +type = "std::unordered_set<::FlexFlow::Node>" + +[[fields]] +name = "tail_nodes" +type = "std::unordered_set<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/bipartite_component.struct.toml b/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/bipartite_component.struct.toml deleted file mode 100644 index 92732f0d89..0000000000 --- a/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/bipartite_component.struct.toml +++ /dev/null @@ -1,25 +0,0 @@ -namespace = "FlexFlow" -name = "BipartiteComponent" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "", - "utils/graph/node/node.dtg.h", -] - -src_includes = [ - "utils/fmt/unordered_set.h", - "utils/hash/unordered_set.h", -] - -[[fields]] -name = "head_nodes" -type = "std::unordered_set<::FlexFlow::Node>" - -[[fields]] -name = "tail_nodes" -type = "std::unordered_set<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/complete_bipartite_composite_decomposition.dtg.toml b/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/complete_bipartite_composite_decomposition.dtg.toml new file mode 100644 index 0000000000..386466d62e --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/complete_bipartite_composite_decomposition.dtg.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "CompleteBipartiteCompositeDecomposition" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/digraph/algorithms/complete_bipartite_composite/bipartite_component.dtg.h", +] + +src_includes = [ + "utils/fmt/unordered_set.h", + "utils/hash/unordered_set.h", +] + +[[fields]] +name = "subgraphs" +type = "std::unordered_set<::FlexFlow::BipartiteComponent>" diff --git a/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/complete_bipartite_composite_decomposition.struct.toml b/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/complete_bipartite_composite_decomposition.struct.toml deleted file mode 100644 index d0274799c4..0000000000 --- a/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/complete_bipartite_composite_decomposition.struct.toml +++ /dev/null @@ -1,20 +0,0 @@ -namespace = "FlexFlow" -name = "CompleteBipartiteCompositeDecomposition" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "utils/graph/digraph/algorithms/complete_bipartite_composite/bipartite_component.dtg.h", -] - -src_includes = [ - "utils/fmt/unordered_set.h", - "utils/hash/unordered_set.h", -] - -[[fields]] -name = "subgraphs" -type = "std::unordered_set<::FlexFlow::BipartiteComponent>" diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_edge_topological_ordering.h b/lib/utils/include/utils/graph/digraph/algorithms/get_edge_topological_ordering.h new file mode 100644 index 0000000000..a67fd716c5 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_edge_topological_ordering.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_EDGE_TOPOLOGICAL_ORDERING_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_EDGE_TOPOLOGICAL_ORDERING_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +std::vector get_edge_topological_ordering(DiGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_edges.h b/lib/utils/include/utils/graph/digraph/algorithms/get_edges.h new file mode 100644 index 0000000000..d4340f6a0b --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_edges.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_EDGES_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +std::unordered_set get_edges(DiGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_initial_nodes.h b/lib/utils/include/utils/graph/digraph/algorithms/get_initial_nodes.h new file mode 100644 index 0000000000..bf907bac52 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_initial_nodes.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_INITIAL_NODES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_INITIAL_NODES_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +/** + * @brief Returns the set of nodes in the graph with no incoming edges. + */ +std::unordered_set get_initial_nodes(DiGraphView const &graph); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_terminal_nodes.h b/lib/utils/include/utils/graph/digraph/algorithms/get_terminal_nodes.h new file mode 100644 index 0000000000..3c8620e134 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_terminal_nodes.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_TERMINAL_NODES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_TERMINAL_NODES_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +/** + * @brief Returns the set of nodes in the graph with no outgoing edges. + */ +std::unordered_set get_terminal_nodes(DiGraphView const &graph); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/inverse_line_graph/inverse_line_graph_result.dtg.toml b/lib/utils/include/utils/graph/digraph/algorithms/inverse_line_graph/inverse_line_graph_result.dtg.toml new file mode 100644 index 0000000000..4e363dbd8b --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/inverse_line_graph/inverse_line_graph_result.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "InverseLineGraphResult" +type = "struct" +features = [ ] + +includes = [ + "utils/graph/multidigraph/multidigraph_view.h", + "utils/bidict/bidict.h", +] + +[[fields]] +name = "graph" +type = "::FlexFlow::MultiDiGraphView" + +[[fields]] +name = "inverse_edge_to_line_node_bidict" +type = "::FlexFlow::bidict<::FlexFlow::MultiDiEdge, ::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/digraph/algorithms/inverse_line_graph/inverse_line_graph_result.struct.toml b/lib/utils/include/utils/graph/digraph/algorithms/inverse_line_graph/inverse_line_graph_result.struct.toml deleted file mode 100644 index 59a6f02429..0000000000 --- a/lib/utils/include/utils/graph/digraph/algorithms/inverse_line_graph/inverse_line_graph_result.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "InverseLineGraphResult" -features = [ ] - -includes = [ - "utils/graph/multidigraph/multidigraph_view.h", - "utils/bidict/bidict.h", -] - -[[fields]] -name = "graph" -type = "::FlexFlow::MultiDiGraphView" - -[[fields]] -name = "inverse_edge_to_line_node_bidict" -type = "::FlexFlow::bidict<::FlexFlow::MultiDiEdge, ::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/digraph/di_input.dtg.toml b/lib/utils/include/utils/graph/digraph/di_input.dtg.toml new file mode 100644 index 0000000000..b2c47b01dd --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/di_input.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "DiInput" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "dst" +type = "::FlexFlow::Node" diff --git a/lib/utils/include/utils/graph/digraph/di_input.struct.toml b/lib/utils/include/utils/graph/digraph/di_input.struct.toml deleted file mode 100644 index 1bd11e069c..0000000000 --- a/lib/utils/include/utils/graph/digraph/di_input.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "DiInput" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/node/node.dtg.h", -] - -[[fields]] -name = "dst" -type = "::FlexFlow::Node" diff --git a/lib/utils/include/utils/graph/digraph/di_output.dtg.toml b/lib/utils/include/utils/graph/digraph/di_output.dtg.toml new file mode 100644 index 0000000000..4a48333e8b --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/di_output.dtg.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "DiOutput" +type = "struct" +features = [ + "eq", + "ord", + "hash", +] + +includes = [ + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "src" +type = "::FlexFlow::Node" diff --git a/lib/utils/include/utils/graph/digraph/di_output.struct.toml b/lib/utils/include/utils/graph/digraph/di_output.struct.toml deleted file mode 100644 index 27a71743f6..0000000000 --- a/lib/utils/include/utils/graph/digraph/di_output.struct.toml +++ /dev/null @@ -1,15 +0,0 @@ -namespace = "FlexFlow" -name = "DiOutput" -features = [ - "eq", - "ord", - "hash", -] - -includes = [ - "utils/graph/node/node.dtg.h", -] - -[[fields]] -name = "src" -type = "::FlexFlow::Node" diff --git a/lib/utils/include/utils/graph/digraph/directed_edge.dtg.toml b/lib/utils/include/utils/graph/digraph/directed_edge.dtg.toml new file mode 100644 index 0000000000..ebabc4e93b --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/directed_edge.dtg.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "DirectedEdge" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "src" +type = "::FlexFlow::Node" + +[[fields]] +name = "dst" +type = "::FlexFlow::Node" diff --git a/lib/utils/include/utils/graph/digraph/directed_edge.struct.toml b/lib/utils/include/utils/graph/digraph/directed_edge.struct.toml deleted file mode 100644 index 9c17bb0325..0000000000 --- a/lib/utils/include/utils/graph/digraph/directed_edge.struct.toml +++ /dev/null @@ -1,20 +0,0 @@ -namespace = "FlexFlow" -name = "DirectedEdge" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/node/node.dtg.h", -] - -[[fields]] -name = "src" -type = "::FlexFlow::Node" - -[[fields]] -name = "dst" -type = "::FlexFlow::Node" diff --git a/lib/utils/include/utils/graph/digraph/directed_edge_query.dtg.toml b/lib/utils/include/utils/graph/digraph/directed_edge_query.dtg.toml new file mode 100644 index 0000000000..1a268d9938 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/directed_edge_query.dtg.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "DirectedEdgeQuery" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/query_set.h", + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "srcs" +type = "::FlexFlow::query_set<::FlexFlow::Node>" + +[[fields]] +name = "dsts" +type = "::FlexFlow::query_set<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/digraph/directed_edge_query.struct.toml b/lib/utils/include/utils/graph/digraph/directed_edge_query.struct.toml deleted file mode 100644 index 3447cdb4b6..0000000000 --- a/lib/utils/include/utils/graph/digraph/directed_edge_query.struct.toml +++ /dev/null @@ -1,21 +0,0 @@ -namespace = "FlexFlow" -name = "DirectedEdgeQuery" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/query_set.h", - "utils/graph/node/node.dtg.h", -] - -[[fields]] -name = "srcs" -type = "::FlexFlow::query_set<::FlexFlow::Node>" - -[[fields]] -name = "dsts" -type = "::FlexFlow::query_set<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/graph_split.dtg.toml b/lib/utils/include/utils/graph/graph_split.dtg.toml new file mode 100644 index 0000000000..05624318c9 --- /dev/null +++ b/lib/utils/include/utils/graph/graph_split.dtg.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "GraphSplit" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "utils/graph/node/node.dtg.h", + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "first" +type = "std::unordered_set<::FlexFlow::Node>" + +[[fields]] +name = "second" +type = "std::unordered_set<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/graph_split.struct.toml b/lib/utils/include/utils/graph/graph_split.struct.toml deleted file mode 100644 index 1f393a9318..0000000000 --- a/lib/utils/include/utils/graph/graph_split.struct.toml +++ /dev/null @@ -1,22 +0,0 @@ -namespace = "FlexFlow" -name = "GraphSplit" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "", - "utils/graph/node/node.dtg.h", - "utils/hash/unordered_set.h", - "utils/fmt/unordered_set.h", -] - -[[fields]] -name = "first" -type = "std::unordered_set<::FlexFlow::Node>" - -[[fields]] -name = "second" -type = "std::unordered_set<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/instances/adjacency_digraph.h b/lib/utils/include/utils/graph/instances/adjacency_digraph.h index 5ff2eff876..b6ae76fd55 100644 --- a/lib/utils/include/utils/graph/instances/adjacency_digraph.h +++ b/lib/utils/include/utils/graph/instances/adjacency_digraph.h @@ -21,9 +21,6 @@ class AdjacencyDiGraph : public IDiGraph { query_edges(DirectedEdgeQuery const &) const override; std::unordered_set query_nodes(NodeQuery const &) const override; - // bool operator==(AdjacencyDiGraph const &) const; - // bool operator!=(AdjacencyDiGraph const & const; - AdjacencyDiGraph *clone() const override; private: diff --git a/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h b/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h new file mode 100644 index 0000000000..2b20b94c96 --- /dev/null +++ b/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h @@ -0,0 +1,238 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_TASK_SET_OPEN_KWARG_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_TASK_SET_OPEN_KWARG_DATAFLOW_GRAPH_H + +#include "utils/containers/contains_key.h" +#include "utils/containers/enumerate.h" +#include "utils/containers/extend.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/map_values.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_edges.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_outputs.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_node_added_result.dtg.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/i_labelled_open_kwarg_dataflow_graph.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/i_labelled_open_kwarg_dataflow_graph_view.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/node/node_source.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_graph_inputs.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_open_kwarg_dataflow_edges.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +struct UnorderedSetLabelledOpenKwargDataflowGraph final + : public ILabelledOpenKwargDataflowGraph, + public ILabelledKwargDataflowGraph { +public: + UnorderedSetLabelledOpenKwargDataflowGraph() = default; + + KwargNodeAddedResult add_node( + NodeLabel const &node_label, + std::unordered_map> const &inputs, + std::unordered_map const &output_labels) override { + return this->add_node( + node_label, + map_values(inputs, + [](KwargDataflowOutput const &o) { + return OpenKwargDataflowValue{o}; + }), + output_labels); + }; + + KwargNodeAddedResult add_node( + NodeLabel const &node_label, + std::unordered_map> const + &inputs, + std::unordered_map const &output_labels) override { + Node new_node = this->node_source.new_node(); + this->nodes.insert({new_node, node_label}); + + for (auto const &[input_slot_name, input_val] : inputs) { + KwargDataflowInput dst = KwargDataflowInput{ + new_node, + input_slot_name, + }; + + OpenKwargDataflowEdge in_edge = + mk_open_kwarg_dataflow_edge_from_src_val_and_dst(input_val, dst); + + this->edges.insert(in_edge); + } + + std::unordered_map> outputs = + generate_map( + keys(output_labels), + [&](SlotName const &output_slot) -> KwargDataflowOutput { + ValueLabel value_label = output_labels.at(output_slot); + + KwargDataflowOutput output = + KwargDataflowOutput{ + /*node=*/new_node, + /*slot_name=*/output_slot, + }; + + this->outputs.insert({ + output, + value_label, + }); + + return output; + }); + + return KwargNodeAddedResult{ + /*node=*/new_node, + /*outputs=*/outputs, + }; + } + + KwargDataflowGraphInput + add_input(GraphInputName const &name, + ValueLabel const &value_label) override { + KwargDataflowGraphInput input = + KwargDataflowGraphInput{name}; + + ASSERT(!contains_key(this->graph_inputs, input)); + this->graph_inputs.insert({input, value_label}); + + return input; + } + + std::unordered_set query_nodes(NodeQuery const &q) const override { + return filter(keys(this->nodes), + [&](Node const &n) { return includes(q.nodes, n); }); + } + + std::unordered_set> + query_edges(OpenKwargDataflowEdgeQuery const &q) + const override { + return filter( + this->edges, + [&](OpenKwargDataflowEdge const &e) { + return open_kwarg_dataflow_edge_query_includes(q, e); + }); + } + + std::unordered_set> query_outputs( + KwargDataflowOutputQuery const &q) const override { + return filter(keys(this->outputs), + [&](KwargDataflowOutput const &output) { + return kwarg_dataflow_output_query_includes(q, output); + }); + } + + std::unordered_set> + get_inputs() const override { + return keys(this->graph_inputs); + } + + NodeLabel at(Node const &n) const override { + return this->nodes.at(n); + } + + ValueLabel at(OpenKwargDataflowValue const &v) + const override { + return v.template visit(overload{ + [&](KwargDataflowOutput const &o) -> ValueLabel { + return this->outputs.at(o); + }, + [&](KwargDataflowGraphInput const &gi) -> ValueLabel { + return this->graph_inputs.at(gi); + }}); + } + + void inplace_materialize_from( + LabelledKwargDataflowGraphView const + &view) override { + std::unordered_set view_nodes = get_nodes(view); + std::unordered_set> view_edges = + get_all_kwarg_dataflow_edges(view); + std::unordered_set> view_outputs = + get_all_kwarg_dataflow_outputs(view); + + this->graph_inputs.clear(); + this->nodes = + generate_map(view_nodes, [&](Node const &n) { return view.at(n); }); + + this->edges = + transform(view_edges, + [&](KwargDataflowEdge const &e) + -> OpenKwargDataflowEdge { + return OpenKwargDataflowEdge{e}; + }); + this->outputs = + generate_map(view_outputs, [&](KwargDataflowOutput const &o) { + return view.at(o); + }); + } + + void inplace_materialize_from( + LabelledOpenKwargDataflowGraphView const &view) override { + std::unordered_set> view_inputs = + get_all_kwarg_dataflow_graph_inputs(view); + std::unordered_set view_nodes = get_nodes(view); + std::unordered_set> + view_edges = get_all_open_kwarg_dataflow_edges(view); + std::unordered_set> view_outputs = + get_all_kwarg_dataflow_outputs(view); + + this->graph_inputs = generate_map( + view_inputs, [&](KwargDataflowGraphInput const &i) { + return view.at(OpenKwargDataflowValue{i}); + }); + this->nodes = + generate_map(view_nodes, [&](Node const &n) { return view.at(n); }); + + this->edges = view_edges; + this->outputs = + generate_map(view_outputs, [&](KwargDataflowOutput const &o) { + return view.at(OpenKwargDataflowValue{o}); + }); + } + + UnorderedSetLabelledOpenKwargDataflowGraph *clone() const override { + return new UnorderedSetLabelledOpenKwargDataflowGraph{ + this->node_source, + this->graph_inputs, + this->nodes, + this->edges, + this->outputs, + }; + } + +private: + UnorderedSetLabelledOpenKwargDataflowGraph( + NodeSource const &node_source, + std::unordered_map, + ValueLabel> const &graph_inputs, + std::unordered_map const &nodes, + std::unordered_set> const + &edges, + std::unordered_map, ValueLabel> const + &outputs) + : node_source(node_source), graph_inputs(graph_inputs), nodes(nodes), + edges(edges), outputs(outputs) {} + +private: + NodeSource node_source; + + std::unordered_map, ValueLabel> + graph_inputs; + std::unordered_map nodes; + std::unordered_set> edges; + std::unordered_map, ValueLabel> outputs; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/instances/unordered_set_open_kwarg_dataflow_graph.h b/lib/utils/include/utils/graph/instances/unordered_set_open_kwarg_dataflow_graph.h new file mode 100644 index 0000000000..3c66b2c689 --- /dev/null +++ b/lib/utils/include/utils/graph/instances/unordered_set_open_kwarg_dataflow_graph.h @@ -0,0 +1,130 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_UNORDERED_SET_OPEN_KWARG_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_UNORDERED_SET_OPEN_KWARG_DATAFLOW_GRAPH_H + +#include "utils/containers/generate_map.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output_query.h" +#include "utils/graph/node/node_source.h" +#include "utils/graph/open_kwarg_dataflow_graph/i_open_kwarg_dataflow_graph.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge_query.h" + +namespace FlexFlow { + +template +struct UnorderedSetOpenKwargDataflowGraph final + : public IOpenKwargDataflowGraph { + UnorderedSetOpenKwargDataflowGraph() = default; + + KwargNodeAddedResult add_node( + std::unordered_map> const + &inputs, + std::unordered_set const &output_slots) override { + Node new_node = this->node_source.new_node(); + this->nodes.insert(new_node); + + for (auto const &[input_slot_name, input_val] : inputs) { + KwargDataflowInput dst = KwargDataflowInput{ + new_node, + input_slot_name, + }; + + OpenKwargDataflowEdge in_edge = + mk_open_kwarg_dataflow_edge_from_src_val_and_dst(input_val, dst); + + this->edges.insert(in_edge); + } + + std::unordered_map> outputs = + generate_map( + output_slots, + [&](SlotName const &output_slot) -> KwargDataflowOutput { + KwargDataflowOutput output = + KwargDataflowOutput{ + /*node=*/new_node, + /*slot_name=*/output_slot, + }; + + this->outputs.insert(output); + + return output; + }); + + return KwargNodeAddedResult{ + /*node=*/new_node, + /*outputs=*/outputs, + }; + } + + KwargDataflowGraphInput + add_input(GraphInputName const &name) override { + KwargDataflowGraphInput input = + KwargDataflowGraphInput{name}; + + this->graph_inputs.insert(input); + + return input; + } + + std::unordered_set query_nodes(NodeQuery const &q) const override { + return filter(this->nodes, + [&](Node const &n) { return includes(q.nodes, n); }); + } + + std::unordered_set> + query_edges(OpenKwargDataflowEdgeQuery const &q) + const override { + return filter( + this->edges, + [&](OpenKwargDataflowEdge const &e) { + return open_kwarg_dataflow_edge_query_includes(q, e); + }); + } + + std::unordered_set> query_outputs( + KwargDataflowOutputQuery const &q) const override { + return filter(this->outputs, + [&](KwargDataflowOutput const &output) { + return kwarg_dataflow_output_query_includes(q, output); + }); + } + + std::unordered_set> + get_inputs() const override { + return this->graph_inputs; + } + + UnorderedSetOpenKwargDataflowGraph *clone() const override { + return new UnorderedSetOpenKwargDataflowGraph{ + this->node_source, + this->graph_inputs, + this->nodes, + this->edges, + this->outputs, + }; + } + +private: + UnorderedSetOpenKwargDataflowGraph( + NodeSource const &node_source, + std::unordered_set> const + &graph_inputs, + std::unordered_set const &nodes, + std::unordered_set> const + &edges, + std::unordered_set> const &outputs) + : node_source(node_source), graph_inputs(graph_inputs), nodes(nodes), + edges(edges), outputs(outputs) {} + +private: + NodeSource node_source; + + std::unordered_set> graph_inputs; + std::unordered_set nodes; + std::unordered_set> edges; + std::unordered_set> outputs; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/find_isomorphism_between_kwarg_dataflow_graphs.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/find_isomorphism_between_kwarg_dataflow_graphs.h new file mode 100644 index 0000000000..146408123d --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/find_isomorphism_between_kwarg_dataflow_graphs.h @@ -0,0 +1,35 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_BETWEEN_KWARG_DATAFLOW_GRAPHS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_BETWEEN_KWARG_DATAFLOW_GRAPHS_H + +#include "utils/containers/get_one_of.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/view_as_open_kwarg_dataflow_graph.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/find_isomorphisms_between_open_kwarg_dataflow_graphs.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_isomorphism.dtg.h" + +namespace FlexFlow { + +template +std::optional> + find_isomorphism_between_kwarg_dataflow_graphs( + KwargDataflowGraphView const &lhs, + KwargDataflowGraphView const &rhs) { + + std::unordered_set> open_isomorphisms = + find_isomorphisms_between_open_kwarg_dataflow_graphs( + view_as_open_kwarg_dataflow_graph(lhs), + view_as_open_kwarg_dataflow_graph(rhs)); + + if (open_isomorphisms.empty()) { + return std::nullopt; + } else { + OpenKwargDataflowGraphIsomorphism chosen = + get_one_of(open_isomorphisms); + ASSERT(chosen.input_mapping.empty()); + return chosen.node_mapping; + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_edges.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_edges.h new file mode 100644 index 0000000000..b881cd9584 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_edges.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_ALL_KWARG_DATAFLOW_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_ALL_KWARG_DATAFLOW_EDGES_H + +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge_query.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +std::unordered_set> + get_all_kwarg_dataflow_edges(KwargDataflowGraphView const &g) { + return g.query_edges(kwarg_dataflow_edge_query_all()); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_outputs.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_outputs.h new file mode 100644 index 0000000000..2405284fd3 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_outputs.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_ALL_KWARG_DATAFLOW_OUTPUTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_ALL_KWARG_DATAFLOW_OUTPUTS_H + +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output_query.h" + +namespace FlexFlow { + +template +std::unordered_set> + get_all_kwarg_dataflow_outputs( + KwargDataflowGraphView const &view) { + return view.query_outputs(kwarg_dataflow_output_query_all()); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_kwarg_dataflow_edges_for_node.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_kwarg_dataflow_edges_for_node.h new file mode 100644 index 0000000000..2c57970736 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_kwarg_dataflow_edges_for_node.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_INCOMING_KWARG_DATAFLOW_EDGES_FOR_NODE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_INCOMING_KWARG_DATAFLOW_EDGES_FOR_NODE_H + +#include "utils/containers/unordered_map_from_pairs.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +std::unordered_map> + get_incoming_kwarg_dataflow_edges_for_node( + KwargDataflowGraphView const &g, Node const &n) { + KwargDataflowEdgeQuery query = KwargDataflowEdgeQuery{ + /*src_nodes=*/query_set::matchall(), + /*src_slots=*/query_set::matchall(), + /*dst_nodes=*/query_set{n}, + /*dst_slots=*/query_set::matchall(), + }; + + return unordered_map_from_pairs( + transform(g.query_edges(query), [](KwargDataflowEdge const &e) { + return std::pair{ + e.dst.slot_name, + e, + }; + })); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_kwarg_dataflow_outputs_for_node.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_kwarg_dataflow_outputs_for_node.h new file mode 100644 index 0000000000..b0940570b0 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_kwarg_dataflow_outputs_for_node.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_INCOMING_KWARG_DATAFLOW_OUTPUTS_FOR_NODE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_INCOMING_KWARG_DATAFLOW_OUTPUTS_FOR_NODE_H + +#include "utils/containers/map_values.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_kwarg_dataflow_edges_for_node.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +std::unordered_map> + get_incoming_kwarg_dataflow_outputs_for_node( + KwargDataflowGraphView const &g, Node const &n) { + return map_values(get_incoming_kwarg_dataflow_edges_for_node(g, n), + [](KwargDataflowEdge const &e) { return e.src; }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_edges_from_node_to_node.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_edges_from_node_to_node.h new file mode 100644 index 0000000000..3fe2d48c6a --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_edges_from_node_to_node.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_KWARG_DATAFLOW_EDGES_FROM_NODE_TO_NODE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_KWARG_DATAFLOW_EDGES_FROM_NODE_TO_NODE_H + +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +std::unordered_set> + get_kwarg_dataflow_edges_from_node_to_node( + KwargDataflowGraphView const &g, + Node const &src, + Node const &dst) { + KwargDataflowEdgeQuery query = KwargDataflowEdgeQuery{ + /*src_nodes=*/query_set{src}, + /*src_slots=*/query_set::matchall(), + /*dst_nodes=*/query_set{dst}, + /*dst_slots=*/query_set::matchall(), + }; + + return g.query_edges(query); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_subgraph_incoming_edges.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_subgraph_incoming_edges.h new file mode 100644 index 0000000000..8e9feaf3b5 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_subgraph_incoming_edges.h @@ -0,0 +1,30 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_KWARG_DATAFLOW_SUBGRAPH_INCOMING_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_KWARG_DATAFLOW_SUBGRAPH_INCOMING_EDGES_H + +#include "utils/containers/set_minus.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +template +std::unordered_set> + get_kwarg_dataflow_subgraph_incoming_edges( + KwargDataflowGraphView const &g, + std::unordered_set const &subgraph) { + std::unordered_set all_nodes = get_nodes(g); + query_set src_query = query_set{set_minus(all_nodes, subgraph)}; + + KwargDataflowEdgeQuery query = KwargDataflowEdgeQuery{ + /*src_nodes=*/src_query, + /*src_slots=*/query_set::matchall(), + /*dst_nodes=*/query_set{subgraph}, + /*dst_slots=*/query_set::matchall(), + }; + + return g.query_edges(query); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_subgraph_outgoing_edges.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_subgraph_outgoing_edges.h new file mode 100644 index 0000000000..532e37f8ec --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_subgraph_outgoing_edges.h @@ -0,0 +1,30 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_KWARG_DATAFLOW_SUBGRAPH_OUTGOING_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_KWARG_DATAFLOW_SUBGRAPH_OUTGOING_EDGES_H + +#include "utils/containers/set_minus.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +template +std::unordered_set> + get_kwarg_dataflow_subgraph_outgoing_edges( + KwargDataflowGraphView const &g, + std::unordered_set const &subgraph) { + std::unordered_set all_nodes = get_nodes(g); + query_set dst_query = query_set{set_minus(all_nodes, subgraph)}; + + KwargDataflowEdgeQuery query = KwargDataflowEdgeQuery{ + /*src_nodes=*/query_set{subgraph}, + /*src_slots=*/query_set::matchall(), + /*dst_nodes=*/dst_query, + /*dst_slots=*/query_set::matchall(), + }; + + return g.query_edges(query); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_edges_for_node.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_edges_for_node.h new file mode 100644 index 0000000000..7ab7b80199 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_edges_for_node.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_OUTGOING_KWARG_DATAFLOW_EDGES_FOR_NODE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_OUTGOING_KWARG_DATAFLOW_EDGES_FOR_NODE_H + +#include "utils/containers/group_by.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" +#include "utils/one_to_many/one_to_many.h" + +namespace FlexFlow { + +template +OneToMany> + get_outgoing_kwarg_dataflow_edges_for_node( + KwargDataflowGraphView const &g, Node const &n) { + KwargDataflowEdgeQuery query = KwargDataflowEdgeQuery{ + /*src_nodes=*/query_set{n}, + /*src_slots=*/query_set::matchall(), + /*dst_nodes=*/query_set::matchall(), + /*dst_slots=*/query_set::matchall(), + }; + + return group_by( + g.query_edges(query), + [](KwargDataflowEdge const &e) { return e.src.slot_name; }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h new file mode 100644 index 0000000000..72eb9c810d --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_OUTGOING_KWARG_DATAFLOW_OUTPUTS_FOR_NODE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_OUTGOING_KWARG_DATAFLOW_OUTPUTS_FOR_NODE_H + +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +std::unordered_map> + get_outgoing_kwarg_dataflow_outputs_for_node( + KwargDataflowGraphView const &g, Node const &n) { + KwargDataflowOutputQuery query = KwargDataflowOutputQuery{ + /*nodes=*/query_set{n}, + /*output_idxs=*/query_set::matchall(), + }; + + std::unordered_map> result; + + for (KwargDataflowOutput const &output : g.query_outputs(query)) { + result.insert({output.slot_name, output}); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/transitive_reduced_kwarg_dataflow_graph/get_kwarg_dataflow_graph_transitive_reduction.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/transitive_reduced_kwarg_dataflow_graph/get_kwarg_dataflow_graph_transitive_reduction.h new file mode 100644 index 0000000000..7c4fb1835f --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/transitive_reduced_kwarg_dataflow_graph/get_kwarg_dataflow_graph_transitive_reduction.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_KWARG_DATAFLOW_GRAPH_GET_KWARG_DATAFLOW_GRAPH_TRANSITIVE_REDUCTION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_KWARG_DATAFLOW_GRAPH_GET_KWARG_DATAFLOW_GRAPH_TRANSITIVE_REDUCTION_H + +#include "utils/graph/digraph/algorithms/transitive_reduction.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/transitive_reduced_kwarg_dataflow_graph/transitive_reduced_kwarg_dataflow_graph_view.dtg.h" + +namespace FlexFlow { + +template +TransitiveReducedKwargDataflowGraphView + get_kwarg_dataflow_graph_transitive_reduction( + KwargDataflowGraphView const &g) { + + DiGraphView as_digraph = g; + DiGraphView transitive_reduced = transitive_reduction(as_digraph); + + return TransitiveReducedKwargDataflowGraphView{ + /*full_dataflow_graph=*/g, + /*transitive_reduction=*/transitive_reduced, + }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/transitive_reduced_kwarg_dataflow_graph/get_transitive_reduced_boundary_nodes_for_kwarg_dataflow_graph_split.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/transitive_reduced_kwarg_dataflow_graph/get_transitive_reduced_boundary_nodes_for_kwarg_dataflow_graph_split.h new file mode 100644 index 0000000000..af7b288aee --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/transitive_reduced_kwarg_dataflow_graph/get_transitive_reduced_boundary_nodes_for_kwarg_dataflow_graph_split.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_KWARG_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_BOUNDARY_NODES_FOR_KWARG_DATAFLOW_GRAPH_SPLIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_KWARG_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_BOUNDARY_NODES_FOR_KWARG_DATAFLOW_GRAPH_SPLIT_H + +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.dtg.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/transitive_reduced_kwarg_dataflow_graph/get_transitive_reduced_kwarg_dataflow_edges_across_split.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/transitive_reduced_kwarg_dataflow_graph/transitive_reduced_kwarg_dataflow_graph_view.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" + +namespace FlexFlow { + +template +SplitBoundaryNodes + get_transitive_reduced_boundary_nodes_for_kwarg_dataflow_graph_split( + TransitiveReducedKwargDataflowGraphView const &tr_g, + BinarySeriesSplit const &split) { + + std::unordered_set> edges = + get_transitive_reduced_kwarg_dataflow_edges_across_split(tr_g, split); + + std::unordered_set src_boundary_nodes = transform( + edges, [](KwargDataflowEdge const &e) { return e.src.node; }); + + std::unordered_set dst_boundary_nodes = transform( + edges, [](KwargDataflowEdge const &e) { return e.dst.node; }); + + return SplitBoundaryNodes{ + /*pre_split_boundary=*/src_boundary_nodes, + /*post_split_boundary=*/dst_boundary_nodes, + }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/transitive_reduced_kwarg_dataflow_graph/get_transitive_reduced_kwarg_dataflow_edges_across_split.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/transitive_reduced_kwarg_dataflow_graph/get_transitive_reduced_kwarg_dataflow_edges_across_split.h new file mode 100644 index 0000000000..56a90c833f --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/transitive_reduced_kwarg_dataflow_graph/get_transitive_reduced_kwarg_dataflow_edges_across_split.h @@ -0,0 +1,36 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_KWARG_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_KWARG_DATAFLOW_EDGES_ACROSS_SPLIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_KWARG_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_KWARG_DATAFLOW_EDGES_ACROSS_SPLIT_H + +#include "utils/containers/flatmap.h" +#include "utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_edges_from_node_to_node.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/transitive_reduced_kwarg_dataflow_graph/transitive_reduced_kwarg_dataflow_graph_view.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" + +namespace FlexFlow { + +template +std::unordered_set> + get_transitive_reduced_kwarg_dataflow_edges_across_split( + TransitiveReducedKwargDataflowGraphView const &tr_g, + BinarySeriesSplit const &split) { + + std::unordered_set src_subgraph = + unordered_set_of(get_leaves(split.get_left_child())); + std::unordered_set dst_subgraph = + unordered_set_of(get_leaves(split.get_right_child())); + + std::unordered_set raw_edges = + get_edges_from_subgraph_to_subgraph( + tr_g.transitive_reduction, src_subgraph, dst_subgraph); + + return flatmap(raw_edges, [&](DirectedEdge const &e) { + return get_kwarg_dataflow_edges_from_node_to_node( + tr_g.full_kwarg_dataflow_graph, e.src, e.dst); + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/transitive_reduced_kwarg_dataflow_graph/get_transitive_reduced_kwarg_dataflow_outputs_across_split.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/transitive_reduced_kwarg_dataflow_graph/get_transitive_reduced_kwarg_dataflow_outputs_across_split.h new file mode 100644 index 0000000000..82804acc04 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/transitive_reduced_kwarg_dataflow_graph/get_transitive_reduced_kwarg_dataflow_outputs_across_split.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_KWARG_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_KWARG_DATAFLOW_OUTPUTS_ACROSS_SPLIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_KWARG_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_KWARG_DATAFLOW_OUTPUTS_ACROSS_SPLIT_H + +#include "utils/containers/transform.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/transitive_reduced_kwarg_dataflow_graph/get_transitive_reduced_kwarg_dataflow_edges_across_split.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/transitive_reduced_kwarg_dataflow_graph/transitive_reduced_kwarg_dataflow_graph_view.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" + +namespace FlexFlow { + +template +std::unordered_set> + get_transitive_reduced_kwarg_dataflow_outputs_across_split( + TransitiveReducedKwargDataflowGraphView const &tr_g, + BinarySeriesSplit const &split) { + + return transform( + get_transitive_reduced_kwarg_dataflow_edges_across_split(tr_g, split), + [](KwargDataflowEdge const &e) { return e.src; }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/transitive_reduced_kwarg_dataflow_graph/transitive_reduced_kwarg_dataflow_graph_view.dtg.toml b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/transitive_reduced_kwarg_dataflow_graph/transitive_reduced_kwarg_dataflow_graph_view.dtg.toml new file mode 100644 index 0000000000..91a6923c49 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/transitive_reduced_kwarg_dataflow_graph/transitive_reduced_kwarg_dataflow_graph_view.dtg.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "TransitiveReducedKwargDataflowGraphView" +type = "struct" +features = [] + +template_params = [ + "SlotName", +] + +includes = [ + "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h", + "utils/graph/digraph/digraph_view.h", +] + +[[fields]] +name = "full_kwarg_dataflow_graph" +type = "::FlexFlow::KwargDataflowGraphView" + +[[fields]] +name = "transitive_reduction" +type = "::FlexFlow::DiGraphView" diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/view_as_open_kwarg_dataflow_graph.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/view_as_open_kwarg_dataflow_graph.h new file mode 100644 index 0000000000..e1ceddc466 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/view_as_open_kwarg_dataflow_graph.h @@ -0,0 +1,58 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_VIEW_AS_OPEN_KWARG_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_VIEW_AS_OPEN_KWARG_DATAFLOW_GRAPH_H + +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +struct KwargDataflowGraphAsOpenView final + : public IOpenKwargDataflowGraphView { +public: + KwargDataflowGraphAsOpenView() = delete; + KwargDataflowGraphAsOpenView(KwargDataflowGraphView const &g) + : g(g) {} + + std::unordered_set query_nodes(NodeQuery const &q) const override { + return this->g.query_nodes(q); + } + + std::unordered_set> + query_edges(OpenKwargDataflowEdgeQuery const &q) + const override { + return transform(this->g.query_edges(q.standard_edge_query), + [](KwargDataflowEdge const &e) { + return OpenKwargDataflowEdge{ + e}; + }); + } + + std::unordered_set> query_outputs( + KwargDataflowOutputQuery const &q) const override { + return this->g.query_outputs(q); + } + + std::unordered_set> + get_inputs() const override { + return {}; + } + + KwargDataflowGraphAsOpenView *clone() const override { + return new KwargDataflowGraphAsOpenView{this->g}; + } + +private: + KwargDataflowGraphView g; +}; + +template +OpenKwargDataflowGraphView + view_as_open_kwarg_dataflow_graph( + KwargDataflowGraphView const &g) { + return OpenKwargDataflowGraphView::template create< + KwargDataflowGraphAsOpenView>(g); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/i_kwarg_dataflow_graph.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/i_kwarg_dataflow_graph.h new file mode 100644 index 0000000000..36f36b3fc8 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/i_kwarg_dataflow_graph.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_I_KWARG_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_I_KWARG_DATAFLOW_GRAPH_H + +#include "utils/graph/kwarg_dataflow_graph/i_kwarg_dataflow_graph_view.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_node_added_result.dtg.h" + +namespace FlexFlow { + +template +struct IKwargDataflowGraph : virtual public IKwargDataflowGraphView { + virtual KwargNodeAddedResult add_node( + std::unordered_map> const &inputs, + std::unordered_set const &outputs) = 0; + + virtual void add_node_unsafe( + Node const &node, + std::unordered_map> const &inputs, + std::unordered_map> const + &outputs) = 0; + + virtual void + inplace_materialize_from(KwargDataflowGraphView const &) = 0; + + virtual IKwargDataflowGraph *clone() const = 0; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(IKwargDataflowGraph); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/i_kwarg_dataflow_graph_view.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/i_kwarg_dataflow_graph_view.h new file mode 100644 index 0000000000..6b85d9f470 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/i_kwarg_dataflow_graph_view.h @@ -0,0 +1,42 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_I_KWARG_DATAFLOW_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_I_KWARG_DATAFLOW_GRAPH_VIEW_H + +#include "utils/graph/dataflow_graph/dataflow_edge.dtg.h" +#include "utils/graph/digraph/i_digraph_view.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge.dtg.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge_query.dtg.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output.dtg.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output_query.dtg.h" + +namespace FlexFlow { + +template +struct IKwargDataflowGraphView : virtual public IDiGraphView { + virtual std::unordered_set> + query_edges(KwargDataflowEdgeQuery const &) const = 0; + virtual std::unordered_set> + query_outputs(KwargDataflowOutputQuery const &) const = 0; + + std::unordered_set + query_edges(DirectedEdgeQuery const &q) const override final { + KwargDataflowEdgeQuery dataflow_query = KwargDataflowEdgeQuery{ + q.srcs, + matchall(), + q.dsts, + matchall(), + }; + std::unordered_set> dataflow_edges = + this->query_edges(dataflow_query); + + return transform(dataflow_edges, [](KwargDataflowEdge const &e) { + return DirectedEdge{e.src.node, e.dst.node}; + }); + }; + + virtual ~IKwargDataflowGraphView() = default; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(IKwargDataflowGraphView); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge.dtg.toml b/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge.dtg.toml new file mode 100644 index 0000000000..b0b67cbdd9 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "KwargDataflowEdge" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +template_params = [ + "SlotName", +] + +includes = [ + "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output.dtg.h", + "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_input.dtg.h", +] + +[[fields]] +name = "src" +type = "::FlexFlow::KwargDataflowOutput" + +[[fields]] +name = "dst" +type = "::FlexFlow::KwargDataflowInput" diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge_query.dtg.toml b/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge_query.dtg.toml new file mode 100644 index 0000000000..b54731d5db --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge_query.dtg.toml @@ -0,0 +1,34 @@ +namespace = "FlexFlow" +name = "KwargDataflowEdgeQuery" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +template_params = [ + "SlotName", +] + +includes = [ + "utils/graph/query_set.h", + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "src_nodes" +type = "::FlexFlow::query_set<::FlexFlow::Node>" + +[[fields]] +name = "src_slots" +type = "::FlexFlow::query_set" + +[[fields]] +name = "dst_nodes" +type = "::FlexFlow::query_set<::FlexFlow::Node>" + +[[fields]] +name = "dst_slots" +type = "::FlexFlow::query_set" diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge_query.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge_query.h new file mode 100644 index 0000000000..cb2a6c17fc --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge_query.h @@ -0,0 +1,41 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_KWARG_DATAFLOW_EDGE_QUERY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_KWARG_DATAFLOW_EDGE_QUERY_H + +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge.dtg.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge_query.dtg.h" + +namespace FlexFlow { + +template +KwargDataflowEdgeQuery kwarg_dataflow_edge_query_all() { + return KwargDataflowEdgeQuery{ + /*src_nodes=*/query_set::matchall(), + /*src_slots=*/query_set::matchall(), + /*dst_nodes=*/query_set::matchall(), + /*dst_slots=*/query_set::matchall(), + }; +} + +template +KwargDataflowEdgeQuery kwarg_dataflow_edge_query_none() { + return KwargDataflowEdgeQuery{ + /*src_nodes=*/query_set::match_none(), + /*src_slots=*/query_set::match_none(), + /*dst_nodes=*/query_set::match_none(), + /*dst_slots=*/query_set::match_none(), + }; +} + +template +bool kwarg_dataflow_edge_query_includes( + KwargDataflowEdgeQuery const &query, + KwargDataflowEdge const &edge) { + return includes(query.src_nodes, edge.src.node) && + includes(query.src_slots, edge.src.slot_name) && + includes(query.dst_nodes, edge.dst.node) && + includes(query.dst_slots, edge.dst.slot_name); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph.h new file mode 100644 index 0000000000..2b6fd22f86 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph.h @@ -0,0 +1,75 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_KWARG_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_KWARG_DATAFLOW_GRAPH_H + +#include "utils/graph/kwarg_dataflow_graph/i_kwarg_dataflow_graph.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_node_added_result.dtg.h" + +namespace FlexFlow { + +template +struct KwargDataflowGraph : virtual public KwargDataflowGraphView { +public: + KwargNodeAddedResult add_node( + std::unordered_map> const &inputs, + std::unordered_set const &outputs) { + return this->get_interface().add_node(inputs, outputs); + } + + void add_node_unsafe( + Node const &node, + std::unordered_map> const &inputs, + std::unordered_map> const + &outputs) { + return this->get_interface().add_node_unsafe(node, inputs, outputs); + } + + std::unordered_set query_nodes(NodeQuery const &q) const { + return this->get_interface().query_nodes(q); + } + + std::unordered_set> + query_edges(KwargDataflowEdgeQuery const &q) const { + return this->get_interface().query_edges(q); + } + + std::unordered_set> + query_outputs(KwargDataflowOutputQuery const &q) const { + return this->get_interface().query_outputs(q); + } + + template + static typename std::enable_if< + std::is_base_of, T>::value, + KwargDataflowGraph>::type + create() { + return KwargDataflowGraph(make_cow_ptr()); + } + + template + static std::enable_if_t, T>, + KwargDataflowGraph> + create_copy_of(KwargDataflowGraphView const &view) { + cow_ptr_t impl = make_cow_ptr(); + impl.get_mutable()->inplace_materialize_from(view); + return KwargDataflowGraph(std::move(impl)); + } + +protected: + using KwargDataflowGraphView::KwargDataflowGraphView; + +private: + IKwargDataflowGraph &get_interface() { + return *std::dynamic_pointer_cast>( + GraphView::ptr.get_mutable()); + } + + IKwargDataflowGraph const &get_interface() const { + return *std::dynamic_pointer_cast const>( + GraphView::ptr.get()); + } +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h new file mode 100644 index 0000000000..9726b0eb34 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h @@ -0,0 +1,50 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_KWARG_DATAFLOW_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_KWARG_DATAFLOW_GRAPH_VIEW_H + +#include "utils/graph/digraph/digraph_view.h" +#include "utils/graph/kwarg_dataflow_graph/i_kwarg_dataflow_graph_view.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge_query.dtg.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output_query.dtg.h" + +namespace FlexFlow { + +template +struct KwargDataflowGraphView : virtual public DiGraphView { + KwargDataflowGraphView(KwargDataflowGraphView const &) = default; + KwargDataflowGraphView &operator=(KwargDataflowGraphView const &) = default; + + std::unordered_set query_nodes(NodeQuery const &q) const { + return this->get_interface().query_nodes(q); + } + + std::unordered_set> + query_edges(KwargDataflowEdgeQuery const &q) const { + return this->get_interface().query_edges(q); + } + + std::unordered_set> + query_outputs(KwargDataflowOutputQuery const &q) const { + return this->get_interface().query_outputs(q); + } + + template + static typename std::enable_if< + std::is_base_of, T>::value, + KwargDataflowGraphView>::type + create(Args &&...args) { + return DataflowGraphView(make_cow_ptr(std::forward(args)...)); + } + +protected: + using DiGraphView::DiGraphView; + +private: + IKwargDataflowGraphView const &get_interface() const { + return *std::dynamic_pointer_cast const>( + GraphView::ptr.get()); + } +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_input.dtg.toml b/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_input.dtg.toml new file mode 100644 index 0000000000..5f092f0540 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_input.dtg.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "KwargDataflowInput" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +template_params = [ + "SlotName", +] + +includes = [ + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "node" +type = "::FlexFlow::Node" + +[[fields]] +name = "slot_name" +type = "SlotName" diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output.dtg.toml b/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output.dtg.toml new file mode 100644 index 0000000000..f286fb90a7 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "KwargDataflowOutput" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +template_params = [ + "SlotName", +] + +includes = [ + "utils/graph/node/node.dtg.h", + "utils/nonnegative_int/nonnegative_int.h", +] + +[[fields]] +name = "node" +type = "::FlexFlow::Node" + +[[fields]] +name = "slot_name" +type = "SlotName" diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output_query.dtg.toml b/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output_query.dtg.toml new file mode 100644 index 0000000000..8b5de44cc3 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output_query.dtg.toml @@ -0,0 +1,31 @@ +namespace = "FlexFlow" +name = "KwargDataflowOutputQuery" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +template_params = [ + "SlotName", +] + +includes = [ + "utils/graph/query_set.h", + "utils/graph/node/node.dtg.h", + "utils/nonnegative_int/nonnegative_int.h", +] + +src_includes = [ + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "nodes" +type = "::FlexFlow::query_set<::FlexFlow::Node>" + +[[fields]] +name = "output_idxs" +type = "::FlexFlow::query_set" diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output_query.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output_query.h new file mode 100644 index 0000000000..0b1015721e --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output_query.h @@ -0,0 +1,27 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_KWARG_DATAFLOW_OUTPUT_QUERY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_KWARG_DATAFLOW_OUTPUT_QUERY_H + +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output.dtg.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output_query.dtg.h" + +namespace FlexFlow { + +template +KwargDataflowOutputQuery kwarg_dataflow_output_query_all() { + return KwargDataflowOutputQuery{ + /*nodes=*/query_set::matchall(), + /*output_idxs=*/query_set::matchall(), + }; +} + +template +bool kwarg_dataflow_output_query_includes( + KwargDataflowOutputQuery const &query, + KwargDataflowOutput const &output) { + return includes(query.nodes, output.node) && + includes(query.output_idxs, output.slot_name); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_node_added_result.dtg.toml b/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_node_added_result.dtg.toml new file mode 100644 index 0000000000..5686368f66 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_node_added_result.dtg.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "KwargNodeAddedResult" +type = "struct" +features = [ + "eq", + "ord", + "fmt", +] + +template_params = [ + "SlotName", +] + +includes = [ + "", + "utils/graph/node/node.dtg.h", + "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output.dtg.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", +] + +[[fields]] +name = "node" +type = "::FlexFlow::Node" + +[[fields]] +name = "outputs" +type = "std::unordered_map>" diff --git a/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.h b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.h new file mode 100644 index 0000000000..1364e9ceb0 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_VIEW_AS_DOT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_VIEW_AS_DOT_H + +#include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +std::string labelled_open_kwarg_dataflow_graph_view_as_dot( + LabelledOpenKwargDataflowGraphView const &g, + std::function const &, + std::function const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/rewrite_labelled_kwarg_dataflow_graph_node_labels.h b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/rewrite_labelled_kwarg_dataflow_graph_node_labels.h new file mode 100644 index 0000000000..156036f451 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/rewrite_labelled_kwarg_dataflow_graph_node_labels.h @@ -0,0 +1,36 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_LABELLED_KWARG_DATAFLOW_GRAPH_NODE_LABELS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_LABELLED_KWARG_DATAFLOW_GRAPH_NODE_LABELS_H + +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/view_as_labelled_open_kwarg_dataflow_graph.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/rewrite_labelled_open_kwarg_dataflow_graph_node_labels.h" + +namespace FlexFlow { + +template > +LabelledKwargDataflowGraphView + rewrite_labelled_kwarg_dataflow_graph_node_labels( + LabelledKwargDataflowGraphView const + &g, + F f) { + return rewrite_labelled_open_kwarg_dataflow_graph_node_labels( + view_as_labelled_open_kwarg_dataflow_graph(g), + f); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/rewrite_labelled_kwarg_dataflow_graph_value_labels.h b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/rewrite_labelled_kwarg_dataflow_graph_value_labels.h new file mode 100644 index 0000000000..1bcaed100e --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/rewrite_labelled_kwarg_dataflow_graph_value_labels.h @@ -0,0 +1,42 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_LABELLED_KWARG_DATAFLOW_GRAPH_VALUE_LABELS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_LABELLED_KWARG_DATAFLOW_GRAPH_VALUE_LABELS_H + +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/view_as_labelled_open_kwarg_dataflow_graph.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/rewrite_labelled_open_kwarg_dataflow_graph_value_labels.h" + +namespace FlexFlow { + +template const &, + ValueLabel const &>> +LabelledKwargDataflowGraphView + rewrite_labelled_kwarg_dataflow_graph_value_labels( + LabelledKwargDataflowGraphView const + &g, + F f) { + auto label_func = [&](OpenKwargDataflowValue const &v, + ValueLabel const &l) -> NewValueLabel { + return v.template visit(overload{ + [](KwargDataflowGraphInput const &) -> NewValueLabel { PANIC(); }, + [&](KwargDataflowOutput const &o) -> NewValueLabel { + return f(o, l); + }}); + }; + + return rewrite_labelled_open_kwarg_dataflow_graph_value_labels( + view_as_labelled_open_kwarg_dataflow_graph(g), + label_func); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/view_as_labelled_open_kwarg_dataflow_graph.h b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/view_as_labelled_open_kwarg_dataflow_graph.h new file mode 100644 index 0000000000..58e71a4587 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/view_as_labelled_open_kwarg_dataflow_graph.h @@ -0,0 +1,89 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_VIEW_AS_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_VIEW_AS_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_H + +#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph_view.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/i_labelled_open_kwarg_dataflow_graph_view.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +struct LabelledKwargDataflowGraphAsOpenView final + : public ILabelledOpenKwargDataflowGraphView { +public: + LabelledKwargDataflowGraphAsOpenView() = delete; + LabelledKwargDataflowGraphAsOpenView( + LabelledKwargDataflowGraphView const &g) + : g(g) {} + + std::unordered_set query_nodes(NodeQuery const &q) const override { + return this->g.query_nodes(q); + } + + std::unordered_set> + query_edges(OpenKwargDataflowEdgeQuery const &q) + const override { + return transform(this->g.query_edges(q.standard_edge_query), + [](KwargDataflowEdge const &e) { + return OpenKwargDataflowEdge{ + e}; + }); + } + + std::unordered_set> query_outputs( + KwargDataflowOutputQuery const &q) const override { + return this->g.query_outputs(q); + } + + std::unordered_set> + get_inputs() const override { + return {}; + } + + NodeLabel at(Node const &n) const override { + return this->g.at(n); + } + + ValueLabel at(OpenKwargDataflowValue const &v) + const override { + return this->g.at(v.require_internal()); + } + + LabelledKwargDataflowGraphAsOpenView *clone() const override { + return new LabelledKwargDataflowGraphAsOpenView{this->g}; + } + +private: + LabelledKwargDataflowGraphView g; +}; + +template +LabelledOpenKwargDataflowGraphView + view_as_labelled_open_kwarg_dataflow_graph( + LabelledKwargDataflowGraphView const + &g) { + return LabelledOpenKwargDataflowGraphView:: + template create>(g); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/i_labelled_kwarg_dataflow_graph.h b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/i_labelled_kwarg_dataflow_graph.h new file mode 100644 index 0000000000..9bf0e51413 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/i_labelled_kwarg_dataflow_graph.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_I_LABELLED_KWARG_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_I_LABELLED_KWARG_DATAFLOW_GRAPH_H + +#include "utils/graph/kwarg_dataflow_graph/kwarg_node_added_result.dtg.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/i_labelled_kwarg_dataflow_graph_view.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +struct ILabelledKwargDataflowGraph + : virtual public ILabelledKwargDataflowGraphView { +public: + virtual KwargNodeAddedResult add_node( + NodeLabel const &node_label, + std::unordered_map> const &inputs, + std::unordered_map const &output_labels) = 0; + virtual void inplace_materialize_from( + LabelledKwargDataflowGraphView const + &) = 0; + + virtual ~ILabelledKwargDataflowGraph() = default; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/i_labelled_kwarg_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/i_labelled_kwarg_dataflow_graph_view.h new file mode 100644 index 0000000000..108e791e07 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/i_labelled_kwarg_dataflow_graph_view.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_I_LABELLED_KWARG_DATAFLOW_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_I_LABELLED_KWARG_DATAFLOW_GRAPH_VIEW_H + +#include "utils/graph/kwarg_dataflow_graph/i_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +struct ILabelledKwargDataflowGraphView + : virtual public IKwargDataflowGraphView { + virtual NodeLabel at(Node const &) const = 0; + virtual ValueLabel at(KwargDataflowOutput const &) const = 0; + + virtual ~ILabelledKwargDataflowGraphView() = default; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT( + ILabelledKwargDataflowGraphView); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph.h b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph.h new file mode 100644 index 0000000000..3f4b469471 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph.h @@ -0,0 +1,62 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_H + +#include "utils/graph/labelled_kwarg_dataflow_graph/i_labelled_kwarg_dataflow_graph.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +struct LabelledKwargDataflowGraph + : virtual LabelledKwargDataflowGraphView { +private: + using Interface = + ILabelledKwargDataflowGraph; + +public: + LabelledKwargDataflowGraph(LabelledKwargDataflowGraph const &) = default; + LabelledKwargDataflowGraph & + operator=(LabelledKwargDataflowGraph const &) = default; + + KwargNodeAddedResult add_node( + NodeLabel const &node_label, + std::unordered_map> const &inputs, + std::unordered_map const &output_labels) { + return this->get_interface().add_node(node_label, inputs, output_labels); + } + + template + static typename std::enable_if::value, + LabelledKwargDataflowGraph>::type + create(Args &&...args) { + return LabelledKwargDataflowGraph( + make_cow_ptr(std::forward(args)...)); + } + + template + static typename std::enable_if::value, + LabelledKwargDataflowGraph>::type + create_copy_of( + LabelledKwargDataflowGraphView const + &view) { + cow_ptr_t impl = make_cow_ptr(); + impl.get_mutable()->inplace_materialize_from(view); + return LabelledKwargDataflowGraph(std::move(impl)); + } + +protected: + using LabelledKwargDataflowGraphView:: + LabelledKwargDataflowGraphView; + +private: + Interface &get_interface() { + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); + } + Interface const &get_interface() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); + } +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph_view.h new file mode 100644 index 0000000000..b7a148c348 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph_view.h @@ -0,0 +1,49 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_VIEW_H + +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/i_labelled_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +struct LabelledKwargDataflowGraphView + : virtual public KwargDataflowGraphView { +private: + using Interface = + ILabelledKwargDataflowGraphView; + +public: + LabelledKwargDataflowGraphView(LabelledKwargDataflowGraphView const &) = + default; + LabelledKwargDataflowGraphView & + operator=(LabelledKwargDataflowGraphView const &) = default; + + NodeLabel at(Node const &n) const { + return this->get_interface().at(n); + } + + OutputLabel at(KwargDataflowOutput const &o) const { + return this->get_interface().at(o); + } + + template + static typename std::enable_if::value, + LabelledKwargDataflowGraphView>::type + create(Args &&...args) { + return LabelledKwargDataflowGraphView( + make_cow_ptr(std::forward(args)...)); + } + +protected: + using KwargDataflowGraphView::KwargDataflowGraphView; + +private: + Interface const &get_interface() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); + } +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.h index ecf9c22143..df8207251f 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_IS_ISOMORPHIC_UNDER_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_IS_ISOMORPHIC_UNDER_H +#include "utils/bidict/algorithms/transform_values.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/permute_node_ids.h" @@ -17,14 +18,14 @@ bool is_isomorphic_under( OpenDataflowGraphIsomorphism const &candidate_isomorphism) { bidict node_permutation = - map_values(candidate_isomorphism.node_mapping, [](Node const &dst_node) { - return NewNode{dst_node}; - }).reversed(); + transform_values(candidate_isomorphism.node_mapping, + [](Node const &dst_node) { return NewNode{dst_node}; }) + .reversed(); bidict input_permutation = - map_values(candidate_isomorphism.input_mapping, - [](DataflowGraphInput const &dst_input) { - return NewDataflowGraphInput{dst_input}; - }) + transform_values(candidate_isomorphism.input_mapping, + [](DataflowGraphInput const &dst_input) { + return NewDataflowGraphInput{dst_input}; + }) .reversed(); return get_graph_data(permute_input_ids( permute_node_ids(src, node_permutation), input_permutation)) == diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_data.dtg.toml b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_data.dtg.toml new file mode 100644 index 0000000000..32880d697b --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_data.dtg.toml @@ -0,0 +1,42 @@ +namespace = "FlexFlow" +name = "LabelledOpenDataflowGraphData" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = ["NodeLabel", "ValueLabel"] + +includes = [ + "utils/graph/node/node.dtg.h", + "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h", + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", + "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h", + "", + "", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/hash/unordered_set.h", + "utils/fmt/unordered_map.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "node_data" +type = "std::unordered_map<::FlexFlow::Node, NodeLabel>" + +[[fields]] +name = "edges" +type = "std::unordered_set<::FlexFlow::OpenDataflowEdge>" + +[[fields]] +name = "inputs" +type = "std::unordered_set<::FlexFlow::DataflowGraphInput>" + +[[fields]] +name = "value_data" +type = "std::unordered_map<::FlexFlow::OpenDataflowValue, ValueLabel>" diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_data.struct.toml b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_data.struct.toml deleted file mode 100644 index 082b61e691..0000000000 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_data.struct.toml +++ /dev/null @@ -1,41 +0,0 @@ -namespace = "FlexFlow" -name = "LabelledOpenDataflowGraphData" -features = [ - "eq", - "hash", - "fmt", -] - -template_params = ["NodeLabel", "ValueLabel"] - -includes = [ - "utils/graph/node/node.dtg.h", - "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h", - "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", - "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h", - "", - "", -] - -src_includes = [ - "utils/hash/unordered_map.h", - "utils/hash/unordered_set.h", - "utils/fmt/unordered_map.h", - "utils/fmt/unordered_set.h", -] - -[[fields]] -name = "node_data" -type = "std::unordered_map<::FlexFlow::Node, NodeLabel>" - -[[fields]] -name = "edges" -type = "std::unordered_set<::FlexFlow::OpenDataflowEdge>" - -[[fields]] -name = "inputs" -type = "std::unordered_set<::FlexFlow::DataflowGraphInput>" - -[[fields]] -name = "value_data" -type = "std::unordered_map<::FlexFlow::OpenDataflowValue, ValueLabel>" diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h index 01777909cd..798efa90dd 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h @@ -22,14 +22,6 @@ struct ILabelledOpenDataflowGraph virtual void inplace_materialize_from( LabelledOpenDataflowGraphView const &) = 0; - // NodeAddedResult add_node(NodeLabel const &node_label, - // std::vector const &inputs, - // std::vector const &output_labels) - // override final { - // return this->add_node(node_label, transform(inputs, [](DataflowOutput - // const &o) { return OpenDataflowValue{o}; }), output_labels); - // } - virtual ~ILabelledOpenDataflowGraph() = default; }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(ILabelledOpenDataflowGraph); diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/find_isomorphism_between_labelled_open_kwarg_dataflow_graphs.h b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/find_isomorphism_between_labelled_open_kwarg_dataflow_graphs.h new file mode 100644 index 0000000000..d50670ed41 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/find_isomorphism_between_labelled_open_kwarg_dataflow_graphs.h @@ -0,0 +1,46 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_BETWEEN_LABELLED_OPEN_KWARG_DATAFLOW_GRAPHS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_BETWEEN_LABELLED_OPEN_KWARG_DATAFLOW_GRAPHS_H + +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graphs_are_isomorphic_under.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph_view.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/find_isomorphisms_between_open_kwarg_dataflow_graphs.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_isomorphism.dtg.h" + +namespace FlexFlow { + +template +std::optional> + find_isomorphism_between_labelled_open_kwarg_dataflow_graphs( + LabelledOpenKwargDataflowGraphView const &src, + LabelledOpenKwargDataflowGraphView const &dst) { + std::unordered_set> + unlabelled_isomorphisms = + find_isomorphisms_between_open_kwarg_dataflow_graphs( + static_cast>( + src), + static_cast>( + dst)); + + for (OpenKwargDataflowGraphIsomorphism const + &candidate_isomorphism : unlabelled_isomorphisms) { + if (labelled_open_kwarg_dataflow_graphs_are_isomorphic_under( + src, dst, candidate_isomorphism)) { + return candidate_isomorphism; + } + } + + return std::nullopt; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/get_labelled_open_kwarg_dataflow_graph_data.h b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/get_labelled_open_kwarg_dataflow_graph_data.h new file mode 100644 index 0000000000..d60c396274 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/get_labelled_open_kwarg_dataflow_graph_data.h @@ -0,0 +1,45 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_DATA_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_DATA_H + +#include "utils/containers/generate_map.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_data.dtg.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph_view.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_graph_inputs.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_open_kwarg_dataflow_edges.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_open_kwarg_dataflow_values.h" + +namespace FlexFlow { + +template +LabelledOpenKwargDataflowGraphData + get_labelled_open_kwarg_dataflow_graph_data( + LabelledOpenKwargDataflowGraphView const &g) { + return LabelledOpenKwargDataflowGraphData{ + /*nodes=*/generate_map( + get_nodes(g), [&](Node const &n) -> NodeLabel { return g.at(n); }), + /*edges=*/get_all_open_kwarg_dataflow_edges(g), + /*inputs=*/get_all_kwarg_dataflow_graph_inputs(g), + /*outputs=*/ + generate_map( + get_all_open_kwarg_dataflow_values(g), + [&](OpenKwargDataflowValue const &v) + -> ValueLabel { return g.at(v); }), + }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_data.dtg.toml b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_data.dtg.toml new file mode 100644 index 0000000000..a3ba15c6ff --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_data.dtg.toml @@ -0,0 +1,47 @@ +namespace = "FlexFlow" +name = "LabelledOpenKwargDataflowGraphData" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = [ + "NodeLabel", + "ValueLabel", + "GraphInputName", + "SlotName", +] + +includes = [ + "utils/graph/node/node.dtg.h", + "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.dtg.h", + "utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_graph_input.dtg.h", + "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_value.dtg.h", + "", + "", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/hash/unordered_set.h", + "utils/fmt/unordered_map.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "node_data" +type = "std::unordered_map<::FlexFlow::Node, NodeLabel>" + +[[fields]] +name = "edges" +type = "std::unordered_set<::FlexFlow::OpenKwargDataflowEdge>" + +[[fields]] +name = "inputs" +type = "std::unordered_set<::FlexFlow::KwargDataflowGraphInput>" + +[[fields]] +name = "value_data" +type = "std::unordered_map<::FlexFlow::OpenKwargDataflowValue, ValueLabel>" diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_data.h b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_data.h new file mode 100644 index 0000000000..d06b96e37f --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_data.h @@ -0,0 +1,54 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_DATA_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_DATA_H + +#include "utils/containers/filtrans.h" +#include "utils/containers/keys.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_data.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.h" + +namespace FlexFlow { + +template +OpenKwargDataflowGraphData + labelled_open_kwarg_dataflow_graph_data_without_labels( + LabelledOpenKwargDataflowGraphData const &labelled_data) { + OpenKwargDataflowGraphData result = + OpenKwargDataflowGraphData{ + /*nodes=*/keys(labelled_data.node_data), + /*edges=*/labelled_data.edges, + /*inputs=*/labelled_data.inputs, + /*outputs=*/ + filtrans( + keys(labelled_data.value_data), + [](OpenKwargDataflowValue const &v) { + return v.try_require_internal(); + }), + }; + + require_open_kwarg_dataflow_graph_data_is_valid(result); + + return result; +} + +template +void require_labelled_open_kwarg_dataflow_graph_data_is_valid( + LabelledOpenKwargDataflowGraphData const &data) { + labelled_open_kwarg_dataflow_graph_data_without_labels(data); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graphs_are_isomorphic_under.h b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graphs_are_isomorphic_under.h new file mode 100644 index 0000000000..015b388fa9 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graphs_are_isomorphic_under.h @@ -0,0 +1,60 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_LABELLED_OPEN_KWARG_DATAFLOW_GRAPHS_ARE_ISOMORPHIC_UNDER_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_LABELLED_OPEN_KWARG_DATAFLOW_GRAPHS_ARE_ISOMORPHIC_UNDER_H + +#include "utils/bidict/algorithms/transform_values.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/get_labelled_open_kwarg_dataflow_graph_data.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_data.dtg.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/permute_labelled_open_kwarg_dataflow_graph_input_ids.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/permute_labelled_open_kwarg_dataflow_graph_node_ids.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph_view.h" +#include "utils/graph/node/algorithms/new_node.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_isomorphism.dtg.h" + +namespace FlexFlow { + +template +bool labelled_open_kwarg_dataflow_graphs_are_isomorphic_under( + LabelledOpenKwargDataflowGraphView const &src, + LabelledOpenKwargDataflowGraphView const &dst, + OpenKwargDataflowGraphIsomorphism const + &candidate_isomorphism) { + bidict new_node_to_old_node = + transform_values(candidate_isomorphism.node_mapping, [](Node const &n) { + return NewNode{n}; + }).reversed(); + + bidict, + KwargDataflowGraphInput> + new_input_to_old_input = candidate_isomorphism.input_mapping.reversed(); + + LabelledOpenKwargDataflowGraphData + permuted_data = get_labelled_open_kwarg_dataflow_graph_data( + permute_labelled_open_kwarg_dataflow_graph_input_ids( + permute_labelled_open_kwarg_dataflow_graph_node_ids( + src, new_node_to_old_node), + new_input_to_old_input)); + + LabelledOpenKwargDataflowGraphData + dst_data = get_labelled_open_kwarg_dataflow_graph_data(dst); + + return permuted_data == dst_data; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_view_with_labelling.h b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_view_with_labelling.h new file mode 100644 index 0000000000..bee9ad0c37 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_view_with_labelling.h @@ -0,0 +1,100 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_OPEN_KWARG_DATAFLOW_GRAPH_VIEW_WITH_LABELLING_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_OPEN_KWARG_DATAFLOW_GRAPH_VIEW_WITH_LABELLING_H + +#include "utils/graph/labelled_open_kwarg_dataflow_graph/i_labelled_open_kwarg_dataflow_graph_view.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph_view.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +struct OpenKwargDataflowGraphLabellingWrapper final + : public ILabelledOpenKwargDataflowGraphView { +public: + OpenKwargDataflowGraphLabellingWrapper() = delete; + OpenKwargDataflowGraphLabellingWrapper( + OpenKwargDataflowGraphView const &unlabelled, + std::unordered_map const &node_labels, + std::unordered_map, + ValueLabel> const &value_labels) + : unlabelled(unlabelled), node_labels(node_labels), + value_labels(value_labels) {} + + std::unordered_set query_nodes(NodeQuery const &q) const override { + return this->unlabelled.query_nodes(q); + } + + std::unordered_set> + query_edges(OpenKwargDataflowEdgeQuery const &q) + const override { + return this->unlabelled.query_edges(q); + } + + std::unordered_set> query_outputs( + KwargDataflowOutputQuery const &q) const override { + return this->unlabelled.query_outputs(q); + } + + std::unordered_set> + get_inputs() const override { + return this->unlabelled.get_inputs(); + } + + NodeLabel at(Node const &n) const override { + return this->node_labels.at(n); + } + + ValueLabel at(OpenKwargDataflowValue const &v) + const override { + return this->value_labels.at(v); + } + + OpenKwargDataflowGraphLabellingWrapper *clone() const override { + return new OpenKwargDataflowGraphLabellingWrapper{ + this->unlabelled, + this->node_labels, + this->value_labels, + }; + } + +private: + OpenKwargDataflowGraphView unlabelled; + std::unordered_map node_labels; + std::unordered_map, + ValueLabel> + value_labels; +}; + +template +LabelledOpenKwargDataflowGraphView + open_kwarg_dataflow_graph_view_with_labelling( + OpenKwargDataflowGraphView const &g, + std::unordered_map const &node_labels, + std::unordered_map, + ValueLabel> const &value_labels) { + return LabelledOpenKwargDataflowGraphView:: + template create>( + g, node_labels, value_labels); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/permute_labelled_open_kwarg_dataflow_graph_input_ids.h b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/permute_labelled_open_kwarg_dataflow_graph_input_ids.h new file mode 100644 index 0000000000..223c2e7673 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/permute_labelled_open_kwarg_dataflow_graph_input_ids.h @@ -0,0 +1,75 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_INPUT_IDS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_INPUT_IDS_H + +#include "utils/containers/generate_map.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_view_with_labelling.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph_view.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/node/algorithms/new_node.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_open_kwarg_dataflow_values.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/permute_open_kwarg_dataflow_graph_input_ids.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +LabelledOpenKwargDataflowGraphView + permute_labelled_open_kwarg_dataflow_graph_input_ids( + LabelledOpenKwargDataflowGraphView const &g, + bidict, + KwargDataflowGraphInput> const + &new_input_to_old_input) { + + OpenKwargDataflowGraphView permuted = + permute_open_kwarg_dataflow_graph_input_ids( + static_cast>(g), + new_input_to_old_input); + + auto old_input_from_new = + [&](KwargDataflowGraphInput const &i) + -> KwargDataflowGraphInput { + return new_input_to_old_input.at_l(i); + }; + + auto old_value_from_new = + [&](OpenKwargDataflowValue const &new_value) { + return new_value + .template visit>( + overload{ + [](KwargDataflowOutput const &o) { + return OpenKwargDataflowValue{ + o}; + }, + [&](KwargDataflowGraphInput const &i) { + return OpenKwargDataflowValue{ + old_input_from_new(i)}; + }, + }); + }; + + std::unordered_map node_labels = + generate_map(get_nodes(permuted), [&](Node const &n) { return g.at(n); }); + + std::unordered_map, + ValueLabel> + value_labels = generate_map( + get_all_open_kwarg_dataflow_values(permuted), + [&](OpenKwargDataflowValue const + &new_value) { return g.at(old_value_from_new(new_value)); }); + + return open_kwarg_dataflow_graph_view_with_labelling( + permuted, node_labels, value_labels); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/permute_labelled_open_kwarg_dataflow_graph_node_ids.h b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/permute_labelled_open_kwarg_dataflow_graph_node_ids.h new file mode 100644 index 0000000000..06728949df --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/permute_labelled_open_kwarg_dataflow_graph_node_ids.h @@ -0,0 +1,76 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_NODE_IDS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_NODE_IDS_H + +#include "utils/containers/generate_map.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_view_with_labelling.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph_view.h" +#include "utils/graph/node/algorithms/new_node.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_open_kwarg_dataflow_values.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/permute_open_kwarg_dataflow_graph_node_ids.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +LabelledOpenKwargDataflowGraphView + permute_labelled_open_kwarg_dataflow_graph_node_ids( + LabelledOpenKwargDataflowGraphView const &g, + bidict const &new_node_tofrom_old_node) { + + OpenKwargDataflowGraphView permuted = + permute_open_kwarg_dataflow_graph_node_ids( + static_cast>(g), + new_node_tofrom_old_node); + + auto old_node_from_new = [&](Node const &new_node) { + return new_node_tofrom_old_node.at_l(NewNode{new_node}); + }; + + auto old_value_from_new = + [&](OpenKwargDataflowValue const &new_value) { + return new_value + .template visit>( + overload{ + [&](KwargDataflowOutput const &new_o) { + return OpenKwargDataflowValue{ + KwargDataflowOutput{ + old_node_from_new(new_o.node), + new_o.slot_name, + }, + }; + }, + [](KwargDataflowGraphInput const &i) { + return OpenKwargDataflowValue{ + i}; + }, + }); + }; + + std::unordered_map node_labels = + generate_map(get_nodes(permuted), [&](Node const &new_node) { + return g.at(old_node_from_new(new_node)); + }); + + std::unordered_map, + ValueLabel> + value_labels = generate_map( + get_all_open_kwarg_dataflow_values(permuted), + [&](OpenKwargDataflowValue const + &new_value) { return g.at(old_value_from_new(new_value)); }); + + return open_kwarg_dataflow_graph_view_with_labelling( + permuted, node_labels, value_labels); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/rewrite_labelled_open_kwarg_dataflow_graph_labels.h b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/rewrite_labelled_open_kwarg_dataflow_graph_labels.h new file mode 100644 index 0000000000..a632cd7b64 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/rewrite_labelled_open_kwarg_dataflow_graph_labels.h @@ -0,0 +1,53 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_LABELS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_LABELS_H + +#include "utils/containers/generate_map.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_view_with_labelling.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_open_kwarg_dataflow_values.h" + +namespace FlexFlow { + +template , + typename NewValueLabel = std::invoke_result_t< + F, + OpenKwargDataflowValue const &, + ValueLabel const &>> +LabelledOpenKwargDataflowGraphView + rewrite_labelled_open_kwarg_dataflow_graph_labels( + LabelledOpenKwargDataflowGraphView const &g, + F f) { + auto get_new_node_label = [&](Node const &n) -> NewNodeLabel { + return f(n, g.at(n)); + }; + + auto get_new_value_label = + [&](OpenKwargDataflowValue const &v) + -> NewValueLabel { return f(v, g.at(v)); }; + + std::unordered_map node_labels = + generate_map(get_nodes(g), get_new_node_label); + std::unordered_map, + NewValueLabel> + value_labels = generate_map(get_all_open_kwarg_dataflow_values(g), + get_new_value_label); + return open_kwarg_dataflow_graph_view_with_labelling( + g, node_labels, value_labels); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/rewrite_labelled_open_kwarg_dataflow_graph_node_labels.h b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/rewrite_labelled_open_kwarg_dataflow_graph_node_labels.h new file mode 100644 index 0000000000..6ed0ea337d --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/rewrite_labelled_open_kwarg_dataflow_graph_node_labels.h @@ -0,0 +1,38 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_NODE_LABELS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_NODE_LABELS_H + +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/rewrite_labelled_open_kwarg_dataflow_graph_labels.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template > +LabelledOpenKwargDataflowGraphView + rewrite_labelled_open_kwarg_dataflow_graph_node_labels( + LabelledOpenKwargDataflowGraphView const &g, + F f) { + + return rewrite_labelled_open_kwarg_dataflow_graph_labels( + g, + overload{ + [&](Node const &n, NodeLabel const &l) { return f(n, l); }, + [](OpenKwargDataflowValue const &, + ValueLabel const &l) { return l; }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/rewrite_labelled_open_kwarg_dataflow_graph_value_labels.h b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/rewrite_labelled_open_kwarg_dataflow_graph_value_labels.h new file mode 100644 index 0000000000..fc501f7e79 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/rewrite_labelled_open_kwarg_dataflow_graph_value_labels.h @@ -0,0 +1,40 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_VALUE_LABELS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_VALUE_LABELS_H + +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/rewrite_labelled_open_kwarg_dataflow_graph_labels.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template const &, + ValueLabel const &>> +LabelledOpenKwargDataflowGraphView + rewrite_labelled_open_kwarg_dataflow_graph_value_labels( + LabelledOpenKwargDataflowGraphView const &g, + F f) { + + return rewrite_labelled_open_kwarg_dataflow_graph_labels( + g, + overload{ + [](Node const &, NodeLabel const &l) { return l; }, + [&](OpenKwargDataflowValue const &v, + ValueLabel const &l) { return f(v, l); }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/view_from_labelled_open_kwarg_dataflow_graph_data.h b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/view_from_labelled_open_kwarg_dataflow_graph_data.h new file mode 100644 index 0000000000..1bf435bafa --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/view_from_labelled_open_kwarg_dataflow_graph_data.h @@ -0,0 +1,38 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_VIEW_FROM_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_DATA_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_VIEW_FROM_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_DATA_H + +#include "utils/containers/filtrans.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_data.dtg.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_data.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_view_with_labelling.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph_view.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/view_from_open_kwarg_dataflow_graph_data.h" + +namespace FlexFlow { + +template +LabelledOpenKwargDataflowGraphView + view_from_labelled_open_kwarg_dataflow_graph_data( + LabelledOpenKwargDataflowGraphData const &data) { + OpenKwargDataflowGraphData unlabelled_data = + labelled_open_kwarg_dataflow_graph_data_without_labels(data); + + return open_kwarg_dataflow_graph_view_with_labelling( + view_from_open_kwarg_dataflow_graph_data(unlabelled_data), + data.node_data, + data.value_data); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/i_labelled_open_kwarg_dataflow_graph.h b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/i_labelled_open_kwarg_dataflow_graph.h new file mode 100644 index 0000000000..bec1c540ea --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/i_labelled_open_kwarg_dataflow_graph.h @@ -0,0 +1,46 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_I_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_I_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_H + +#include "utils/graph/kwarg_dataflow_graph/kwarg_node_added_result.dtg.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/i_labelled_kwarg_dataflow_graph.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/i_labelled_open_kwarg_dataflow_graph_view.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +struct ILabelledOpenKwargDataflowGraph + : virtual public ILabelledOpenKwargDataflowGraphView, + virtual public ILabelledKwargDataflowGraphView { + virtual KwargNodeAddedResult add_node( + NodeLabel const &node_label, + std::unordered_map> const + &inputs, + std::unordered_map const &output_labels) = 0; + + virtual KwargDataflowGraphInput + add_input(GraphInputName const &name, ValueLabel const &value_label) = 0; + + virtual void inplace_materialize_from( + LabelledOpenKwargDataflowGraphView const &) = 0; + + virtual ~ILabelledOpenKwargDataflowGraph() = default; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT( + ILabelledOpenKwargDataflowGraph); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/i_labelled_open_kwarg_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/i_labelled_open_kwarg_dataflow_graph_view.h new file mode 100644 index 0000000000..a241021c3f --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/i_labelled_open_kwarg_dataflow_graph_view.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_I_LABELLED_OPEN_DATAFLOW_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_I_LABELLED_OPEN_DATAFLOW_GRAPH_VIEW_H + +#include "utils/graph/labelled_kwarg_dataflow_graph/i_labelled_kwarg_dataflow_graph_view.h" +#include "utils/graph/open_kwarg_dataflow_graph/i_open_kwarg_dataflow_graph_view.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_value.dtg.h" + +namespace FlexFlow { + +template +struct ILabelledOpenKwargDataflowGraphView + : virtual public ILabelledKwargDataflowGraphView, + virtual public IOpenKwargDataflowGraphView { + virtual NodeLabel at(Node const &) const override = 0; + virtual ValueLabel + at(OpenKwargDataflowValue const &) const = 0; + + ValueLabel at(KwargDataflowOutput const &o) const override final { + return this->at(OpenKwargDataflowValue{o}); + } + + virtual ~ILabelledOpenKwargDataflowGraphView() = default; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT( + ILabelledOpenKwargDataflowGraphView); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph.h b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph.h new file mode 100644 index 0000000000..1f03f4d341 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph.h @@ -0,0 +1,82 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_H + +#include "utils/graph/labelled_open_kwarg_dataflow_graph/i_labelled_open_kwarg_dataflow_graph.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +struct LabelledOpenKwargDataflowGraph + : virtual public LabelledOpenKwargDataflowGraphView { +private: + using Interface = ILabelledOpenKwargDataflowGraph; + +public: + LabelledOpenKwargDataflowGraph(LabelledOpenKwargDataflowGraph const &) = + default; + LabelledOpenKwargDataflowGraph & + operator=(LabelledOpenKwargDataflowGraph const &) = default; + + KwargNodeAddedResult add_node( + NodeLabel const &node_label, + std::unordered_map> const + &inputs, + std::unordered_map const &output_labels) { + return this->get_interface().add_node(node_label, inputs, output_labels); + } + + KwargDataflowGraphInput + add_input(GraphInputName const &name, ValueLabel const &value_label) { + return this->get_interface().add_input(name, value_label); + } + + template + static typename std::enable_if::value, + LabelledOpenKwargDataflowGraph>::type + create() { + return LabelledOpenKwargDataflowGraph(make_cow_ptr()); + } + + template + static std::enable_if_t, + LabelledOpenKwargDataflowGraph> + create_copy_of(LabelledOpenKwargDataflowGraphView const &view) { + cow_ptr_t impl = make_cow_ptr(); + impl.get_mutable()->inplace_materialize_from(view); + return LabelledOpenKwargDataflowGraph(std::move(impl)); + } + +protected: + using LabelledOpenKwargDataflowGraphView< + NodeLabel, + ValueLabel, + GraphInputName, + SlotName>::LabelledOpenKwargDataflowGraphView; + +private: + Interface &get_interface() { + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); + } + + Interface const &get_interface() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); + } +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph_view.h new file mode 100644 index 0000000000..0740df1505 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph_view.h @@ -0,0 +1,65 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_VIEW_H + +#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph_view.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/i_labelled_open_kwarg_dataflow_graph_view.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +struct LabelledOpenKwargDataflowGraphView + : virtual public LabelledKwargDataflowGraphView, + virtual public OpenKwargDataflowGraphView { +private: + using Interface = ILabelledOpenKwargDataflowGraphView; + +public: + LabelledOpenKwargDataflowGraphView( + LabelledOpenKwargDataflowGraphView const &) = default; + LabelledOpenKwargDataflowGraphView & + operator=(LabelledOpenKwargDataflowGraphView const &) = default; + + NodeLabel at(Node const &n) const { + return this->get_interface().at(n); + } + + ValueLabel + at(OpenKwargDataflowValue const &v) const { + return this->get_interface().at(v); + } + + template + static typename std::enable_if< + std::is_base_of::value, + LabelledOpenKwargDataflowGraphView>::type + create(Args &&...args) { + return LabelledOpenKwargDataflowGraphView( + static_cast>( + make_cow_ptr(std::forward(args)...))); + } + +protected: + using OpenKwargDataflowGraphView::OpenKwargDataflowGraphView; + +private: + Interface const &get_interface() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); + } +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/multidigraph/multidiedge.dtg.toml b/lib/utils/include/utils/graph/multidigraph/multidiedge.dtg.toml new file mode 100644 index 0000000000..27bc02c42d --- /dev/null +++ b/lib/utils/include/utils/graph/multidigraph/multidiedge.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "MultiDiEdge" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "", +] + +[[fields]] +name = "raw_uid" +type = "size_t" diff --git a/lib/utils/include/utils/graph/multidigraph/multidiedge.struct.toml b/lib/utils/include/utils/graph/multidigraph/multidiedge.struct.toml deleted file mode 100644 index 687aa1ff69..0000000000 --- a/lib/utils/include/utils/graph/multidigraph/multidiedge.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "MultiDiEdge" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "", -] - -[[fields]] -name = "raw_uid" -type = "size_t" diff --git a/lib/utils/include/utils/graph/multidigraph/multidiedge_query.dtg.toml b/lib/utils/include/utils/graph/multidigraph/multidiedge_query.dtg.toml new file mode 100644 index 0000000000..7796886a0c --- /dev/null +++ b/lib/utils/include/utils/graph/multidigraph/multidiedge_query.dtg.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "MultiDiEdgeQuery" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/query_set.h", + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "srcs" +type = "::FlexFlow::query_set<::FlexFlow::Node>" + +[[fields]] +name = "dsts" +type = "::FlexFlow::query_set<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/multidigraph/multidiedge_query.struct.toml b/lib/utils/include/utils/graph/multidigraph/multidiedge_query.struct.toml deleted file mode 100644 index 1d555b2626..0000000000 --- a/lib/utils/include/utils/graph/multidigraph/multidiedge_query.struct.toml +++ /dev/null @@ -1,21 +0,0 @@ -namespace = "FlexFlow" -name = "MultiDiEdgeQuery" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/query_set.h", - "utils/graph/node/node.dtg.h", -] - -[[fields]] -name = "srcs" -type = "::FlexFlow::query_set<::FlexFlow::Node>" - -[[fields]] -name = "dsts" -type = "::FlexFlow::query_set<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/node/algorithms/new_node.dtg.toml b/lib/utils/include/utils/graph/node/algorithms/new_node.dtg.toml new file mode 100644 index 0000000000..512bf6d1d4 --- /dev/null +++ b/lib/utils/include/utils/graph/node/algorithms/new_node.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "NewNode" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "raw_node" +type = "::FlexFlow::Node" diff --git a/lib/utils/include/utils/graph/node/algorithms/new_node.struct.toml b/lib/utils/include/utils/graph/node/algorithms/new_node.struct.toml deleted file mode 100644 index f3b8244573..0000000000 --- a/lib/utils/include/utils/graph/node/algorithms/new_node.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "NewNode" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/node/node.dtg.h", -] - -[[fields]] -name = "raw_node" -type = "::FlexFlow::Node" diff --git a/lib/utils/include/utils/graph/node/node.dtg.toml b/lib/utils/include/utils/graph/node/node.dtg.toml new file mode 100644 index 0000000000..603406a9cc --- /dev/null +++ b/lib/utils/include/utils/graph/node/node.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "Node" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", +] + +includes = [ + "", +] + +[[fields]] +name = "raw_uid" +type = "size_t" diff --git a/lib/utils/include/utils/graph/node/node.struct.toml b/lib/utils/include/utils/graph/node/node.struct.toml deleted file mode 100644 index d5c22e5d3d..0000000000 --- a/lib/utils/include/utils/graph/node/node.struct.toml +++ /dev/null @@ -1,17 +0,0 @@ -namespace = "FlexFlow" -name = "Node" -features = [ - "eq", - "ord", - "hash", - "fmt", - "json", -] - -includes = [ - "", -] - -[[fields]] -name = "raw_uid" -type = "size_t" diff --git a/lib/utils/include/utils/graph/node/node_query.dtg.toml b/lib/utils/include/utils/graph/node/node_query.dtg.toml new file mode 100644 index 0000000000..87602b2d4b --- /dev/null +++ b/lib/utils/include/utils/graph/node/node_query.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "NodeQuery" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", + "utils/graph/query_set.h", +] + +[[fields]] +name = "nodes" +type = "::FlexFlow::query_set<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/node/node_query.struct.toml b/lib/utils/include/utils/graph/node/node_query.struct.toml deleted file mode 100644 index 0519e01650..0000000000 --- a/lib/utils/include/utils/graph/node/node_query.struct.toml +++ /dev/null @@ -1,17 +0,0 @@ -namespace = "FlexFlow" -name = "NodeQuery" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/node/node.dtg.h", - "utils/graph/query_set.h", -] - -[[fields]] -name = "nodes" -type = "::FlexFlow::query_set<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.dtg.toml b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.dtg.toml new file mode 100644 index 0000000000..0d7e8440af --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "NewDataflowGraphInput" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", +] + +[[fields]] +name = "raw_input" +type = "::FlexFlow::DataflowGraphInput" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.struct.toml b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.struct.toml deleted file mode 100644 index 76b062e211..0000000000 --- a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "NewDataflowGraphInput" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", -] - -[[fields]] -name = "raw_input" -type = "::FlexFlow::DataflowGraphInput" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_data.dtg.toml b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_data.dtg.toml new file mode 100644 index 0000000000..98285329fd --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_data.dtg.toml @@ -0,0 +1,37 @@ +namespace = "FlexFlow" +name = "OpenDataflowGraphData" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", + "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h", + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", + "utils/graph/dataflow_graph/dataflow_output.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "nodes" +type = "std::unordered_set<::FlexFlow::Node>" + +[[fields]] +name = "edges" +type = "std::unordered_set<::FlexFlow::OpenDataflowEdge>" + +[[fields]] +name = "inputs" +type = "std::unordered_set<::FlexFlow::DataflowGraphInput>" + +[[fields]] +name = "outputs" +type = "std::unordered_set<::FlexFlow::DataflowOutput>" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_data.struct.toml b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_data.struct.toml deleted file mode 100644 index 467ca73b3f..0000000000 --- a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_data.struct.toml +++ /dev/null @@ -1,36 +0,0 @@ -namespace = "FlexFlow" -name = "OpenDataflowGraphData" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "utils/graph/node/node.dtg.h", - "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h", - "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", - "utils/graph/dataflow_graph/dataflow_output.dtg.h", - "", -] - -src_includes = [ - "utils/hash/unordered_set.h", - "utils/fmt/unordered_set.h", -] - -[[fields]] -name = "nodes" -type = "std::unordered_set<::FlexFlow::Node>" - -[[fields]] -name = "edges" -type = "std::unordered_set<::FlexFlow::OpenDataflowEdge>" - -[[fields]] -name = "inputs" -type = "std::unordered_set<::FlexFlow::DataflowGraphInput>" - -[[fields]] -name = "outputs" -type = "std::unordered_set<::FlexFlow::DataflowOutput>" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.dtg.toml b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.dtg.toml new file mode 100644 index 0000000000..d85a29403f --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.dtg.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "OpenDataflowGraphIsomorphism" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/bidict/bidict.h", + "utils/graph/node/node.dtg.h", + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", +] + +[[fields]] +name = "node_mapping" +type = "::FlexFlow::bidict<::FlexFlow::Node, ::FlexFlow::Node>" + +[[fields]] +name = "input_mapping" +type = "::FlexFlow::bidict<::FlexFlow::DataflowGraphInput, ::FlexFlow::DataflowGraphInput>" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.struct.toml b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.struct.toml deleted file mode 100644 index bafe3c7117..0000000000 --- a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.struct.toml +++ /dev/null @@ -1,21 +0,0 @@ -namespace = "FlexFlow" -name = "OpenDataflowGraphIsomorphism" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "utils/bidict/bidict.h", - "utils/graph/node/node.dtg.h", - "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", -] - -[[fields]] -name = "node_mapping" -type = "::FlexFlow::bidict<::FlexFlow::Node, ::FlexFlow::Node>" - -[[fields]] -name = "input_mapping" -type = "::FlexFlow::bidict<::FlexFlow::DataflowGraphInput, ::FlexFlow::DataflowGraphInput>" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_subgraph_result.dtg.toml b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_subgraph_result.dtg.toml new file mode 100644 index 0000000000..10339080d5 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_subgraph_result.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "OpenDataflowSubgraphResult" +type = "struct" +features = [] + +includes = [ + "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h", + "utils/bidict/bidict.h", + "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h", + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", +] + +[[fields]] +name = "graph" +type = "::FlexFlow::OpenDataflowGraphView" + +[[fields]] +name = "full_graph_values_to_subgraph_inputs" +type = "::FlexFlow::bidict<::FlexFlow::OpenDataflowValue, ::FlexFlow::DataflowGraphInput>" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_subgraph_result.struct.toml b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_subgraph_result.struct.toml deleted file mode 100644 index 99e1ea5dd2..0000000000 --- a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_subgraph_result.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "OpenDataflowSubgraphResult" -features = [] - -includes = [ - "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h", - "utils/bidict/bidict.h", - "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h", - "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", -] - -[[fields]] -name = "graph" -type = "::FlexFlow::OpenDataflowGraphView" - -[[fields]] -name = "full_graph_values_to_subgraph_inputs" -type = "::FlexFlow::bidict<::FlexFlow::OpenDataflowValue, ::FlexFlow::DataflowGraphInput>" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.toml b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.toml new file mode 100644 index 0000000000..ba26d75e03 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "DataflowGraphInput" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "", +] + +[[fields]] +name = "idx" +type = "size_t" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_graph_input.struct.toml b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_graph_input.struct.toml deleted file mode 100644 index e9e52be893..0000000000 --- a/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_graph_input.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "DataflowGraphInput" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "", -] - -[[fields]] -name = "idx" -type = "size_t" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge.dtg.toml b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge.dtg.toml new file mode 100644 index 0000000000..138f17edee --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge.dtg.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "DataflowInputEdge" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", + "utils/graph/dataflow_graph/dataflow_input.dtg.h", +] + +[[fields]] +name = "src" +type = "::FlexFlow::DataflowGraphInput" + +[[fields]] +name = "dst" +type = "::FlexFlow::DataflowInput" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge.struct.toml b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge.struct.toml deleted file mode 100644 index fdfcfcf511..0000000000 --- a/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge.struct.toml +++ /dev/null @@ -1,21 +0,0 @@ -namespace = "FlexFlow" -name = "DataflowInputEdge" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", - "utils/graph/dataflow_graph/dataflow_input.dtg.h", -] - -[[fields]] -name = "src" -type = "::FlexFlow::DataflowGraphInput" - -[[fields]] -name = "dst" -type = "::FlexFlow::DataflowInput" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.dtg.toml b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.dtg.toml new file mode 100644 index 0000000000..d1f81c2d71 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.dtg.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "DataflowInputEdgeQuery" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/query_set.h", + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", + "utils/graph/node/node.dtg.h", + "utils/nonnegative_int/nonnegative_int.h", +] + +[[fields]] +name = "srcs" +type = "::FlexFlow::query_set<::FlexFlow::DataflowGraphInput>" + +[[fields]] +name = "dst_nodes" +type = "::FlexFlow::query_set<::FlexFlow::Node>" + +[[fields]] +name = "dst_idxs" +type = "::FlexFlow::query_set<::FlexFlow::nonnegative_int>" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.struct.toml b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.struct.toml deleted file mode 100644 index f67e8b88e0..0000000000 --- a/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.struct.toml +++ /dev/null @@ -1,27 +0,0 @@ -namespace = "FlexFlow" -name = "DataflowInputEdgeQuery" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/query_set.h", - "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", - "utils/graph/node/node.dtg.h", - "utils/nonnegative_int/nonnegative_int.h", -] - -[[fields]] -name = "srcs" -type = "::FlexFlow::query_set<::FlexFlow::DataflowGraphInput>" - -[[fields]] -name = "dst_nodes" -type = "::FlexFlow::query_set<::FlexFlow::Node>" - -[[fields]] -name = "dst_idxs" -type = "::FlexFlow::query_set<::FlexFlow::nonnegative_int>" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.toml b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.toml new file mode 100644 index 0000000000..722b1ba682 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "OpenDataflowEdge" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_dataflow_graph/dataflow_input_edge.dtg.h", + "utils/graph/dataflow_graph/dataflow_edge.dtg.h", +] + +[[values]] +type = "::FlexFlow::DataflowInputEdge" + +[[values]] +type = "::FlexFlow::DataflowEdge" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.variant.toml b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.variant.toml deleted file mode 100644 index 29f14fcf0d..0000000000 --- a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.variant.toml +++ /dev/null @@ -1,19 +0,0 @@ -namespace = "FlexFlow" -name = "OpenDataflowEdge" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/open_dataflow_graph/dataflow_input_edge.dtg.h", - "utils/graph/dataflow_graph/dataflow_edge.dtg.h", -] - -[[values]] -type = "::FlexFlow::DataflowInputEdge" - -[[values]] -type = "::FlexFlow::DataflowEdge" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.dtg.toml b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.dtg.toml new file mode 100644 index 0000000000..ae28e5b204 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.dtg.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "OpenDataflowEdgeQuery" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/dataflow_graph/dataflow_edge_query.dtg.h", + "utils/graph/open_dataflow_graph/dataflow_input_edge_query.dtg.h", +] + +[[fields]] +name = "input_edge_query" +type = "::FlexFlow::DataflowInputEdgeQuery" + +[[fields]] +name = "standard_edge_query" +type = "::FlexFlow::DataflowEdgeQuery" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.struct.toml b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.struct.toml deleted file mode 100644 index 1e2bb9221e..0000000000 --- a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.struct.toml +++ /dev/null @@ -1,21 +0,0 @@ -namespace = "FlexFlow" -name = "OpenDataflowEdgeQuery" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/dataflow_graph/dataflow_edge_query.dtg.h", - "utils/graph/open_dataflow_graph/dataflow_input_edge_query.dtg.h", -] - -[[fields]] -name = "input_edge_query" -type = "::FlexFlow::DataflowInputEdgeQuery" - -[[fields]] -name = "standard_edge_query" -type = "::FlexFlow::DataflowEdgeQuery" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_value.dtg.toml b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_value.dtg.toml new file mode 100644 index 0000000000..f77af58411 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_value.dtg.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "OpenDataflowValue" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/dataflow_graph/dataflow_output.dtg.h", + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", +] + +[[values]] +type = "::FlexFlow::DataflowOutput" + +[[values]] +type = "::FlexFlow::DataflowGraphInput" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_value.variant.toml b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_value.variant.toml deleted file mode 100644 index ba28a8772a..0000000000 --- a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_value.variant.toml +++ /dev/null @@ -1,19 +0,0 @@ -namespace = "FlexFlow" -name = "OpenDataflowValue" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/dataflow_graph/dataflow_output.dtg.h", - "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", -] - -[[values]] -type = "::FlexFlow::DataflowOutput" - -[[values]] -type = "::FlexFlow::DataflowGraphInput" diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/find_isomorphisms_between_open_kwarg_dataflow_graphs.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/find_isomorphisms_between_open_kwarg_dataflow_graphs.h new file mode 100644 index 0000000000..179211f44b --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/find_isomorphisms_between_open_kwarg_dataflow_graphs.h @@ -0,0 +1,264 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISMS_BETWEEN_OPEN_KWARG_DATAFLOW_GRAPHS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISMS_BETWEEN_OPEN_KWARG_DATAFLOW_GRAPHS_H + +#include "utils/bidict/algorithms/bidict_from_keys_and_values.h" +#include "utils/bidict/algorithms/left_entries.h" +#include "utils/bidict/algorithms/right_entries.h" +#include "utils/containers/get_all_permutations.h" +#include "utils/containers/values.h" +#include "utils/containers/vector_of.h" +#include "utils/containers/zip_values_strict.h" +#include "utils/graph/digraph/algorithms/get_terminal_nodes.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_incoming_open_kwarg_dataflow_edges_for_node.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_unused_open_kwarg_dataflow_graph_inputs.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_isomorphism.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graphs_are_isomorphic_under.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_value.dtg.h" + +namespace FlexFlow { + +template +std::optional> + find_isomorphism_under_sink_node_mapping( + OpenKwargDataflowGraphView const &src_g, + OpenKwargDataflowGraphView const &dst_g, + bidict const &sink_node_mapping, + bidict, + KwargDataflowGraphInput> const + &unused_graph_inputs_mapping) { + { + std::unordered_set already_mapped_src_nodes = + left_entries(sink_node_mapping); + std::unordered_set src_g_sink_nodes = get_terminal_nodes(src_g); + ASSERT(already_mapped_src_nodes == src_g_sink_nodes); + } + + { + std::unordered_set already_mapped_dst_nodes = + right_entries(sink_node_mapping); + std::unordered_set dst_g_sink_nodes = get_terminal_nodes(dst_g); + ASSERT(already_mapped_dst_nodes == dst_g_sink_nodes); + } + + { + std::unordered_set> + already_mapped_src_inputs = right_entries(unused_graph_inputs_mapping); + std::unordered_set> + src_g_unused_inputs = + get_unused_open_kwarg_dataflow_graph_inputs(src_g); + ASSERT(already_mapped_src_inputs == src_g_unused_inputs); + } + + { + std::unordered_set> + already_mapped_dst_inputs = right_entries(unused_graph_inputs_mapping); + std::unordered_set> + dst_g_unused_inputs = + get_unused_open_kwarg_dataflow_graph_inputs(dst_g); + ASSERT(already_mapped_dst_inputs == dst_g_unused_inputs); + } + + std::optional> result = + OpenKwargDataflowGraphIsomorphism{ + bidict{}, + unused_graph_inputs_mapping, + }; + + auto fail = [&]() -> void { result = std::nullopt; }; + + auto has_failed = [&]() -> bool { return result == std::nullopt; }; + + std::function unify_nodes; + std::function const &, + OpenKwargDataflowEdge const &)> + unify_edges; + std::function const &, + KwargDataflowGraphInput const &)> + unify_graph_inputs; + std::function const &, + OpenKwargDataflowValue const &)> + unify_values; + std::function const &, + KwargDataflowOutput const &)> + unify_outputs; + + unify_outputs = [&](KwargDataflowOutput const &src_output, + KwargDataflowOutput const &dst_output) { + if (has_failed()) { + return; + } + + if (src_output.slot_name != dst_output.slot_name) { + result = std::nullopt; + return; + } + + unify_nodes(src_output.node, dst_output.node); + }; + + unify_values = + [&](OpenKwargDataflowValue const &src_val, + OpenKwargDataflowValue const &dst_val) { + if (has_failed()) { + return; + } + + if (src_val.index() != dst_val.index()) { + fail(); + return; + } + + if (src_val.is_internal()) { + unify_outputs(src_val.require_internal(), dst_val.require_internal()); + } else { + unify_graph_inputs(src_val.require_external(), + dst_val.require_external()); + } + }; + + unify_graph_inputs = [&](KwargDataflowGraphInput const &src, + KwargDataflowGraphInput const &dst) { + if (has_failed()) { + return; + } + + if (result->input_mapping.contains_l(src) && + result->input_mapping.at_l(src) != dst) { + fail(); + return; + } + if (result->input_mapping.contains_r(dst) && + result->input_mapping.at_r(dst) != src) { + fail(); + return; + } + + result->input_mapping.equate(src, dst); + }; + + unify_edges = + [&](OpenKwargDataflowEdge const &src_edge, + OpenKwargDataflowEdge const &dst_edge) { + if (has_failed()) { + return; + } + + ASSERT(get_dst_of_open_kwarg_dataflow_edge(src_edge).slot_name == + get_dst_of_open_kwarg_dataflow_edge(dst_edge).slot_name); + ASSERT(get_dst_of_open_kwarg_dataflow_edge(src_edge).node == + result->node_mapping.at_r( + get_dst_of_open_kwarg_dataflow_edge(dst_edge).node)); + + unify_values(get_src_of_open_kwarg_dataflow_edge(src_edge), + get_src_of_open_kwarg_dataflow_edge(dst_edge)); + }; + + unify_nodes = [&](Node const &src_node, Node const &dst_node) { + if (has_failed()) { + return; + } + + if (result->node_mapping.contains(src_node, dst_node)) { + return; + } + + if (result->node_mapping.contains_l(src_node) && + result->node_mapping.at_l(src_node) != dst_node) { + fail(); + return; + } + if (result->node_mapping.contains_r(dst_node) && + result->node_mapping.at_r(dst_node) != src_node) { + fail(); + return; + } + + result->node_mapping.equate(src_node, dst_node); + + std::unordered_map> + src_incoming_edges = + get_incoming_open_kwarg_dataflow_edges_for_node(src_g, src_node); + std::unordered_map> + dst_incoming_edges = + get_incoming_open_kwarg_dataflow_edges_for_node(dst_g, dst_node); + + if (src_incoming_edges.size() != dst_incoming_edges.size()) { + fail(); + return; + } + + for (auto const &[src_edge, dst_edge] : + values(zip_values_strict(src_incoming_edges, dst_incoming_edges))) { + unify_edges(src_edge, dst_edge); + } + }; + + for (auto const &[src_node, dst_node] : sink_node_mapping) { + unify_nodes(src_node, dst_node); + } + + return result; +} + +template +std::unordered_set> + find_isomorphisms_between_open_kwarg_dataflow_graphs( + OpenKwargDataflowGraphView const &src, + OpenKwargDataflowGraphView const &dst) { + std::unordered_set> result; + + std::vector src_sink_nodes = vector_of(get_terminal_nodes(src)); + std::unordered_set dst_sink_nodes = get_terminal_nodes(dst); + + if (src_sink_nodes.size() != dst_sink_nodes.size()) { + return {}; + } + + std::vector> src_unused_graph_inputs = + vector_of(get_unused_open_kwarg_dataflow_graph_inputs(src)); + std::unordered_set> + dst_unused_graph_inputs = + get_unused_open_kwarg_dataflow_graph_inputs(dst); + + if (src_unused_graph_inputs.size() != dst_unused_graph_inputs.size()) { + return {}; + } + + for (std::vector const &dst_sink_nodes : + get_all_permutations(dst_sink_nodes)) { + + bidict sink_node_mapping = + bidict_from_keys_and_values(src_sink_nodes, dst_sink_nodes); + + for (std::vector> const + &dst_unused_graph_inputs : + get_all_permutations(dst_unused_graph_inputs)) { + + bidict, + KwargDataflowGraphInput> + unused_graph_inputs_mapping = bidict_from_keys_and_values( + src_unused_graph_inputs, dst_unused_graph_inputs); + + std::optional> found = + find_isomorphism_under_sink_node_mapping( + src, dst, sink_node_mapping, unused_graph_inputs_mapping); + + if (found.has_value()) { + ASSERT(open_kwarg_dataflow_graphs_are_isomorphic_under( + src, dst, found.value())); + + result.insert(found.value()); + } + } + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/generate_new_kwarg_dataflow_graph_input_id_permutation.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/generate_new_kwarg_dataflow_graph_input_id_permutation.h new file mode 100644 index 0000000000..060463d6be --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/generate_new_kwarg_dataflow_graph_input_id_permutation.h @@ -0,0 +1,41 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GENERATE_NEW_KWARG_DATAFLOW_GRAPH_INPUT_ID_PERMUTATION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GENERATE_NEW_KWARG_DATAFLOW_GRAPH_INPUT_ID_PERMUTATION_H + +#include "utils/bidict/generate_bidict.h" +#include "utils/containers/contains.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_graph_inputs.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +bidict, + KwargDataflowGraphInput> + generate_new_kwarg_dataflow_graph_input_id_permutation( + OpenKwargDataflowGraphView const &g, + std::function const &input_id_source) { + std::unordered_set> old_graph_inputs = + get_all_kwarg_dataflow_graph_inputs(g); + + auto fresh_input_id = [&]() -> GraphInputName { + while (true) { + GraphInputName candidate = input_id_source(); + + if (!contains(old_graph_inputs, KwargDataflowGraphInput{candidate})) { + return candidate; + } + } + }; + + return generate_bidict(old_graph_inputs, + [&](KwargDataflowGraphInput const &) { + return KwargDataflowGraphInput{ + fresh_input_id(), + }; + }) + .reversed(); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_graph_inputs.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_graph_inputs.h new file mode 100644 index 0000000000..b14deaebf4 --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_graph_inputs.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_ALL_KWARG_DATAFLOW_GRAPH_INPUTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_ALL_KWARG_DATAFLOW_GRAPH_INPUTS_H + +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +std::unordered_set> + get_all_kwarg_dataflow_graph_inputs( + OpenKwargDataflowGraphView const &view) { + return view.get_inputs(); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_open_kwarg_dataflow_edges.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_open_kwarg_dataflow_edges.h new file mode 100644 index 0000000000..7459dac065 --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_open_kwarg_dataflow_edges.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_ALL_OPEN_KWARG_DATAFLOW_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_ALL_OPEN_KWARG_DATAFLOW_EDGES_H + +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge_query.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +std::unordered_set> + get_all_open_kwarg_dataflow_edges( + OpenKwargDataflowGraphView const &view) { + return view.query_edges( + open_kwarg_dataflow_edge_query_all()); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_open_kwarg_dataflow_values.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_open_kwarg_dataflow_values.h new file mode 100644 index 0000000000..73b0e9c29d --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_open_kwarg_dataflow_values.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_ALL_OPEN_KWARG_DATAFLOW_VALUES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_ALL_OPEN_KWARG_DATAFLOW_VALUES_H + +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_outputs.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_graph_inputs.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_value.dtg.h" + +namespace FlexFlow { + +template +std::unordered_set> + get_all_open_kwarg_dataflow_values( + OpenKwargDataflowGraphView const &g) { + std::unordered_set> internal_values = + get_all_kwarg_dataflow_outputs(g); + + std::unordered_set> external_values = + get_all_kwarg_dataflow_graph_inputs(g); + + return set_union( + transform(internal_values, + [](KwargDataflowOutput const &o) { + return OpenKwargDataflowValue(o); + }), + transform(external_values, + [](KwargDataflowGraphInput const &i) { + return OpenKwargDataflowValue(i); + })); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_incoming_open_kwarg_dataflow_edges_for_node.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_incoming_open_kwarg_dataflow_edges_for_node.h new file mode 100644 index 0000000000..661105b3d6 --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_incoming_open_kwarg_dataflow_edges_for_node.h @@ -0,0 +1,44 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_INCOMING_OPEN_KWARG_DATAFLOW_EDGES_FOR_NODE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_INCOMING_OPEN_KWARG_DATAFLOW_EDGES_FOR_NODE_H + +#include "utils/containers/unordered_map_from_pairs.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +std::unordered_map> + get_incoming_open_kwarg_dataflow_edges_for_node( + OpenKwargDataflowGraphView const &g, + Node const &n) { + OpenKwargDataflowEdgeQuery query = + OpenKwargDataflowEdgeQuery{ + /*input_edge_query=*/ + KwargDataflowInputEdgeQuery{ + /*srcs=*/query_set::matchall(), + /*dst_nodes=*/query_set{n}, + /*dst_slots=*/query_set::matchall(), + }, + /*standard_edge_query=*/ + KwargDataflowEdgeQuery{ + /*src_nodes=*/query_set::matchall(), + /*src_slots=*/query_set::matchall(), + /*dst_nodes=*/query_set{n}, + /*dst_slots=*/query_set::matchall(), + }, + }; + + return unordered_map_from_pairs( + transform(g.query_edges(query), + [](OpenKwargDataflowEdge const &e) { + return std::pair{ + get_dst_of_open_kwarg_dataflow_edge(e).slot_name, + e, + }; + })); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_incoming_open_kwarg_dataflow_values_for_node.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_incoming_open_kwarg_dataflow_values_for_node.h new file mode 100644 index 0000000000..e30d554e89 --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_incoming_open_kwarg_dataflow_values_for_node.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_INCOMING_OPEN_KWARG_DATAFLOW_VALUES_FOR_NODE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_INCOMING_OPEN_KWARG_DATAFLOW_VALUES_FOR_NODE_H + +#include "utils/containers/map_values.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_incoming_open_kwarg_dataflow_edges_for_node.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +std::unordered_map> + get_incoming_open_kwarg_dataflow_values_for_node( + OpenKwargDataflowGraphView const &g, + Node const &n) { + return map_values(get_incoming_open_kwarg_dataflow_edges_for_node(g, n), + [](OpenKwargDataflowEdge const &e) + -> OpenKwargDataflowValue { + return get_src_of_open_kwarg_dataflow_edge(e); + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_graph_data.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_graph_data.h new file mode 100644 index 0000000000..b27ad13cc9 --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_graph_data.h @@ -0,0 +1,27 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_OPEN_KWARG_DATAFLOW_GRAPH_DATA_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_OPEN_KWARG_DATAFLOW_GRAPH_DATA_H + +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_outputs.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_graph_inputs.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_open_kwarg_dataflow_edges.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +OpenKwargDataflowGraphData + get_open_kwarg_dataflow_graph_data( + OpenKwargDataflowGraphView const &g) { + return OpenKwargDataflowGraphData{ + /*nodes=*/get_nodes(g), + /*edges=*/get_all_open_kwarg_dataflow_edges(g), + /*inputs=*/get_all_kwarg_dataflow_graph_inputs(g), + /*outputs=*/get_all_kwarg_dataflow_outputs(g), + }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_graph_subgraph.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_graph_subgraph.h new file mode 100644 index 0000000000..d45fdfa640 --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_graph_subgraph.h @@ -0,0 +1,137 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_OPEN_KWARG_DATAFLOW_GRAPH_SUBGRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_OPEN_KWARG_DATAFLOW_GRAPH_SUBGRAPH_H + +#include "utils/bidict/generate_bidict.h" +#include "utils/containers/set_union.h" +#include "utils/containers/values.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output_query.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_subgraph_incoming_edges.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_subgraph_inputs.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_subgraph_result.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/view_from_open_kwarg_dataflow_graph_data.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +OpenKwargDataflowSubgraphResult + get_open_kwarg_dataflow_graph_subgraph( + OpenKwargDataflowGraphView const &g, + std::unordered_set const &subgraph_nodes, + std::function const &input_source) { + bidict, + KwargDataflowGraphInput> + full_graph_values_to_subgraph_inputs = + get_full_kwarg_dataflow_graph_values_to_subgraph_inputs( + g, subgraph_nodes, input_source); + + return OpenKwargDataflowSubgraphResult{ + view_from_open_kwarg_dataflow_graph_data( + get_open_kwarg_dataflow_subgraph_data( + g, subgraph_nodes, full_graph_values_to_subgraph_inputs)), + full_graph_values_to_subgraph_inputs, + }; +} + +template +bidict, + KwargDataflowGraphInput> + get_full_kwarg_dataflow_graph_values_to_subgraph_inputs( + OpenKwargDataflowGraphView const &g, + std::unordered_set const &subgraph_nodes, + std::function const &input_source) { + return generate_bidict( + get_open_kwarg_dataflow_subgraph_inputs(g, subgraph_nodes), + [&](OpenKwargDataflowValue const &v) + -> KwargDataflowGraphInput { + return v.template visit>( + overload{ + [](KwargDataflowGraphInput const &i) { + return i; + }, + [&](KwargDataflowOutput const &) { + return KwargDataflowGraphInput{ + input_source(), + }; + }, + }); + }); +} + +template +OpenKwargDataflowGraphData + get_open_kwarg_dataflow_subgraph_data( + OpenKwargDataflowGraphView const &g, + std::unordered_set const &subgraph_nodes, + bidict, + KwargDataflowGraphInput> const + &full_graph_values_to_subgraph_inputs) { + std::unordered_set> + subgraph_input_edges = transform( + get_open_kwarg_dataflow_subgraph_incoming_edges(g, subgraph_nodes), + [&](OpenKwargDataflowEdge const &edge) { + return edge.template visit< + OpenKwargDataflowEdge>(overload{ + [&](KwargDataflowInputEdge const &e) + -> OpenKwargDataflowEdge { + return OpenKwargDataflowEdge{ + KwargDataflowInputEdge{ + full_graph_values_to_subgraph_inputs.at_l( + OpenKwargDataflowValue{ + e.src}), + e.dst}, + }; + }, + [&](KwargDataflowEdge const &e) { + return OpenKwargDataflowEdge{ + KwargDataflowInputEdge{ + full_graph_values_to_subgraph_inputs.at_l( + OpenKwargDataflowValue{ + e.src}), + e.dst, + }, + }; + }, + }); + }); + + OpenKwargDataflowEdgeQuery + subgraph_interior_edges_query = + OpenKwargDataflowEdgeQuery{ + KwargDataflowInputEdgeQuery{ + /*srcs=*/query_set::match_none(), + /*dst_nodes=*/query_set::match_none(), + /*dst_slots=*/query_set::match_none(), + }, + KwargDataflowEdgeQuery{ + /*srcs=*/query_set{subgraph_nodes}, + /*src_slots=*/query_set::matchall(), + /*dsts=*/query_set{subgraph_nodes}, + /*dst_slots=*/query_set::matchall(), + }, + }; + + std::unordered_set> + subgraph_interior_edges = g.query_edges(subgraph_interior_edges_query); + + std::unordered_set> subgraph_inputs = + unordered_set_of(values(full_graph_values_to_subgraph_inputs)); + + std::unordered_set> subgraph_outputs = + filter(g.query_outputs(kwarg_dataflow_output_query_all()), + [&](KwargDataflowOutput const &o) { + return contains(subgraph_nodes, o.node); + }); + + return OpenKwargDataflowGraphData{ + subgraph_nodes, + set_union(subgraph_input_edges, subgraph_interior_edges), + subgraph_inputs, + subgraph_outputs, + }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_subgraph_incoming_edges.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_subgraph_incoming_edges.h new file mode 100644 index 0000000000..c711f79100 --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_subgraph_incoming_edges.h @@ -0,0 +1,40 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_OPEN_KWARG_DATAFLOW_SUBGRAPH_INCOMING_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_OPEN_KWARG_DATAFLOW_SUBGRAPH_INCOMING_EDGES_H + +#include "utils/containers/set_minus.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +std::unordered_set> + get_open_kwarg_dataflow_subgraph_incoming_edges( + OpenKwargDataflowGraphView const &g, + std::unordered_set const &subgraph) { + std::unordered_set all_nodes = get_nodes(g); + query_set src_query = query_set{set_minus(all_nodes, subgraph)}; + + OpenKwargDataflowEdgeQuery query = + OpenKwargDataflowEdgeQuery{ + /*input_edge_query=*/KwargDataflowInputEdgeQuery{ + /*srcs=*/query_set::matchall(), + /*dst_nodes=*/query_set{subgraph}, + /*dst_slots=*/query_set::matchall(), + }, + /*standard_edge_query=*/ + KwargDataflowEdgeQuery{ + /*src_nodes=*/src_query, + /*src_slots=*/query_set::matchall(), + /*dst_nodes=*/query_set{subgraph}, + /*dst_slots=*/query_set::matchall(), + }, + }; + + return g.query_edges(query); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_subgraph_inputs.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_subgraph_inputs.h new file mode 100644 index 0000000000..2e724bc217 --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_subgraph_inputs.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_OPEN_KWARG_DATAFLOW_SUBGRAPH_INPUTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_OPEN_KWARG_DATAFLOW_SUBGRAPH_INPUTS_H + +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_value.dtg.h" + +namespace FlexFlow { + +template +std::unordered_set> + get_open_kwarg_dataflow_subgraph_inputs( + OpenKwargDataflowGraphView const &g, + std::unordered_set const &subgraph_nodes) { + + return transform( + get_open_kwarg_dataflow_subgraph_incoming_edges(g, subgraph_nodes), + [](OpenKwargDataflowEdge const &e) { + return get_src_of_open_kwarg_dataflow_edge(e); + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_value_uses.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_value_uses.h new file mode 100644 index 0000000000..94f729abbf --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_value_uses.h @@ -0,0 +1,54 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_OPEN_KWARG_DATAFLOW_VALUE_USES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_OPEN_KWARG_DATAFLOW_VALUE_USES_H + +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge_query.h" +#include "utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_input_edge_query.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_value.dtg.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +std::unordered_set> + get_open_kwarg_dataflow_value_uses( + OpenKwargDataflowGraphView const &g, + OpenKwargDataflowValue const &v) { + + OpenKwargDataflowEdgeQuery query = v.template visit< + OpenKwargDataflowEdgeQuery>(overload{ + [&](KwargDataflowOutput const &o) { + return OpenKwargDataflowEdgeQuery{ + kwarg_dataflow_input_edge_query_none(), + KwargDataflowEdgeQuery{ + /*src_nodes=*/query_set{o.node}, + /*src_slots=*/query_set{o.slot_name}, + /*dst_nodes=*/query_set::matchall(), + /*dst_slots=*/query_set::matchall(), + }, + }; + }, + [&](KwargDataflowGraphInput const &i) { + return OpenKwargDataflowEdgeQuery{ + KwargDataflowInputEdgeQuery{ + /*srcs=*/query_set{i.name}, + /*dst_nodes=*/query_set::matchall(), + /*dst_slots=*/query_set::matchall(), + }, + kwarg_dataflow_edge_query_none(), + }; + }}); + + std::unordered_set> edges = + g.query_edges(query); + + return transform( + edges, [&](OpenKwargDataflowEdge const &e) { + return get_dst_of_open_kwarg_dataflow_edge(e); + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_unused_open_kwarg_dataflow_graph_inputs.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_unused_open_kwarg_dataflow_graph_inputs.h new file mode 100644 index 0000000000..b6880abecd --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_unused_open_kwarg_dataflow_graph_inputs.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_UNUSED_OPEN_KWARG_DATAFLOW_GRAPH_INPUTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_UNUSED_OPEN_KWARG_DATAFLOW_GRAPH_INPUTS_H + +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_graph_inputs.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_value_uses.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +std::unordered_set> + get_unused_open_kwarg_dataflow_graph_inputs( + OpenKwargDataflowGraphView const &g) { + return filter( + get_all_kwarg_dataflow_graph_inputs(g), + [&](KwargDataflowGraphInput const &i) { + return get_open_kwarg_dataflow_value_uses( + g, OpenKwargDataflowValue{i}) + .empty(); + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/new_kwarg_dataflow_graph_input.dtg.toml b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/new_kwarg_dataflow_graph_input.dtg.toml new file mode 100644 index 0000000000..f790504fbb --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/new_kwarg_dataflow_graph_input.dtg.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "NewKwargDataflowGraphInput" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +template_params = [ + "GraphInputName", +] + +includes = [ + "utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_graph_input.dtg.h", +] + +[[fields]] +name = "raw_input" +type = "::FlexFlow::KwargDataflowGraphInput" diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.dtg.toml b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.dtg.toml new file mode 100644 index 0000000000..49c139cd0a --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.dtg.toml @@ -0,0 +1,42 @@ +namespace = "FlexFlow" +name = "OpenKwargDataflowGraphData" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = [ + "GraphInputName", + "SlotName", +] + +includes = [ + "utils/graph/node/node.dtg.h", + "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.dtg.h", + "utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_graph_input.dtg.h", + "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "nodes" +type = "std::unordered_set<::FlexFlow::Node>" + +[[fields]] +name = "edges" +type = "std::unordered_set<::FlexFlow::OpenKwargDataflowEdge>" + +[[fields]] +name = "inputs" +type = "std::unordered_set<::FlexFlow::KwargDataflowGraphInput>" + +[[fields]] +name = "outputs" +type = "std::unordered_set<::FlexFlow::KwargDataflowOutput>" diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.h new file mode 100644 index 0000000000..46a3577a06 --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_OPEN_KWARG_DATAFLOW_GRAPH_DATA_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_OPEN_KWARG_DATAFLOW_GRAPH_DATA_H + +#include "utils/containers/filtrans.h" +#include "utils/containers/is_subseteq_of.h" +#include "utils/containers/transform.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.dtg.h" + +namespace FlexFlow { + +template +void require_open_kwarg_dataflow_graph_data_is_valid( + OpenKwargDataflowGraphData const &data) { + std::unordered_set> + inputs_from_edges = filtrans( + data.edges, + [](OpenKwargDataflowEdge const &e) + -> std::optional> { + return transform( + e.try_require_input_edge(), + [](KwargDataflowInputEdge const + &input_e) -> KwargDataflowGraphInput { + return input_e.src; + }); + }); + + ASSERT(is_subseteq_of(inputs_from_edges, data.inputs)); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_isomorphism.dtg.toml b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_isomorphism.dtg.toml new file mode 100644 index 0000000000..f7b08229d5 --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_isomorphism.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "OpenKwargDataflowGraphIsomorphism" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = [ + "GraphInputName", +] + +includes = [ + "utils/bidict/bidict.h", + "utils/graph/node/node.dtg.h", + "utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_graph_input.dtg.h", +] + +[[fields]] +name = "node_mapping" +type = "::FlexFlow::bidict<::FlexFlow::Node, ::FlexFlow::Node>" + +[[fields]] +name = "input_mapping" +type = "::FlexFlow::bidict<::FlexFlow::KwargDataflowGraphInput, ::FlexFlow::KwargDataflowGraphInput>" diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_isomorphism.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_isomorphism.h new file mode 100644 index 0000000000..a0317e7f89 --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_isomorphism.h @@ -0,0 +1,77 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_OPEN_KWARG_DATAFLOW_GRAPH_ISOMORPHISM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_OPEN_KWARG_DATAFLOW_GRAPH_ISOMORPHISM_H + +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_isomorphism.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_value.dtg.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +OpenKwargDataflowValue + isomorphism_map_r_open_kwarg_dataflow_value_from_l( + OpenKwargDataflowGraphIsomorphism const &iso, + OpenKwargDataflowValue const &l_value) { + return l_value + .template visit>( + overload{ + [&](KwargDataflowGraphInput const &l_input) { + return OpenKwargDataflowValue{ + iso.input_mapping.at_l(l_input), + }; + }, + [&](KwargDataflowOutput const &l_output) { + return OpenKwargDataflowValue{ + isomorphism_map_r_kwarg_dataflow_output_from_l(iso, + l_output), + }; + }, + }); +} + +template +OpenKwargDataflowValue + isomorphism_map_l_open_kwarg_dataflow_value_from_r( + OpenKwargDataflowGraphIsomorphism const &iso, + OpenKwargDataflowValue const &r_value) { + return r_value + .template visit>( + overload{ + [&](KwargDataflowGraphInput const &r_input) { + return OpenKwargDataflowValue{ + iso.input_mapping.at_r(r_input), + }; + }, + [&](KwargDataflowOutput const &r_output) { + return OpenKwargDataflowValue{ + isomorphism_map_l_kwarg_dataflow_output_from_r(iso, + r_output), + }; + }, + }); +} + +template +KwargDataflowOutput isomorphism_map_r_kwarg_dataflow_output_from_l( + OpenKwargDataflowGraphIsomorphism const &iso, + KwargDataflowOutput const &l_output) { + return KwargDataflowOutput{ + iso.node_mapping.at_l(l_output.node), + l_output.slot_name, + }; +} + +template +KwargDataflowOutput isomorphism_map_l_kwarg_dataflow_output_from_r( + OpenKwargDataflowGraphIsomorphism const &iso, + KwargDataflowOutput const &r_output) { + return KwargDataflowOutput{ + iso.node_mapping.at_r(r_output.node), + r_output.slot_name, + }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graphs_are_isomorphic_under.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graphs_are_isomorphic_under.h new file mode 100644 index 0000000000..63c367a987 --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graphs_are_isomorphic_under.h @@ -0,0 +1,43 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_OPEN_KWARG_DATAFLOW_GRAPHS_ARE_ISOMORPHIC_UNDER_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_OPEN_KWARG_DATAFLOW_GRAPHS_ARE_ISOMORPHIC_UNDER_H + +#include "utils/bidict/algorithms/transform_values.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_graph_data.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_isomorphism.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/permute_open_kwarg_dataflow_graph_input_ids.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/permute_open_kwarg_dataflow_graph_node_ids.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +bool open_kwarg_dataflow_graphs_are_isomorphic_under( + OpenKwargDataflowGraphView const &src, + OpenKwargDataflowGraphView const &dst, + OpenKwargDataflowGraphIsomorphism const &isomorphism) { + bidict new_node_to_old_node = + transform_values(isomorphism.node_mapping, [](Node const &n) { + return NewNode{n}; + }).reversed(); + + bidict, + KwargDataflowGraphInput> + new_input_to_old_input = isomorphism.input_mapping.reversed(); + + OpenKwargDataflowGraphData permuted_data = + get_open_kwarg_dataflow_graph_data( + permute_open_kwarg_dataflow_graph_input_ids( + permute_open_kwarg_dataflow_graph_node_ids(src, + new_node_to_old_node), + new_input_to_old_input)); + + OpenKwargDataflowGraphData dst_data = + get_open_kwarg_dataflow_graph_data(dst); + + return permuted_data == dst_data; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_subgraph_result.dtg.toml b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_subgraph_result.dtg.toml new file mode 100644 index 0000000000..fbc4f46782 --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_subgraph_result.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "OpenKwargDataflowSubgraphResult" +type = "struct" +features = [] + +template_params = [ + "GraphInputName", + "SlotName", +] + +includes = [ + "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h", + "utils/bidict/bidict.h", + "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_value.dtg.h", + "utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_graph_input.dtg.h", +] + +[[fields]] +name = "graph" +type = "::FlexFlow::OpenKwargDataflowGraphView" + +[[fields]] +name = "full_graph_values_to_subgraph_inputs" +type = "::FlexFlow::bidict<::FlexFlow::OpenKwargDataflowValue, ::FlexFlow::KwargDataflowGraphInput>" diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/permute_open_kwarg_dataflow_graph_input_ids.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/permute_open_kwarg_dataflow_graph_input_ids.h new file mode 100644 index 0000000000..38c9fe89d2 --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/permute_open_kwarg_dataflow_graph_input_ids.h @@ -0,0 +1,66 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_OPEN_KWARG_DATAFLOW_GRAPH_INPUT_IDS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_OPEN_KWARG_DATAFLOW_GRAPH_INPUT_IDS_H + +#include "utils/containers/transform.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_graph_data.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/view_from_open_kwarg_dataflow_graph_data.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +OpenKwargDataflowGraphView + permute_open_kwarg_dataflow_graph_input_ids( + OpenKwargDataflowGraphView const &g, + bidict, + KwargDataflowGraphInput> const + &new_input_to_old_input) { + std::unordered_set> g_inputs = + get_all_kwarg_dataflow_graph_inputs(g); + ASSERT(g_inputs == new_input_to_old_input.right_values()); + + auto new_input_from_old = + [&](KwargDataflowGraphInput const &i) + -> KwargDataflowGraphInput { + return new_input_to_old_input.at_r(i); + }; + + auto new_edge_from_old = + [&](OpenKwargDataflowEdge const &e) + -> OpenKwargDataflowEdge { + return e.template visit< + OpenKwargDataflowEdge>(overload{ + [&](KwargDataflowInputEdge const &input_edge) + -> OpenKwargDataflowEdge { + return OpenKwargDataflowEdge{ + KwargDataflowInputEdge{ + /*src=*/new_input_from_old(input_edge.src), + /*dst=*/input_edge.dst, + }, + }; + }, + [](KwargDataflowEdge const &standard_edge) + -> OpenKwargDataflowEdge { + return OpenKwargDataflowEdge{standard_edge}; + }, + }); + }; + + OpenKwargDataflowGraphData old_data = + get_open_kwarg_dataflow_graph_data(g); + + OpenKwargDataflowGraphData permuted_data = + OpenKwargDataflowGraphData{ + /*nodes=*/old_data.nodes, + /*edges=*/transform(old_data.edges, new_edge_from_old), + /*inputs=*/transform(old_data.inputs, new_input_from_old), + /*outputs=*/old_data.outputs, + }; + + return view_from_open_kwarg_dataflow_graph_data(permuted_data); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/permute_open_kwarg_dataflow_graph_node_ids.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/permute_open_kwarg_dataflow_graph_node_ids.h new file mode 100644 index 0000000000..c28d8d5999 --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/permute_open_kwarg_dataflow_graph_node_ids.h @@ -0,0 +1,75 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_OPEN_KWARG_DATAFLOW_GRAPH_NODE_IDS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_OPEN_KWARG_DATAFLOW_GRAPH_NODE_IDS_H + +#include "utils/graph/node/algorithms/new_node.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_graph_data.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/view_from_open_kwarg_dataflow_graph_data.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +OpenKwargDataflowGraphView + permute_open_kwarg_dataflow_graph_node_ids( + OpenKwargDataflowGraphView const &g, + bidict const &new_node_to_old_node) { + auto new_node_from_old = [&](Node const &n) -> Node { + return new_node_to_old_node.at_r(n).raw_node; + }; + + auto new_output_from_old = [&](KwargDataflowOutput const &o) + -> KwargDataflowOutput { + return KwargDataflowOutput{ + /*node=*/new_node_from_old(o.node), + /*slot_name=*/o.slot_name, + }; + }; + + auto new_input_from_old = [&](KwargDataflowInput const &i) + -> KwargDataflowInput { + return KwargDataflowInput{ + /*node=*/new_node_from_old(i.node), + /*slot_name=*/i.slot_name, + }; + }; + + auto new_edge_from_old = [&](OpenKwargDataflowEdge const &e) { + return e.template visit>( + overload{[&](KwargDataflowInputEdge const + &input_edge) { + return OpenKwargDataflowEdge{ + KwargDataflowInputEdge{ + /*src=*/input_edge.src, + /*dst=*/new_input_from_old(input_edge.dst), + }, + }; + }, + [&](KwargDataflowEdge const &standard_edge) { + return OpenKwargDataflowEdge{ + KwargDataflowEdge{ + /*src=*/new_output_from_old(standard_edge.src), + /*dst=*/new_input_from_old(standard_edge.dst), + }, + }; + }}); + }; + + OpenKwargDataflowGraphData old_data = + get_open_kwarg_dataflow_graph_data(g); + + OpenKwargDataflowGraphData permuted_data = + OpenKwargDataflowGraphData{ + /*nodes=*/transform(old_data.nodes, new_node_from_old), + /*edges=*/transform(old_data.edges, new_edge_from_old), + /*inputs=*/old_data.inputs, + /*outputs=*/transform(old_data.outputs, new_output_from_old), + }; + + return view_from_open_kwarg_dataflow_graph_data(permuted_data); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/try_find_isomorphism_between_open_kwarg_dataflow_graphs.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/try_find_isomorphism_between_open_kwarg_dataflow_graphs.h new file mode 100644 index 0000000000..fd644c355a --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/try_find_isomorphism_between_open_kwarg_dataflow_graphs.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_TRY_FIND_ISOMORPHISM_BETWEEN_OPEN_KWARG_DATAFLOW_GRAPHS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_TRY_FIND_ISOMORPHISM_BETWEEN_OPEN_KWARG_DATAFLOW_GRAPHS_H + +#include "utils/containers/try_get_one_of.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/find_isomorphisms_between_open_kwarg_dataflow_graphs.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +std::optional> + try_find_isomorphism_between_open_kwarg_dataflow_graphs( + OpenKwargDataflowGraphView const &src, + OpenKwargDataflowGraphView const &dst) { + std::unordered_set> + isomorphisms = + find_isomorphisms_between_open_kwarg_dataflow_graphs(src, dst); + + return try_get_one_of(isomorphisms); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/view_from_open_kwarg_dataflow_graph_data.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/view_from_open_kwarg_dataflow_graph_data.h new file mode 100644 index 0000000000..4dcde44f4d --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/view_from_open_kwarg_dataflow_graph_data.h @@ -0,0 +1,69 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_VIEW_FROM_OPEN_KWARG_DATAFLOW_GRAPH_DATA_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_VIEW_FROM_OPEN_KWARG_DATAFLOW_GRAPH_DATA_H + +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output_query.h" +#include "utils/graph/node/node_query.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge_query.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +struct ViewFromOpenKwargDataflowGraphView final + : virtual public IOpenKwargDataflowGraphView { + ViewFromOpenKwargDataflowGraphView( + OpenKwargDataflowGraphData const &data) + : data(data) {} + + std::unordered_set query_nodes(NodeQuery const &query) const override { + return apply_node_query(query, this->data.nodes); + } + + std::unordered_set> + get_inputs() const override { + return this->data.inputs; + } + + std::unordered_set> + query_edges(OpenKwargDataflowEdgeQuery const + &query) const override { + return filter( + this->data.edges, + [&](OpenKwargDataflowEdge const &e) { + return open_kwarg_dataflow_edge_query_includes(query, e); + }); + } + + std::unordered_set> query_outputs( + KwargDataflowOutputQuery const &query) const override { + return filter(this->data.outputs, + [&](KwargDataflowOutput const &o) { + return kwarg_dataflow_output_query_includes(query, o); + }); + } + + ViewFromOpenKwargDataflowGraphView * + clone() const override { + return new ViewFromOpenKwargDataflowGraphView{ + this->data}; + } + +private: + OpenKwargDataflowGraphData data; +}; + +template +OpenKwargDataflowGraphView + view_from_open_kwarg_dataflow_graph_data( + OpenKwargDataflowGraphData const &data) { + require_open_kwarg_dataflow_graph_data_is_valid(data); + + return OpenKwargDataflowGraphView::template create< + ViewFromOpenKwargDataflowGraphView>(data); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/i_open_kwarg_dataflow_graph.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/i_open_kwarg_dataflow_graph.h new file mode 100644 index 0000000000..7f59628705 --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/i_open_kwarg_dataflow_graph.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_I_OPEN_KWARG_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_I_OPEN_KWARG_DATAFLOW_GRAPH_H + +#include "utils/graph/kwarg_dataflow_graph/kwarg_node_added_result.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/i_open_kwarg_dataflow_graph_view.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_value.dtg.h" + +namespace FlexFlow { + +template +struct IOpenKwargDataflowGraph + : virtual public IOpenKwargDataflowGraphView { + virtual KwargNodeAddedResult add_node( + std::unordered_map> const + &inputs, + std::unordered_set const &outputs) = 0; + virtual KwargDataflowGraphInput + add_input(GraphInputName const &name) = 0; + virtual IOpenKwargDataflowGraph *clone() const = 0; + + virtual ~IOpenKwargDataflowGraph() = default; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(IOpenKwargDataflowGraph); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/i_open_kwarg_dataflow_graph_view.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/i_open_kwarg_dataflow_graph_view.h new file mode 100644 index 0000000000..aee878a268 --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/i_open_kwarg_dataflow_graph_view.h @@ -0,0 +1,48 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_I_OPEN_KWARG_DATAFLOW_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_I_OPEN_KWARG_DATAFLOW_GRAPH_VIEW_H + +#include "utils/graph/kwarg_dataflow_graph/i_kwarg_dataflow_graph_view.h" +#include "utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_graph_input.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_input_edge_query.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge_query.dtg.h" + +namespace FlexFlow { + +template +struct IOpenKwargDataflowGraphView + : virtual public IKwargDataflowGraphView { + virtual std::unordered_set> + get_inputs() const = 0; + virtual std::unordered_set> + query_edges(OpenKwargDataflowEdgeQuery const &) + const = 0; + + std::unordered_set> query_edges( + KwargDataflowEdgeQuery const &query) const override final { + OpenKwargDataflowEdgeQuery open_query = + OpenKwargDataflowEdgeQuery{ + /*input_edge_query=*/kwarg_dataflow_input_edge_query_none< + GraphInputName, + SlotName>(), + /*standard_edge_query=*/query, + }; + + std::unordered_set> + open_edges = this->query_edges(open_query); + + return transform( + open_edges, + [](OpenKwargDataflowEdge const &e) + -> KwargDataflowEdge { + return e.require_internal_edge(); + }); + } + + virtual ~IOpenKwargDataflowGraphView() = default; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(IOpenKwargDataflowGraphView); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_graph_input.dtg.toml b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_graph_input.dtg.toml new file mode 100644 index 0000000000..05f677a901 --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_graph_input.dtg.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "KwargDataflowGraphInput" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", + "rapidcheck", +] + +template_params = [ + "T" +] + +includes = [] +src_includes = [] + +[[fields]] +name = "name" +type = "T" diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_input_edge.dtg.toml b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_input_edge.dtg.toml new file mode 100644 index 0000000000..eec48bb4e0 --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_input_edge.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "KwargDataflowInputEdge" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +template_params = [ + "GraphInputName", + "SlotName", +] + +includes = [ + "utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_graph_input.dtg.h", + "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_input.dtg.h", +] + +[[fields]] +name = "src" +type = "::FlexFlow::KwargDataflowGraphInput" + +[[fields]] +name = "dst" +type = "::FlexFlow::KwargDataflowInput" diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_input_edge_query.dtg.toml b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_input_edge_query.dtg.toml new file mode 100644 index 0000000000..88a9076d9f --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_input_edge_query.dtg.toml @@ -0,0 +1,32 @@ +namespace = "FlexFlow" +name = "KwargDataflowInputEdgeQuery" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +template_params = [ + "GraphInputName", + "SlotName", +] + +includes = [ + "utils/graph/query_set.h", + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "srcs" +type = "::FlexFlow::query_set" + +[[fields]] +name = "dst_nodes" +type = "::FlexFlow::query_set<::FlexFlow::Node>" + +[[fields]] +name = "dst_slots" +type = "::FlexFlow::query_set" diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_input_edge_query.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_input_edge_query.h new file mode 100644 index 0000000000..dc810920e5 --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_input_edge_query.h @@ -0,0 +1,40 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_KWARG_DATAFLOW_INPUT_EDGE_QUERY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_KWARG_DATAFLOW_INPUT_EDGE_QUERY_H + +#include "utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_input_edge.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_input_edge_query.dtg.h" + +namespace FlexFlow { + +template +KwargDataflowInputEdgeQuery + kwarg_dataflow_input_edge_query_all() { + return KwargDataflowInputEdgeQuery{ + /*srcs=*/query_set::matchall(), + /*dst_nodes=*/query_set::matchall(), + /*dst_slots=*/query_set::matchall(), + }; +} + +template +KwargDataflowInputEdgeQuery + kwarg_dataflow_input_edge_query_none() { + return KwargDataflowInputEdgeQuery{ + /*srcs=*/query_set::match_none(), + /*dst_nodes=*/query_set::match_none(), + /*dst_slots=*/query_set::match_none(), + }; +} + +template +bool kwarg_dataflow_input_edge_query_includes( + KwargDataflowInputEdgeQuery const &query, + KwargDataflowInputEdge const &edge) { + return includes(query.srcs, edge.src.name) && + includes(query.dst_nodes, edge.dst.node) && + includes(query.dst_slots, edge.dst.slot_name); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.dtg.toml b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.dtg.toml new file mode 100644 index 0000000000..cc9d11c608 --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "OpenKwargDataflowEdge" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +template_params = [ + "GraphInputName", + "SlotName", +] + +includes = [ + "utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_input_edge.dtg.h", + "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge.dtg.h", +] + +[[values]] +type = "::FlexFlow::KwargDataflowInputEdge" +key = "input_edge" + +[[values]] +type = "::FlexFlow::KwargDataflowEdge" +key = "internal_edge" diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.h new file mode 100644 index 0000000000..3ba930dc8a --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.h @@ -0,0 +1,65 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_OPEN_KWARG_DATAFLOW_EDGE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_OPEN_KWARG_DATAFLOW_EDGE_H + +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_value.dtg.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +OpenKwargDataflowEdge + mk_open_kwarg_dataflow_edge_from_src_val_and_dst( + OpenKwargDataflowValue const &src, + KwargDataflowInput const &dst) { + return src.template visit>( + overload{[&](KwargDataflowOutput const &output) { + return OpenKwargDataflowEdge{ + KwargDataflowEdge{ + /*src=*/output, + /*dst=*/dst, + }, + }; + }, + [&](KwargDataflowGraphInput const &graph_input) { + return OpenKwargDataflowEdge{ + KwargDataflowInputEdge{ + /*src=*/graph_input, + /*dst=*/dst, + }, + }; + }}); +} + +template +OpenKwargDataflowValue + get_src_of_open_kwarg_dataflow_edge( + OpenKwargDataflowEdge const &e) { + return e.template visit>( + overload{[](KwargDataflowInputEdge const + &external_edge) { + return OpenKwargDataflowValue{ + external_edge.src, + }; + }, + [](KwargDataflowEdge const &internal_edge) { + return OpenKwargDataflowValue{ + internal_edge.src, + }; + }}); +} + +template +KwargDataflowInput get_dst_of_open_kwarg_dataflow_edge( + OpenKwargDataflowEdge const &e) { + return e.template visit>( + overload{[](KwargDataflowInputEdge const + &external_edge) { return external_edge.dst; }, + [](KwargDataflowEdge const &internal_edge) { + return internal_edge.dst; + }}); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge_query.dtg.toml b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge_query.dtg.toml new file mode 100644 index 0000000000..de86771398 --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge_query.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "OpenKwargDataflowEdgeQuery" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +template_params = [ + "GraphInputName", + "SlotName", +] + +includes = [ + "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge_query.dtg.h", + "utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_input_edge_query.dtg.h", +] + +[[fields]] +name = "input_edge_query" +type = "::FlexFlow::KwargDataflowInputEdgeQuery" + +[[fields]] +name = "standard_edge_query" +type = "::FlexFlow::KwargDataflowEdgeQuery" diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge_query.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge_query.h new file mode 100644 index 0000000000..d15c080abe --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge_query.h @@ -0,0 +1,40 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_OPEN_KWARG_DATAFLOW_EDGE_QUERY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_OPEN_KWARG_DATAFLOW_EDGE_QUERY_H + +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge_query.h" +#include "utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_input_edge_query.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge_query.dtg.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +OpenKwargDataflowEdgeQuery + open_kwarg_dataflow_edge_query_all() { + + return OpenKwargDataflowEdgeQuery{ + /*input_edge_query=*/kwarg_dataflow_input_edge_query_all(), + /*standard_edge_query=*/kwarg_dataflow_edge_query_all(), + }; +} + +template +bool open_kwarg_dataflow_edge_query_includes( + OpenKwargDataflowEdgeQuery const &query, + OpenKwargDataflowEdge const &edge) { + return edge.template visit(overload{ + [&](KwargDataflowInputEdge const &input_edge) { + return kwarg_dataflow_input_edge_query_includes(query.input_edge_query, + input_edge); + }, + [&](KwargDataflowEdge const &internal_edge) { + return kwarg_dataflow_edge_query_includes(query.standard_edge_query, + internal_edge); + }}); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph.h new file mode 100644 index 0000000000..1c903b1dab --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph.h @@ -0,0 +1,57 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_H + +#include "utils/graph/kwarg_dataflow_graph/kwarg_node_added_result.dtg.h" +#include "utils/graph/open_kwarg_dataflow_graph/i_open_kwarg_dataflow_graph.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_value.dtg.h" + +namespace FlexFlow { + +template +struct OpenKwargDataflowGraph + : virtual public OpenKwargDataflowGraphView { +public: + KwargNodeAddedResult add_node( + std::unordered_map> const + &inputs, + std::unordered_set const &outputs) { + return this->get_interface().add_node(inputs, outputs); + } + + KwargDataflowGraphInput add_input(GraphInputName const &n) { + return this->get_interface().add_input(n); + } + + template + static typename std::enable_if< + std::is_base_of, + T>::value, + OpenKwargDataflowGraph>::type + create(Args &&...args) { + return OpenKwargDataflowGraph(make_cow_ptr(std::forward(args)...)); + } + +protected: + using OpenKwargDataflowGraphView::OpenKwargDataflowGraphView; + +private: + IOpenKwargDataflowGraph &get_interface() { + return *std::dynamic_pointer_cast< + IOpenKwargDataflowGraph>( + GraphView::ptr.get_mutable()); + } + + IOpenKwargDataflowGraph const & + get_interface() const { + return *std::dynamic_pointer_cast< + IOpenKwargDataflowGraph const>( + GraphView::ptr.get()); + } +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h new file mode 100644 index 0000000000..3153d1cbe9 --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h @@ -0,0 +1,52 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_VIEW_H + +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" +#include "utils/graph/open_kwarg_dataflow_graph/i_open_kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +struct OpenKwargDataflowGraphView + : virtual public KwargDataflowGraphView { +public: + OpenKwargDataflowGraphView(OpenKwargDataflowGraphView const &) = default; + OpenKwargDataflowGraphView & + operator=(OpenKwargDataflowGraphView const &) = default; + + std::unordered_set> + get_inputs() const { + return this->get_interface().get_inputs(); + } + + std::unordered_set> + query_edges( + OpenKwargDataflowEdgeQuery const &q) const { + return this->get_interface().query_edges(q); + } + + template + static typename std::enable_if< + std::is_base_of, + T>::value, + OpenKwargDataflowGraphView>::type + create(Args &&...args) { + return OpenKwargDataflowGraphView( + make_cow_ptr(std::forward(args)...)); + } + +protected: + using KwargDataflowGraphView::KwargDataflowGraphView; + +private: + IOpenKwargDataflowGraphView const & + get_interface() const { + return *std::dynamic_pointer_cast< + IOpenKwargDataflowGraphView const>( + GraphView::ptr.get()); + } +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_value.dtg.toml b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_value.dtg.toml new file mode 100644 index 0000000000..24a2021eca --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_value.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "OpenKwargDataflowValue" +type = "variant" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +template_params = [ + "GraphInputName", + "SlotName", +] + +includes = [ + "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output.dtg.h", + "utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_graph_input.dtg.h", +] + +[[values]] +type = "::FlexFlow::KwargDataflowOutput" +key = "internal" + +[[values]] +type = "::FlexFlow::KwargDataflowGraphInput" +key = "external" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.toml new file mode 100644 index 0000000000..bb607264d8 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "BinaryParallelSplit" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct BinarySPDecompositionTree", +] + +post_includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h", +] + +[[fields]] +name = "left_child" +type = "::FlexFlow::BinarySPDecompositionTree" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::BinarySPDecompositionTree" +indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml deleted file mode 100644 index 37e3bbee09..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml +++ /dev/null @@ -1,25 +0,0 @@ -namespace = "FlexFlow" -name = "BinaryParallelSplit" -features = [ - "eq", - "hash", - "fmt", -] - -fwd_decls = [ - "struct BinarySPDecompositionTree", -] - -post_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h", -] - -[[fields]] -name = "left_child" -type = "::FlexFlow::BinarySPDecompositionTree" -indirect = true - -[[fields]] -name = "right_child" -type = "::FlexFlow::BinarySPDecompositionTree" -indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.toml new file mode 100644 index 0000000000..0e4cfaadcd --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "BinarySeriesSplit" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct BinarySPDecompositionTree", +] + +post_includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h", +] + +[[fields]] +name = "left_child" +type = "::FlexFlow::BinarySPDecompositionTree" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::BinarySPDecompositionTree" +indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml deleted file mode 100644 index 7e6e86ba76..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml +++ /dev/null @@ -1,25 +0,0 @@ -namespace = "FlexFlow" -name = "BinarySeriesSplit" -features = [ - "eq", - "hash", - "fmt", -] - -fwd_decls = [ - "struct BinarySPDecompositionTree", -] - -post_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h", -] - -[[fields]] -name = "left_child" -type = "::FlexFlow::BinarySPDecompositionTree" -indirect = true - -[[fields]] -name = "right_child" -type = "::FlexFlow::BinarySPDecompositionTree" -indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.toml new file mode 100644 index 0000000000..faaf38626e --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "BinarySPDecompositionTree" +type = "variant" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.h", + "utils/graph/node/node.dtg.h", +] + +[[values]] +type = "::FlexFlow::BinarySeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::BinaryParallelSplit" +key = "parallel" + +[[values]] +type = "::FlexFlow::Node" +key = "node" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.variant.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.variant.toml deleted file mode 100644 index c586b49d9d..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.variant.toml +++ /dev/null @@ -1,25 +0,0 @@ -namespace = "FlexFlow" -name = "BinarySPDecompositionTree" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.h", - "utils/graph/node/node.dtg.h", -] - -[[values]] -type = "::FlexFlow::BinarySeriesSplit" -key = "series" - -[[values]] -type = "::FlexFlow::BinaryParallelSplit" -key = "parallel" - -[[values]] -type = "::FlexFlow::Node" -key = "node" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.toml new file mode 100644 index 0000000000..1f772e8be6 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.toml @@ -0,0 +1,48 @@ +namespace = "FlexFlow" +name = "GenericBinarySPDecompositionTreeImplementation" +type = "struct" +features = [] + +template_params = [ + "Tree", + "Series", + "Parallel", + "Leaf", +] + +includes = [ + "", + "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h", +] + +[[fields]] +name = "series_get_left_child" +type = "std::function" + +[[fields]] +name = "parallel_get_left_child" +type = "std::function" + +[[fields]] +name = "series_get_right_child" +type = "std::function" + +[[fields]] +name = "parallel_get_right_child" +type = "std::function" + +[[fields]] +name = "get_node_type" +type = "std::function<::FlexFlow::SPDecompositionTreeNodeType(Tree const &)>" + +[[fields]] +name = "require_series" +type = "std::function" + +[[fields]] +name = "require_parallel" +type = "std::function" + +[[fields]] +name = "require_leaf" +type = "std::function" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.struct.toml deleted file mode 100644 index 3ccbfd959b..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.struct.toml +++ /dev/null @@ -1,47 +0,0 @@ -namespace = "FlexFlow" -name = "GenericBinarySPDecompositionTreeImplementation" -features = [] - -template_params = [ - "Tree", - "Series", - "Parallel", - "Leaf", -] - -includes = [ - "", - "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h", -] - -[[fields]] -name = "series_get_left_child" -type = "std::function" - -[[fields]] -name = "parallel_get_left_child" -type = "std::function" - -[[fields]] -name = "series_get_right_child" -type = "std::function" - -[[fields]] -name = "parallel_get_right_child" -type = "std::function" - -[[fields]] -name = "get_node_type" -type = "std::function<::FlexFlow::SPDecompositionTreeNodeType(Tree const &)>" - -[[fields]] -name = "require_series" -type = "std::function" - -[[fields]] -name = "require_parallel" -type = "std::function" - -[[fields]] -name = "require_leaf" -type = "std::function" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.dtg.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.dtg.toml new file mode 100644 index 0000000000..627c962a42 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.dtg.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "GenericBinarySPDecompositionTreeVisitor" +type = "struct" +features = [] + +template_params = [ + "ReturnType", + "Tree", + "Series", + "Parallel", + "Leaf", +] + +includes = [ + "", +] + +[[fields]] +name = "series_func" +type = "std::function" + +[[fields]] +name = "parallel_func" +type = "std::function" + +[[fields]] +name = "leaf_func" +type = "std::function" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.struct.toml deleted file mode 100644 index 6275c82a0c..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.struct.toml +++ /dev/null @@ -1,27 +0,0 @@ -namespace = "FlexFlow" -name = "GenericBinarySPDecompositionTreeVisitor" -features = [] - -template_params = [ - "ReturnType", - "Tree", - "Series", - "Parallel", - "Leaf", -] - -includes = [ - "", -] - -[[fields]] -name = "series_func" -type = "std::function" - -[[fields]] -name = "parallel_func" -type = "std::function" - -[[fields]] -name = "leaf_func" -type = "std::function" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_path_to_leaf_map.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_path_to_leaf_map.h new file mode 100644 index 0000000000..a07b2c3926 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_path_to_leaf_map.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_PATH_TO_LEAF_MAP_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_PATH_TO_LEAF_MAP_H + +#include "utils/full_binary_tree/get_path_to_leaf_map.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" + +namespace FlexFlow { + +template +std::unordered_map get_path_to_leaf_map( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl) { + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); + + return get_path_to_leaf_map(tree, full_binary_impl); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/extended_parallel_reduction.dtg.toml b/lib/utils/include/utils/graph/series_parallel/extended_parallel_reduction.dtg.toml new file mode 100644 index 0000000000..331b0eed8a --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/extended_parallel_reduction.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "ExtendedParallelReduction" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +docstring = """\ +@brief An ExtendedParallelReduction is a unordered collection of +`MultiDiEdge`s such that they share a common source and destination node. +""" + +includes = [ + "utils/graph/multidigraph/multidiedge.dtg.h", + "" +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "edges" +type = "std::unordered_set<::FlexFlow::MultiDiEdge>" diff --git a/lib/utils/include/utils/graph/series_parallel/extended_parallel_reduction.struct.toml b/lib/utils/include/utils/graph/series_parallel/extended_parallel_reduction.struct.toml deleted file mode 100644 index ca43a987e2..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/extended_parallel_reduction.struct.toml +++ /dev/null @@ -1,26 +0,0 @@ -namespace = "FlexFlow" -name = "ExtendedParallelReduction" -features = [ - "eq", - "hash", - "fmt", -] - -docstring = """\ -@brief An ExtendedParallelReduction is a unordered collection of -`MultiDiEdge`s such that they share a common source and destination node. -""" - -includes = [ - "utils/graph/multidigraph/multidiedge.dtg.h", - "" -] - -src_includes = [ - "utils/hash/unordered_set.h", - "utils/fmt/unordered_set.h", -] - -[[fields]] -name = "edges" -type = "std::unordered_set<::FlexFlow::MultiDiEdge>" diff --git a/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.dtg.toml b/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.dtg.toml new file mode 100644 index 0000000000..166cb71b46 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.dtg.toml @@ -0,0 +1,32 @@ +namespace = "FlexFlow" +name = "ExtendedSeriesReduction" +type = "struct" + +docstring = """\ +@details An `ExtendedSeriesReduction` is an ordered collection of +`MultiDiEdges` such that: +- The destination node of the nth edge is the same as the source node of the + (n+1)th edge. +- Such a node (intermediate node) has exactly two edges: one incoming (nth + edge) and one outgoing ((n+1)th edge). +""" + +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/multidigraph/multidiedge.dtg.h", + "" +] + +src_includes = [ + "utils/hash/vector.h", + "utils/fmt/vector.h", +] + +[[fields]] +name = "edges" +type = "std::vector<::FlexFlow::MultiDiEdge>" diff --git a/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.struct.toml b/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.struct.toml deleted file mode 100644 index ed999a22df..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.struct.toml +++ /dev/null @@ -1,31 +0,0 @@ -namespace = "FlexFlow" -name = "ExtendedSeriesReduction" - -docstring = """\ -@details An `ExtendedSeriesReduction` is an ordered collection of -`MultiDiEdges` such that: -- The destination node of the nth edge is the same as the source node of the - (n+1)th edge. -- Such a node (intermediate node) has exactly two edges: one incoming (nth - edge) and one outgoing ((n+1)th edge). -""" - -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "utils/graph/multidigraph/multidiedge.dtg.h", - "" -] - -src_includes = [ - "utils/hash/vector.h", - "utils/fmt/vector.h", -] - -[[fields]] -name = "edges" -type = "std::vector<::FlexFlow::MultiDiEdge>" diff --git a/lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.dtg.toml b/lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.dtg.toml new file mode 100644 index 0000000000..1508c2d25f --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.dtg.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "IntermediateSpDecompositionTree" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/series_parallel/split_type.dtg.h", + "", + "", + "utils/graph/node/node.dtg.h", +] + +src_includes = [ + "utils/hash/vector.h", + "utils/fmt/vector.h", + "utils/fmt/variant.h" +] + +[[fields]] +name = "type" +type = "::FlexFlow::SplitType" + +[[fields]] +name = "children" +type = "std::vector>" diff --git a/lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.struct.toml b/lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.struct.toml deleted file mode 100644 index e7666fcd3f..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.struct.toml +++ /dev/null @@ -1,29 +0,0 @@ -namespace = "FlexFlow" -name = "IntermediateSpDecompositionTree" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/series_parallel/split_type.dtg.h", - "", - "", - "utils/graph/node/node.dtg.h", -] - -src_includes = [ - "utils/hash/vector.h", - "utils/fmt/vector.h", - "utils/fmt/variant.h" -] - -[[fields]] -name = "type" -type = "::FlexFlow::SplitType" - -[[fields]] -name = "children" -type = "std::vector>" diff --git a/lib/utils/include/utils/graph/series_parallel/parallel_reduction.dtg.toml b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.dtg.toml new file mode 100644 index 0000000000..3344ad04ba --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "ParallelReduction" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/multidigraph/multidiedge.dtg.h", + "utils/commutative_pair.h", +] + +[[fields]] +name = "edges" +type = "::FlexFlow::commutative_pair<::FlexFlow::MultiDiEdge>" diff --git a/lib/utils/include/utils/graph/series_parallel/parallel_reduction.struct.toml b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.struct.toml deleted file mode 100644 index aa531ed1ea..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/parallel_reduction.struct.toml +++ /dev/null @@ -1,17 +0,0 @@ -namespace = "FlexFlow" -name = "ParallelReduction" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/multidigraph/multidiedge.dtg.h", - "utils/commutative_pair.h", -] - -[[fields]] -name = "edges" -type = "::FlexFlow::commutative_pair<::FlexFlow::MultiDiEdge>" diff --git a/lib/utils/include/utils/graph/series_parallel/parallel_split.dtg.toml b/lib/utils/include/utils/graph/series_parallel/parallel_split.dtg.toml new file mode 100644 index 0000000000..a3315d506b --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/parallel_split.dtg.toml @@ -0,0 +1,33 @@ +namespace = "FlexFlow" +name = "ParallelSplit" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct SeriesSplit" +] + +post_includes = [ + "utils/graph/series_parallel/series_split.dtg.h", +] + +includes = [ + "", + "", + "utils/graph/node/node.dtg.h", +] + +src_includes = [ + "utils/fmt/variant.h", + "utils/fmt/unordered_multiset.h", + "utils/hash/unordered_multiset.h", +] + +[[fields]] +name = "children" +type = "std::unordered_multiset>" +indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/parallel_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/parallel_split.struct.toml deleted file mode 100644 index dd68adf3f6..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/parallel_split.struct.toml +++ /dev/null @@ -1,32 +0,0 @@ -namespace = "FlexFlow" -name = "ParallelSplit" -features = [ - "eq", - "hash", - "fmt", -] - -fwd_decls = [ - "struct SeriesSplit" -] - -post_includes = [ - "utils/graph/series_parallel/series_split.dtg.h", -] - -includes = [ - "", - "", - "utils/graph/node/node.dtg.h", -] - -src_includes = [ - "utils/fmt/variant.h", - "utils/fmt/unordered_multiset.h", - "utils/hash/unordered_multiset.h", -] - -[[fields]] -name = "children" -type = "std::unordered_multiset>" -indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.dtg.toml b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.dtg.toml new file mode 100644 index 0000000000..4635bdd877 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.dtg.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "SeriesParallelDecomposition" +type = "variant" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/series_parallel/parallel_split.dtg.h", + "utils/graph/series_parallel/series_split.dtg.h", + "utils/graph/node/node.dtg.h", +] + +[[values]] +type = "::FlexFlow::SeriesSplit" + +[[values]] +type = "::FlexFlow::ParallelSplit" + +[[values]] +type = "::FlexFlow::Node" diff --git a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.variant.toml b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.variant.toml deleted file mode 100644 index 921499ebd1..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.variant.toml +++ /dev/null @@ -1,21 +0,0 @@ -namespace = "FlexFlow" -name = "SeriesParallelDecomposition" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "utils/graph/series_parallel/series_parallel_splits.h", - "utils/graph/node/node.dtg.h", -] - -[[values]] -type = "::FlexFlow::SeriesSplit" - -[[values]] -type = "::FlexFlow::ParallelSplit" - -[[values]] -type = "::FlexFlow::Node" diff --git a/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h b/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h deleted file mode 100644 index 7374b45a60..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h +++ /dev/null @@ -1,76 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H - -#include "utils/graph/series_parallel/parallel_split.dtg.h" -#include "utils/graph/series_parallel/series_split.dtg.h" - -namespace FlexFlow { - -// struct SeriesSplit { -// public: -// SeriesSplit() = delete; -// explicit SeriesSplit(std::vector> const -// &); explicit SeriesSplit( -// std::initializer_list> const &); -// -// bool operator==(SeriesSplit const &) const; -// bool operator!=(SeriesSplit const &) const; -// -// public: -// std::vector> children; -// -// private: -// using Tie = std::tuple; -// Tie tie() const; -// }; -// -// std::string format_as(SeriesSplit const &); -// std::ostream &operator<<(std::ostream &, SeriesSplit const &); -// -// } // namespace FlexFlow -// -// namespace std { -// -// template <> -// struct hash<::FlexFlow::SeriesSplit> { -// size_t operator()(::FlexFlow::SeriesSplit const &) const; -// }; -// -// } // namespace std -// -// namespace FlexFlow { -// -// struct ParallelSplit { -// public: -// ParallelSplit() = delete; -// explicit ParallelSplit( -// std::unordered_multiset> const &); -// explicit ParallelSplit( -// std::initializer_list> const &); -// -// bool operator==(ParallelSplit const &) const; -// bool operator!=(ParallelSplit const &) const; -// -// public: -// std::unordered_multiset> children; -// -// private: -// using Tie = std::tuple; -// Tie tie() const; -// }; -// -// std::string format_as(ParallelSplit const &); -// std::ostream &operator<<(std::ostream &, ParallelSplit const &); -// -// } // namespace FlexFlow -// -// namespace std { -// -// template <> -// struct hash<::FlexFlow::ParallelSplit> { -// size_t operator()(::FlexFlow::ParallelSplit const &) const; -// }; - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/series_reduction.dtg.toml b/lib/utils/include/utils/graph/series_parallel/series_reduction.dtg.toml new file mode 100644 index 0000000000..b99518581a --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/series_reduction.dtg.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "SeriesReduction" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/multidigraph/multidiedge.dtg.h", +] + +[[fields]] +name = "first" +type = "::FlexFlow::MultiDiEdge" + +[[fields]] +name = "second" +type = "::FlexFlow::MultiDiEdge" diff --git a/lib/utils/include/utils/graph/series_parallel/series_reduction.struct.toml b/lib/utils/include/utils/graph/series_parallel/series_reduction.struct.toml deleted file mode 100644 index b9cc02af1c..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/series_reduction.struct.toml +++ /dev/null @@ -1,20 +0,0 @@ -namespace = "FlexFlow" -name = "SeriesReduction" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/multidigraph/multidiedge.dtg.h", -] - -[[fields]] -name = "first" -type = "::FlexFlow::MultiDiEdge" - -[[fields]] -name = "second" -type = "::FlexFlow::MultiDiEdge" diff --git a/lib/utils/include/utils/graph/series_parallel/series_split.dtg.toml b/lib/utils/include/utils/graph/series_parallel/series_split.dtg.toml new file mode 100644 index 0000000000..e37762a059 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/series_split.dtg.toml @@ -0,0 +1,32 @@ +namespace = "FlexFlow" +name = "SeriesSplit" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct ParallelSplit" +] + +post_includes = [ + "utils/graph/series_parallel/parallel_split.dtg.h", +] + +includes = [ + "", + "", + "utils/graph/node/node.dtg.h", +] + +src_includes = [ + "utils/fmt/variant.h", + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "children" +type = "std::vector>" diff --git a/lib/utils/include/utils/graph/series_parallel/series_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/series_split.struct.toml deleted file mode 100644 index fdb0a29972..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/series_split.struct.toml +++ /dev/null @@ -1,31 +0,0 @@ -namespace = "FlexFlow" -name = "SeriesSplit" -features = [ - "eq", - "hash", - "fmt", -] - -fwd_decls = [ - "struct ParallelSplit" -] - -post_includes = [ - "utils/graph/series_parallel/parallel_split.dtg.h", -] - -includes = [ - "", - "", - "utils/graph/node/node.dtg.h", -] - -src_includes = [ - "utils/fmt/variant.h", - "utils/fmt/vector.h", - "utils/hash/vector.h", -] - -[[fields]] -name = "children" -type = "std::vector>" diff --git a/lib/utils/include/utils/graph/series_parallel/sink_settings.dtg.toml b/lib/utils/include/utils/graph/series_parallel/sink_settings.dtg.toml new file mode 100644 index 0000000000..6a4a7befc1 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/sink_settings.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "SinkSettings" +type = "enum" +features = [ + "hash", + "fmt", + "json", + "rapidcheck", +] + +[[values]] +name = "INCLUDE_SINK_NODES" + +[[values]] +name = "EXCLUDE_SINK_NODES" diff --git a/lib/utils/include/utils/graph/series_parallel/sink_settings.enum.toml b/lib/utils/include/utils/graph/series_parallel/sink_settings.enum.toml deleted file mode 100644 index 5668556543..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/sink_settings.enum.toml +++ /dev/null @@ -1,14 +0,0 @@ -namespace = "FlexFlow" -name = "SinkSettings" -features = [ - "hash", - "fmt", - "json", - "rapidcheck", -] - -[[values]] -name = "INCLUDE_SINK_NODES" - -[[values]] -name = "EXCLUDE_SINK_NODES" diff --git a/lib/utils/include/utils/graph/series_parallel/source_settings.dtg.toml b/lib/utils/include/utils/graph/series_parallel/source_settings.dtg.toml new file mode 100644 index 0000000000..27352f0836 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/source_settings.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "SourceSettings" +type = "enum" +features = [ + "hash", + "fmt", + "json", + "rapidcheck", +] + +[[values]] +name = "INCLUDE_SOURCE_NODES" + +[[values]] +name = "EXCLUDE_SOURCE_NODES" diff --git a/lib/utils/include/utils/graph/series_parallel/source_settings.enum.toml b/lib/utils/include/utils/graph/series_parallel/source_settings.enum.toml deleted file mode 100644 index 8d17dc4d77..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/source_settings.enum.toml +++ /dev/null @@ -1,14 +0,0 @@ -namespace = "FlexFlow" -name = "SourceSettings" -features = [ - "hash", - "fmt", - "json", - "rapidcheck", -] - -[[values]] -name = "INCLUDE_SOURCE_NODES" - -[[values]] -name = "EXCLUDE_SOURCE_NODES" diff --git a/lib/utils/include/utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.toml b/lib/utils/include/utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.toml new file mode 100644 index 0000000000..21d06de50c --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "SPDecompositionTreeNodeType" +type = "enum" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "SERIES" + +[[values]] +name = "PARALLEL" + +[[values]] +name = "NODE" diff --git a/lib/utils/include/utils/graph/series_parallel/sp_decomposition_tree_node_type.enum.toml b/lib/utils/include/utils/graph/series_parallel/sp_decomposition_tree_node_type.enum.toml deleted file mode 100644 index 2050800cbd..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/sp_decomposition_tree_node_type.enum.toml +++ /dev/null @@ -1,17 +0,0 @@ -namespace = "FlexFlow" -name = "SPDecompositionTreeNodeType" -features = [ - "hash", - "fmt", - "rapidcheck", - "json", -] - -[[values]] -name = "SERIES" - -[[values]] -name = "PARALLEL" - -[[values]] -name = "NODE" diff --git a/lib/utils/include/utils/graph/series_parallel/split_type.dtg.toml b/lib/utils/include/utils/graph/series_parallel/split_type.dtg.toml new file mode 100644 index 0000000000..bae99530b2 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/split_type.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "SplitType" +type = "enum" +features = [ + "hash", + "json", + "fmt", + "rapidcheck", +] + +[[values]] +name = "SERIES" + +[[values]] +name = "PARALLEL" diff --git a/lib/utils/include/utils/graph/series_parallel/split_type.enum.toml b/lib/utils/include/utils/graph/series_parallel/split_type.enum.toml deleted file mode 100644 index c1a1cb5978..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/split_type.enum.toml +++ /dev/null @@ -1,14 +0,0 @@ -namespace = "FlexFlow" -name = "SplitType" -features = [ - "hash", - "json", - "fmt", - "rapidcheck", -] - -[[values]] -name = "SERIES" - -[[values]] -name = "PARALLEL" diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge.dtg.toml b/lib/utils/include/utils/graph/undirected/undirected_edge.dtg.toml new file mode 100644 index 0000000000..910c4fe0f3 --- /dev/null +++ b/lib/utils/include/utils/graph/undirected/undirected_edge.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "UndirectedEdge" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/commutative_pair.h", + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "endpoints" +type = "::FlexFlow::commutative_pair<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge.h b/lib/utils/include/utils/graph/undirected/undirected_edge.h index d051413faa..1eeea7b3c2 100644 --- a/lib/utils/include/utils/graph/undirected/undirected_edge.h +++ b/lib/utils/include/utils/graph/undirected/undirected_edge.h @@ -8,6 +8,8 @@ namespace FlexFlow { bool is_connected_to(UndirectedEdge const &e, Node const &n); +std::unordered_set get_endpoints(UndirectedEdge const &); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml b/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml deleted file mode 100644 index 0ad8232339..0000000000 --- a/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml +++ /dev/null @@ -1,17 +0,0 @@ -namespace = "FlexFlow" -name = "UndirectedEdge" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/commutative_pair.h", - "utils/graph/node/node.dtg.h", -] - -[[fields]] -name = "endpoints" -type = "::FlexFlow::commutative_pair<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge_query.dtg.toml b/lib/utils/include/utils/graph/undirected/undirected_edge_query.dtg.toml new file mode 100644 index 0000000000..3099edea10 --- /dev/null +++ b/lib/utils/include/utils/graph/undirected/undirected_edge_query.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "UndirectedEdgeQuery" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/query_set.h", + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "nodes" +type = "::FlexFlow::query_set<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge_query.struct.toml b/lib/utils/include/utils/graph/undirected/undirected_edge_query.struct.toml deleted file mode 100644 index 239194a275..0000000000 --- a/lib/utils/include/utils/graph/undirected/undirected_edge_query.struct.toml +++ /dev/null @@ -1,17 +0,0 @@ -namespace = "FlexFlow" -name = "UndirectedEdgeQuery" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/query_set.h", - "utils/graph/node/node.dtg.h", -] - -[[fields]] -name = "nodes" -type = "::FlexFlow::query_set<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/views/join_node_key.dtg.toml b/lib/utils/include/utils/graph/views/join_node_key.dtg.toml new file mode 100644 index 0000000000..df48138a19 --- /dev/null +++ b/lib/utils/include/utils/graph/views/join_node_key.dtg.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "JoinNodeKey" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", + "utils/graph/views/lr_direction.dtg.h", +] + +[[fields]] +name = "node" +type = "::FlexFlow::Node" + +[[fields]] +name = "direction" +type = "::FlexFlow::LRDirection" diff --git a/lib/utils/include/utils/graph/views/join_node_key.struct.toml b/lib/utils/include/utils/graph/views/join_node_key.struct.toml deleted file mode 100644 index 9dce99f0a0..0000000000 --- a/lib/utils/include/utils/graph/views/join_node_key.struct.toml +++ /dev/null @@ -1,21 +0,0 @@ -namespace = "FlexFlow" -name = "JoinNodeKey" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/node/node.dtg.h", - "utils/graph/views/lr_direction.dtg.h", -] - -[[fields]] -name = "node" -type = "::FlexFlow::Node" - -[[fields]] -name = "direction" -type = "::FlexFlow::LRDirection" diff --git a/lib/utils/include/utils/graph/views/lr_direction.dtg.toml b/lib/utils/include/utils/graph/views/lr_direction.dtg.toml new file mode 100644 index 0000000000..098018b040 --- /dev/null +++ b/lib/utils/include/utils/graph/views/lr_direction.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "LRDirection" +type = "enum" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "LEFT" + +[[values]] +name = "RIGHT" diff --git a/lib/utils/include/utils/graph/views/lr_direction.enum.toml b/lib/utils/include/utils/graph/views/lr_direction.enum.toml deleted file mode 100644 index 878a937b0b..0000000000 --- a/lib/utils/include/utils/graph/views/lr_direction.enum.toml +++ /dev/null @@ -1,14 +0,0 @@ -namespace = "FlexFlow" -name = "LRDirection" -features = [ - "hash", - "fmt", - "rapidcheck", - "json", -] - -[[values]] -name = "LEFT" - -[[values]] -name = "RIGHT" diff --git a/lib/utils/include/utils/int_ge_two/algorithms/try_int_ge_two_from_positive_int.h b/lib/utils/include/utils/int_ge_two/algorithms/try_int_ge_two_from_positive_int.h new file mode 100644 index 0000000000..77e58309b6 --- /dev/null +++ b/lib/utils/include/utils/int_ge_two/algorithms/try_int_ge_two_from_positive_int.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_INT_GE_TWO_ALGORITHMS_TRY_INT_GE_TWO_FROM_POSITIVE_INT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_INT_GE_TWO_ALGORITHMS_TRY_INT_GE_TWO_FROM_POSITIVE_INT_H + +#include "utils/int_ge_two/int_ge_two.h" + +namespace FlexFlow { + +std::optional try_int_ge_two_from_positive_int(positive_int); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/int_ge_two/int_ge_two.h b/lib/utils/include/utils/int_ge_two/int_ge_two.h new file mode 100644 index 0000000000..c22254b219 --- /dev/null +++ b/lib/utils/include/utils/int_ge_two/int_ge_two.h @@ -0,0 +1,132 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_INT_GE_TWO_INT_GE_TWO_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_INT_GE_TWO_INT_GE_TWO_H + +#include "utils/positive_int/positive_int.h" +#include + +namespace FlexFlow { + +struct int_ge_two { + int_ge_two() = delete; + explicit int_ge_two(int value); + explicit int_ge_two(size_t value); + explicit int_ge_two(nonnegative_int value); + explicit int_ge_two(positive_int value); + + explicit operator int() const noexcept; + explicit operator nonnegative_int() const noexcept; + explicit operator positive_int() const noexcept; + + bool operator<(int_ge_two other) const; + bool operator==(int_ge_two other) const; + bool operator>(int_ge_two other) const; + bool operator<=(int_ge_two other) const; + bool operator!=(int_ge_two other) const; + bool operator>=(int_ge_two other) const; + + bool operator<(positive_int other) const; + bool operator==(positive_int other) const; + bool operator>(positive_int other) const; + bool operator<=(positive_int other) const; + bool operator!=(positive_int other) const; + bool operator>=(positive_int other) const; + + friend bool operator<(positive_int lhs, int_ge_two rhs); + friend bool operator==(positive_int lhs, int_ge_two rhs); + friend bool operator>(positive_int lhs, int_ge_two rhs); + friend bool operator<=(positive_int lhs, int_ge_two rhs); + friend bool operator!=(positive_int lhs, int_ge_two rhs); + friend bool operator>=(positive_int lhs, int_ge_two rhs); + + bool operator<(nonnegative_int other) const; + bool operator==(nonnegative_int other) const; + bool operator>(nonnegative_int other) const; + bool operator<=(nonnegative_int other) const; + bool operator!=(nonnegative_int other) const; + bool operator>=(nonnegative_int other) const; + + friend bool operator<(nonnegative_int lhs, int_ge_two rhs); + friend bool operator==(nonnegative_int lhs, int_ge_two rhs); + friend bool operator>(nonnegative_int lhs, int_ge_two rhs); + friend bool operator<=(nonnegative_int lhs, int_ge_two rhs); + friend bool operator!=(nonnegative_int lhs, int_ge_two rhs); + friend bool operator>=(nonnegative_int lhs, int_ge_two rhs); + + bool operator<(int other) const; + bool operator==(int other) const; + bool operator>(int other) const; + bool operator<=(int other) const; + bool operator!=(int other) const; + bool operator>=(int other) const; + + friend bool operator<(int lhs, int_ge_two rhs); + friend bool operator==(int lhs, int_ge_two rhs); + friend bool operator>(int lhs, int_ge_two rhs); + friend bool operator<=(int lhs, int_ge_two rhs); + friend bool operator!=(int lhs, int_ge_two rhs); + friend bool operator>=(int lhs, int_ge_two rhs); + + int_ge_two operator+(int_ge_two other) const; + int_ge_two operator+(positive_int other) const; + int_ge_two operator+(nonnegative_int other) const; + int_ge_two &operator++(); + int_ge_two operator++(int); + int_ge_two &operator+=(int_ge_two other); + int_ge_two &operator+=(positive_int other); + int_ge_two &operator+=(nonnegative_int other); + + friend int_ge_two operator+(nonnegative_int lhs, int_ge_two rhs); + friend int_ge_two operator+(positive_int lhs, int_ge_two rhs); + + int_ge_two operator*(int_ge_two other) const; + int_ge_two &operator*=(int_ge_two other); + int_ge_two operator*(positive_int other) const; + int_ge_two &operator*=(positive_int other); + nonnegative_int operator*(nonnegative_int other) const; + + friend int_ge_two operator*(positive_int lhs, int_ge_two rhs); + friend nonnegative_int operator*(nonnegative_int lhs, int_ge_two rhs); + + int int_from_int_ge_two() const; + nonnegative_int nonnegative_int_from_int_ge_two() const; + positive_int positive_int_from_int_ge_two() const; + + friend std::ostream &operator<<(std::ostream &os, int_ge_two n); + + friend int format_as(int_ge_two); + +private: + void check_invariant() const; + +private: + int value_; +}; + +int_ge_two operator""_ge2(unsigned long long int); + +std::optional try_int_ge_two_from_positive_int(positive_int); + +} // namespace FlexFlow + +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::int_ge_two> { + static ::FlexFlow::int_ge_two from_json(json const &j); + static void to_json(json &j, ::FlexFlow::int_ge_two t); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary<::FlexFlow::int_ge_two> { + static Gen<::FlexFlow::int_ge_two> arbitrary(); +}; +} // namespace rc + +namespace std { +template <> +struct hash<::FlexFlow::int_ge_two> { + std::size_t operator()(FlexFlow::int_ge_two n) const noexcept; +}; +} // namespace std +#endif diff --git a/lib/utils/include/utils/internal_only_tag.h b/lib/utils/include/utils/internal_only_tag.h deleted file mode 100644 index 1e5f8571d0..0000000000 --- a/lib/utils/include/utils/internal_only_tag.h +++ /dev/null @@ -1,10 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_INTERNAL_ONLY_TAG_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_INTERNAL_ONLY_TAG_H - -namespace FlexFlow { -struct should_only_be_used_internally_tag_t { - explicit should_only_be_used_internally_tag_t() = default; -}; -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/json/check_is_json_deserializable.h b/lib/utils/include/utils/json/check_is_json_deserializable.h index dd5f397c19..02edb37026 100644 --- a/lib/utils/include/utils/json/check_is_json_deserializable.h +++ b/lib/utils/include/utils/json/check_is_json_deserializable.h @@ -5,9 +5,9 @@ namespace FlexFlow { -#define CHECK_IS_JSON_DESERIALIZABLE(TYPENAME) \ - static_assert(::FlexFlow::is_json_deserializable::value, \ - #TYPENAME " should be json deserializeable") +#define CHECK_IS_JSON_DESERIALIZABLE(...) \ + static_assert(::FlexFlow::is_json_deserializable<__VA_ARGS__>::value, \ + #__VA_ARGS__ " should be json deserializeable") } // namespace FlexFlow diff --git a/lib/utils/include/utils/json/check_is_json_serializable.h b/lib/utils/include/utils/json/check_is_json_serializable.h index dfcb26081d..2533a5894c 100644 --- a/lib/utils/include/utils/json/check_is_json_serializable.h +++ b/lib/utils/include/utils/json/check_is_json_serializable.h @@ -5,9 +5,9 @@ namespace FlexFlow { -#define CHECK_IS_JSON_SERIALIZABLE(TYPENAME) \ - static_assert(::FlexFlow::is_json_serializable::value, \ - #TYPENAME " should be json serializeable") +#define CHECK_IS_JSON_SERIALIZABLE(...) \ + static_assert(::FlexFlow::is_json_serializable<__VA_ARGS__>::value, \ + #__VA_ARGS__ " should be json serializeable") } // namespace FlexFlow diff --git a/lib/utils/include/utils/json/check_is_jsonable.h b/lib/utils/include/utils/json/check_is_jsonable.h index 41a64a1b83..9d4f65f005 100644 --- a/lib/utils/include/utils/json/check_is_jsonable.h +++ b/lib/utils/include/utils/json/check_is_jsonable.h @@ -6,11 +6,11 @@ namespace FlexFlow { -#define CHECK_IS_JSONABLE(TYPENAME) \ - static_assert(is_json_serializable::value, \ - #TYPENAME " should be json serializeable"); \ - static_assert(is_json_deserializable::value, \ - #TYPENAME " should be json deserializeable") +#define CHECK_IS_JSONABLE(...) \ + static_assert(is_json_serializable<__VA_ARGS__>::value, \ + #__VA_ARGS__ " should be json serializeable"); \ + static_assert(is_json_deserializable<__VA_ARGS__>::value, \ + #__VA_ARGS__ " should be json deserializeable") } // namespace FlexFlow diff --git a/lib/utils/include/utils/json/monostate.h b/lib/utils/include/utils/json/monostate.h new file mode 100644 index 0000000000..14d2fe233f --- /dev/null +++ b/lib/utils/include/utils/json/monostate.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_MONOSTATE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_MONOSTATE_H + +#include +#include + +namespace nlohmann { + +template <> +struct adl_serializer { + static void to_json(json &, std::monostate); + static void from_json(json const &, std::monostate &); +}; + +} // namespace nlohmann + +#endif diff --git a/lib/utils/include/utils/json/variant.h b/lib/utils/include/utils/json/variant.h deleted file mode 100644 index fe2c3f3b6c..0000000000 --- a/lib/utils/include/utils/json/variant.h +++ /dev/null @@ -1,89 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_VARIANT_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_VARIANT_H - -#include "utils/json/is_jsonable.h" -#include - -namespace FlexFlow { - -struct VariantToJsonFunctor { - VariantToJsonFunctor(nlohmann::json &j) : j(j) {} - - nlohmann::json &j; - - template - void operator()(T const &t) { - static_assert(is_jsonable::value, ""); - - j = t; - } -}; - -template -void variant_to_json(json &j, std::variant const &v) { - json jval; - visit(::FlexFlow::VariantToJsonFunctor{jval}, v); - j["value"] = jval; - j["index"] = v.index(); -} - -template -std::optional variant_from_json_impl(json const &j) { - using Type = typename std::variant_alternative::type; - - if (j.at("index").get() == Idx) { - return j.at("value").get(); - } - return std::nullopt; -} - -template -std::optional variant_from_json_impl(json const &j, - std::index_sequence) { - // If there were no errors when parsing, all but one element of the array - // will be nullopt. This is because each call to variant_from_json_impl will - // have a unique index and exactly one of them will match the index in the - // json object. - std::array, sizeof...(Is)> results{ - variant_from_json_impl(j)...}; - for (std::optional &maybe : results) { - if (maybe) { - return maybe.value(); - } - } - return std::nullopt; -} - -template -std::variant variant_from_json(json const &j) { - using Variant = std::variant; - std::optional result = variant_from_json_impl( - j, std::make_index_sequence()); - if (!result.has_value()) { - throw ::FlexFlow::mk_runtime_error("Invalid type {} found in json", - j.at("index").get()); - } - return result.value(); -} - -} // namespace FlexFlow - -namespace nlohmann { - -template -struct adl_serializer, - typename std::enable_if<::FlexFlow::elements_satisfy< - ::FlexFlow::is_json_serializable, - std::variant>::value>::type> { - static void to_json(json &j, std::variant const &v) { - return ::FlexFlow::variant_to_json(j, v); - } - - static std::variant from_json(json const &j) { - return ::FlexFlow::variant_from_json(j); - } -}; - -} // namespace nlohmann - -#endif diff --git a/lib/utils/include/utils/json/visitable.h b/lib/utils/include/utils/json/visitable.h deleted file mode 100644 index abc20065de..0000000000 --- a/lib/utils/include/utils/json/visitable.h +++ /dev/null @@ -1,152 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_JSON_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_JSON_H - -#include "utils/json/is_json_deserializable.h" -#include "utils/json/is_json_serializable.h" -#include "utils/json/is_jsonable.h" -#include "utils/json_core.h" -#include "utils/optional.h" -#include "utils/sequence.h" -#include "utils/type_traits.h" -#include "utils/variant.h" -#include "utils/visitable.h" - -namespace FlexFlow { - -struct json_serialization_visitor { - json_serialization_visitor() = delete; - json_serialization_visitor(json &j) : j(j) {} - - json &j; - - template - void operator()(char const *field_name, T const &field_value) { - j[field_name] = field_value; - } -}; - -struct json_deserialization_visitor { - json_deserialization_visitor() = delete; - json_deserialization_visitor(json const &j) : j(j) {} - - json const &j; - - template - void operator()(char const *field_name, T &field_value) { - j.at(field_name).get_to(field_value); - } -}; - -static_assert(std::is_same>, - std::tuple>::value, - ""); -static_assert(std::is_same>, - std::tuple<>>::value, - ""); - -template -typename std::enable_if<(idx >= std::tuple_size>::value), - std::tuple<>>::type - tuple_from_json_impl(json const &j) { - return std::tuple<>{}; -} - -template -struct TupleFromJson { - tuple_tail_t> operator()(json const &j) { - using FieldT = visit_struct::type_at; - - FieldT field = - j.at(visit_struct::get_name()).template get(); - - return std::tuple_cat(std::tuple(field), - TupleFromJson<(idx + 1), T>{}(j)); - } -}; - -template -struct TupleFromJson< - idx, - T, - typename std::enable_if<( - idx > std::tuple_size>::value)>::type> { - std::tuple<> operator()(json const &j) { - return {}; - } -}; - -template -visit_as_tuple_t tuple_from_json(json const &j) { - return TupleFromJson<0, T>{}(j); -} - -template -void visit_json_serialize(json &j, T const &t) { - static_assert(is_visitable::value, "Type must be visitable"); - static_assert(elements_satisfy::value, - "Elements must be deserializable"); - - json_serialization_visitor vis(j); - visit_struct::for_each(t, vis); -} - -template -void visit_json_deserialize(json const &j, T &t) { - static_assert(is_visitable::value, "Type must be visitable"); - static_assert(elements_satisfy::value, - "Elements must be deserializable"); - - json_deserialization_visitor vis(j); - visit_struct::for_each(t, vis); -} - -template -T moveonly_visit_json_deserialize(json const &j) { - static_assert(is_visitable::value, "Type must be visitable"); - static_assert(!std::is_default_constructible::value, ""); - static_assert(elements_satisfy::value, - "Elements must be deserializable"); - - return visitable_from_tuple(tuple_from_json(j)); -} - -} // namespace FlexFlow - -namespace nlohmann { - -template -struct adl_serializer< - T, - typename std::enable_if<::FlexFlow::conjunction< - ::FlexFlow::is_visitable, - ::FlexFlow::elements_satisfy<::FlexFlow::is_json_serializable, T>, - std::is_default_constructible>::value>::type> { - static void to_json(json &j, T const &t) { - ::FlexFlow::visit_json_serialize(j, t); - } - - static void from_json(json const &j, T &t) { - ::FlexFlow::visit_json_deserialize(j, t); - } -}; - -template -struct adl_serializer< - T, - typename std::enable_if<::FlexFlow::conjunction< - ::FlexFlow::is_visitable, - ::FlexFlow::elements_satisfy<::FlexFlow::is_json_serializable, T>, - ::FlexFlow::negation>, - std::is_move_constructible>::value>::type> { - static void to_json(json &j, T const &t) { - ::FlexFlow::visit_json_serialize(j, t); - } - - static T from_json(json const &j) { - return ::FlexFlow::moveonly_visit_json_deserialize(j); - } -}; - -} // namespace nlohmann - -#endif diff --git a/lib/utils/include/utils/json_core.h b/lib/utils/include/utils/json_core.h deleted file mode 100644 index eb99463e6a..0000000000 --- a/lib/utils/include/utils/json_core.h +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_JSON_CORE_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_JSON_CORE_H - -#include "nlohmann/json.hpp" - -namespace FlexFlow { - -using json = nlohmann::json; - -} - -#endif diff --git a/lib/utils/include/utils/many_to_one/exhaustive_relational_join.h b/lib/utils/include/utils/many_to_one/exhaustive_relational_join.h new file mode 100644 index 0000000000..c908a9dcec --- /dev/null +++ b/lib/utils/include/utils/many_to_one/exhaustive_relational_join.h @@ -0,0 +1,35 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_EXHAUSTIVE_RELATIONAL_JOIN_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_EXHAUSTIVE_RELATIONAL_JOIN_H + +#include "utils/many_to_one/many_to_one.h" + +namespace FlexFlow { + +template +ManyToOne exhaustive_relational_join(ManyToOne const &fst, + ManyToOne const &snd) { + ManyToOne result; + + if (fst.right_values() != snd.left_values()) { + throw mk_runtime_error( + fmt::format("exhaustive_relational_join for ManyToOne received inputs " + "with non-matching inner dimensions: right dimension of " + "fst is {} while left dimension of snd is {}", + fst.right_values(), + snd.left_values())); + } + + for (T3 const &t3 : snd.right_values()) { + for (T2 const &t2 : snd.at_r(t3)) { + for (T1 const &t1 : fst.at_r(t2)) { + result.insert({t1, t3}); + } + } + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/many_to_one/invert_many_to_one.h b/lib/utils/include/utils/many_to_one/invert_many_to_one.h new file mode 100644 index 0000000000..7fdf36859f --- /dev/null +++ b/lib/utils/include/utils/many_to_one/invert_many_to_one.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_INVERT_MANY_TO_ONE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_INVERT_MANY_TO_ONE_H + +#include "utils/many_to_one/many_to_one.h" +#include "utils/one_to_many/one_to_many.h" + +namespace FlexFlow { + +template +OneToMany invert_many_to_one(ManyToOne const &many_to_one) { + OneToMany result; + + for (L const &l : many_to_one.left_values()) { + result.insert({many_to_one.at_l(l), l}); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/many_to_one/many_to_one.h b/lib/utils/include/utils/many_to_one/many_to_one.h new file mode 100644 index 0000000000..d2f727661c --- /dev/null +++ b/lib/utils/include/utils/many_to_one/many_to_one.h @@ -0,0 +1,184 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_MANY_TO_ONE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_MANY_TO_ONE_H + +#include "utils/containers/keys.h" +#include "utils/containers/try_at.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/values.h" +#include "utils/exception.h" +#include "utils/fmt/unordered_map.h" +#include "utils/fmt/unordered_set.h" +#include "utils/hash-utils.h" +#include "utils/hash/tuple.h" +#include "utils/hash/unordered_map.h" +#include "utils/hash/unordered_set.h" +#include "utils/json/check_is_json_deserializable.h" +#include "utils/json/check_is_json_serializable.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { + +template +struct ManyToOne { +public: + ManyToOne() : m_l_to_r(), m_r_to_l() {} + + template + ManyToOne(It start, It end) : ManyToOne() { + for (; start < end; start++) { + ASSERT(start->first.size() > 0); + for (L const &l : start->first) { + this->insert(std::pair{l, start->second}); + } + } + } + + ManyToOne(std::initializer_list, R>> const + &l_to_r) + : ManyToOne(l_to_r.begin(), l_to_r.end()) {} + + bool operator==(ManyToOne const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(ManyToOne const &other) const { + return this->tie() != other.tie(); + } + + void insert(std::pair const &p) { + L l = p.first; + R r = p.second; + + std::optional found_r = try_at(this->m_l_to_r, l); + + if (!found_r.has_value()) { + this->m_l_to_r.insert({l, r}); + this->m_r_to_l[r].insert(l); + } else if (found_r.value() == r) { + return; + } else { + PANIC(fmt::format( + "Existing mapping found for left value {}: tried to map to right " + "value {}, but is already bound to right value {}", + l, + r, + found_r.value())); + } + } + + bool contains_l(L const &l) const { + return contains_key(this->m_l_to_r, l); + } + + bool contains_r(R const &r) const { + return contains_key(this->m_r_to_l, r); + } + + R const &at_l(L const &l) const { + return this->m_l_to_r.at(l); + } + + std::unordered_set const &at_r(R const &r) const { + return this->m_r_to_l.at(r); + } + + std::unordered_set left_values() const { + return keys(this->m_l_to_r); + } + + std::unordered_set> left_groups() const { + return unordered_set_of(values(this->m_r_to_l)); + } + + std::unordered_set right_values() const { + return keys(this->m_r_to_l); + } + + std::unordered_map const &l_to_r() const { + return this->m_l_to_r; + } + + std::unordered_map> const &r_to_l() const { + return this->m_r_to_l; + } + +private: + std::unordered_map m_l_to_r; + std::unordered_map> m_r_to_l; + +private: + std::tuple + tie() const { + return std::tie(this->m_l_to_r, this->m_r_to_l); + } + + friend struct std::hash>; +}; + +template +std::unordered_map, R> + format_as(ManyToOne const &m) { + std::unordered_map, R> result; + + for (R const &r : m.right_values()) { + result.insert({m.at_r(r), r}); + } + + return result; +} + +template +std::ostream &operator<<(std::ostream &s, ManyToOne const &m) { + return (s << fmt::to_string(m)); +} + +} // namespace FlexFlow + +namespace nlohmann { + +template +struct adl_serializer<::FlexFlow::ManyToOne> { + static ::FlexFlow::ManyToOne from_json(json const &j) { + CHECK_IS_JSON_DESERIALIZABLE(L); + CHECK_IS_JSON_DESERIALIZABLE(R); + + NOT_IMPLEMENTED(); + } + + static void to_json(json &j, ::FlexFlow::ManyToOne const &m) { + CHECK_IS_JSON_SERIALIZABLE(L); + CHECK_IS_JSON_SERIALIZABLE(R); + + NOT_IMPLEMENTED(); + } +}; + +} // namespace nlohmann + +namespace rc { + +template +struct Arbitrary<::FlexFlow::ManyToOne> { + static Gen<::FlexFlow::ManyToOne> arbitrary() { + NOT_IMPLEMENTED(); + } +}; + +} // namespace rc + +namespace std { + +template +struct hash<::FlexFlow::ManyToOne> { + size_t operator()(::FlexFlow::ManyToOne const &m) { + return ::FlexFlow::get_std_hash(m.tie()); + } +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/many_to_one/many_to_one_from_bidict.h b/lib/utils/include/utils/many_to_one/many_to_one_from_bidict.h new file mode 100644 index 0000000000..ba50a960c2 --- /dev/null +++ b/lib/utils/include/utils/many_to_one/many_to_one_from_bidict.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_MANY_TO_ONE_FROM_BIDICT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_MANY_TO_ONE_FROM_BIDICT_H + +#include "utils/bidict/bidict.h" +#include "utils/many_to_one/many_to_one.h" + +namespace FlexFlow { + +template +ManyToOne many_to_one_from_bidict(bidict const &b) { + ManyToOne result; + + for (auto const &[l, r] : b) { + result.insert({l, r}); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/many_to_one/many_to_one_from_map.h b/lib/utils/include/utils/many_to_one/many_to_one_from_map.h new file mode 100644 index 0000000000..e0484d2131 --- /dev/null +++ b/lib/utils/include/utils/many_to_one/many_to_one_from_map.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_MANY_TO_ONE_FROM_MAP_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_MANY_TO_ONE_FROM_MAP_H + +#include "utils/many_to_one/many_to_one.h" + +namespace FlexFlow { + +template +ManyToOne many_to_one_from_map(std::unordered_map const &m) { + ManyToOne result; + + for (auto const &[l, r] : m) { + result.insert({l, r}); + } + + return result; +} + +template +ManyToOne many_to_one_from_map(std::map const &m) { + ManyToOne result; + + for (auto const &[l, r] : m) { + result.insert({l, r}); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/many_to_one/many_to_one_from_unstructured_relation.h b/lib/utils/include/utils/many_to_one/many_to_one_from_unstructured_relation.h new file mode 100644 index 0000000000..171c6c15d6 --- /dev/null +++ b/lib/utils/include/utils/many_to_one/many_to_one_from_unstructured_relation.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_MANY_TO_ONE_FROM_UNSTRUCTURED_RELATION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_MANY_TO_ONE_FROM_UNSTRUCTURED_RELATION_H + +#include "utils/many_to_one/many_to_one.h" + +namespace FlexFlow { + +template +ManyToOne many_to_one_from_unstructured_relation( + std::unordered_set> const &relation) { + ManyToOne result; + for (auto const &lr : relation) { + result.insert(lr); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/many_to_one/unstructured_relation_from_many_to_one.h b/lib/utils/include/utils/many_to_one/unstructured_relation_from_many_to_one.h new file mode 100644 index 0000000000..676c5efa5d --- /dev/null +++ b/lib/utils/include/utils/many_to_one/unstructured_relation_from_many_to_one.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_UNSTRUCTURED_RELATION_FROM_MANY_TO_ONE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_UNSTRUCTURED_RELATION_FROM_MANY_TO_ONE_H + +#include "utils/containers/unordered_set_of.h" +#include "utils/many_to_one/many_to_one.h" + +namespace FlexFlow { + +template +std::unordered_set> + unstructured_relation_from_many_to_one(ManyToOne const &many_to_one) { + return unordered_set_of(many_to_one.l_to_r()); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/metafunction.h b/lib/utils/include/utils/metafunction.h index 4abd109e3e..b7b2c6f581 100644 --- a/lib/utils/include/utils/metafunction.h +++ b/lib/utils/include/utils/metafunction.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_METAFUNCTION_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_METAFUNCTION_H -#include "type_traits_core.h" +#include "utils/type_traits_core.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/nonnegative_int/ceildiv.h b/lib/utils/include/utils/nonnegative_int/ceildiv.h deleted file mode 100644 index e2ff0bc52a..0000000000 --- a/lib/utils/include/utils/nonnegative_int/ceildiv.h +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_NONNEGATIVE_INT_CEILDIV_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_NONNEGATIVE_INT_CEILDIV_H - -#include "utils/nonnegative_int/nonnegative_int.h" - -namespace FlexFlow { - -nonnegative_int ceildiv(nonnegative_int numerator, nonnegative_int denominator); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/nonnegative_int/range.h b/lib/utils/include/utils/nonnegative_int/range.h new file mode 100644 index 0000000000..7b1fa8d480 --- /dev/null +++ b/lib/utils/include/utils/nonnegative_int/range.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_NONNEGATIVE_INT_RANGE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_NONNEGATIVE_INT_RANGE_H + +#include "utils/nonnegative_int/nonnegative_int.h" +#include + +namespace FlexFlow { + +std::vector + range(nonnegative_int start, nonnegative_int end, int step = 1); +std::vector range(nonnegative_int end); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/one_to_many/exhaustive_relational_join.h b/lib/utils/include/utils/one_to_many/exhaustive_relational_join.h new file mode 100644 index 0000000000..a959df2398 --- /dev/null +++ b/lib/utils/include/utils/one_to_many/exhaustive_relational_join.h @@ -0,0 +1,35 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_EXHAUSTIVE_RELATIONAL_JOIN_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_EXHAUSTIVE_RELATIONAL_JOIN_H + +#include "utils/one_to_many/one_to_many.h" + +namespace FlexFlow { + +template +OneToMany exhaustive_relational_join(OneToMany const &fst, + OneToMany const &snd) { + OneToMany result; + + if (fst.right_values() != snd.left_values()) { + throw mk_runtime_error( + fmt::format("exhaustive_relational_join for OneToMany received inputs " + "with non-matching inner dimensions: right dimension of " + "fst is {} while left dimension of snd is {}", + fst.right_values(), + snd.left_values())); + } + + for (T1 const &t1 : fst.left_values()) { + for (T2 const &t2 : fst.at_l(t1)) { + for (T3 const &t3 : snd.at_l(t2)) { + result.insert({t1, t3}); + } + } + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/one_to_many/invert_one_to_many.h b/lib/utils/include/utils/one_to_many/invert_one_to_many.h new file mode 100644 index 0000000000..bde623d387 --- /dev/null +++ b/lib/utils/include/utils/one_to_many/invert_one_to_many.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_INVERT_ONE_TO_MANY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_INVERT_ONE_TO_MANY_H + +#include "utils/many_to_one/many_to_one.h" +#include "utils/one_to_many/one_to_many.h" + +namespace FlexFlow { + +template +ManyToOne invert_one_to_many(OneToMany const &one_to_many) { + ManyToOne result; + + for (R const &r : one_to_many.right_values()) { + result.insert({r, one_to_many.at_r(r)}); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/one_to_many/one_to_many.h b/lib/utils/include/utils/one_to_many/one_to_many.h new file mode 100644 index 0000000000..798ae2fb87 --- /dev/null +++ b/lib/utils/include/utils/one_to_many/one_to_many.h @@ -0,0 +1,171 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_H + +#include "utils/containers/generate_map.h" +#include "utils/containers/keys.h" +#include "utils/containers/try_at.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/values.h" +#include "utils/exception.h" +#include "utils/fmt/unordered_map.h" +#include "utils/fmt/unordered_set.h" +#include "utils/hash-utils.h" +#include "utils/hash/tuple.h" +#include "utils/hash/unordered_map.h" +#include "utils/hash/unordered_set.h" +#include "utils/json/check_is_json_deserializable.h" +#include "utils/json/check_is_json_serializable.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { + +template +struct OneToMany { +public: + OneToMany() : m_l_to_r(), m_r_to_l() {} + + template + OneToMany(It start, It end) : OneToMany() { + for (; start < end; start++) { + ASSERT(start->second.size() > 0); + for (R const &r : start->second) { + this->insert(std::pair{start->first, r}); + } + } + } + + OneToMany(std::initializer_list>> const + &l_to_r) + : OneToMany(l_to_r.begin(), l_to_r.end()) {} + + bool operator==(OneToMany const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(OneToMany const &other) const { + return this->tie() != other.tie(); + } + + void insert(std::pair const &p) { + L l = p.first; + R r = p.second; + + std::optional found_l = try_at(this->m_r_to_l, r); + + if (!found_l.has_value()) { + this->m_r_to_l.insert({r, l}); + this->m_l_to_r[l].insert(r); + } else if (found_l.value() == l) { + return; + } else { + throw mk_runtime_error( + fmt::format("Existing mapping found for right value {}: tried to map " + "to left value {}, but is already bound to left value {}", + r, + l, + found_l.value())); + } + } + + std::unordered_set const &at_l(L const &l) const { + return this->m_l_to_r.at(l); + } + + L const &at_r(R const &r) const { + return this->m_r_to_l.at(r); + } + + std::unordered_set left_values() const { + return keys(this->m_l_to_r); + } + + std::unordered_set right_values() const { + return keys(this->m_r_to_l); + } + + std::unordered_set> right_groups() const { + return unordered_set_of(values(this->m_l_to_r)); + } + + std::unordered_map> const &l_to_r() const { + return this->m_l_to_r; + } + + std::unordered_map const &r_to_l() const { + return this->m_r_to_l; + } + +private: + std::unordered_map> m_l_to_r; + std::unordered_map m_r_to_l; + +private: + std::tuple + tie() const { + return std::tie(this->m_l_to_r, this->m_r_to_l); + } + + friend struct std::hash>; +}; + +template +std::unordered_map> + format_as(OneToMany const &m) { + return generate_map(m.left_values(), [&](L const &l) { return m.at_l(l); }); +} + +template +std::ostream &operator<<(std::ostream &s, OneToMany const &m) { + return (s << fmt::to_string(m)); +} + +} // namespace FlexFlow + +namespace nlohmann { + +template +struct adl_serializer<::FlexFlow::OneToMany> { + static ::FlexFlow::OneToMany from_json(json const &j) { + CHECK_IS_JSON_DESERIALIZABLE(L); + CHECK_IS_JSON_DESERIALIZABLE(R); + + NOT_IMPLEMENTED(); + } + + static void to_json(json &j, ::FlexFlow::OneToMany const &m) { + CHECK_IS_JSON_SERIALIZABLE(L); + CHECK_IS_JSON_SERIALIZABLE(R); + + NOT_IMPLEMENTED(); + } +}; + +} // namespace nlohmann + +namespace rc { + +template +struct Arbitrary<::FlexFlow::OneToMany> { + static Gen<::FlexFlow::OneToMany> arbitrary() { + NOT_IMPLEMENTED(); + } +}; + +} // namespace rc + +namespace std { + +template +struct hash<::FlexFlow::OneToMany> { + size_t operator()(::FlexFlow::OneToMany const &m) { + return ::FlexFlow::get_std_hash(m.tie()); + } +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/one_to_many/one_to_many_from_bidict.h b/lib/utils/include/utils/one_to_many/one_to_many_from_bidict.h new file mode 100644 index 0000000000..3783f1f663 --- /dev/null +++ b/lib/utils/include/utils/one_to_many/one_to_many_from_bidict.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_FROM_BIDICT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_FROM_BIDICT_H + +#include "utils/bidict/bidict.h" +#include "utils/one_to_many/one_to_many.h" + +namespace FlexFlow { + +template +OneToMany one_to_many_from_bidict(bidict const &b) { + OneToMany result; + + for (auto const &[l, r] : b) { + result.insert({l, r}); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/one_to_many/one_to_many_from_l_to_r_mapping.h b/lib/utils/include/utils/one_to_many/one_to_many_from_l_to_r_mapping.h new file mode 100644 index 0000000000..bed62caaf6 --- /dev/null +++ b/lib/utils/include/utils/one_to_many/one_to_many_from_l_to_r_mapping.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_FROM_L_TO_R_MAPPING_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_FROM_L_TO_R_MAPPING_H + +#include "utils/one_to_many/one_to_many.h" +#include + +namespace FlexFlow { + +template +OneToMany one_to_many_from_l_to_r_mapping( + std::unordered_map> const &m) { + OneToMany result; + + for (auto const &[l, rs] : m) { + ASSERT(rs.size() > 0); + for (auto const &r : rs) { + result.insert({l, r}); + } + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/one_to_many/one_to_many_from_unstructured_relation.h b/lib/utils/include/utils/one_to_many/one_to_many_from_unstructured_relation.h new file mode 100644 index 0000000000..11c0a767d6 --- /dev/null +++ b/lib/utils/include/utils/one_to_many/one_to_many_from_unstructured_relation.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_FROM_UNSTRUCTURED_RELATION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_FROM_UNSTRUCTURED_RELATION_H + +#include "utils/one_to_many/one_to_many.h" + +namespace FlexFlow { + +template +OneToMany one_to_many_from_unstructured_relation( + std::unordered_set> const &rel) { + OneToMany result; + for (auto const &lr : rel) { + result.insert(lr); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/one_to_many/unstructured_relation_from_one_to_many.h b/lib/utils/include/utils/one_to_many/unstructured_relation_from_one_to_many.h new file mode 100644 index 0000000000..02ed225610 --- /dev/null +++ b/lib/utils/include/utils/one_to_many/unstructured_relation_from_one_to_many.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_UNSTRUCTURED_RELATION_FROM_ONE_TO_MANY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_UNSTRUCTURED_RELATION_FROM_ONE_TO_MANY_H + +#include "utils/containers/transform.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/one_to_many/one_to_many.h" + +namespace FlexFlow { + +template +std::unordered_set> + unstructured_relation_from_one_to_many(OneToMany const &one_to_many) { + return transform(unordered_set_of(one_to_many.r_to_l()), + [](std::pair const &rl) -> std::pair { + return std::pair{rl.second, rl.first}; + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/optional.h b/lib/utils/include/utils/optional.h index 377561d70c..cec3734907 100644 --- a/lib/utils/include/utils/optional.h +++ b/lib/utils/include/utils/optional.h @@ -28,7 +28,7 @@ T const &unwrap(std::optional const &o, F const &f) { template T const &assert_unwrap(std::optional const &o) { - assert(o.has_value()); + ASSERT(o.has_value()); return o.value(); } diff --git a/lib/utils/include/utils/ord/unordered_map.h b/lib/utils/include/utils/ord/unordered_map.h new file mode 100644 index 0000000000..1cfbdb27b6 --- /dev/null +++ b/lib/utils/include/utils/ord/unordered_map.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORD_UNORDERED_MAP_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORD_UNORDERED_MAP_H + +#include "utils/type_traits_core.h" +#include +#include +#include + +namespace FlexFlow { + +template +std::enable_if_t>, bool> + operator<(std::unordered_map const &lhs, + std::unordered_map const &rhs) { + CHECK_LT_COMPARABLE(K); + CHECK_LT_COMPARABLE(V); + + std::map lhs_ordered(lhs.cbegin(), lhs.cend()); + std::map rhs_ordered(rhs.cbegin(), rhs.cend()); + + return lhs_ordered < rhs_ordered; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/ord/unordered_set.h b/lib/utils/include/utils/ord/unordered_set.h new file mode 100644 index 0000000000..902b2df474 --- /dev/null +++ b/lib/utils/include/utils/ord/unordered_set.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORD_UNORDERED_SET_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORD_UNORDERED_SET_H + +#include "utils/type_traits_core.h" +#include +#include + +namespace FlexFlow { + +template +std::enable_if_t, bool> + operator<(std::unordered_set const &lhs, + std::unordered_set const &rhs) { + CHECK_LT_COMPARABLE(T); + + std::set lhs_ordered(lhs.cbegin(), lhs.cend()); + std::set rhs_ordered(rhs.cbegin(), rhs.cend()); + + return lhs_ordered < rhs_ordered; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/ord/vector.h b/lib/utils/include/utils/ord/vector.h new file mode 100644 index 0000000000..fca09d457a --- /dev/null +++ b/lib/utils/include/utils/ord/vector.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORD_VECTOR_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORD_VECTOR_H + +#include "utils/type_traits_core.h" +#include +#include + +namespace FlexFlow { + +template +std::enable_if_t, bool> + operator<(std::vector const &lhs, std::vector const &rhs) { + CHECK_LT_COMPARABLE(T); + + return std::lexicographical_compare( + lhs.cbegin(), lhs.cend(), rhs.cbegin(), rhs.cend()); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/dim_coord.dtg.toml b/lib/utils/include/utils/orthotope/dim_coord.dtg.toml new file mode 100644 index 0000000000..3ac729fc7e --- /dev/null +++ b/lib/utils/include/utils/orthotope/dim_coord.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "DimCoord" +type = "struct" +features = [ + "eq", + "ord", + "fmt", + "hash", +] + +template_params = [ + "T", +] + +includes = [ + "", + "utils/nonnegative_int/nonnegative_int.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "raw" +type = "std::unordered_map" diff --git a/lib/utils/include/utils/orthotope/dim_coord.h b/lib/utils/include/utils/orthotope/dim_coord.h new file mode 100644 index 0000000000..87a05a7315 --- /dev/null +++ b/lib/utils/include/utils/orthotope/dim_coord.h @@ -0,0 +1,174 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DIM_COORD_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DIM_COORD_H + +#include "utils/containers/all_of.h" +#include "utils/containers/contains_key.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/get_all_assignments.h" +#include "utils/containers/is_subseteq_of.h" +#include "utils/containers/keys.h" +#include "utils/containers/map_from_keys_and_values.h" +#include "utils/containers/map_values.h" +#include "utils/containers/product.h" +#include "utils/containers/require_same.h" +#include "utils/containers/restrict_keys.h" +#include "utils/containers/scanr.h" +#include "utils/containers/sorted_by.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/zip_with_strict.h" +#include "utils/exception.h" +#include "utils/nonnegative_int/nonnegative_range.h" +#include "utils/orthotope/dim_coord.dtg.h" +#include "utils/orthotope/dim_domain.dtg.h" +#include "utils/orthotope/dim_domain.h" +#include "utils/orthotope/minimal_dim_domain.h" +#include "utils/orthotope/orthotope.h" + +namespace FlexFlow { + +template +std::unordered_set get_coord_dims(DimCoord const &coord) { + return keys(coord.raw); +} + +template +DimCoord restrict_coord_to_dims(DimCoord const &coord, + std::unordered_set const &dims) { + return DimCoord{ + restrict_keys(coord.raw, dims), + }; +} + +template +OrthotopeCoord + orthotope_coord_from_dim_coord(DimCoord const &coord, + DimOrdering const &dim_ordering) { + return OrthotopeCoord{ + transform(sorted_by(get_coord_dims(coord), dim_ordering.lt), + [&](T const &t) { return coord.raw.at(t); }), + }; +} + +template +DimCoord dim_coord_from_orthotope_coord(OrthotopeCoord const &coord, + std::unordered_set const &dims, + DimOrdering const &dim_ordering) { + return DimCoord{ + map_from_keys_and_values(sorted_by(dims, dim_ordering.lt), coord.raw), + }; +} + +template +DimCoord lift_dim_coord(DimCoord const &coord, + std::unordered_set const &lifted_dims) { + ASSERT(is_subseteq_of(get_coord_dims(coord), lifted_dims)); + + return DimCoord{ + generate_map(lifted_dims, + [&](T const &dim) { + if (contains_key(coord.raw, dim)) { + return coord.raw.at(dim); + } else { + return 0_n; + } + }), + }; +} + +template +std::unordered_set> + get_coords_in_dim_domain(DimDomain const &dim_domain) { + std::unordered_map> + component_possible_values = map_values( + dim_domain.dims, + [](positive_int component_size) + -> std::unordered_set { + return unordered_set_of(nonnegative_range(component_size)); + }); + + return transform( + get_all_assignments(component_possible_values), + [](std::unordered_map const &assignment) { + return DimCoord{ + assignment, + }; + }); +} + +template +std::unordered_set> get_coords_in_minimal_dim_domain( + MinimalDimDomain const &minimal_dim_domain) { + return get_coords_in_dim_domain(lift_minimal_dim_domain(minimal_dim_domain)); +} + +template +DimCoord get_maximum_coord_in_domain(DimDomain const &domain) { + return DimCoord{ + map_values(domain.dims, + [](positive_int dim) -> nonnegative_int { + return nonnegative_int{ + dim.int_from_positive_int() - 1, + }; + }), + }; +} + +template +DimDomain get_domain_for_maximum_coord(DimCoord const &max_coord) { + return DimDomain{ + map_values(max_coord.raw, + [](nonnegative_int dim) -> positive_int { return dim + 1_p; }), + }; +} + +template +bool dim_domain_contains_coord(DimDomain const &domain, + DimCoord const &coord) { + ASSERT(get_domain_dims(domain) == get_coord_dims(coord)); + + std::unordered_set dims = + require_same(get_domain_dims(domain), get_coord_dims(coord)); + return all_of(dims, [&](T const &dim) { + return coord.raw.at(dim) < domain.dims.at(dim); + }); +} + +template +bool minimal_dim_domain_contains_coord(MinimalDimDomain const &domain, + DimCoord const &coord) { + return dim_domain_contains_coord(lift_minimal_dim_domain(domain), coord); +} + +template +nonnegative_int flatten_dim_coord(DimCoord const &coord, + DimDomain const &domain, + DimOrdering const &dim_ordering) { + ASSERT( + get_coord_dims(coord) == get_domain_dims(domain), + "flatten_dim_coord expected coord dimensions to match domain dimensions", + coord, + domain); + + OrthotopeCoord orthotope_coord = + orthotope_coord_from_dim_coord(coord, dim_ordering); + Orthotope orthotope_domain = orthotope_from_dim_domain(domain, dim_ordering); + + return flatten_orthotope_coord(orthotope_coord, orthotope_domain); +} + +template +DimCoord unflatten_dim_coord(nonnegative_int flattened, + DimDomain const &domain, + DimOrdering const &dim_ordering) { + Orthotope orthotope_domain = orthotope_from_dim_domain(domain, dim_ordering); + OrthotopeCoord orthotope_coord = + unflatten_orthotope_coord(flattened, orthotope_domain); + + return dim_coord_from_orthotope_coord( + orthotope_coord, get_domain_dims(domain), dim_ordering); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/dim_domain.dtg.toml b/lib/utils/include/utils/orthotope/dim_domain.dtg.toml new file mode 100644 index 0000000000..ccad639aac --- /dev/null +++ b/lib/utils/include/utils/orthotope/dim_domain.dtg.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "DimDomain" +type = "struct" +features = [ + "eq", + "ord", + "fmt", + "hash", + "json", +] + +template_params = [ + "T", +] + +includes = [ + "", + "utils/positive_int/positive_int.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "dims" +type = "std::unordered_map" diff --git a/lib/utils/include/utils/orthotope/dim_domain.h b/lib/utils/include/utils/orthotope/dim_domain.h new file mode 100644 index 0000000000..c940745a78 --- /dev/null +++ b/lib/utils/include/utils/orthotope/dim_domain.h @@ -0,0 +1,71 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DIM_DOMAIN_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DIM_DOMAIN_H + +#include "utils/containers/filter.h" +#include "utils/containers/keys.h" +#include "utils/containers/map_from_keys_and_values.h" +#include "utils/containers/restrict_keys.h" +#include "utils/containers/set_minus.h" +#include "utils/containers/sorted_by.h" +#include "utils/containers/transform.h" +#include "utils/nonnegative_int/num_elements.h" +#include "utils/orthotope/dim_domain.dtg.h" +#include "utils/orthotope/dim_ordering.dtg.h" +#include "utils/orthotope/orthotope.dtg.h" + +namespace FlexFlow { + +template +DimDomain empty_dim_domain() { + return DimDomain{{}}; +}; + +template +nonnegative_int dim_domain_num_dims(DimDomain const &domain) { + return num_elements(domain.dims); +} + +template +std::unordered_set get_domain_dims(DimDomain const &domain) { + return keys(domain.dims); +} + +template +std::unordered_set get_trivial_domain_dims(DimDomain const &domain) { + return filter(get_domain_dims(domain), + [&](T const &idx) { return domain.dims.at(idx) == 1; }); +} + +template +std::unordered_set get_nontrivial_domain_dims(DimDomain const &domain) { + return set_minus(get_domain_dims(domain), get_trivial_domain_dims(domain)); +} + +template +DimDomain restrict_domain_to_dims(DimDomain const &domain, + std::unordered_set const &allowed) { + return DimDomain{restrict_keys(domain.dims, allowed)}; +} + +template +Orthotope orthotope_from_dim_domain(DimDomain const &domain, + DimOrdering const &dim_ordering) { + return Orthotope{ + transform(sorted_by(get_domain_dims(domain), dim_ordering.lt), + [&](T const &t) { return domain.dims.at(t); }), + }; +} + +template +DimDomain dim_domain_from_orthotope(Orthotope const &orthotope, + std::unordered_set const &dims, + DimOrdering const &dim_ordering) { + return DimDomain{ + map_from_keys_and_values(sorted_by(dims, dim_ordering.lt), + orthotope.dims), + }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/dim_domain_mapping.h b/lib/utils/include/utils/orthotope/dim_domain_mapping.h new file mode 100644 index 0000000000..01440b4921 --- /dev/null +++ b/lib/utils/include/utils/orthotope/dim_domain_mapping.h @@ -0,0 +1,180 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DIM_DOMAIN_MAPPING_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DIM_DOMAIN_MAPPING_H + +#include "utils/bidict/algorithms/exhaustive_relational_join.h" +#include "utils/bidict/algorithms/left_entries.h" +#include "utils/bidict/algorithms/right_entries.h" +#include "utils/bidict/bidict.h" +#include "utils/bidict/generate_bidict.h" +#include "utils/hash/tuple.h" +#include "utils/orthotope/dim_coord.dtg.h" +#include "utils/orthotope/dim_coord.h" +#include "utils/orthotope/dim_domain.dtg.h" +#include "utils/orthotope/dim_ordering.dtg.h" +#include "utils/orthotope/dim_projection.h" +#include "utils/orthotope/minimal_dim_domain.dtg.h" + +namespace FlexFlow { + +template +struct DimDomainMapping { +public: + explicit DimDomainMapping( + bidict, DimCoord> const &coord_mapping, + DimDomain const &l_domain, + DimDomain const &r_domain) + : coord_mapping(coord_mapping), l_domain(l_domain), r_domain(r_domain) { + ASSERT(get_coords_in_dim_domain(l_domain) == left_entries(coord_mapping)); + ASSERT(get_coords_in_dim_domain(r_domain) == right_entries(coord_mapping)); + } + + DimCoord at_l(DimCoord const &l_coord) const { + ASSERT(dim_domain_contains_coord(this->l_domain, l_coord)); + + return this->coord_mapping.at_l(l_coord); + } + + DimCoord at_r(DimCoord const &r_coord) const { + ASSERT(dim_domain_contains_coord(this->r_domain, r_coord)); + + return this->coord_mapping.at_r(r_coord); + } + + bool operator==(DimDomainMapping const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(DimDomainMapping const &other) const { + return this->tie() != other.tie(); + } + +public: + bidict, DimCoord> coord_mapping; + DimDomain l_domain; + DimDomain r_domain; + +private: + std::tuple + tie() const { + return std::tie(this->coord_mapping, this->l_domain, this->r_domain); + } + + friend struct ::std::hash>; +}; + +template +std::string format_as(DimDomainMapping const &m) { + CHECK_FMTABLE(L); + CHECK_FMTABLE(R); + + return fmt::format( + "", + m.l_domain, + m.r_domain, + m.coord_mapping); +} + +template +std::ostream &operator<<(std::ostream &s, DimDomainMapping const &m) { + CHECK_FMTABLE(L); + CHECK_FMTABLE(R); + + return (s << fmt::to_string(m)); +} + +template +DimDomainMapping empty_dim_domain_mapping() { + return DimDomainMapping{ + /*coord_mapping=*/{ + {DimCoord{{}}, DimCoord{{}}}, + }, + /*l_domain=*/empty_dim_domain(), + /*r_domain=*/empty_dim_domain(), + }; +} + +template +DimDomainMapping + dim_domain_mapping_identity_map(DimDomain const &l_domain, + DimDomain const &r_domain, + DimOrdering const &l_dim_ordering, + DimOrdering const &r_dim_ordering) { + DimProjection projection = dim_projection_identity_map( + l_domain, r_domain, l_dim_ordering, r_dim_ordering); + + return dim_domain_mapping_from_projection( + /*projection=*/projection, + /*l_domain=*/l_domain, + /*r_domain=*/r_domain, + /*l_dim_ordering=*/l_dim_ordering, + /*r_dim_ordering=*/r_dim_ordering); +} + +template +DimDomainMapping invert_dim_domain_mapping( + DimDomainMapping const &dim_domain_mapping) { + + return DimDomainMapping{ + /*coord_mapping=*/dim_domain_mapping.coord_mapping.reversed(), + /*l_domain=*/dim_domain_mapping.r_domain, + /*r_domain=*/dim_domain_mapping.l_domain, + }; +} + +template +DimDomainMapping + compose_dim_domain_mappings(DimDomainMapping const &lhs, + DimDomainMapping const &rhs) { + + ASSERT(lhs.r_domain == rhs.l_domain); + + return DimDomainMapping{ + /*coord_mapping=*/exhaustive_relational_join(lhs.coord_mapping, + rhs.coord_mapping), + /*l_domain=*/lhs.l_domain, + /*r_domain=*/rhs.r_domain, + }; +} + +template +DimDomainMapping + dim_domain_mapping_from_projection(DimProjection const &projection, + DimDomain const &l_domain, + DimDomain const &r_domain, + DimOrdering const &l_dim_ordering, + DimOrdering const &r_dim_ordering) { + + return DimDomainMapping{ + /*coord_mapping=*/generate_bidict( + get_coords_in_dim_domain(l_domain), + [&](DimCoord const &l_coord) { + return compute_dim_projection( + /*projection=*/projection, + /*input_coord=*/l_coord, + /*input_domain=*/l_domain, + /*output_domain=*/r_domain, + /*input_dim_ordering=*/l_dim_ordering, + /*output_dim_ordering=*/r_dim_ordering); + }), + /*l_domain=*/l_domain, + /*r_domain=*/r_domain, + }; +} + +} // namespace FlexFlow + +namespace std { + +template +struct hash<::FlexFlow::DimDomainMapping> { + size_t operator()( + ::FlexFlow::DimDomainMapping const &dim_domain_mapping) const { + return get_std_hash(dim_domain_mapping.tie()); + } +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/orthotope/dim_ordering.dtg.toml b/lib/utils/include/utils/orthotope/dim_ordering.dtg.toml new file mode 100644 index 0000000000..26709724de --- /dev/null +++ b/lib/utils/include/utils/orthotope/dim_ordering.dtg.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "DimOrdering" +type = "struct" +features = [] + +template_params = [ + "T", +] + +includes = [ + "", +] + +[[fields]] +name = "lt" +type = "std::function" diff --git a/lib/utils/include/utils/orthotope/dim_ordering.h b/lib/utils/include/utils/orthotope/dim_ordering.h new file mode 100644 index 0000000000..774f1798a0 --- /dev/null +++ b/lib/utils/include/utils/orthotope/dim_ordering.h @@ -0,0 +1,36 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DIM_ORDERING_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DIM_ORDERING_H + +#include "utils/bidict/algorithms/bidict_from_enumerating.h" +#include "utils/orthotope/dim_ordering.dtg.h" + +namespace FlexFlow { + +template +DimOrdering make_default_dim_ordering() { + return DimOrdering{ + [](T const &lhs, T const &rhs) -> bool { return lhs < rhs; }, + }; +} + +template +DimOrdering make_reversed_dim_ordering() { + return DimOrdering{ + [](T const &lhs, T const &rhs) -> bool { return rhs < lhs; }, + }; +} + +template +DimOrdering make_dim_ordering_from_vector(std::vector const &v) { + bidict v_as_map = bidict_from_enumerating(v); + + return DimOrdering{ + [=](T const &lhs, T const &rhs) -> bool { + return v_as_map.at_r(lhs) <= v_as_map.at_r(rhs); + }, + }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/dim_projection.dtg.toml b/lib/utils/include/utils/orthotope/dim_projection.dtg.toml new file mode 100644 index 0000000000..a530adac5d --- /dev/null +++ b/lib/utils/include/utils/orthotope/dim_projection.dtg.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "DimProjection" +type = "variant" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = [ + "L", "R" +] + +includes = [ + "utils/orthotope/up_projection.dtg.h", + "utils/orthotope/eq_projection.dtg.h", + "utils/orthotope/down_projection.dtg.h", +] + +[[values]] +type = "::FlexFlow::UpProjection" +key = "up_proj" + +[[values]] +type = "::FlexFlow::EqProjection" +key = "eq_proj" + +[[values]] +type = "::FlexFlow::DownProjection" +key = "down_proj" diff --git a/lib/utils/include/utils/orthotope/dim_projection.h b/lib/utils/include/utils/orthotope/dim_projection.h new file mode 100644 index 0000000000..fa47edd897 --- /dev/null +++ b/lib/utils/include/utils/orthotope/dim_projection.h @@ -0,0 +1,235 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DIM_PROJECTION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DIM_PROJECTION_H + +#include "utils/bidict/algorithms/bidict_from_keys_and_values.h" +#include "utils/orthotope/dim_coord.h" +#include "utils/orthotope/dim_projection.dtg.h" +#include "utils/orthotope/down_projection.h" +#include "utils/orthotope/eq_projection.h" +#include "utils/orthotope/up_projection.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +DimProjection + dim_projection_identity_map(DimDomain const &input_domain, + DimDomain const &output_domain, + DimOrdering const &input_dim_ordering, + DimOrdering const &output_dim_ordering) { + + std::vector input_dims = + sorted_by(get_domain_dims(input_domain), input_dim_ordering.lt); + + std::vector output_dims = + sorted_by(get_domain_dims(output_domain), output_dim_ordering.lt); + + return DimProjection{ + EqProjection{ + bidict_from_keys_and_values(input_dims, output_dims), + }, + }; +} + +template +std::unordered_set + input_dims_of_projection(DimProjection const &projection) { + return projection.template visit>(overload{ + [](UpProjection const &p) { + return input_dims_of_up_projection(p); + }, + [](EqProjection const &p) { + return input_dims_of_eq_projection(p); + }, + [](DownProjection const &p) { + return input_dims_of_down_projection(p); + }, + }); +} + +template +std::unordered_set + output_dims_of_projection(DimProjection const &projection) { + return projection.template visit>(overload{ + [](UpProjection const &p) { + return output_dims_of_up_projection(p); + }, + [](EqProjection const &p) { + return output_dims_of_eq_projection(p); + }, + [](DownProjection const &p) { + return output_dims_of_down_projection(p); + }, + }); +}; + +template +DimProjection + invert_dim_projection(DimProjection const &projection) { + return projection.template visit>(overload{ + [](UpProjection const &p) { + return DimProjection{ + invert_up_projection(p), + }; + }, + [](EqProjection const &p) { + return DimProjection{ + invert_eq_projection(p), + }; + }, + [](DownProjection const &p) { + return DimProjection{ + invert_down_projection(p), + }; + }, + }); +} + +template +DimCoord compute_dim_projection(DimProjection const &projection, + DimCoord const &input_coord, + DimDomain const &input_domain, + DimDomain const &output_domain, + DimOrdering const &input_dim_ordering, + DimOrdering const &output_dim_ordering) { + DimCoord lifted_input_coord = + lift_dim_coord(input_coord, get_domain_dims(input_domain)); + + ASSERT(dim_domain_contains_coord(input_domain, input_coord), + input_domain, + input_coord); + + { + std::unordered_set nontrivial_input_domain_dims = + get_nontrivial_domain_dims(input_domain); + std::unordered_set projection_input_dims = + input_dims_of_projection(projection); + std::unordered_set all_input_domain_dims = get_domain_dims(input_domain); + + ASSERT(is_subseteq_of(nontrivial_input_domain_dims, projection_input_dims), + nontrivial_input_domain_dims, + projection_input_dims); + ASSERT(is_subseteq_of(projection_input_dims, all_input_domain_dims), + projection_input_dims, + all_input_domain_dims); + } + + { + std::unordered_set nontrivial_output_domain_dims = + get_nontrivial_domain_dims(output_domain); + std::unordered_set projection_output_dims = + output_dims_of_projection(projection); + std::unordered_set all_output_domain_dims = + get_domain_dims(output_domain); + + ASSERT( + is_subseteq_of(nontrivial_output_domain_dims, projection_output_dims), + nontrivial_output_domain_dims, + projection_output_dims); + ASSERT(is_subseteq_of(projection_output_dims, all_output_domain_dims), + projection_output_dims, + all_output_domain_dims); + } + + DimCoord output_coord = projection.template visit>(overload{ + [&](UpProjection const &p) -> DimCoord { + return compute_up_projection( + p, lifted_input_coord, output_domain, output_dim_ordering); + }, + [&](EqProjection const &p) -> DimCoord { + return compute_eq_projection(p, lifted_input_coord); + }, + [&](DownProjection const &p) -> DimCoord { + return compute_down_projection( + p, lifted_input_coord, input_domain, input_dim_ordering); + }, + }); + + DimCoord lifted_output_coord = + lift_dim_coord(output_coord, get_domain_dims(output_domain)); + + ASSERT(dim_domain_contains_coord(output_domain, lifted_output_coord), + output_domain, + lifted_output_coord, + input_domain, + lifted_input_coord); + + return lifted_output_coord; +} + +template +DimProjection + right_compose_eq_projection(DimProjection const &lhs, + EqProjection const &rhs) { + return lhs.template visit>(overload{ + [&](UpProjection const &lhs_up_proj) { + return DimProjection{ + compose_up_projections(lhs_up_proj, up_from_eq_proj(rhs)), + }; + }, + [&](EqProjection const &lhs_eq_proj) { + return DimProjection{ + compose_eq_projections(lhs_eq_proj, rhs), + }; + }, + [&](DownProjection const &lhs_down_proj) { + return DimProjection{ + compose_down_projections(lhs_down_proj, down_from_eq_proj(rhs)), + }; + }, + }); +} + +template +DimProjection + left_compose_eq_projection(EqProjection const &lhs, + DimProjection const &rhs) { + return rhs.template visit>(overload{ + [&](UpProjection const &rhs_up_proj) { + return DimProjection{ + compose_up_projections(up_from_eq_proj(lhs), rhs_up_proj), + }; + }, + [&](EqProjection const &rhs_eq_proj) { + return DimProjection{ + compose_eq_projections(lhs, rhs_eq_proj), + }; + }, + [&](DownProjection const &rhs_down_proj) { + return DimProjection{ + compose_down_projections(down_from_eq_proj(lhs), rhs_down_proj), + }; + }, + }); +} + +template +DimProjection + compose_dim_projections(DimProjection const &lhs, + DimProjection const &rhs) { + + if (lhs.is_eq_proj()) { + return DimProjection{ + left_compose_eq_projection(lhs.require_eq_proj(), rhs), + }; + } else if (rhs.is_eq_proj()) { + return DimProjection{ + right_compose_eq_projection(lhs, rhs.require_eq_proj()), + }; + } else if (lhs.is_up_proj() && rhs.is_up_proj()) { + return DimProjection{ + compose_up_projections(lhs.require_up_proj(), rhs.require_up_proj()), + }; + } else if (lhs.is_down_proj() && rhs.is_down_proj()) { + return DimProjection{ + compose_down_projections(lhs.require_down_proj(), + rhs.require_down_proj()), + }; + } else { + PANIC("Cannot compose projections", lhs, rhs); + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/down_projection.dtg.toml b/lib/utils/include/utils/orthotope/down_projection.dtg.toml new file mode 100644 index 0000000000..9a642d2b9f --- /dev/null +++ b/lib/utils/include/utils/orthotope/down_projection.dtg.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "DownProjection" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = [ + "L", "R" +] + +includes = [ + "utils/many_to_one/many_to_one.h", +] + +[[fields]] +name = "dim_mapping" +type = "::FlexFlow::ManyToOne" diff --git a/lib/utils/include/utils/orthotope/down_projection.h b/lib/utils/include/utils/orthotope/down_projection.h new file mode 100644 index 0000000000..f46a0f16c8 --- /dev/null +++ b/lib/utils/include/utils/orthotope/down_projection.h @@ -0,0 +1,103 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DOWN_PROJECTION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DOWN_PROJECTION_H + +#include "utils/many_to_one/exhaustive_relational_join.h" +#include "utils/many_to_one/invert_many_to_one.h" +#include "utils/many_to_one/many_to_one_from_bidict.h" +#include "utils/orthotope/dim_coord.h" +#include "utils/orthotope/dim_domain.dtg.h" +#include "utils/orthotope/dim_ordering.dtg.h" +#include "utils/orthotope/down_projection.dtg.h" +#include "utils/orthotope/eq_projection.dtg.h" +#include "utils/orthotope/orthotope.dtg.h" +#include "utils/orthotope/orthotope.h" +#include "utils/orthotope/orthotope_coord.dtg.h" +#include "utils/orthotope/up_projection.dtg.h" + +namespace FlexFlow { + +template +DownProjection make_empty_down_projection() { + return DownProjection{ManyToOne{}}; +} + +template +std::unordered_set + input_dims_of_down_projection(DownProjection const &projection) { + return projection.dim_mapping.left_values(); +} + +template +std::unordered_set + output_dims_of_down_projection(DownProjection const &projection) { + return projection.dim_mapping.right_values(); +} + +template +DimCoord compute_down_projection(DownProjection const &projection, + DimCoord const &coord, + DimDomain const &input_domain, + DimOrdering const &input_dim_ordering) { + std::unordered_set input_dims = input_dims_of_down_projection(projection); + std::unordered_set coord_dims = get_coord_dims(coord); + ASSERT(input_dims == coord_dims, + "compute_down_projection expected coord dimensions to match " + "projection input dimensions"); + + std::unordered_set output_dims = + output_dims_of_down_projection(projection); + + return DimCoord{ + generate_map( + output_dims, + [&](R const &output_dim) { + std::unordered_set src_dims = + projection.dim_mapping.at_r(output_dim); + + DimCoord src_coord = restrict_coord_to_dims(coord, src_dims); + DimDomain src_domain = + restrict_domain_to_dims(input_domain, src_dims); + + return flatten_dim_coord(src_coord, src_domain, input_dim_ordering); + }), + }; +} + +template +void project_dims(DownProjection &proj, + std::unordered_set const &from, + R const &onto) { + ASSERT(from.size() > 0); + + for (L const &l : from) { + proj.dim_mapping.insert({l, onto}); + } +} + +template +UpProjection + invert_down_projection(DownProjection const &down_proj) { + return UpProjection{ + /*dim_mapping=*/invert_many_to_one(down_proj.dim_mapping), + }; +} + +template +DownProjection + compose_down_projections(DownProjection const &fst, + DownProjection const &snd) { + return DownProjection{ + exhaustive_relational_join(fst.dim_mapping, snd.dim_mapping), + }; +} + +template +DownProjection down_from_eq_proj(EqProjection const &eq) { + return DownProjection{ + many_to_one_from_bidict(eq.dim_mapping), + }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/eq_projection.dtg.toml b/lib/utils/include/utils/orthotope/eq_projection.dtg.toml new file mode 100644 index 0000000000..972952f907 --- /dev/null +++ b/lib/utils/include/utils/orthotope/eq_projection.dtg.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "EqProjection" +type = "struct" +features = [ + "eq", + "hash", + "fmt", + "rapidcheck", +] + +template_params = [ + "L", "R" +] + +includes = [ + "utils/bidict/bidict.h", +] + +[[fields]] +name = "dim_mapping" +type = "::FlexFlow::bidict" diff --git a/lib/utils/include/utils/orthotope/eq_projection.h b/lib/utils/include/utils/orthotope/eq_projection.h new file mode 100644 index 0000000000..6b394ac04b --- /dev/null +++ b/lib/utils/include/utils/orthotope/eq_projection.h @@ -0,0 +1,61 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_EQ_PROJECTION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_EQ_PROJECTION_H + +#include "utils/bidict/algorithms/exhaustive_relational_join.h" +#include "utils/containers/map_keys.h" +#include "utils/orthotope/dim_coord.dtg.h" +#include "utils/orthotope/dim_domain.dtg.h" +#include "utils/orthotope/eq_projection.dtg.h" + +namespace FlexFlow { + +template +EqProjection make_empty_eq_projection() { + return EqProjection{bidict{}}; +} + +template +std::unordered_set + input_dims_of_eq_projection(EqProjection const &projection) { + return projection.dim_mapping.left_values(); +} + +template +std::unordered_set + output_dims_of_eq_projection(EqProjection const &projection) { + return projection.dim_mapping.right_values(); +} + +template +void project_dims(EqProjection &proj, L const &from, R const &to) { + proj.dim_mapping.equate(from, to); +} + +template +EqProjection invert_eq_projection(EqProjection const &input) { + return EqProjection{ + input.dim_mapping.reversed(), + }; +} + +template +EqProjection compose_eq_projections(EqProjection const &fst, + EqProjection const &snd) { + return EqProjection{ + exhaustive_relational_join(fst.dim_mapping, snd.dim_mapping)}; +} + +template +DimCoord compute_eq_projection(EqProjection const &projection, + DimCoord const &coord) { + return DimCoord{ + map_keys(coord.raw, + [&](L const &input_dim) -> R { + return projection.dim_mapping.at_l(input_dim); + }), + }; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/minimal_dim_domain.dtg.toml b/lib/utils/include/utils/orthotope/minimal_dim_domain.dtg.toml new file mode 100644 index 0000000000..18d2aad26f --- /dev/null +++ b/lib/utils/include/utils/orthotope/minimal_dim_domain.dtg.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "MinimalDimDomain" +type = "struct" +features = [ + "eq", + "ord", + "fmt", + "hash", + "json", +] + +template_params = [ + "T", +] + +includes = [ + "", + "utils/int_ge_two/int_ge_two.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "dims" +type = "std::unordered_map" diff --git a/lib/utils/include/utils/orthotope/minimal_dim_domain.h b/lib/utils/include/utils/orthotope/minimal_dim_domain.h new file mode 100644 index 0000000000..3934e2af62 --- /dev/null +++ b/lib/utils/include/utils/orthotope/minimal_dim_domain.h @@ -0,0 +1,114 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_MINIMAL_DIM_DOMAIN_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_MINIMAL_DIM_DOMAIN_H + +#include "utils/containers/are_disjoint.h" +#include "utils/containers/binary_merge_disjoint_maps.h" +#include "utils/containers/filtermap_values.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/keys.h" +#include "utils/containers/map_from_keys_and_values.h" +#include "utils/containers/map_values.h" +#include "utils/containers/restrict_keys.h" +#include "utils/containers/sorted_by.h" +#include "utils/containers/transform.h" +#include "utils/nonnegative_int/num_elements.h" +#include "utils/orthotope/dim_domain.dtg.h" +#include "utils/orthotope/dim_ordering.dtg.h" +#include "utils/orthotope/minimal_dim_domain.dtg.h" +#include "utils/orthotope/minimal_orthotope.dtg.h" + +namespace FlexFlow { + +template +MinimalDimDomain empty_minimal_dim_domain() { + return MinimalDimDomain{{}}; +} + +template +nonnegative_int minimal_dim_domain_num_dims(MinimalDimDomain const &domain) { + return num_elements(domain.dims); +} + +template +DimDomain + lift_minimal_dim_domain(MinimalDimDomain const &minimal_dim_domain) { + return DimDomain{ + map_values(minimal_dim_domain.dims, + [](int_ge_two component) { + return component.positive_int_from_int_ge_two(); + }), + }; +} + +template +MinimalDimDomain + require_dim_domain_is_minimal(DimDomain const &dim_domain) { + return MinimalDimDomain{ + map_values(dim_domain.dims, + [](positive_int dim_size) { return int_ge_two{dim_size}; }), + }; +} + +template +MinimalDimDomain + minimal_dim_domain_from_dim_domain(DimDomain const &dim_domain) { + return MinimalDimDomain{ + filtermap_values(dim_domain.dims, try_int_ge_two_from_positive_int)}; +} + +template +DimDomain dim_domain_from_minimal_dim_domain( + MinimalDimDomain const &minimal_dim_domain, + std::unordered_set const &trivial_dims) { + std::unordered_set nontrivial_dims = + get_minimal_domain_dims(minimal_dim_domain); + + ASSERT(are_disjoint(nontrivial_dims, trivial_dims)); + + return DimDomain{ + /*dims=*/binary_merge_disjoint_maps( + map_values( + minimal_dim_domain.dims, + [](int_ge_two x) { return x.positive_int_from_int_ge_two(); }), + generate_map(trivial_dims, [](T const &) { return 1_p; })), + }; +} + +template +std::unordered_set + get_minimal_domain_dims(MinimalDimDomain const &domain) { + return keys(domain.dims); +} + +template +MinimalDimDomain + restrict_minimal_domain_to_dims(MinimalDimDomain const &domain, + std::unordered_set const &allowed) { + return MinimalDimDomain{restrict_keys(domain.dims, allowed)}; +} + +template +MinimalOrthotope minimal_orthotope_from_minimal_dim_domain( + MinimalDimDomain const &domain, DimOrdering const &dim_ordering) { + + return MinimalOrthotope{ + transform(sorted_by(get_minimal_domain_dims(domain), dim_ordering.lt), + [&](T const &t) { return domain.dims.at(t); }), + }; +} + +template +MinimalDimDomain minimal_dim_domain_from_minimal_orthotope( + MinimalOrthotope const &orthotope, + std::unordered_set const &dims, + DimOrdering const &dim_ordering) { + + return MinimalDimDomain{ + map_from_keys_and_values(sorted_by(dims, dim_ordering.lt), + orthotope.dims), + }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/minimal_dim_domain_mapping.h b/lib/utils/include/utils/orthotope/minimal_dim_domain_mapping.h new file mode 100644 index 0000000000..1ebff61701 --- /dev/null +++ b/lib/utils/include/utils/orthotope/minimal_dim_domain_mapping.h @@ -0,0 +1,263 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_MINIMAL_DIM_DOMAIN_MAPPING_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_MINIMAL_DIM_DOMAIN_MAPPING_H + +#include "utils/bidict/algorithms/exhaustive_relational_join.h" +#include "utils/bidict/algorithms/left_entries.h" +#include "utils/bidict/algorithms/right_entries.h" +#include "utils/bidict/algorithms/transform_keys.h" +#include "utils/bidict/algorithms/transform_values.h" +#include "utils/bidict/bidict.h" +#include "utils/bidict/generate_bidict.h" +#include "utils/hash/tuple.h" +#include "utils/orthotope/dim_coord.dtg.h" +#include "utils/orthotope/dim_coord.h" +#include "utils/orthotope/dim_domain_mapping.h" +#include "utils/orthotope/dim_ordering.dtg.h" +#include "utils/orthotope/dim_projection.h" +#include "utils/orthotope/minimal_dim_domain.dtg.h" + +namespace FlexFlow { + +template +struct MinimalDimDomainMapping { +public: + explicit MinimalDimDomainMapping( + bidict, DimCoord> const &coord_mapping, + MinimalDimDomain const &l_domain, + MinimalDimDomain const &r_domain) + : coord_mapping(coord_mapping), l_domain(l_domain), r_domain(r_domain) { + ASSERT(get_coords_in_minimal_dim_domain(l_domain) == + left_entries(coord_mapping)); + ASSERT(get_coords_in_minimal_dim_domain(r_domain) == + right_entries(coord_mapping)); + } + + DimCoord at_l(DimCoord const &l_coord) const { + ASSERT(minimal_dim_domain_contains_coord(this->l_domain, l_coord)); + + return this->coord_mapping.at_l(l_coord); + } + + DimCoord at_r(DimCoord const &r_coord) const { + ASSERT(minimal_dim_domain_contains_coord(this->r_domain, r_coord)); + + return this->coord_mapping.at_r(r_coord); + } + + bool operator==(MinimalDimDomainMapping const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(MinimalDimDomainMapping const &other) const { + return this->tie() != other.tie(); + } + +public: + bidict, DimCoord> coord_mapping; + MinimalDimDomain l_domain; + MinimalDimDomain r_domain; + +private: + std::tuple + tie() const { + return std::tie(this->coord_mapping, this->l_domain, this->r_domain); + } + + friend struct ::std::hash>; +}; + +template +std::string format_as(MinimalDimDomainMapping const &m) { + CHECK_FMTABLE(L); + CHECK_FMTABLE(R); + + return fmt::format( + "", + m.l_domain, + m.r_domain, + m.coord_mapping); +} + +template +std::ostream &operator<<(std::ostream &s, + MinimalDimDomainMapping const &m) { + CHECK_FMTABLE(L); + CHECK_FMTABLE(R); + + return (s << fmt::to_string(m)); +} + +template +MinimalDimDomainMapping + minimal_mapping_from_dim_domain_mapping(DimDomainMapping const &m) { + + std::unordered_set l_nontrivial_dims = + get_nontrivial_domain_dims(m.l_domain); + + std::unordered_set r_nontrivial_dims = + get_nontrivial_domain_dims(m.r_domain); + + return MinimalDimDomainMapping{ + /*coord_mapping=*/ + transform_keys(transform_values(m.coord_mapping, + [&](DimCoord const &r_coord) { + return restrict_coord_to_dims( + r_coord, r_nontrivial_dims); + }), + [&](DimCoord const &l_coord) { + return restrict_coord_to_dims(l_coord, + l_nontrivial_dims); + }), + /*l_domain=*/minimal_dim_domain_from_dim_domain(m.l_domain), + /*r_domain=*/minimal_dim_domain_from_dim_domain(m.r_domain), + }; +} + +template +DimDomainMapping dim_domain_mapping_from_minimal_dim_domain( + MinimalDimDomainMapping const &m, + std::unordered_set const &l_trivial_dims, + std::unordered_set const &r_trivial_dims) { + + DimDomain l_domain = + dim_domain_from_minimal_dim_domain(m.l_domain, l_trivial_dims); + DimDomain r_domain = + dim_domain_from_minimal_dim_domain(m.r_domain, r_trivial_dims); + + std::unordered_set all_l_dims = get_domain_dims(l_domain); + std::unordered_set all_r_dims = get_domain_dims(r_domain); + + return DimDomainMapping{ + /*coord_mapping=*/ + transform_keys(transform_values(m.coord_mapping, + [&](DimCoord const &r_coord) { + return lift_dim_coord(r_coord, + all_r_dims); + }), + [&](DimCoord const &l_coord) { + return lift_dim_coord(l_coord, all_l_dims); + }), + /*l_domain=*/l_domain, + /*r_domain=*/r_domain, + }; +} + +template +MinimalDimDomainMapping empty_minimal_dim_domain_mapping() { + return MinimalDimDomainMapping{ + /*coord_mapping=*/{}, + /*l_domain=*/empty_minimal_dim_domain(), + /*r_domain=*/empty_minimal_dim_domain(), + }; +} + +template +MinimalDimDomainMapping minimal_dim_domain_mapping_identity_map( + MinimalDimDomain const &l_domain, + MinimalDimDomain const &r_domain, + DimOrdering const &l_dim_ordering, + DimOrdering const &r_dim_ordering) { + DimProjection projection = + dim_projection_identity_map(lift_minimal_dim_domain(l_domain), + lift_minimal_dim_domain(r_domain), + l_dim_ordering, + r_dim_ordering); + + return minimal_dim_domain_mapping_from_projection( + /*projection=*/projection, + /*l_domain=*/l_domain, + /*r_domain=*/r_domain, + /*l_dim_ordering=*/l_dim_ordering, + /*r_dim_ordering=*/r_dim_ordering); +} + +template +MinimalDimDomainMapping invert_minimal_dim_domain_mapping( + MinimalDimDomainMapping const &minimal_dim_domain_mapping) { + + return MinimalDimDomainMapping{ + /*coord_mapping=*/minimal_dim_domain_mapping.coord_mapping.reversed(), + /*l_domain=*/minimal_dim_domain_mapping.r_domain, + /*r_domain=*/minimal_dim_domain_mapping.l_domain, + }; +} + +template +MinimalDimDomainMapping compose_minimal_dim_domain_mappings( + MinimalDimDomainMapping const &lhs, + MinimalDimDomainMapping const &rhs) { + + ASSERT(lhs.r_domain == rhs.l_domain); + + return MinimalDimDomainMapping{ + /*coord_mapping=*/exhaustive_relational_join(lhs.coord_mapping, + rhs.coord_mapping), + /*l_domain=*/lhs.l_domain, + /*r_domain=*/rhs.r_domain, + }; +} + +template +DimDomainMapping compose_dim_domain_mappings_through_minimal( + DimDomainMapping const &lhs, DimDomainMapping const &rhs) { + + MinimalDimDomainMapping minimal_lhs = + minimal_mapping_from_dim_domain_mapping(lhs); + + std::unordered_set t1_trivial_dims = + get_trivial_domain_dims(lhs.l_domain); + + MinimalDimDomainMapping minimal_rhs = + minimal_mapping_from_dim_domain_mapping(rhs); + + std::unordered_set t3_trivial_dims = + get_trivial_domain_dims(rhs.r_domain); + + return dim_domain_mapping_from_minimal_dim_domain( + compose_minimal_dim_domain_mappings(minimal_lhs, minimal_rhs), + t1_trivial_dims, + t3_trivial_dims); +} + +template +MinimalDimDomainMapping minimal_dim_domain_mapping_from_projection( + DimProjection const &projection, + MinimalDimDomain const &l_domain, + MinimalDimDomain const &r_domain, + DimOrdering const &l_dim_ordering, + DimOrdering const &r_dim_ordering) { + + return MinimalDimDomainMapping{ + /*coord_mapping=*/generate_bidict( + get_coords_in_minimal_dim_domain(l_domain), + [&](DimCoord const &l_coord) { + return compute_dim_projection( + /*projection=*/projection, + /*input_coord=*/l_coord, + /*input_domain=*/lift_minimal_dim_domain(l_domain), + /*output_domain=*/lift_minimal_dim_domain(r_domain), + /*input_dim_ordering=*/l_dim_ordering, + /*output_dim_ordering=*/r_dim_ordering); + }), + /*l_domain=*/l_domain, + /*r_domain=*/r_domain, + }; +} + +} // namespace FlexFlow + +namespace std { + +template +struct hash<::FlexFlow::MinimalDimDomainMapping> { + size_t operator()(::FlexFlow::MinimalDimDomainMapping const + &minimal_dim_domain_mapping) const { + return get_std_hash(minimal_dim_domain_mapping.tie()); + } +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/orthotope/minimal_orthotope.dtg.toml b/lib/utils/include/utils/orthotope/minimal_orthotope.dtg.toml new file mode 100644 index 0000000000..ef959d23e6 --- /dev/null +++ b/lib/utils/include/utils/orthotope/minimal_orthotope.dtg.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "MinimalOrthotope" +type = "struct" +features = [ + "eq", + "ord", + "fmt", + "hash", + "json", + "rapidcheck", +] + +includes = [ + "", + "utils/int_ge_two/int_ge_two.h", +] + +src_includes = [ + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "dims" +type = "std::vector<::FlexFlow::int_ge_two>" diff --git a/lib/utils/include/utils/orthotope/minimal_orthotope.h b/lib/utils/include/utils/orthotope/minimal_orthotope.h new file mode 100644 index 0000000000..481713c1e6 --- /dev/null +++ b/lib/utils/include/utils/orthotope/minimal_orthotope.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_MINIMAL_ORTHOTOPE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_MINIMAL_ORTHOTOPE_H + +#include "utils/orthotope/minimal_orthotope.dtg.h" +#include "utils/orthotope/orthotope.dtg.h" + +namespace FlexFlow { + +nonnegative_int minimal_orthotope_get_num_dims(MinimalOrthotope const &); +positive_int minimal_orthotope_get_volume(MinimalOrthotope const &); + +MinimalOrthotope require_orthotope_is_minimal(Orthotope const &); +Orthotope orthotope_from_minimal_orthotope(MinimalOrthotope const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/orthotope.dtg.toml b/lib/utils/include/utils/orthotope/orthotope.dtg.toml new file mode 100644 index 0000000000..1b608178f1 --- /dev/null +++ b/lib/utils/include/utils/orthotope/orthotope.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "Orthotope" +type = "struct" +features = [ + "eq", + "ord", + "fmt", + "hash", + "json", +] + +includes = [ + "", + "utils/positive_int/positive_int.h", +] + +src_includes = [ + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "dims" +type = "std::vector<::FlexFlow::positive_int>" diff --git a/lib/utils/include/utils/orthotope/orthotope.h b/lib/utils/include/utils/orthotope/orthotope.h new file mode 100644 index 0000000000..509497ff00 --- /dev/null +++ b/lib/utils/include/utils/orthotope/orthotope.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_H + +#include "utils/orthotope/orthotope.dtg.h" +#include "utils/orthotope/orthotope_coord.dtg.h" + +namespace FlexFlow { + +nonnegative_int orthotope_get_num_dims(Orthotope const &); + +positive_int orthotope_get_volume(Orthotope const &); + +std::unordered_set + get_all_coords_in_orthotope(Orthotope const &); + +bool orthotope_contains_coord(Orthotope const &, OrthotopeCoord const &); + +Orthotope restrict_orthotope_to_dims(Orthotope const &, + std::set const &); + +nonnegative_int flatten_orthotope_coord(OrthotopeCoord const &, + Orthotope const &); + +OrthotopeCoord orthotope_get_maximum_coord(Orthotope const &); + +nonnegative_int orthotope_get_maximum_offset(Orthotope const &); + +OrthotopeCoord unflatten_orthotope_coord(nonnegative_int, Orthotope const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/orthotope_coord.dtg.toml b/lib/utils/include/utils/orthotope/orthotope_coord.dtg.toml new file mode 100644 index 0000000000..76e2b6c76d --- /dev/null +++ b/lib/utils/include/utils/orthotope/orthotope_coord.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "OrthotopeCoord" +type = "struct" +features = [ + "eq", + "ord", + "fmt", + "hash", + "json", +] + +includes = [ + "", + "utils/nonnegative_int/nonnegative_int.h", +] + +src_includes = [ + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "raw" +type = "std::vector<::FlexFlow::nonnegative_int>" diff --git a/lib/utils/include/utils/orthotope/orthotope_coord.h b/lib/utils/include/utils/orthotope/orthotope_coord.h new file mode 100644 index 0000000000..97d9afa03c --- /dev/null +++ b/lib/utils/include/utils/orthotope/orthotope_coord.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_COORD_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_COORD_H + +#include "utils/orthotope/orthotope_coord.dtg.h" + +namespace FlexFlow { + +nonnegative_int orthotope_coord_num_dims(OrthotopeCoord const &); + +OrthotopeCoord restrict_orthotope_coord_to_dims( + OrthotopeCoord const &coord, std::set const &allowed_dims); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/up_projection.dtg.toml b/lib/utils/include/utils/orthotope/up_projection.dtg.toml new file mode 100644 index 0000000000..c99e6eec93 --- /dev/null +++ b/lib/utils/include/utils/orthotope/up_projection.dtg.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "UpProjection" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = [ + "L", "R" +] + +includes = [ + "utils/one_to_many/one_to_many.h", +] + +[[fields]] +name = "dim_mapping" +type = "::FlexFlow::OneToMany" diff --git a/lib/utils/include/utils/orthotope/up_projection.h b/lib/utils/include/utils/orthotope/up_projection.h new file mode 100644 index 0000000000..7fa7c0339c --- /dev/null +++ b/lib/utils/include/utils/orthotope/up_projection.h @@ -0,0 +1,106 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_UP_PROJECTION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_UP_PROJECTION_H + +#include "utils/containers/flatmap.h" +#include "utils/containers/is_subseteq_of.h" +#include "utils/containers/keys.h" +#include "utils/containers/values.h" +#include "utils/one_to_many/exhaustive_relational_join.h" +#include "utils/one_to_many/invert_one_to_many.h" +#include "utils/one_to_many/one_to_many_from_bidict.h" +#include "utils/orthotope/dim_coord.h" +#include "utils/orthotope/dim_domain.dtg.h" +#include "utils/orthotope/down_projection.dtg.h" +#include "utils/orthotope/eq_projection.dtg.h" +#include "utils/orthotope/up_projection.dtg.h" + +namespace FlexFlow { + +template +UpProjection make_empty_up_projection() { + return UpProjection{OneToMany{}}; +} + +template +std::unordered_set + input_dims_of_up_projection(UpProjection const &projection) { + return projection.dim_mapping.left_values(); +} + +template +std::unordered_set + output_dims_of_up_projection(UpProjection const &projection) { + return projection.dim_mapping.right_values(); +} + +template +DimCoord compute_up_projection(UpProjection const &projection, + DimCoord const &coord, + DimDomain const &output_domain, + DimOrdering const &output_dim_ordering) { + std::unordered_set input_dims = input_dims_of_up_projection(projection); + std::unordered_set coord_dims = get_coord_dims(coord); + ASSERT(input_dims == coord_dims, + "compute_up_projection expected coord dimensions to match projection " + "input dimensions"); + + std::unordered_set output_dims = output_dims_of_up_projection(projection); + std::unordered_set output_domain_dims = get_domain_dims(output_domain); + ASSERT(is_subseteq_of(output_dims, output_domain_dims)); + + DimCoord unlifted = DimCoord{ + flatmap(coord.raw, + [&](L const &input_dim, nonnegative_int input_dim_val) { + std::unordered_set dst_dims = + projection.dim_mapping.at_l(input_dim); + + DimDomain dst_domain = + restrict_domain_to_dims(output_domain, dst_dims); + + DimCoord dst_coord = unflatten_dim_coord( + input_dim_val, dst_domain, output_dim_ordering); + + return dst_coord.raw; + }), + }; + + return unlifted; +} + +template +void project_dims(UpProjection &proj, + L const &onto, + std::unordered_set const &from) { + ASSERT(from.size() > 0); + + for (R const &r : from) { + proj.dim_mapping.insert({onto, r}); + } +} + +template +DownProjection invert_up_projection(UpProjection const &up_proj) { + return DownProjection{ + /*dim_mapping=*/invert_one_to_many(up_proj.dim_mapping), + }; +} + +template +UpProjection compose_up_projections(UpProjection const &fst, + UpProjection const &snd) { + return UpProjection{ + /*dim_mapping=*/exhaustive_relational_join(fst.dim_mapping, + snd.dim_mapping), + }; +} + +template +UpProjection up_from_eq_proj(EqProjection const &eq) { + return UpProjection{ + one_to_many_from_bidict(eq.dim_mapping), + }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/positive_int/positive_int.h b/lib/utils/include/utils/positive_int/positive_int.h index 6ddddadf50..6d97e21cd6 100644 --- a/lib/utils/include/utils/positive_int/positive_int.h +++ b/lib/utils/include/utils/positive_int/positive_int.h @@ -56,6 +56,8 @@ struct positive_int { positive_int &operator+=(positive_int other); positive_int &operator+=(nonnegative_int other); + friend positive_int operator+(nonnegative_int lhs, positive_int rhs); + positive_int operator*(positive_int other) const; positive_int &operator*=(positive_int other); nonnegative_int operator*(nonnegative_int other) const; diff --git a/lib/utils/include/utils/random_utils.h b/lib/utils/include/utils/random_utils.h index 99da9646a1..efc038f851 100644 --- a/lib/utils/include/utils/random_utils.h +++ b/lib/utils/include/utils/random_utils.h @@ -5,9 +5,9 @@ #include #include -float randf() { - return static_cast(std::rand()) / static_cast(RAND_MAX); -} +namespace FlexFlow { + +float randf(); template T select_random(std::vector const &values) { @@ -45,4 +45,6 @@ T select_random(std::vector const &values, return select_random_determistic(values, weights, randf()); } +} // namespace FlexFlow + #endif // _RANDOM_UTILS_H diff --git a/lib/utils/include/utils/required.h b/lib/utils/include/utils/required.h index d16b67ba86..863ee36fc6 100644 --- a/lib/utils/include/utils/required.h +++ b/lib/utils/include/utils/required.h @@ -28,31 +28,16 @@ struct adl_serializer<::FlexFlow::req> { }; } // namespace nlohmann -/* namespace fmt { */ - -/* template */ -/* struct formatter<::FlexFlow::req> : formatter { */ -/* template */ -/* auto format(::FlexFlow::req const &t, FormatContext &ctx) */ -/* -> decltype(ctx.out()) { */ -/* return formatter::format(static_cast(t), ctx); */ -/* } */ -/* }; */ - -/* } // namespace fmt */ - namespace FlexFlow { static_assert(is_json_serializable>::value, ""); static_assert(is_json_deserializable>::value, ""); static_assert(is_jsonable>::value, ""); CHECK_FMTABLE(req); CHECK_FMTABLE(std::vector); -// CHECK_FMTABLE(required_inheritance_impl>); static_assert( std::is_base_of>, req>>::value, ""); -// CHECK_FMTABLE(req>); } // namespace FlexFlow diff --git a/lib/utils/include/utils/required_core.h b/lib/utils/include/utils/required_core.h index 8ac772439f..83110af416 100644 --- a/lib/utils/include/utils/required_core.h +++ b/lib/utils/include/utils/required_core.h @@ -1,9 +1,8 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_REQUIRED_CORE_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_REQUIRED_CORE_H -#include "hash-utils.h" -#include "test_types.h" -#include "type_traits_core.h" +#include "utils/hash-utils.h" +#include "utils/type_traits_core.h" #include #include @@ -15,12 +14,6 @@ struct enable_if_valid {}; template struct enable_if_valid, Args...> : type_identity {}; -/* required_wrapper_impl() + std::declval())>> */ -/* operator+(required_wrapper_impl const &lhs, required_wrapper_impl const - * &rhs) { */ -/* /1* return 1; *1/ */ -/* } */ - template struct required_wrapper_impl { public: @@ -52,14 +45,6 @@ struct required_wrapper_impl { return static_cast(r); } - /* T const &operator*() const { */ - /* return this->m_value; */ - /* } */ - - /* T const *operator->() const { */ - /* return &this->m_value; */ - /* } */ - template enable_if_t::value, bool> operator==(required_wrapper_impl const &rhs) const { @@ -72,54 +57,12 @@ struct required_wrapper_impl { return this->m_value == rhs; } - /* friend enable_if_t::value, bool> */ - /* operator==(required_wrapper_impl const &lhs, T const &rhs) { */ - /* return lhs.m_value == rhs; */ - /* } */ - - /* friend enable_if_t::value, bool> */ - /* operator==(T const &lhs, required_wrapper_impl const &rhs) { */ - /* return lhs == rhs.m_value; */ - /* } */ - template enable_if_t::value, bool> operator!=(required_wrapper_impl const &rhs) const { return this->m_value != rhs.m_value; } - /* friend enable_if_t::value, - * required_wrapper_impl() + std::declval())>> */ - /* operator+(required_wrapper_impl const &lhs, required_wrapper_impl - * const &rhs) { */ - /* /1* return 1; *1/ */ - /* } */ - /* required_wrapper_impl */ - /* operator+(required_wrapper_impl const &rhs) { */ - /* Out o = this->m_value + rhs.m_value; */ - /* return required_wrapper_impl{o}; */ - /* } */ - - /* template ::value> = true> */ - /* required_wrapper_impl operator-(required_wrapper_impl const &rhs) { */ - /* return {this->m_value - rhs.m_value}; */ - /* } */ - - /* template ::value> = true> */ - /* required_wrapper_impl operator*(required_wrapper_impl const &rhs) { */ - /* return {this->m_value * rhs.m_value}; */ - /* } */ - - /* bool operator<(T const &other) const { */ - /* return this->m_value < other; */ - /* } */ - - /* bool operator>(T const &other) const { */ - /* return this->m_value > other; */ - /* } */ - private: T m_value; }; @@ -211,10 +154,6 @@ using remove_req_t = typename remove_req::type; static_assert( is_equal_comparable>>::value, ""); -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_HASH( - required_inheritance_impl); -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_HASH( - required_wrapper_impl); /* static_assert(std::is_same>() * + std::declval>()), diff --git a/lib/utils/include/utils/sequence.h b/lib/utils/include/utils/sequence.h index 07e4554299..26ed4a55f9 100644 --- a/lib/utils/include/utils/sequence.h +++ b/lib/utils/include/utils/sequence.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_SEQUENCE_H #include "utils/tuple.h" -#include "utils/visitable_core.h" #include #include @@ -79,13 +78,6 @@ using seq_enumerate_t = typename seq_enumerate::type; template struct seq_transform_type; -template -struct seq_transform_type> - : tuple_prepend_type< - visit_struct::traits::clean_t()( - std::declval>()))>, - typename seq_transform_type>::type> {}; - template struct seq_transform_type> { using type = std::tuple<>; diff --git a/lib/utils/include/utils/stack_map.h b/lib/utils/include/utils/stack_map.h index f26deee92d..cdb6defb66 100644 --- a/lib/utils/include/utils/stack_map.h +++ b/lib/utils/include/utils/stack_map.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_UTILS_STACK_MAP_H #define _FLEXFLOW_UTILS_STACK_MAP_H +#include "utils/containers/sorted_by.h" #include "utils/stack_vector/stack_vector.h" namespace std { diff --git a/lib/utils/include/utils/stack_string.h b/lib/utils/include/utils/stack_string.h index c1ab3f4570..bb62e2f90b 100644 --- a/lib/utils/include/utils/stack_string.h +++ b/lib/utils/include/utils/stack_string.h @@ -1,11 +1,11 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_STACK_STRING_H #define _FLEXFLOW_UTILS_INCLUDE_STACK_STRING_H -#include "fmt/core.h" -#include "stack_vector/stack_vector.h" #include "utils/fmt.h" +#include "utils/stack_vector/stack_vector.h" #include "utils/type_traits.h" #include +#include #include #include #include diff --git a/lib/utils/include/utils/stack_vector/stack_vector.h b/lib/utils/include/utils/stack_vector/stack_vector.h index 64d005a10e..f38c88d9bd 100644 --- a/lib/utils/include/utils/stack_vector/stack_vector.h +++ b/lib/utils/include/utils/stack_vector/stack_vector.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_UTILS_STACK_VECTOR_H #define _FLEXFLOW_UTILS_STACK_VECTOR_H +#include "utils/check_fmtable.h" #include "utils/hash-utils.h" #include "utils/join_strings.h" -#include "utils/test_types.h" #include "utils/type_traits.h" #include #include diff --git a/lib/utils/include/utils/strong_typedef.h b/lib/utils/include/utils/strong_typedef.h deleted file mode 100644 index cdadef8b96..0000000000 --- a/lib/utils/include/utils/strong_typedef.h +++ /dev/null @@ -1,236 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_STRONG_TYPEDEF_H -#define _FLEXFLOW_UTILS_INCLUDE_STRONG_TYPEDEF_H - -#include "utils/fmt.h" -#include -#include -#include - -namespace FlexFlow { - -// derived from https://www.foonathan.net/2016/10/strong-typedefs/ -template -class strong_typedef { -public: - strong_typedef() = delete; - - explicit strong_typedef(T const &value) : value_(value) {} - - explicit strong_typedef(T &&value) noexcept( - std::is_nothrow_move_constructible::value) - : value_(std::move(value)) {} - - explicit operator T &() noexcept { - return value_; - } - - explicit operator T const &() const noexcept { - return value_; - } - - template ::value && - !std::is_same::value), - bool>::type = true> - explicit operator TT() const { - return static_cast(this->value_); - } - - template ::value && - !std::is_same::value), - bool>::type = true> - operator TT() const { - return (this->value_); - } - - friend void swap(strong_typedef &a, strong_typedef &b) noexcept { - using std::swap; - swap(static_cast(a), static_cast(b)); - } - - friend bool operator==(strong_typedef const &lhs, strong_typedef const &rhs) { - return lhs.value() == rhs.value(); - } - - friend bool operator!=(strong_typedef const &lhs, strong_typedef const &rhs) { - return lhs.value() != rhs.value(); - } - - friend bool operator<(strong_typedef const &lhs, strong_typedef const &rhs) { - return lhs.value() < rhs.value(); - } - - T const &value() const noexcept { - return value_; - } - - T &value() noexcept { - return value_; - } - - template - strong_typedef fmap(F const &f) { - static_assert( - std::is_same()(std::declval())), - T>::value, - "Function must return an value of the underlying type"); - - return strong_typedef(f(this->value_)); - } - -private: - T value_; -}; - -template -T underlying_type_impl(strong_typedef); - -template -using underlying_type_t = decltype(underlying_type_impl(std::declval())); -// derived from -// https://github.com/foonathan/type_safe/blob/3612e2828b4b4e0d1cc689373e63a6d59d4bfd79/include/type_safe/strong_typedef.hpp -template -struct hashable : std::hash> { - using underlying_ty = underlying_type_t; - using underlying_hash = std::hash; - - std::size_t operator()(StrongTypedef const &lhs) const - noexcept(noexcept(underlying_hash{}(std::declval()))) { - return underlying_hash{}(static_cast(lhs)); - } -}; - -template -struct numerical_typedef : strong_typedef { - using strong_typedef::strong_typedef; - - friend StrongTypedef &operator+=(StrongTypedef &lhs, T const &rhs) { - static_cast(lhs) += static_cast(rhs); - return lhs; - } - - friend StrongTypedef &operator++(StrongTypedef &lhs) { - static_cast(lhs) += static_cast(1); - return lhs; - } - - friend StrongTypedef operator++(StrongTypedef &lhs, int) { - StrongTypedef tmp = lhs; - ++lhs; - return tmp; - } - - friend StrongTypedef operator+(StrongTypedef const &lhs, T const &rhs) { - return StrongTypedef(lhs.value() + rhs); - } - - friend StrongTypedef operator+(T const &lhs, StrongTypedef const &rhs) { - return (rhs + lhs); - } - - friend StrongTypedef operator-=(StrongTypedef &lhs, T const &rhs) { - static_cast(lhs) -= static_cast(rhs); - } - - friend StrongTypedef &operator--(StrongTypedef &lhs) { - static_cast(lhs) -= static_cast(1); - return lhs; - } - - friend StrongTypedef operator--(StrongTypedef &lhs, int) { - StrongTypedef tmp = lhs; - --lhs; - return tmp; - } - - friend StrongTypedef operator-(StrongTypedef const &lhs, T const &rhs) { - return StrongTypedef(lhs.value() + rhs); - } - - friend bool operator<(StrongTypedef const &lhs, T const &rhs) { - return lhs.value() < rhs; - } - - friend bool operator==(StrongTypedef const &lhs, T const &rhs) { - return lhs.value() == rhs; - } - - friend bool operator>(StrongTypedef const &lhs, T const &rhs) { - return lhs.value() > rhs; - } - - friend bool operator>=(StrongTypedef const &lhs, T const &rhs) { - return lhs.value() >= rhs; - } - - friend bool operator!=(StrongTypedef const &lhs, T const &rhs) { - return lhs.value() != rhs; - } - - friend bool operator<=(StrongTypedef const &lhs, T const &rhs) { - return lhs.value() <= rhs; - } - - friend bool operator<(T const &lhs, StrongTypedef const &rhs) { - return lhs < rhs.value(); - } - - friend bool operator==(T const &lhs, StrongTypedef const &rhs) { - return lhs == rhs.value(); - } - - friend bool operator>(T const &lhs, StrongTypedef const &rhs) { - return lhs > rhs.value(); - } - - friend bool operator<=(T const &lhs, StrongTypedef const &rhs) { - return lhs <= rhs.value(); - } - - friend bool operator!=(T const &lhs, StrongTypedef const &rhs) { - return lhs != rhs.value(); - } - - friend bool operator>=(T const &lhs, StrongTypedef const &rhs) { - return lhs >= rhs.value(); - } -}; - -} // namespace FlexFlow - -#define MAKE_TYPEDEF_HASHABLE(TYPEDEF_NAME) \ - namespace std { \ - template <> \ - struct hash : ::FlexFlow::hashable {}; \ - } \ - static_assert(true, "") - -#define MAKE_TYPEDEF_PRINTABLE(TYPEDEF_NAME, TYPEDEF_SHORTNAME) \ - namespace fmt { \ - template <> \ - struct formatter : formatter<::std::string> { \ - template \ - auto format(TYPEDEF_NAME const &x, FormatContext &ctx) const \ - -> decltype(ctx.out()) { \ - ::std::string s = fmt::format("{}({})", (TYPEDEF_SHORTNAME), x.value()); \ - return formatter<::std::string>::format(s, ctx); \ - } \ - }; \ - } \ - static_assert(true, "") - -#define FF_TYPEDEF_HASHABLE(TYPEDEF_NAME) \ - } \ - MAKE_TYPEDEF_HASHABLE(::FlexFlow::TYPEDEF_NAME); \ - namespace FlexFlow { \ - static_assert(true, ""); - -#define FF_TYPEDEF_PRINTABLE(TYPEDEF_NAME, TYPEDEF_SHORTNAME) \ - } \ - MAKE_TYPEDEF_PRINTABLE(::FlexFlow::TYPEDEF_NAME, TYPEDEF_SHORTNAME); \ - namespace FlexFlow { \ - DELEGATE_OSTREAM(TYPEDEF_NAME); \ - static_assert(true, ""); - -#endif diff --git a/lib/utils/include/utils/test_types.h b/lib/utils/include/utils/test_types.h deleted file mode 100644 index 984c0bc60d..0000000000 --- a/lib/utils/include/utils/test_types.h +++ /dev/null @@ -1,159 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_TEST_TYPES_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_TEST_TYPES_H - -#include "type_traits_core.h" - -namespace FlexFlow { - -namespace test_types { - -enum capability { - HASHABLE, - EQ, - CMP, - DEFAULT_CONSTRUCTIBLE, - MOVE_CONSTRUCTIBLE, - MOVE_ASSIGNABLE, - COPY_CONSTRUCTIBLE, - COPY_ASSIGNABLE, - PLUS, - PLUSEQ, - FMT -}; - -template -struct capability_implies : std::false_type {}; - -template <> -struct capability_implies : std::true_type {}; - -template -struct capability_implies : std::true_type {}; - -template -struct has_capability; - -template -struct has_capability - : disjunction, - has_capability> {}; - -template -struct has_capability : std::false_type {}; - -template -struct test_type_t { - template - using supports = conjunction...>; - - template ::value, bool>::type = true> - test_type_t(); - - template ::value, bool>::type = true> - test_type_t() = delete; - - template ::value, bool>::type = true> - test_type_t(test_type_t const &); - - template ::value, bool>::type = true> - test_type_t(test_type_t const &) = delete; - - template ::value, bool>::type = true> - test_type_t &operator=(test_type_t const &); - - template ::value, bool>::type = true> - test_type_t &operator=(test_type_t const &) = delete; - - template ::value, bool>::type = true> - test_type_t(test_type_t &&); - - template ::value, bool>::type = true> - test_type_t(test_type_t &&) = delete; - - template ::value, bool>::type = true> - test_type_t &operator=(test_type_t &&); - - template ::value, bool>::type = true> - test_type_t &operator=(test_type_t &&) = delete; - - template - typename std::enable_if::value, bool>::type - operator==(test_type_t const &) const; - - template - typename std::enable_if::value, bool>::type - operator!=(test_type_t const &) const; - - template - typename std::enable_if::value, bool>::type - operator<(test_type_t const &) const; - - template - typename std::enable_if::value, bool>::type - operator>(test_type_t const &) const; - - template - typename std::enable_if::value, bool>::type - operator<=(test_type_t const &) const; - - template - typename std::enable_if::value, bool>::type - operator>=(test_type_t const &) const; - - template - typename std::enable_if::value, test_type_t>::type - operator+(test_type_t const &); - - template - typename std::enable_if::value, test_type_t>::type - operator+=(test_type_t const &); -}; - -template -enable_if_t::value, std::string> - format_as(test_type_t); - -using no_eq = test_type_t<>; -using eq = test_type_t; -using cmp = test_type_t; -using hash_cmp = test_type_t; -using plusable = test_type_t; -using fmtable = test_type_t; -using well_behaved_value_type = test_type_t; - -} // namespace test_types -} // namespace FlexFlow - -namespace std { - -template < - ::FlexFlow::test_types:: - capability... CAPABILITIES> //, typename = typename - // std::enable_if<::FlexFlow::test_types::has_capability<::FlexFlow::test_types::HASHABLE>::value, - // bool>::type> -struct hash<::FlexFlow::test_types::test_type_t> { - typename std::enable_if< - ::FlexFlow::test_types::has_capability<::FlexFlow::test_types::HASHABLE, - CAPABILITIES...>::value, - size_t>::type - operator()( - ::FlexFlow::test_types::test_type_t const &) const; -}; - -} // namespace std - -#endif diff --git a/lib/utils/include/utils/type_traits.h b/lib/utils/include/utils/type_traits.h index 7abb3ffd5b..2f07050527 100644 --- a/lib/utils/include/utils/type_traits.h +++ b/lib/utils/include/utils/type_traits.h @@ -3,7 +3,6 @@ #include "utils/metafunction.h" #include "utils/type_traits_core.h" -#include "utils/visitable_core.h" #include #include @@ -81,13 +80,6 @@ struct elements_satisfy { "than 1 argument"); }; -template