From cea0d1100cbbcb6223d816ebdb57a8f90a019130 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Thu, 14 May 2026 16:35:21 -0700 Subject: [PATCH] Add HTML and JSON-based dot interface, fix bugs in graph querying --- .../src/export-model-arch/main.cc | 2 +- .../src/substitution-to-dot/main.cc | 2 +- flake.nix | 7 +- .../unity_algorithm/graph_optimize_state.h | 9 - ...racted_tensor_set_movement_across_split.cc | 14 +- .../machine_mapping/machine_mapping.cc | 47 +++-- lib/compiler/src/compiler/search_result.cc | 2 +- ...ion_graph_series_parallel_decomposition.cc | 16 +- .../unity_algorithm/graph_optimize_state.cc | 60 +----- .../op-attrs/computation_graph_op_attrs.h | 5 +- lib/op-attrs/include/op-attrs/ops/broadcast.h | 2 - lib/op-attrs/include/op-attrs/ops/cast.h | 2 - lib/op-attrs/include/op-attrs/ops/combine.h | 3 - lib/op-attrs/include/op-attrs/ops/embedding.h | 2 - lib/op-attrs/include/op-attrs/ops/linear.h | 2 - lib/op-attrs/include/op-attrs/ops/reduction.h | 3 - .../include/op-attrs/ops/repartition.h | 3 - lib/op-attrs/include/op-attrs/ops/replicate.h | 3 - lib/op-attrs/include/op-attrs/ops/weight.h | 2 - .../include/op-attrs/pcg_operator_attrs.h | 6 +- .../op-attrs/computation_graph_op_attrs.cc | 15 +- lib/op-attrs/src/op-attrs/ops/broadcast.cc | 17 -- lib/op-attrs/src/op-attrs/ops/cast.cc | 14 -- lib/op-attrs/src/op-attrs/ops/combine.cc | 14 -- lib/op-attrs/src/op-attrs/ops/embedding.cc | 16 -- lib/op-attrs/src/op-attrs/ops/linear.cc | 16 -- lib/op-attrs/src/op-attrs/ops/reduction.cc | 14 -- lib/op-attrs/src/op-attrs/ops/repartition.cc | 15 -- lib/op-attrs/src/op-attrs/ops/replicate.cc | 14 -- lib/op-attrs/src/op-attrs/ops/weight.cc | 12 -- .../src/op-attrs/parallel_tensor_shape.cc | 18 ++ .../src/op-attrs/pcg_operator_attrs.cc | 23 +- .../mapped_operator_task_group.h | 8 + ...mapped_parallel_computation_graph.dtg.toml | 9 +- .../mapped_parallel_computation_graph.h | 26 +++ .../mapped_parallel_layer_attrs.dtg.toml | 35 ++++ .../mapped_parallel_layer_attrs.h | 17 ++ .../parallel_computation_graph.h | 2 +- lib/pcg/src/pcg/computation_graph.cc | 50 ++--- .../mapped_operator_task_group.cc | 37 +++- .../mapped_parallel_computation_graph.cc | 153 +++++++++++++- .../mapped_parallel_layer_attrs.cc | 20 ++ lib/pcg/src/pcg/optimizer_attrs.cc | 1 - .../parallel_computation_graph.cc | 49 +++-- .../parallel_computation_graph_builder.cc | 2 +- .../parallel_layer_attrs.cc | 2 +- .../mapped_parallel_computation_graph.cc | 145 +++++++++++++ .../test/src/realm-execution/test_e2e.cc | 147 ++++++------- .../dynamic_open_dataflow_graph.h | 5 + .../dynamic_open_dataflow_graph.cc | 83 ++++++++ ...mic_open_dataflow_graph_from_mapped_pcg.cc | 16 +- lib/utils/include/utils/{ => dot}/dot_file.h | 52 ++++- .../include/utils/dot/dot_html_from_json.h | 14 ++ .../include/utils/dot/dot_html_table.dtg.toml | 34 +++ .../utils/dot/dot_html_table_cell.dtg.toml | 30 +++ .../utils/dot/dot_html_table_cell_contents.h | 36 ++++ .../utils/dot/dot_html_table_row.dtg.toml | 22 ++ .../dot/render_dot_html_table_to_string.h | 12 ++ .../include/utils/full_binary_tree/as_dot.h | 28 ++- lib/utils/include/utils/graph/algorithms.h | 2 +- .../graph/dataflow_graph/algorithms/as_dot.h | 16 -- .../algorithms/dataflow_graph_as_dot.h | 30 +++ .../algorithms/dataflow_graph_data.dtg.toml | 33 +++ .../view_from_dataflow_graph_data.h | 30 +++ .../dataflow_graph/dataflow_edge.dtg.toml | 1 + .../dataflow_graph/dataflow_input.dtg.toml | 1 + .../dataflow_graph/dataflow_output.dtg.toml | 1 + .../graph/digraph/algorithms/digraph_as_dot.h | 3 +- .../unordered_set_kwarg_dataflow_graph.h | 130 ++++++++++++ ...raph_data_from_kwarg_dataflow_graph_data.h | 98 +++++++++ ...dataflow_graph_from_kwarg_dataflow_graph.h | 29 +++ .../get_all_kwarg_dataflow_inputs.h | 21 ++ ...t_incoming_kwarg_dataflow_edges_for_node.h | 2 +- .../algorithms/get_incoming_slots_for_node.h | 18 ++ ...t_kwarg_dataflow_edges_from_node_to_node.h | 4 +- .../get_kwarg_dataflow_graph_data.h | 23 ++ ...t_kwarg_dataflow_subgraph_incoming_edges.h | 6 +- ...t_kwarg_dataflow_subgraph_outgoing_edges.h | 6 +- ...t_outgoing_kwarg_dataflow_edges_for_node.h | 2 +- ...outgoing_kwarg_dataflow_outputs_for_node.h | 2 +- .../algorithms/get_outgoing_slots_for_node.h | 18 ++ .../algorithms/kwarg_dataflow_graph_as_dot.h | 60 ++++++ .../kwarg_dataflow_graph_data.dtg.toml | 36 ++++ .../algorithms/kwarg_dataflow_graph_data.h | 43 ++++ .../kwarg_dataflow_graphs_are_isomorphic.h | 20 ++ .../view_from_kwarg_dataflow_graph_data.h | 56 +++++ .../kwarg_dataflow_graph_view.h | 2 +- ...abelled_kwarg_dataflow_graph_view_as_dot.h | 39 ++++ ...ed_open_kwarg_dataflow_graph_view_as_dot.h | 24 --- ...alize_labelled_kwarg_dataflow_graph_view.h | 23 ++ .../algorithms/as_dot.h | 29 --- .../labelled_open_dataflow_graph_as_dot.h | 38 ++++ ...ed_open_kwarg_dataflow_graph_view_as_dot.h | 44 ++++ .../open_dataflow_graph/algorithms/as_dot.h | 17 -- .../algorithms/open_dataflow_graph_as_dot.h | 21 ++ ...hisms_between_open_kwarg_dataflow_graphs.h | 11 +- ...oming_open_kwarg_dataflow_edges_for_node.h | 4 +- .../get_open_kwarg_dataflow_graph_subgraph.h | 28 +-- ...n_kwarg_dataflow_subgraph_incoming_edges.h | 8 +- .../get_open_kwarg_dataflow_value_uses.h | 7 +- .../open_kwarg_dataflow_graph_as_dot.h | 140 +++++++++++++ .../open_kwarg_dataflow_graph_data.h | 22 ++ ...g_dataflow_graph_by_materializing_inputs.h | 116 +++++++++++ lib/utils/include/utils/graph/query_set.h | 51 +++-- .../nonempty_unordered_set.h | 115 ++++++++++ .../include/utils/one_to_many/one_to_many.h | 27 ++- .../one_to_many_transform_values.h | 23 ++ lib/utils/include/utils/orientation.dtg.toml | 15 ++ .../include/utils/orthotope/up_projection.h | 3 +- lib/utils/include/utils/record_formatter.h | 64 +++++- lib/utils/src/utils/dot/dot_file.cc | 10 + lib/utils/src/utils/dot/dot_html_from_json.cc | 137 ++++++++++++ .../utils/dot/dot_html_table_cell_contents.cc | 62 ++++++ .../dot/render_dot_html_table_to_string.cc | 66 ++++++ lib/utils/src/utils/dot_file.cc | 1 - lib/utils/src/utils/graph/algorithms.cc | 24 ++- .../utils/graph/dataflow_graph/algorithms.cc | 2 +- .../graph/dataflow_graph/algorithms/as_dot.cc | 77 ------- .../algorithms/dataflow_graph_as_dot.cc | 152 ++++++++++++++ .../get_dataflow_edges_from_node_to_node.cc | 4 +- .../algorithms/get_incoming_edges.cc | 5 +- .../algorithms/get_outgoing_edges.cc | 5 +- .../algorithms/get_subgraph_incoming_edges.cc | 6 +- .../algorithms/get_subgraph_outgoing_edges.cc | 6 +- .../view_from_dataflow_graph_data.cc | 40 ++++ .../dataflow_graph/dataflow_edge_query.cc | 16 +- .../dataflow_graph/dataflow_output_query.cc | 4 +- .../get_cbc_decomposition.cc | 6 +- .../digraph/algorithms/digraph_as_dot.cc | 10 +- .../digraph/algorithms/digraph_has_edge.cc | 4 +- .../get_edges_from_subgraph_to_subgraph.cc | 5 +- .../digraph/algorithms/get_incoming_edges.cc | 20 +- .../digraph/algorithms/get_outgoing_edges.cc | 20 +- .../algorithms/get_subgraph_outgoing_edges.cc | 7 +- .../graph/digraph/directed_edge_query.cc | 6 +- .../unordered_set_kwarg_dataflow_graph.cc | 10 + ...aph_data_from_kwarg_dataflow_graph_data.cc | 13 ++ ...ataflow_graph_from_kwarg_dataflow_graph.cc | 13 ++ .../get_all_kwarg_dataflow_inputs.cc | 11 + .../algorithms/get_incoming_slots_for_node.cc | 11 + .../get_kwarg_dataflow_graph_data.cc | 11 + .../algorithms/get_outgoing_slots_for_node.cc | 11 + .../algorithms/kwarg_dataflow_graph_as_dot.cc | 17 ++ .../algorithms/kwarg_dataflow_graph_data.cc | 11 + .../kwarg_dataflow_graphs_are_isomorphic.cc | 12 ++ .../view_from_kwarg_dataflow_graph_data.cc | 11 + ...belled_kwarg_dataflow_graph_view_as_dot.cc | 19 ++ ...lize_labelled_kwarg_dataflow_graph_view.cc | 16 ++ .../algorithms/as_dot.cc | 1 - .../labelled_open_dataflow_graph_as_dot.cc | 17 ++ ...d_open_kwarg_dataflow_graph_view_as_dot.cc | 11 +- .../algorithms/get_incoming_edges.cc | 25 ++- .../algorithms/get_outgoing_edges.cc | 26 ++- lib/utils/src/utils/graph/node/algorithms.cc | 6 +- lib/utils/src/utils/graph/node/node_query.cc | 8 +- .../open_dataflow_graph/algorithms/as_dot.cc | 71 ------- .../algorithms/get_incoming_edges.cc | 4 +- .../algorithms/get_subgraph.cc | 4 +- .../algorithms/get_subgraph_incoming_edges.cc | 7 +- .../algorithms/open_dataflow_graph_as_dot.cc | 86 ++++++++ .../dataflow_input_edge_query.cc | 12 +- .../open_kwarg_dataflow_graph_as_dot.cc | 19 ++ .../open_kwarg_dataflow_graph_data.cc | 3 + ..._dataflow_graph_by_materializing_inputs.cc | 14 ++ lib/utils/src/utils/graph/render_dot.cc | 6 +- .../sp_ization/escribano_algo.cc | 18 +- .../sp_ization/work_duplicating_sp_ization.cc | 2 +- .../algorithms/get_neighboring_nodes.cc | 4 +- lib/utils/src/utils/graph/views/views.cc | 26 ++- .../nonempty_unordered_set.cc | 10 + .../src/utils/one_to_many/one_to_many.cc | 2 +- .../one_to_many_transform_values.cc | 14 ++ lib/utils/src/utils/record_formatter.cc | 44 +++- .../test/src/utils/{ => dot}/dot_file.cc | 6 +- .../test/src/utils/dot/dot_html_from_json.cc | 196 ++++++++++++++++++ .../dot/render_dot_html_table_to_string.cc | 34 +++ .../algorithms/dataflow_graph_as_dot.cc | 75 +++++++ .../graph/digraph/directed_edge_query.cc | 41 ++-- .../graph/instances/adjacency_digraph.cc | 58 ++++-- .../graph/instances/adjacency_multidigraph.cc | 25 ++- .../unordered_set_kwarg_dataflow_graph.cc | 140 +++++++++++++ ...aph_data_from_kwarg_dataflow_graph_data.cc | 135 ++++++++++++ ...ataflow_graph_from_kwarg_dataflow_graph.cc | 92 ++++++++ .../view_from_kwarg_dataflow_graph_data.cc | 90 ++++++++ .../utils/graph/multidigraph/multidigraph.cc | 130 +++++++++--- .../algorithms/permute_node_ids.cc | 14 +- ...rg_dataflow_graphs_are_isomorphic_under.cc | 57 +++++ ..._dataflow_graph_by_materializing_inputs.cc | 150 ++++++++++++++ lib/utils/test/src/utils/graph/query_set.cc | 45 ++++ .../graph/undirected/undirected_graph.cc | 69 ++++-- .../nonempty_unordered_set.cc | 44 ++++ .../test/src/utils/one_to_many/one_to_many.cc | 4 +- .../one_to_many_transform_values.cc | 37 ++++ lib/utils/test/src/utils/record_formatter.cc | 8 +- 194 files changed, 4772 insertions(+), 956 deletions(-) create mode 100644 lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_attrs.dtg.toml create mode 100644 lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_attrs.h create mode 100644 lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_attrs.cc create mode 100644 lib/pcg/test/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc rename lib/utils/include/utils/{ => dot}/dot_file.h (83%) create mode 100644 lib/utils/include/utils/dot/dot_html_from_json.h create mode 100644 lib/utils/include/utils/dot/dot_html_table.dtg.toml create mode 100644 lib/utils/include/utils/dot/dot_html_table_cell.dtg.toml create mode 100644 lib/utils/include/utils/dot/dot_html_table_cell_contents.h create mode 100644 lib/utils/include/utils/dot/dot_html_table_row.dtg.toml create mode 100644 lib/utils/include/utils/dot/render_dot_html_table_to_string.h delete mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/as_dot.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graph_data.dtg.toml create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/view_from_dataflow_graph_data.h create mode 100644 lib/utils/include/utils/graph/instances/unordered_set_kwarg_dataflow_graph.h create mode 100644 lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_data_from_kwarg_dataflow_graph_data.h create mode 100644 lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_from_kwarg_dataflow_graph.h create mode 100644 lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_inputs.h create mode 100644 lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_slots_for_node.h create mode 100644 lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_data.h create mode 100644 lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_slots_for_node.h create mode 100644 lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_as_dot.h create mode 100644 lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_data.dtg.toml create mode 100644 lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_data.h create mode 100644 lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graphs_are_isomorphic.h create mode 100644 lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/view_from_kwarg_dataflow_graph_data.h create mode 100644 lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_view_as_dot.h delete mode 100644 lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.h create mode 100644 lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/materialize_labelled_kwarg_dataflow_graph_view.h delete mode 100644 lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h create mode 100644 lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_as_dot.h create mode 100644 lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.h delete mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/as_dot.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_as_dot.h create mode 100644 lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_as_dot.h create mode 100644 lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/view_as_closed_kwarg_dataflow_graph_by_materializing_inputs.h create mode 100644 lib/utils/include/utils/nonempty_unordered_set/nonempty_unordered_set.h create mode 100644 lib/utils/include/utils/one_to_many/one_to_many_transform_values.h create mode 100644 lib/utils/include/utils/orientation.dtg.toml create mode 100644 lib/utils/src/utils/dot/dot_file.cc create mode 100644 lib/utils/src/utils/dot/dot_html_from_json.cc create mode 100644 lib/utils/src/utils/dot/dot_html_table_cell_contents.cc create mode 100644 lib/utils/src/utils/dot/render_dot_html_table_to_string.cc delete mode 100644 lib/utils/src/utils/dot_file.cc delete mode 100644 lib/utils/src/utils/graph/dataflow_graph/algorithms/as_dot.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/algorithms/view_from_dataflow_graph_data.cc create mode 100644 lib/utils/src/utils/graph/instances/unordered_set_kwarg_dataflow_graph.cc create mode 100644 lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_data_from_kwarg_dataflow_graph_data.cc create mode 100644 lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_from_kwarg_dataflow_graph.cc create mode 100644 lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_inputs.cc create mode 100644 lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_slots_for_node.cc create mode 100644 lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_data.cc create mode 100644 lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_slots_for_node.cc create mode 100644 lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_as_dot.cc create mode 100644 lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_data.cc create mode 100644 lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graphs_are_isomorphic.cc create mode 100644 lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/view_from_kwarg_dataflow_graph_data.cc create mode 100644 lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_view_as_dot.cc create mode 100644 lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/materialize_labelled_kwarg_dataflow_graph_view.cc delete mode 100644 lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.cc create mode 100644 lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_as_dot.cc rename lib/utils/src/utils/graph/{labelled_kwarg_dataflow_graph => labelled_open_kwarg_dataflow_graph}/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.cc (52%) delete mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/algorithms/as_dot.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_as_dot.cc create mode 100644 lib/utils/src/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_as_dot.cc create mode 100644 lib/utils/src/utils/graph/open_kwarg_dataflow_graph/algorithms/view_as_closed_kwarg_dataflow_graph_by_materializing_inputs.cc create mode 100644 lib/utils/src/utils/nonempty_unordered_set/nonempty_unordered_set.cc create mode 100644 lib/utils/src/utils/one_to_many/one_to_many_transform_values.cc rename lib/utils/test/src/utils/{ => dot}/dot_file.cc (93%) create mode 100644 lib/utils/test/src/utils/dot/dot_html_from_json.cc create mode 100644 lib/utils/test/src/utils/dot/render_dot_html_table_to_string.cc create mode 100644 lib/utils/test/src/utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.cc create mode 100644 lib/utils/test/src/utils/graph/instances/unordered_set_kwarg_dataflow_graph.cc create mode 100644 lib/utils/test/src/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_data_from_kwarg_dataflow_graph_data.cc create mode 100644 lib/utils/test/src/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_from_kwarg_dataflow_graph.cc create mode 100644 lib/utils/test/src/utils/graph/kwarg_dataflow_graph/algorithms/view_from_kwarg_dataflow_graph_data.cc create mode 100644 lib/utils/test/src/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graphs_are_isomorphic_under.cc create mode 100644 lib/utils/test/src/utils/graph/open_kwarg_dataflow_graph/algorithms/view_as_closed_kwarg_dataflow_graph_by_materializing_inputs.cc create mode 100644 lib/utils/test/src/utils/graph/query_set.cc create mode 100644 lib/utils/test/src/utils/nonempty_unordered_set/nonempty_unordered_set.cc create mode 100644 lib/utils/test/src/utils/one_to_many/one_to_many_transform_values.cc diff --git a/bin/export-model-arch/src/export-model-arch/main.cc b/bin/export-model-arch/src/export-model-arch/main.cc index e62809dda5..c42a59f0ce 100644 --- a/bin/export-model-arch/src/export-model-arch/main.cc +++ b/bin/export-model-arch/src/export-model-arch/main.cc @@ -14,7 +14,7 @@ #include "utils/cli/cli_parse.h" #include "utils/cli/cli_parse_result.h" #include "utils/cli/cli_spec.h" -#include "utils/graph/open_dataflow_graph/algorithms/as_dot.h" +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_as_dot.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" #include "utils/graph/series_parallel/get_series_parallel_decomposition.h" diff --git a/bin/substitution-to-dot/src/substitution-to-dot/main.cc b/bin/substitution-to-dot/src/substitution-to-dot/main.cc index 1b5f715bcd..9ff3f0109f 100644 --- a/bin/substitution-to-dot/src/substitution-to-dot/main.cc +++ b/bin/substitution-to-dot/src/substitution-to-dot/main.cc @@ -1,5 +1,5 @@ #include "substitution-generator/legacy_rules.h" -#include "utils/dot_file.h" +#include "utils/dot/dot_file.h" #include #include diff --git a/flake.nix b/flake.nix index 3e5c477dea..da162eba26 100644 --- a/flake.nix +++ b/flake.nix @@ -5,10 +5,10 @@ bash-prompt-prefix = "(ff) "; extra-substituters = [ "https://ff.cachix.org" - "https://cuda-maintainers.cachix.org/" + #"https://cuda-maintainers.cachix.org/" ]; extra-trusted-public-keys = [ - "cuda-maintainers.cachix.org-1:0dq3bujKpuEPMCX6U4WylrUDZ9JyUG0VpVZa7CNfq5E=" + #"cuda-maintainers.cachix.org-1:0dq3bujKpuEPMCX6U4WylrUDZ9JyUG0VpVZa7CNfq5E=" "ff.cachix.org-1:IRdsNEnht4YKGUasP6SX5DfpaOTBckhpJDEODz7wMFM=" ]; }; @@ -119,6 +119,7 @@ gbenchmark libtorch-bin graphviz # for documentation + texliveBasic # for documentation ]) (with proj-repo.packages.${system}; [ proj @@ -162,6 +163,8 @@ jq gh expect + universal-ctags + ninja ]) (with pkgs.python3Packages; [ gitpython diff --git a/lib/compiler/include/compiler/unity_algorithm/graph_optimize_state.h b/lib/compiler/include/compiler/unity_algorithm/graph_optimize_state.h index c0952c0684..24a9f5d206 100644 --- a/lib/compiler/include/compiler/unity_algorithm/graph_optimize_state.h +++ b/lib/compiler/include/compiler/unity_algorithm/graph_optimize_state.h @@ -26,15 +26,6 @@ struct GraphOptimizeState { std::string format_as(GraphOptimizeState const &); std::ostream &operator<<(std::ostream &, GraphOptimizeState const &); -// TODO(@lockshaw)(#pr): Delete this if still unused -// std::optional -// graph_optimize_state_from_machine_mapping_result(ParallelComputationGraph -// const &, -// PCGBinarySPDecomposition -// const &, -// MachineMappingResult const -// &); - } // namespace FlexFlow namespace std { diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc index df02655ccc..151008f65f 100644 --- a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc @@ -101,13 +101,13 @@ AbstractedTensorSetMovement get_abstracted_tensor_set_movement_across_split( }; return AbstractedTensorSetMovement{ - transform( - edges_by_tensor.right_groups(), - [&](std::unordered_set const &edges) { - return merge_abstracted_single_tensor_movements( - transform(unordered_multiset_of(edges), - to_abstracted_single_tensor_movement)); - }), + transform(edges_by_tensor.right_groups(), + [&](nonempty_unordered_set const + &edges) { + return merge_abstracted_single_tensor_movements(transform( + unordered_multiset_of(edges.unwrap_as_unordered_set()), + to_abstracted_single_tensor_movement)); + }), }; } diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc index 8a16ff9dda..a2307716ba 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc @@ -2,6 +2,9 @@ #include "compiler/machine_mapping/machine_view.h" #include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" #include "op-attrs/computation_graph_op_attrs.h" +#include "op-attrs/pcg_operator_attrs.h" +#include "pcg/machine_compute_resource_slice.h" +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h" #include "utils/bidict/algorithms/bidict_from_map.h" #include "utils/containers/are_disjoint.h" #include "utils/containers/binary_merge_disjoint_maps.h" @@ -15,30 +18,32 @@ MappedParallelComputationGraph std::unordered_set pcg_layers = get_parallel_layers(pcg); + std::unordered_set mapped_layers = keys(mapping.machine_views); - ASSERT(pcg_layers == mapped_layers); - - return MappedParallelComputationGraph{ - /*pcg=*/pcg, - /*mapped_tasks=*/ - generate_map( - get_parallel_layers(pcg), - [&](parallel_layer_guid_t l) -> MappedOperatorTaskGroup { - ComputationGraphOpAttrs op_attrs = - compgraph_op_attrs_from_pcg_op_attrs(pcg_get_op_attrs(pcg, l)) - .value(); - - std::unordered_map - inputs_dim_degrees = get_incoming_input_degrees(pcg, l); - - ASSERT(contains_key(mapping.machine_views, l)); - MachineView machine_view = mapping.machine_views.at(l); - - return mapped_operator_task_group_from_machine_view( - op_attrs, inputs_dim_degrees, machine_view); - }), + + ASSERT(mapped_layers == pcg_layers); + + auto mapping_for_layer = + [&](parallel_layer_guid_t l) -> MappedOperatorTaskGroup { + ComputationGraphOpAttrs op_attrs = assert_unwrap( + compgraph_op_attrs_from_pcg_op_attrs(pcg_get_op_attrs(pcg, l))); + + std::unordered_map + inputs_dim_degrees = get_incoming_input_degrees(pcg, l); + + ASSERT(contains_key(mapping.machine_views, l)); + MachineView machine_view = mapping.machine_views.at(l); + + return mapped_operator_task_group_from_machine_view( + op_attrs, inputs_dim_degrees, machine_view); }; + + std::unordered_map + mapped_op_task_groups = generate_map(mapped_layers, mapping_for_layer); + + return mapped_pcg_from_pcg_and_mapped_op_task_groups(pcg, + mapped_op_task_groups); } MachineMapping combine_disjoint_mappings(MachineMapping const &m1, diff --git a/lib/compiler/src/compiler/search_result.cc b/lib/compiler/src/compiler/search_result.cc index 28eec7f247..00de9b4b34 100644 --- a/lib/compiler/src/compiler/search_result.cc +++ b/lib/compiler/src/compiler/search_result.cc @@ -10,7 +10,7 @@ MappedParallelComputationGraph std::string format_as(SearchResult const &r) { return fmt::format("", - as_dot(r.pcg), + pcg_as_dot(r.pcg), r.machine_mapping); } diff --git a/lib/compiler/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc b/lib/compiler/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc index 144036f970..50b95f3c3e 100644 --- a/lib/compiler/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc @@ -35,23 +35,15 @@ std::string render_preprocessed_computation_graph_for_sp_decomposition( preprocessed_digraph.add_edge(DirectedEdge{fake_node, dst.raw_node}); } - std::function get_node_label = - [&](Node const &n) -> std::string { + std::function get_node_label = + [&](Node const &n) -> nlohmann::json { if (n == fake_node) { return "FAKE"; } - LayerAttrs a = cg.raw_graph.at(n); - RecordFormatter r = as_dot(a.op_attrs); - if (a.name.has_value()) { - RecordFormatter rr; - rr << "Name" << a.name.value(); - r << rr; - } + nlohmann::json result = cg.raw_graph.at(n); - std::ostringstream oss; - oss << r; - return oss.str(); + return result; }; std::string preprocessed_dot = digraph_as_dot( transitive_reduction(preprocessed_digraph), get_node_label); diff --git a/lib/compiler/src/compiler/unity_algorithm/graph_optimize_state.cc b/lib/compiler/src/compiler/unity_algorithm/graph_optimize_state.cc index 7e7a80018e..6883098dab 100644 --- a/lib/compiler/src/compiler/unity_algorithm/graph_optimize_state.cc +++ b/lib/compiler/src/compiler/unity_algorithm/graph_optimize_state.cc @@ -84,71 +84,13 @@ bool GraphOptimizeState::operator<(GraphOptimizeState const &other) const { std::string format_as(GraphOptimizeState const &s) { return fmt::format( - "", s.runtime, as_dot(s.pcg)); + "", s.runtime, pcg_as_dot(s.pcg)); } std::ostream &operator<<(std::ostream &s, GraphOptimizeState const &x) { return (s << fmt::to_string(x)); } -// TODO(@lockshaw)(#pr): Delete this if still unused -// std::optional -// graph_optimize_state_from_machine_mapping_result(ParallelComputationGraph -// const &pcg, -// PCGBinarySPDecomposition -// const -// &binary_sp_decomposition, -// MachineMappingResult const -// &machine_mapping_result) { -// -// FeasibleMachineMappingResult feasible_mapping = ({ -// if (is_infeasible(machine_mapping_result)) { -// return std::nullopt; -// } -// -// require_feasible(machine_mapping_result); -// }); -// -// bidict path_to_leaf_map = -// bidict_from_map(pcg_sp_tree_get_path_to_leaf_map(binary_sp_decomposition)); -// -// std::unordered_map -// mapped_tasks_by_path = zip_values_strict_with( -// path_to_leaf_map.as_unordered_map(), -// feasible_mapping.machine_mapping.raw_mapping, -// [&](parallel_layer_guid_t const &layer_guid, MachineView const &mv) -// -> MappedOperatorTaskGroup -// { -// ComputationGraphOpAttrs comp_graph_op_attrs = -// assert_unwrap(compgraph_op_attrs_from_pcg_op_attrs(pcg_get_op_attrs(pcg, -// layer_guid))); -// -// return mapped_operator_task_group_from_machine_view( -// comp_graph_op_attrs, -// get_incoming_input_degrees(pcg, layer_guid), -// mv); -// }); -// -// std::unordered_map -// mapped_tasks = map_keys(mapped_tasks_by_path, -// [&](BinaryTreePath const &path) -> -// parallel_layer_guid_t { -// return path_to_leaf_map.at_l(path); -// }); -// -// GraphOptimizeResult result = GraphOptimizeResult{ -// MappedParallelComputationGraph{ -// pcg, -// mapped_tasks, -// }, -// }; -// -// return GraphOptimizeState{ -// result, -// feasible_mapping.runtime, -// }; -// } - } // namespace FlexFlow namespace std { diff --git a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h index fd0707aa2e..6ea98c797b 100644 --- a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h +++ b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h @@ -3,12 +3,13 @@ #include "op-attrs/computation_graph_op_attrs.dtg.h" #include "op-attrs/pcg_operator_attrs.dtg.h" -#include "utils/record_formatter.h" namespace FlexFlow { OperatorType get_op_type(ComputationGraphOpAttrs const &); -RecordFormatter as_dot(ComputationGraphOpAttrs const &); + +nlohmann::json cg_op_attrs_as_dot_json(ComputationGraphOpAttrs const &); + std::optional compgraph_op_attrs_from_pcg_op_attrs(PCGOperatorAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.h b/lib/op-attrs/include/op-attrs/ops/broadcast.h index 9b6bd49418..6c7f8407f9 100644 --- a/lib/op-attrs/include/op-attrs/ops/broadcast.h +++ b/lib/op-attrs/include/op-attrs/ops/broadcast.h @@ -9,8 +9,6 @@ namespace FlexFlow { -RecordFormatter as_dot(BroadcastAttrs const &); - tl::expected get_output_shape(BroadcastAttrs const &, TensorShape const &); ParallelTensorShape get_output_shape(BroadcastAttrs const &, diff --git a/lib/op-attrs/include/op-attrs/ops/cast.h b/lib/op-attrs/include/op-attrs/ops/cast.h index 38a1e87a76..0daa6a30a1 100644 --- a/lib/op-attrs/include/op-attrs/ops/cast.h +++ b/lib/op-attrs/include/op-attrs/ops/cast.h @@ -9,8 +9,6 @@ namespace FlexFlow { -RecordFormatter as_dot(CastAttrs const &); - tl::expected get_output_shape(CastAttrs const &, TensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/ops/combine.h b/lib/op-attrs/include/op-attrs/ops/combine.h index 6839bc12e1..a60e19bc29 100644 --- a/lib/op-attrs/include/op-attrs/ops/combine.h +++ b/lib/op-attrs/include/op-attrs/ops/combine.h @@ -3,13 +3,10 @@ #include "op-attrs/ops/combine_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" -#include "utils/record_formatter.h" #include namespace FlexFlow { -RecordFormatter as_dot(CombineAttrs const &); - tl::expected get_output_shape(CombineAttrs const &, ParallelTensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/ops/embedding.h b/lib/op-attrs/include/op-attrs/ops/embedding.h index ff4aecae98..7ae8350dee 100644 --- a/lib/op-attrs/include/op-attrs/ops/embedding.h +++ b/lib/op-attrs/include/op-attrs/ops/embedding.h @@ -11,8 +11,6 @@ namespace FlexFlow { -RecordFormatter as_dot(EmbeddingAttrs const &); - tl::expected get_output_shape(EmbeddingAttrs const &, TensorShape const &); tl::expected get_weights_shape(EmbeddingAttrs const &, diff --git a/lib/op-attrs/include/op-attrs/ops/linear.h b/lib/op-attrs/include/op-attrs/ops/linear.h index fb44e5cb4f..b5010c7186 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear.h +++ b/lib/op-attrs/include/op-attrs/ops/linear.h @@ -20,8 +20,6 @@ namespace FlexFlow { std::unordered_map get_linear_incoming_tensor_roles(LinearAttrs const &); -RecordFormatter as_dot(LinearAttrs const &); - tl::expected get_projection_shape(LinearAttrs const &attrs, TensorShape const &input); tl::expected get_bias_shape(LinearAttrs const &attrs, diff --git a/lib/op-attrs/include/op-attrs/ops/reduction.h b/lib/op-attrs/include/op-attrs/ops/reduction.h index b107178744..14076e6c23 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduction.h +++ b/lib/op-attrs/include/op-attrs/ops/reduction.h @@ -3,13 +3,10 @@ #include "op-attrs/ops/reduction_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" -#include "utils/record_formatter.h" #include namespace FlexFlow { -RecordFormatter as_dot(ReductionAttrs const &); - tl::expected get_output_shape(ReductionAttrs const &attrs, ParallelTensorShape const &input_shape); diff --git a/lib/op-attrs/include/op-attrs/ops/repartition.h b/lib/op-attrs/include/op-attrs/ops/repartition.h index 7733bc6989..48e4c554ae 100644 --- a/lib/op-attrs/include/op-attrs/ops/repartition.h +++ b/lib/op-attrs/include/op-attrs/ops/repartition.h @@ -3,13 +3,10 @@ #include "op-attrs/ops/repartition_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" -#include "utils/record_formatter.h" #include namespace FlexFlow { -RecordFormatter as_dot(RepartitionAttrs const &); - tl::expected get_output_shape(RepartitionAttrs const &, ParallelTensorShape const &input_shape); diff --git a/lib/op-attrs/include/op-attrs/ops/replicate.h b/lib/op-attrs/include/op-attrs/ops/replicate.h index 6a6ecd3d1e..ef526ef2a9 100644 --- a/lib/op-attrs/include/op-attrs/ops/replicate.h +++ b/lib/op-attrs/include/op-attrs/ops/replicate.h @@ -3,12 +3,9 @@ #include "op-attrs/ops/replicate_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" -#include "utils/record_formatter.h" namespace FlexFlow { -RecordFormatter as_dot(ReplicateAttrs const &); - ParallelTensorShape get_output_shape(ReplicateAttrs const &attrs, ParallelTensorShape const &input_shape); diff --git a/lib/op-attrs/include/op-attrs/ops/weight.h b/lib/op-attrs/include/op-attrs/ops/weight.h index 3d488ef24c..90a831aa59 100644 --- a/lib/op-attrs/include/op-attrs/ops/weight.h +++ b/lib/op-attrs/include/op-attrs/ops/weight.h @@ -10,8 +10,6 @@ namespace FlexFlow { -RecordFormatter as_dot(WeightAttrs const &); - TensorShape get_output_shape(WeightAttrs const &); ParallelTensorShape get_output_parallel_tensor_shape(WeightAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h index 6e300836d4..650730341b 100644 --- a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h @@ -4,14 +4,16 @@ #include "op-attrs/computation_graph_op_attrs.dtg.h" #include "op-attrs/computation_graph_op_attrs.h" #include "op-attrs/pcg_operator_attrs.dtg.h" +#include namespace FlexFlow { bool is_parallel_op(PCGOperatorAttrs const &); -OperatorType get_op_type(PCGOperatorAttrs const &); +OperatorType pcg_op_attrs_get_op_type(PCGOperatorAttrs const &); PCGOperatorAttrs pcg_op_attrs_from_compgraph_op_attrs(ComputationGraphOpAttrs const &); -RecordFormatter as_dot(PCGOperatorAttrs const &); + +nlohmann::json pcg_op_attrs_as_dot_json(PCGOperatorAttrs const &); } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc index 9d1a9f68d4..166cce0ab5 100644 --- a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc +++ b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc @@ -14,19 +14,8 @@ OperatorType get_op_type(ComputationGraphOpAttrs const &attrs) { [](auto const &x) { return get_op_type(x); }); } -RecordFormatter as_dot(ComputationGraphOpAttrs const &attrs) { - RecordFormatter result = attrs.visit(overload{ - [](LinearAttrs const &l) { return as_dot(l); }, - [](CastAttrs const &a) { return as_dot(a); }, - [](EmbeddingAttrs const &a) { return as_dot(a); }, - [](WeightAttrs const &a) { return as_dot(a); }, - [](BroadcastAttrs const &a) { return as_dot(a); }, - [&](auto const &) { return RecordFormatter{}; }, - }); - - RecordFormatter rr; - rr << "Op Type" << fmt::to_string(get_op_type(attrs)); - result << rr; +nlohmann::json cg_op_attrs_as_dot_json(ComputationGraphOpAttrs const &attrs) { + nlohmann::json result = attrs; return result; } diff --git a/lib/op-attrs/src/op-attrs/ops/broadcast.cc b/lib/op-attrs/src/op-attrs/ops/broadcast.cc index 927d4fd913..861824266f 100644 --- a/lib/op-attrs/src/op-attrs/ops/broadcast.cc +++ b/lib/op-attrs/src/op-attrs/ops/broadcast.cc @@ -6,23 +6,6 @@ namespace FlexFlow { -RecordFormatter as_dot(BroadcastAttrs const &attrs) { - RecordFormatter r; - - auto kv = [](std::string const &label, auto const &val) { - RecordFormatter rr; - rr << label << fmt::to_string(val); - return rr; - }; - - for (ff_dim_t dim_idx : tensor_dims_range(get_num_dims(attrs.target_dims))) { - r << kv(fmt::format("target_dims[{}]", dim_idx.value), - dim_at_idx(attrs.target_dims, dim_idx)); - } - - return r; -} - tl::expected get_output_shape(BroadcastAttrs const &attrs, TensorShape const &input_shape) { diff --git a/lib/op-attrs/src/op-attrs/ops/cast.cc b/lib/op-attrs/src/op-attrs/ops/cast.cc index fdf840973d..4b173badcf 100644 --- a/lib/op-attrs/src/op-attrs/ops/cast.cc +++ b/lib/op-attrs/src/op-attrs/ops/cast.cc @@ -3,20 +3,6 @@ namespace FlexFlow { -RecordFormatter as_dot(CastAttrs const &attrs) { - RecordFormatter r; - - auto kv = [](std::string const &label, auto const &val) { - RecordFormatter rr; - rr << label << fmt::to_string(val); - return rr; - }; - - r << kv("to", attrs.dtype); - - return r; -} - tl::expected get_output_shape(CastAttrs const &attrs, TensorShape const &input) { diff --git a/lib/op-attrs/src/op-attrs/ops/combine.cc b/lib/op-attrs/src/op-attrs/ops/combine.cc index 64e9316ea2..f470f281fc 100644 --- a/lib/op-attrs/src/op-attrs/ops/combine.cc +++ b/lib/op-attrs/src/op-attrs/ops/combine.cc @@ -4,20 +4,6 @@ namespace FlexFlow { -RecordFormatter as_dot(CombineAttrs const &attrs) { - RecordFormatter r; - - auto kv = [](std::string const &label, auto const &val) { - RecordFormatter rr; - rr << label << fmt::to_string(val); - return rr; - }; - - r << kv("dim", attrs.combine_dim) << kv("degree", attrs.combine_degree); - - return r; -} - tl::expected get_output_shape(CombineAttrs const &attrs, ParallelTensorShape const &input) { diff --git a/lib/op-attrs/src/op-attrs/ops/embedding.cc b/lib/op-attrs/src/op-attrs/ops/embedding.cc index 451468ba28..b400c6263a 100644 --- a/lib/op-attrs/src/op-attrs/ops/embedding.cc +++ b/lib/op-attrs/src/op-attrs/ops/embedding.cc @@ -10,22 +10,6 @@ namespace FlexFlow { -RecordFormatter as_dot(EmbeddingAttrs const &attrs) { - RecordFormatter r; - - auto kv = [](std::string const &label, auto const &val) { - RecordFormatter rr; - rr << label << fmt::to_string(val); - return rr; - }; - - r << kv("num_entries", attrs.num_entries) - << kv("out_channels", attrs.out_channels) << kv("aggr", attrs.aggr) - << kv("output_type", attrs.data_type); - - return r; -} - static std::optional basic_check(EmbeddingAttrs const &attrs, TensorShape const &input) { if (input.data_type != DataType::INT32 && diff --git a/lib/op-attrs/src/op-attrs/ops/linear.cc b/lib/op-attrs/src/op-attrs/ops/linear.cc index a9f8fdf02a..9099c7dac6 100644 --- a/lib/op-attrs/src/op-attrs/ops/linear.cc +++ b/lib/op-attrs/src/op-attrs/ops/linear.cc @@ -40,22 +40,6 @@ std::unordered_map return result; } -RecordFormatter as_dot(LinearAttrs const &attrs) { - RecordFormatter r; - - auto kv = [](std::string const &label, auto const &val) { - RecordFormatter rr; - rr << label << fmt::to_string(val); - return rr; - }; - - r << kv("out_channels", attrs.out_channels) << kv("use_bias", attrs.use_bias) - << kv("data_type", attrs.data_type) << kv("activation", attrs.activation) - << kv("regularizer", attrs.regularizer); - - return r; -} - tl::expected get_projection_shape(LinearAttrs const &attrs, TensorShape const &input_shape) { diff --git a/lib/op-attrs/src/op-attrs/ops/reduction.cc b/lib/op-attrs/src/op-attrs/ops/reduction.cc index 580d47b1e9..45df46e9cc 100644 --- a/lib/op-attrs/src/op-attrs/ops/reduction.cc +++ b/lib/op-attrs/src/op-attrs/ops/reduction.cc @@ -3,20 +3,6 @@ namespace FlexFlow { -RecordFormatter as_dot(ReductionAttrs const &attrs) { - RecordFormatter r; - - auto kv = [](std::string const &label, auto const &val) { - RecordFormatter rr; - rr << label << fmt::to_string(val); - return rr; - }; - - r << kv("degree", attrs.reduction_degree); - - return r; -} - tl::expected get_output_shape(ReductionAttrs const &attrs, ParallelTensorShape const &input_shape) { diff --git a/lib/op-attrs/src/op-attrs/ops/repartition.cc b/lib/op-attrs/src/op-attrs/ops/repartition.cc index d57a198585..5bda589eb3 100644 --- a/lib/op-attrs/src/op-attrs/ops/repartition.cc +++ b/lib/op-attrs/src/op-attrs/ops/repartition.cc @@ -2,21 +2,6 @@ namespace FlexFlow { -RecordFormatter as_dot(RepartitionAttrs const &attrs) { - RecordFormatter r; - - auto kv = [](std::string const &label, auto const &val) { - RecordFormatter rr; - rr << label << fmt::to_string(val); - return rr; - }; - - r << kv("dim", attrs.repartition_dim) - << kv("degree", attrs.repartition_degree); - - return r; -} - tl::expected get_output_shape(RepartitionAttrs const &attrs, ParallelTensorShape const &input_shape) { diff --git a/lib/op-attrs/src/op-attrs/ops/replicate.cc b/lib/op-attrs/src/op-attrs/ops/replicate.cc index d3f1e87841..9e163cb55a 100644 --- a/lib/op-attrs/src/op-attrs/ops/replicate.cc +++ b/lib/op-attrs/src/op-attrs/ops/replicate.cc @@ -2,20 +2,6 @@ namespace FlexFlow { -RecordFormatter as_dot(ReplicateAttrs const &attrs) { - RecordFormatter r; - - auto kv = [](std::string const &label, auto const &val) { - RecordFormatter rr; - rr << label << fmt::to_string(val); - return rr; - }; - - r << kv("degree", attrs.replicate_degree); - - return r; -} - ParallelTensorShape get_output_shape(ReplicateAttrs const &attrs, ParallelTensorShape const &input_shape) { ParallelTensorShape output_shape = input_shape; diff --git a/lib/op-attrs/src/op-attrs/ops/weight.cc b/lib/op-attrs/src/op-attrs/ops/weight.cc index ba63eca6ee..ee206ab37e 100644 --- a/lib/op-attrs/src/op-attrs/ops/weight.cc +++ b/lib/op-attrs/src/op-attrs/ops/weight.cc @@ -5,18 +5,6 @@ namespace FlexFlow { -RecordFormatter as_dot(WeightAttrs const &attrs) { - RecordFormatter r; - - for (positive_int dim : attrs.tensor_shape.dims.ff_ordered) { - r << fmt::to_string(dim); - } - - r << fmt::to_string(attrs.tensor_shape.data_type); - - return r; -} - TensorShape get_output_shape(WeightAttrs const &attrs) { return attrs.tensor_shape; } diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index 91d3d0b1aa..cc88692124 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -1,4 +1,5 @@ #include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/ff_ordered/enumerate.h" #include "op-attrs/parallel_tensor_dims.h" #include "op-attrs/tensor_dims.h" #include "utils/containers/extend.h" @@ -9,6 +10,7 @@ #include "utils/hash-utils.h" #include "utils/nonnegative_int/nonnegative_range.h" #include "utils/overload.h" +#include "utils/record_formatter.h" #include namespace FlexFlow { @@ -150,4 +152,20 @@ std::unordered_set return indices; } +RecordFormatter dot_for_parallel_tensor_shape(ParallelTensorShape const &s) { + RecordFormatter result = mk_empty_record(Orientation::VERTICAL); + + result << mk_kv_record("sum_degree", get_sum_degree(s)) + << mk_kv_record("discard_copy_degree", get_discard_copy_degree(s)); + + for (auto const &[idx, dim] : enumerate(s.dims.shard_dims)) { + result << mk_kv_record(fmt::to_string(idx), + fmt::format("{}/{}", dim.size, dim.degree)); + } + + result << mk_kv_record("data_type", s.data_type); + + return result; +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc index b2e4ae5a58..c838b50f1a 100644 --- a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc +++ b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc @@ -18,28 +18,15 @@ bool is_parallel_op(PCGOperatorAttrs const &attrs) { attrs.has() || attrs.has()); } -OperatorType get_op_type(PCGOperatorAttrs const &attrs) { +OperatorType pcg_op_attrs_get_op_type(PCGOperatorAttrs const &attrs) { return attrs.visit( [](auto const &x) { return get_op_type(x); }); } -RecordFormatter as_dot(PCGOperatorAttrs const &attrs) { - return attrs.visit(overload{ - [](LinearAttrs const &l) { return as_dot(l); }, - [](CastAttrs const &a) { return as_dot(a); }, - [](EmbeddingAttrs const &a) { return as_dot(a); }, - [](WeightAttrs const &a) { return as_dot(a); }, - [](BroadcastAttrs const &a) { return as_dot(a); }, - [](RepartitionAttrs const &a) { return as_dot(a); }, - [](CombineAttrs const &a) { return as_dot(a); }, - [](ReplicateAttrs const &a) { return as_dot(a); }, - [](ReductionAttrs const &a) { return as_dot(a); }, - [&](auto const &) { - RecordFormatter r; - r << fmt::to_string(get_op_type(attrs)); - return r; - }, - }); +nlohmann::json pcg_op_attrs_as_dot_json(PCGOperatorAttrs const &attrs) { + nlohmann::json result = attrs; + + return result; } PCGOperatorAttrs pcg_op_attrs_from_compgraph_op_attrs( diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h index b15b91e0e3..aded1eb657 100644 --- a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h @@ -20,6 +20,11 @@ struct MappedOperatorTaskGroup { [[nodiscard]] bool operator==(MappedOperatorTaskGroup const &) const; [[nodiscard]] bool operator!=(MappedOperatorTaskGroup const &) const; + [[nodiscard]] bool operator<(MappedOperatorTaskGroup const &) const; + [[nodiscard]] bool operator>(MappedOperatorTaskGroup const &) const; + [[nodiscard]] bool operator<=(MappedOperatorTaskGroup const &) const; + [[nodiscard]] bool operator>=(MappedOperatorTaskGroup const &) const; + [[nodiscard]] bidict const & get_shard_bindings() const; @@ -37,6 +42,9 @@ bidict get_tensor_bindings_for_slot_name(MappedOperatorTaskGroup const &, TensorSlotName const &); +nlohmann::json + mapped_operator_task_group_as_dot_json(MappedOperatorTaskGroup const &); + std::string format_as(::FlexFlow::MappedOperatorTaskGroup const &); std::ostream &operator<<(std::ostream &, ::FlexFlow::MappedOperatorTaskGroup const &); diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.toml b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.toml index 8786cfe889..a4f53166d2 100644 --- a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.toml +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.toml @@ -6,13 +6,10 @@ features = [] includes = [ "pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h", "pcg/parallel_computation_graph/parallel_computation_graph.h", + "pcg/mapped_parallel_computation_graph/mapped_parallel_layer_attrs.dtg.h", "", ] [[fields]] -name = "pcg" -type = "::FlexFlow::ParallelComputationGraph" - -[[fields]] -name = "mapped_tasks" -type = "std::unordered_map<::FlexFlow::parallel_layer_guid_t, ::FlexFlow::MappedOperatorTaskGroup>" +name = "raw_graph" +type = "::FlexFlow::LabelledKwargDataflowGraphView<::FlexFlow::MappedParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs, ::FlexFlow::TensorSlotName>" diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h index 0e3db03a91..12c7921282 100644 --- a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h @@ -2,13 +2,39 @@ #define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MAPPED_PARALLEL_COMPUTATION_GRAPH_MAPPED_PARALLEL_COMPUTATION_GRAPH_H #include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" namespace FlexFlow { +std::unordered_set + mpcg_get_parallel_layers(MappedParallelComputationGraph const &); +MappedOperatorTaskGroup + mpcg_get_mapping_for_layer(MappedParallelComputationGraph const &, + parallel_layer_guid_t); + +ParallelComputationGraph pcg_from_mpcg(MappedParallelComputationGraph const &); + +std::unordered_set + mpcg_get_edges(MappedParallelComputationGraph const &); + +MappedParallelComputationGraph mapped_pcg_from_pcg_and_mapped_op_task_groups( + ParallelComputationGraph const &pcg, + std::unordered_map const + &mapped_op_task_groups); + +MappedParallelComputationGraph + mapped_pcg_without_layer_names(MappedParallelComputationGraph const &); + std::string format_as(MappedParallelComputationGraph const &); std::ostream &operator<<(std::ostream &, MappedParallelComputationGraph const &); +bool mapped_pcgs_are_isomorphic(MappedParallelComputationGraph const &, + MappedParallelComputationGraph const &); + +std::string mapped_pcg_as_dot(MappedParallelComputationGraph const &); +void debug_print_mapped_pcg_as_dot(MappedParallelComputationGraph const &); + } // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_attrs.dtg.toml b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_attrs.dtg.toml new file mode 100644 index 0000000000..20320ad54f --- /dev/null +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_attrs.dtg.toml @@ -0,0 +1,35 @@ +namespace = "FlexFlow" +name = "MappedParallelLayerAttrs" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h", + "op-attrs/pcg_operator_attrs.dtg.h", + "utils/stack_string.h", + "", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", +] + +[[fields]] +name = "op_attrs" +type = "::FlexFlow::PCGOperatorAttrs" + +[[fields]] +name = "name" +type = "std::optional<::FlexFlow::stack_string>" + +[[fields]] +name = "mapping" +type = "::FlexFlow::MappedOperatorTaskGroup" diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_attrs.h b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_attrs.h new file mode 100644 index 0000000000..09ce4b6c1e --- /dev/null +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_attrs.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MAPPED_PARALLEL_COMPUTATION_GRAPH_MAPPED_PARALLEL_LAYER_ATTRS_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MAPPED_PARALLEL_COMPUTATION_GRAPH_MAPPED_PARALLEL_LAYER_ATTRS_H + +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_layer_attrs.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" + +namespace FlexFlow { + +ParallelLayerAttrs + unmapped_parallel_layer_attrs_from_mapped(MappedParallelLayerAttrs const &); + +MappedParallelLayerAttrs mapped_parallel_layer_attrs_without_layer_name( + MappedParallelLayerAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index 917200af68..0368be62bc 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -121,7 +121,7 @@ ParallelComputationGraph without_layer_names(ParallelComputationGraph const &); bool pcgs_are_isomorphic(ParallelComputationGraph const &, ParallelComputationGraph const &); -std::string as_dot(ParallelComputationGraph const &); +std::string pcg_as_dot(ParallelComputationGraph const &); void debug_print_dot(ParallelComputationGraph const &); } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc index 4eac3d1cfa..56bfb98856 100644 --- a/lib/pcg/src/pcg/computation_graph.cc +++ b/lib/pcg/src/pcg/computation_graph.cc @@ -29,10 +29,10 @@ #include "utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.h" #include "utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h" #include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h" -#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.h" #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/rewrite_labelled_kwarg_dataflow_graph_node_labels.h" #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/view_as_labelled_open_kwarg_dataflow_graph.h" -#include "utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_as_dot.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.h" #include "utils/graph/node/algorithms.h" #include "utils/record_formatter.h" @@ -326,43 +326,45 @@ bool computation_graphs_are_isomorphic(ComputationGraph const &lhs, } std::string as_dot(ComputationGraph const &cg) { - std::function get_node_label = - [](LayerAttrs const &a) -> std::string { - RecordFormatter r = as_dot(a.op_attrs); - - if (a.name.has_value()) { - RecordFormatter rr; - rr << "Name" << a.name.value(); - r << rr; - } - - std::ostringstream oss; - oss << r; - return oss.str(); + std::function get_node_label = + [](LayerAttrs const &a) -> nlohmann::json { + nlohmann::json result = a; + + return result; }; - std::function get_input_label = - [](TensorAttrs const &a) -> std::string { - RecordFormatter r; + std::function get_input_label = + [](TensorAttrs const &a) -> nlohmann::json { + nlohmann::json result = a; + + return result; + }; - r << fmt::to_string(a.shape); + std::function render_slot_name = + [](TensorSlotName const &s) -> nlohmann::json { + nlohmann::json result = fmt::to_string(s); - std::ostringstream oss; - oss << r; - return oss.str(); + return result; }; + std::function( + std::unordered_set const &)> + order_slots = [](std::unordered_set const &unordered) + -> nlohmann::json { return sorted(unordered); }; + return labelled_open_kwarg_dataflow_graph_view_as_dot( view_as_labelled_open_kwarg_dataflow_graph(cg.raw_graph), get_node_label, - get_input_label); + get_input_label, + render_slot_name, + order_slots); } void debug_print_dot(ComputationGraph const &cg) { - std::cout << as_dot(cg) << std::endl; + std::cerr << as_dot(cg) << std::endl; } } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc index 8f9db7eac7..d0fd3300f5 100644 --- a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc +++ b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc @@ -7,6 +7,7 @@ #include "utils/bidict/generate_bidict.h" #include "utils/containers/are_all_distinct.h" #include "utils/containers/require_all_same.h" +#include "utils/containers/sorted.h" #include "utils/containers/transform.h" #include "utils/containers/vector_of.h" #include "utils/hash/tuple.h" @@ -56,7 +57,27 @@ bool MappedOperatorTaskGroup::operator==( bool MappedOperatorTaskGroup::operator!=( MappedOperatorTaskGroup const &other) const { - return this->tie() == other.tie(); + return this->tie() != other.tie(); +} + +bool MappedOperatorTaskGroup::operator<( + MappedOperatorTaskGroup const &other) const { + return this->tie() < other.tie(); +} + +bool MappedOperatorTaskGroup::operator>( + MappedOperatorTaskGroup const &other) const { + return this->tie() > other.tie(); +} + +bool MappedOperatorTaskGroup::operator<=( + MappedOperatorTaskGroup const &other) const { + return this->tie() <= other.tie(); +} + +bool MappedOperatorTaskGroup::operator>=( + MappedOperatorTaskGroup const &other) const { + return this->tie() >= other.tie(); } std::tuple< @@ -82,6 +103,20 @@ bidict .reversed(); } +nlohmann::json + mapped_operator_task_group_as_dot_json(MappedOperatorTaskGroup const &m) { + + std::vector coordinates = + sorted(m.get_shard_bindings().left_values()); + + return nlohmann::json{ + transform(coordinates, + [&](MachineSpaceCoordinate const &c) -> std::string { + return fmt::format("({}, {})", c.node_idx, c.device_idx); + }), + }; +} + std::string format_as(::FlexFlow::MappedOperatorTaskGroup const &m) { return fmt::format("", m.get_shard_bindings()); diff --git a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc index 17ac533162..f4fa946a66 100644 --- a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc @@ -1,12 +1,104 @@ #include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h" +#include "op-attrs/pcg_operator_attrs.h" +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_layer_attrs.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/bidict/algorithms/transform_keys.h" +#include "utils/containers/transform.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/find_isomorphism_between_kwarg_dataflow_graphs.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_view_as_dot.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/materialize_labelled_kwarg_dataflow_graph_view.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/rewrite_labelled_kwarg_dataflow_graph_node_labels.h" namespace FlexFlow { +std::unordered_set + mpcg_get_parallel_layers(MappedParallelComputationGraph const &mpcg) { + return get_parallel_layers(pcg_from_mpcg(mpcg)); +} + +MappedOperatorTaskGroup + mpcg_get_mapping_for_layer(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t l) { + MappedParallelLayerAttrs layer_attrs = mpcg.raw_graph.at(l.raw_graph_node); + + return layer_attrs.mapping; +} + +ParallelComputationGraph + pcg_from_mpcg(MappedParallelComputationGraph const &mpcg) { + LabelledKwargDataflowGraphView + raw_view = rewrite_labelled_kwarg_dataflow_graph_node_labels( + mpcg.raw_graph, + [](Node const &, + MappedParallelLayerAttrs const &a) -> ParallelLayerAttrs { + return unmapped_parallel_layer_attrs_from_mapped(a); + }); + + LabelledKwargDataflowGraph + raw_graph = materialize_labelled_kwarg_dataflow_graph_view(raw_view); + + return ParallelComputationGraph{ + raw_graph, + }; +} + +MappedParallelComputationGraph mapped_pcg_from_pcg_and_mapped_op_task_groups( + ParallelComputationGraph const &pcg, + std::unordered_map const + &mapped_op_task_groups) { + auto mapping_for_layer = + [&](parallel_layer_guid_t l) -> MappedOperatorTaskGroup { + OperatorType op_type = pcg_op_attrs_get_op_type(pcg_get_op_attrs(pcg, l)); + + return mapped_op_task_groups.at(l); + }; + + auto mpcg_layer_attrs_from_pcg_layer_attrs = + [&](Node const &node, ParallelLayerAttrs const &pcg_layer_attrs) + -> MappedParallelLayerAttrs { + parallel_layer_guid_t l = parallel_layer_guid_t{node}; + + return MappedParallelLayerAttrs{ + /*op_attrs=*/pcg_layer_attrs.op_attrs, + /*name=*/pcg_layer_attrs.name, + /*mapping=*/mapping_for_layer(l), + }; + }; + + LabelledKwargDataflowGraphView + result = rewrite_labelled_kwarg_dataflow_graph_node_labels( + pcg.raw_graph, mpcg_layer_attrs_from_pcg_layer_attrs); + + return MappedParallelComputationGraph{ + result, + }; +} + +MappedParallelComputationGraph + mapped_pcg_without_layer_names(MappedParallelComputationGraph const &mpcg) { + LabelledKwargDataflowGraphView + result = rewrite_labelled_kwarg_dataflow_graph_node_labels( + mpcg.raw_graph, + [&](Node const &, MappedParallelLayerAttrs const &with_name) + -> MappedParallelLayerAttrs { + return mapped_parallel_layer_attrs_without_layer_name(with_name); + }); + + return MappedParallelComputationGraph{ + result, + }; +} + std::string format_as(MappedParallelComputationGraph const &mapped_pcg) { - return fmt::format( - "", - as_dot(mapped_pcg.pcg), - mapped_pcg.mapped_tasks); + return mapped_pcg_as_dot(mapped_pcg); } std::ostream &operator<<(std::ostream &s, @@ -14,4 +106,57 @@ std::ostream &operator<<(std::ostream &s, return (s << fmt::to_string(mapped_pcg)); } +bool mapped_pcgs_are_isomorphic(MappedParallelComputationGraph const &src, + MappedParallelComputationGraph const &dst) { + std::optional> maybe_isomorphism = + find_isomorphism_between_kwarg_dataflow_graphs( + mapped_pcg_without_layer_names(src).raw_graph, + mapped_pcg_without_layer_names(dst).raw_graph); + + return maybe_isomorphism.has_value(); +} + +std::string mapped_pcg_as_dot(MappedParallelComputationGraph const &mpcg) { + + std::function + render_node_label = + [](MappedParallelLayerAttrs const &a) -> nlohmann::json { + nlohmann::json result = pcg_op_attrs_as_dot_json(a.op_attrs); + + if (a.name.has_value()) { + result["Name"] = a.name.value(); + } + + result["Mapping"] = mapped_operator_task_group_as_dot_json(a.mapping); + + return result; + }; + + std::function + render_input_label = [](ParallelTensorAttrs const &a) -> nlohmann::json { + nlohmann::json result = a; + return result; + }; + + std::function render_slot_name = + [](TensorSlotName const &slot_name) -> nlohmann::json { + return fmt::to_string(slot_name); + }; + + std::function( + std::unordered_set const &)> + order_slots = [](std::unordered_set const &slot_names) + -> std::vector { return sorted(slot_names); }; + + return labelled_kwarg_dataflow_graph_view_as_dot(mpcg.raw_graph, + render_node_label, + render_input_label, + render_slot_name, + order_slots); +} + +void debug_print_mapped_pcg_as_dot(MappedParallelComputationGraph const &mpcg) { + std::cerr << mapped_pcg_as_dot(mpcg) << std::endl; +} + } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_attrs.cc b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_attrs.cc new file mode 100644 index 0000000000..b449170ba4 --- /dev/null +++ b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_attrs.cc @@ -0,0 +1,20 @@ +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_layer_attrs.h" + +namespace FlexFlow { + +ParallelLayerAttrs unmapped_parallel_layer_attrs_from_mapped( + MappedParallelLayerAttrs const &mapped) { + return ParallelLayerAttrs{ + /*op_attrs=*/mapped.op_attrs, + /*name=*/mapped.name, + }; +} + +MappedParallelLayerAttrs mapped_parallel_layer_attrs_without_layer_name( + MappedParallelLayerAttrs const &m) { + MappedParallelLayerAttrs result = m; + result.name = std::nullopt; + return result; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/optimizer_attrs.cc b/lib/pcg/src/pcg/optimizer_attrs.cc index 4651292f6e..46192c2103 100644 --- a/lib/pcg/src/pcg/optimizer_attrs.cc +++ b/lib/pcg/src/pcg/optimizer_attrs.cc @@ -30,7 +30,6 @@ std::unordered_set -> std::unordered_set { if (sgd_attrs.momentum > 0.0f) { return {OptimizerSlotName::SGD_V}; - ; } else { return {}; } diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index a7d61d0644..a548ceb65a 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -30,7 +30,7 @@ #include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_edges_from_node_to_node.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_edges_for_node.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h" -#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_view_as_dot.h" #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/rewrite_labelled_kwarg_dataflow_graph_node_labels.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/node/node.dtg.h" @@ -429,25 +429,22 @@ bool pcgs_are_isomorphic(ParallelComputationGraph const &lhs, .has_value(); } -std::string as_dot(ParallelComputationGraph const &cg) { - std::function get_node_label = - [](ParallelLayerAttrs const &a) -> std::string { - RecordFormatter r = as_dot(a.op_attrs); +std::string pcg_as_dot(ParallelComputationGraph const &cg) { + + std::function render_node_label = + [](ParallelLayerAttrs const &a) -> nlohmann::json { + nlohmann::json result = pcg_op_attrs_as_dot_json(a.op_attrs); if (a.name.has_value()) { - RecordFormatter rr; - rr << "Name" << a.name.value(); - r << rr; + result["Name"] = a.name.value(); } - std::ostringstream oss; - oss << r; - return oss.str(); + return result; }; - std::function get_input_label = - [](ParallelTensorAttrs const &a) -> std::string { - RecordFormatter r; + std::function + render_input_label = [](ParallelTensorAttrs const &a) -> nlohmann::json { + RecordFormatter r = mk_empty_record(Orientation::HORIZONTAL); r << fmt::to_string(a.shape); @@ -456,17 +453,25 @@ std::string as_dot(ParallelComputationGraph const &cg) { return oss.str(); }; - return labelled_open_kwarg_dataflow_graph_view_as_dot( - view_as_labelled_open_kwarg_dataflow_graph(cg.raw_graph), - get_node_label, - get_input_label); + std::function render_slot_name = + [](TensorSlotName const &slot_name) -> nlohmann::json { + return fmt::to_string(slot_name); + }; + + std::function( + std::unordered_set const &)> + order_slots = [](std::unordered_set const &slot_names) + -> std::vector { return sorted(slot_names); }; + + return labelled_kwarg_dataflow_graph_view_as_dot(cg.raw_graph, + render_node_label, + render_input_label, + render_slot_name, + order_slots); } void debug_print_dot(ParallelComputationGraph const &cg) { - std::cout << as_dot(cg) << std::endl; + std::cout << pcg_as_dot(cg) << std::endl; } } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index d18fc17621..1d6713dcdb 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -42,7 +42,7 @@ static std::string get_default_name(OperatorType op_type) { } static std::string get_default_name(PCGOperatorAttrs const &attrs) { - return get_default_name(get_op_type(attrs)); + return get_default_name(pcg_op_attrs_get_op_type(attrs)); } ParallelComputationGraphBuilder::ParallelComputationGraphBuilder() diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_attrs.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_attrs.cc index d88f88d4ca..b66d0fc7a2 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_attrs.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_attrs.cc @@ -5,7 +5,7 @@ namespace FlexFlow { OperatorType get_op_type(ParallelLayerAttrs const &a) { - return get_op_type(a.op_attrs); + return pcg_op_attrs_get_op_type(a.op_attrs); } ParallelLayerAttrs diff --git a/lib/pcg/test/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc b/lib/pcg/test/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc new file mode 100644 index 0000000000..7856d89f27 --- /dev/null +++ b/lib/pcg/test/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc @@ -0,0 +1,145 @@ +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h" +#include "op-attrs/initializer_attrs.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "utils/containers/require_only_key.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("mapped_pcgs_are_isomorphic") { + auto make_mpcg = []() -> MappedParallelComputationGraph { + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 8_p, + 3_p, + 6_p, + }, + }, + DataType::FLOAT, + }; + + ParallelComputationGraphBuilder b; + + std::string input1_name = "input1"; + std::string input2_name = "input2"; + std::string partition1_name = "partition1"; + std::string partition2_name = "partition2"; + std::string add_name = "add"; + + parallel_tensor_guid_t t1 = + b.create_input_tensor(input_shape, input1_name); + t1 = b.parallel_partition(t1, ff_dim_t{0_n}, 2_p, partition1_name); + parallel_tensor_guid_t t2 = + b.create_input_tensor(input_shape, input2_name); + t2 = b.parallel_partition(t2, ff_dim_t{0_n}, 2_p, partition2_name); + + parallel_tensor_guid_t t3 = b.add(t1, t2, add_name); + + ParallelComputationGraph pcg = b.pcg; + + parallel_layer_guid_t l_input1 = + get_parallel_layer_by_name(pcg, input1_name); + parallel_layer_guid_t l_input2 = + get_parallel_layer_by_name(pcg, input2_name); + parallel_layer_guid_t l_partition1 = + get_parallel_layer_by_name(pcg, partition1_name); + parallel_layer_guid_t l_partition2 = + get_parallel_layer_by_name(pcg, partition2_name); + parallel_layer_guid_t l_add = get_parallel_layer_by_name(pcg, add_name); + + auto machine_coord = [](nonnegative_int x) -> MachineSpaceCoordinate { + return MachineSpaceCoordinate{ + /*node_idx=*/0_n, + /*device_idx=*/x, + /*device_type=*/DeviceType::GPU, + }; + }; + + auto ptensor_coord = + [](nonnegative_int x) -> ParallelTensorSpaceCoordinate { + return ParallelTensorSpaceCoordinate{ + /*sum_component=*/0_n, + /*discard_copy_component=*/0_n, + /*shard_components=*/FFOrdered{x, 0_n, 0_n}, + }; + }; + + MappedOperatorTaskGroup input_mapping = MappedOperatorTaskGroup{ + bidict{ + {machine_coord(0_n), + OperatorAtomicTaskShardBinding{ + { + {TensorSlotName::OUTPUT, ptensor_coord(0_n)}, + }, + }}, + }, + }; + + MappedOperatorTaskGroup partition_mapping = MappedOperatorTaskGroup{ + bidict{ + {machine_coord(0_n), + OperatorAtomicTaskShardBinding{ + { + {TensorSlotName::OUTPUT, ptensor_coord(0_n)}, + }, + }}, + {machine_coord(1_n), + OperatorAtomicTaskShardBinding{ + { + {TensorSlotName::OUTPUT, ptensor_coord(1_n)}, + }, + }}, + }, + }; + + std::unordered_map + mapped_tasks = { + { + l_input1, + input_mapping, + }, + { + l_input2, + input_mapping, + }, + { + l_partition1, + partition_mapping, + }, + { + l_partition2, + partition_mapping, + }, + {l_add, + MappedOperatorTaskGroup{ + bidict{ + {machine_coord(0_n), + OperatorAtomicTaskShardBinding{ + { + {TensorSlotName::LHS_INPUT, ptensor_coord(0_n)}, + {TensorSlotName::RHS_INPUT, ptensor_coord(0_n)}, + }, + }}, + {machine_coord(1_n), + OperatorAtomicTaskShardBinding{ + { + {TensorSlotName::LHS_INPUT, ptensor_coord(1_n)}, + {TensorSlotName::RHS_INPUT, ptensor_coord(1_n)}, + }, + }}, + }, + }}, + }; + + return mapped_pcg_from_pcg_and_mapped_op_task_groups(pcg, mapped_tasks); + }; + + MappedParallelComputationGraph mpcg1 = make_mpcg(); + MappedParallelComputationGraph mpcg2 = make_mpcg(); + + CHECK(mapped_pcgs_are_isomorphic(mpcg1, mpcg2)); + } +} diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index 4a8edb3b6c..e7ef0da483 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -9,6 +9,7 @@ #include "op-attrs/tensor_slot_name.dtg.h" #include "pcg/device_type.dtg.h" #include "pcg/machine_space_coordinate.dtg.h" +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h" #include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" @@ -150,42 +151,43 @@ TEST_SUITE(FF_TEST_SUITE) { MachineSpaceCoordinate cpu0{0_n, 0_n, DeviceType::CPU}; MachineSpaceCoordinate cpu1{0_n, 1_n, DeviceType::CPU}; ParallelTensorSpaceCoordinate tensor_coord0{0_n, 0_n, FFOrdered{0_n}}; - MappedParallelComputationGraph mpcg{ - pcg, - { - {inputs_layer.parallel_layer, - MappedOperatorTaskGroup{ - {{cpu0, - OperatorAtomicTaskShardBinding{ - {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, - {weights_layer_1.parallel_layer, - MappedOperatorTaskGroup{ - {{cpu0, - OperatorAtomicTaskShardBinding{ - {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, - {weights_layer_2.parallel_layer, - MappedOperatorTaskGroup{ - {{cpu1, - OperatorAtomicTaskShardBinding{ - {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, - {linear_operator_1.parallel_layer, - MappedOperatorTaskGroup{ - {{cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::WEIGHT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}}}}}, - {linear_operator_2.parallel_layer, - MappedOperatorTaskGroup{ - {{cpu1, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::WEIGHT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}}}}}, - }, - }; + MappedParallelComputationGraph mpcg = + mapped_pcg_from_pcg_and_mapped_op_task_groups( + /*pcg=*/pcg, + /*mapped_op_task_groups=*/{ + {inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {weights_layer_1.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {weights_layer_2.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu1, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {linear_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::WEIGHT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}}}}, + {linear_operator_2.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::WEIGHT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}}}}, + }); + MappedOperatorTaskGroup loss_mapping{ {{cpu0, OperatorAtomicTaskShardBinding{{ @@ -362,42 +364,43 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { MachineSpaceCoordinate gpu0{0_n, 0_n, DeviceType::GPU}; ParallelTensorSpaceCoordinate tensor_coord0{0_n, 0_n, FFOrdered{0_n}}; - MappedParallelComputationGraph mpcg{ - pcg, - { - {inputs_layer.parallel_layer, - MappedOperatorTaskGroup{ - {{gpu0, - OperatorAtomicTaskShardBinding{ - {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, - {weights_layer_1.parallel_layer, - MappedOperatorTaskGroup{ - {{gpu0, - OperatorAtomicTaskShardBinding{ - {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, - {weights_layer_2.parallel_layer, - MappedOperatorTaskGroup{ - {{gpu0, - OperatorAtomicTaskShardBinding{ - {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, - {linear_operator_1.parallel_layer, - MappedOperatorTaskGroup{ - {{gpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::WEIGHT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}}}}}, - {linear_operator_2.parallel_layer, - MappedOperatorTaskGroup{ - {{gpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::WEIGHT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}}}}}, - }, - }; + MappedParallelComputationGraph mpcg = + mapped_pcg_from_pcg_and_mapped_op_task_groups( + /*pcg=*/pcg, + /*mapped_op_task_groups=*/{ + {inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + {{gpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {weights_layer_1.parallel_layer, + MappedOperatorTaskGroup{ + {{gpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {weights_layer_2.parallel_layer, + MappedOperatorTaskGroup{ + {{gpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {linear_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + {{gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::WEIGHT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}}}}, + {linear_operator_2.parallel_layer, + MappedOperatorTaskGroup{ + {{gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::WEIGHT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}}}}, + }); + MappedOperatorTaskGroup loss_mapping{ {{gpu0, OperatorAtomicTaskShardBinding{{ diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_open_dataflow_graph.h b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_open_dataflow_graph.h index ddd97a258a..4ca62db5b1 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_open_dataflow_graph.h +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_open_dataflow_graph.h @@ -61,6 +61,11 @@ std::pair, + bidict> + labelled_result = + labelled_open_kwarg_dataflow_graph_from_dynamic_open_dataflow_graph( + g); + + LabelledOpenKwargDataflowGraph + labelled_g = labelled_result.first; + + bidict invocations = labelled_result.second; + + auto dot_for_training_operation_attrs = + [](TrainingOperationAttrs const &training_attrs) -> nlohmann::json { + nlohmann::json result = training_attrs; + + return result; + }; + + std::function render_node_label = + [](DynamicNodeAttrs const &a) -> nlohmann::json { + nlohmann::json result = dynamic_node_attrs_to_serializable(a); + + return result; + }; + + auto render_parallel_tensor_space_coord = + [](ParallelTensorSpaceCoordinate const &c) -> std::string { + std::vector replica_dim_entries = { + fmt::format("+/{}", c.sum_component), + fmt::format("=/{}", c.discard_copy_component), + }; + + std::vector shard_entries = transform( + vector_of(c.shard_components), + [](nonnegative_int x) -> std::string { return fmt::to_string(x); }); + + return ( + "(" + + join_strings(concat_vectors(replica_dim_entries, shard_entries), ", ") + + ")"); + }; + + std::function render_value_label = + [&](DynamicValueAttrs const &a) -> nlohmann::json { + nlohmann::json result = dynamic_value_attrs_to_serializable(a); + return result; + }; + + std::function render_slot_name = + [](DynamicTensorSlot const &slot_name) -> nlohmann::json { + nlohmann::json result = slot_name; + return result; + }; + + std::function( + std::unordered_set const &)> + order_slots = [](std::unordered_set const &slot_names) + -> std::vector { return sorted(slot_names); }; + + return labelled_open_kwarg_dataflow_graph_view_as_dot(labelled_g, + render_node_label, + render_value_label, + render_slot_name, + order_slots); +} + +void debug_print_dynamic_open_dataflow_graph_as_dot( + DynamicOpenDataflowGraph const &g) { + std::cerr << dynamic_open_dataflow_graph_as_dot(g) << std::endl; +} + } // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc index 246f9a3242..380c2d17a1 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc @@ -1,6 +1,7 @@ #include "task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/pcg_operator_attrs.h" +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" #include "task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.h" @@ -17,23 +18,24 @@ DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mapped_pcg( MappedParallelComputationGraph const &mpcg) { DynamicOpenDataflowGraph result = make_empty_dynamic_open_dataflow_graph(); - for (auto const &[layer, attrs] : - get_parallel_layer_attrs_mapping(mpcg.pcg)) { + ParallelComputationGraph pcg = pcg_from_mpcg(mpcg); + + for (auto const &[layer, attrs] : get_parallel_layer_attrs_mapping(pcg)) { DynamicNodeAttrs result_attrs{ /*task_type=*/std::nullopt, /*device_coord=*/std::nullopt, - /*mapping=*/mpcg.mapped_tasks.at(layer), + /*mapping=*/mpcg_get_mapping_for_layer(mpcg, layer), /*op_attrs=*/TrainingOperationAttrs{attrs.op_attrs}, /*pcg_layer_guid=*/dynamic_layer_guid_t{layer}, /*per_device_op_state=*/std::nullopt, }; std::unordered_map result_inputs = - transform(get_incoming_tensors(mpcg.pcg, layer), + transform(get_incoming_tensors(pcg, layer), [&](TensorSlotName const &slot_name, parallel_tensor_guid_t const &tensor) { ParallelTensorAttrs attrs = - get_parallel_tensor_attrs(mpcg.pcg, tensor); + get_parallel_tensor_attrs(pcg, tensor); return std::pair{ DynamicTensorSlot{ /*slot_name=*/slot_name, @@ -50,11 +52,11 @@ DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mapped_pcg( }; }); std::unordered_map result_outputs = - transform(get_outgoing_tensors(mpcg.pcg, layer), + transform(get_outgoing_tensors(pcg, layer), [&](TensorSlotName const &slot_name, parallel_tensor_guid_t const &tensor) { ParallelTensorAttrs attrs = - get_parallel_tensor_attrs(mpcg.pcg, tensor); + get_parallel_tensor_attrs(pcg, tensor); return std::pair{ DynamicTensorSlot{ /*slot_name=*/slot_name, diff --git a/lib/utils/include/utils/dot_file.h b/lib/utils/include/utils/dot/dot_file.h similarity index 83% rename from lib/utils/include/utils/dot_file.h rename to lib/utils/include/utils/dot/dot_file.h index 214e6eeddc..eaef0832ae 100644 --- a/lib/utils/include/utils/dot_file.h +++ b/lib/utils/include/utils/dot/dot_file.h @@ -1,7 +1,10 @@ -#ifndef _DOT_FILE_H -#define _DOT_FILE_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_DOT_DOT_FILE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_DOT_DOT_FILE_H -#include "record_formatter.h" +#include "utils/containers/flatmap.h" +#include "utils/dot/dot_file.h" +#include "utils/dot/render_dot_html_table_to_string.h" +#include "utils/record_formatter.h" #include #include #include @@ -12,6 +15,13 @@ #include #include +namespace FlexFlow { + +/** + * \brief A helper interface for generating DOT/graphviz output + * + * \note This is very old code and should not be emulated stylistically + */ template class DotFile { private: @@ -23,14 +33,17 @@ class DotFile { std::unordered_map> subgraph_parents; std::optional owned_fstream = std::nullopt; std::ostream *out = nullptr; + std::string get_node_name(size_t node_id) const { std::ostringstream s; s << "node" << node_id; return s.str(); } + bool has_ostream() const { return this->owned_fstream.has_value() || this->out != nullptr; } + std::ostream &get_ostream() { bool has_owned_stream = this->owned_fstream.has_value(); bool has_stream_ref = (this->out != nullptr); @@ -43,15 +56,18 @@ class DotFile { throw std::runtime_error("No ostream value set"); } } + void start_output() { this->get_ostream() << "digraph taskgraph {" << std::endl; } public: DotFile() {} + DotFile(std::string const &filename) : owned_fstream(filename) { this->start_output(); } + DotFile(std::ostream &s) : node_id(0), out(&s) { this->start_output(); } @@ -60,11 +76,13 @@ class DotFile { this->owned_fstream = std::ofstream(filename); this->start_output(); } + void reserve_node(T const &t) { if (this->node_ids.find(t) == this->node_ids.end()) { this->node_ids[t] = this->node_id++; } } + void add_node(T const &t, std::map const ¶ms) { this->reserve_node(t); this->get_ostream() << " " << this->get_node_name(this->node_ids.at(t)) @@ -77,12 +95,33 @@ class DotFile { } this->get_ostream() << "];" << std::endl; } + void add_record_node(T const &t, RecordFormatter const &rf) { std::ostringstream oss; - oss << "\"" << rf << "\""; + + oss << "\""; + if (rf.orientation == Orientation::HORIZONTAL) { + oss << "{ " << rf << " }"; + } else { + oss << rf; + } + oss << "\""; + this->add_node(t, {{"shape", "record"}, {"label", oss.str()}}); } + void add_html_node(T const &t, DotHtmlTable const &table) { + std::ostringstream oss; + + oss << "<" << render_dot_html_table_to_string(table) << ">"; + + this->add_node(t, + { + {"label", oss.str()}, + {"shape", "plaintext"}, + }); + } + void dump_subgraph(size_t subgraph) { this->get_ostream() << "subgraph cluster_" << subgraph << " {" << std::endl; for (size_t node_id : this->subgraphs.at(subgraph)) { @@ -118,6 +157,7 @@ class DotFile { << " -> " << dst_name << get_field_suffix(dst_field) << ";" << std::endl; } + void close() { for (size_t subgraph = 0; subgraph < this->subgraph_id; subgraph++) { if (!this->subgraph_parents.at(subgraph).has_value()) { @@ -157,4 +197,6 @@ class DotFile { } }; -#endif // _DOT_FILE_H +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/dot/dot_html_from_json.h b/lib/utils/include/utils/dot/dot_html_from_json.h new file mode 100644 index 0000000000..5eed3be33c --- /dev/null +++ b/lib/utils/include/utils/dot/dot_html_from_json.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_DOT_DOT_HTML_FROM_JSON_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_DOT_DOT_HTML_FROM_JSON_H + +#include "utils/dot/dot_html_table.dtg.h" +#include + +namespace FlexFlow { + +DotHtmlTable dot_html_table_from_json(nlohmann::json const &); +DotHtmlTableCell dot_html_cell_from_json(nlohmann::json const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/dot/dot_html_table.dtg.toml b/lib/utils/include/utils/dot/dot_html_table.dtg.toml new file mode 100644 index 0000000000..ae33fd18f3 --- /dev/null +++ b/lib/utils/include/utils/dot/dot_html_table.dtg.toml @@ -0,0 +1,34 @@ +namespace = "FlexFlow" +name = "DotHtmlTable" +type = "struct" +features = [ + "eq", + "fmt", +] + +includes = [ + "utils/nonnegative_int/nonnegative_int.h", + "utils/dot/dot_html_table_row.dtg.h", +] + +src_includes = [ + "utils/fmt/vector.h", + "utils/ord/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "border" +type = "::FlexFlow::nonnegative_int" + +[[fields]] +name = "cellborder" +type = "::FlexFlow::nonnegative_int" + +[[fields]] +name = "cellspacing" +type = "::FlexFlow::nonnegative_int" + +[[fields]] +name = "rows" +type = "std::vector<::FlexFlow::DotHtmlTableRow>" diff --git a/lib/utils/include/utils/dot/dot_html_table_cell.dtg.toml b/lib/utils/include/utils/dot/dot_html_table_cell.dtg.toml new file mode 100644 index 0000000000..a2600cdb36 --- /dev/null +++ b/lib/utils/include/utils/dot/dot_html_table_cell.dtg.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "DotHtmlTableCell" +type = "struct" +features = [ + "eq", + "fmt", +] + +includes = [ + "utils/dot/dot_html_table_cell_contents.h", + "utils/positive_int/positive_int.h", + "", + "", +] + +src_includes = [ + "utils/fmt/optional.h", +] + +[[fields]] +name = "content" +type = "::FlexFlow::DotHtmlTableCellContents" + +[[fields]] +name = "port" +type = "std::optional" + +[[fields]] +name = "colspan" +type = "std::optional<::FlexFlow::positive_int>" diff --git a/lib/utils/include/utils/dot/dot_html_table_cell_contents.h b/lib/utils/include/utils/dot/dot_html_table_cell_contents.h new file mode 100644 index 0000000000..54dc529f49 --- /dev/null +++ b/lib/utils/include/utils/dot/dot_html_table_cell_contents.h @@ -0,0 +1,36 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_DOT_DOT_HTML_TABLE_CELL_CONTENTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_DOT_DOT_HTML_TABLE_CELL_CONTENTS_H + +#include +#include +#include + +namespace FlexFlow { + +struct DotHtmlTable; + +struct DotHtmlTableCellContents { +public: + DotHtmlTableCellContents() = delete; + explicit DotHtmlTableCellContents(std::string const &); + explicit DotHtmlTableCellContents(DotHtmlTable const &); + + DotHtmlTable const &require_nested() const; + std::string const &require_simple() const; + + bool is_simple() const; + bool is_nested() const; + + bool operator==(DotHtmlTableCellContents const &) const; + bool operator!=(DotHtmlTableCellContents const &) const; + +private: + std::variant> value; +}; + +std::string format_as(DotHtmlTableCellContents const &); +std::ostream &operator<<(std::ostream &, DotHtmlTableCellContents const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/dot/dot_html_table_row.dtg.toml b/lib/utils/include/utils/dot/dot_html_table_row.dtg.toml new file mode 100644 index 0000000000..708fd88090 --- /dev/null +++ b/lib/utils/include/utils/dot/dot_html_table_row.dtg.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "DotHtmlTableRow" +type = "struct" +features = [ + "eq", + "fmt", +] + +includes = [ + "", + "utils/dot/dot_html_table_cell.dtg.h", +] + +src_includes = [ + "utils/fmt/vector.h", + "utils/ord/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "cells" +type = "std::vector<::FlexFlow::DotHtmlTableCell>" diff --git a/lib/utils/include/utils/dot/render_dot_html_table_to_string.h b/lib/utils/include/utils/dot/render_dot_html_table_to_string.h new file mode 100644 index 0000000000..f9d6db54ac --- /dev/null +++ b/lib/utils/include/utils/dot/render_dot_html_table_to_string.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_DOT_RENDER_DOT_HTML_TABLE_TO_STRING_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_DOT_RENDER_DOT_HTML_TABLE_TO_STRING_H + +#include "utils/dot/dot_html_table.dtg.h" + +namespace FlexFlow { + +std::string render_dot_html_table_to_string(DotHtmlTable const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/as_dot.h b/lib/utils/include/utils/full_binary_tree/as_dot.h index e104d05e06..14631974df 100644 --- a/lib/utils/include/utils/full_binary_tree/as_dot.h +++ b/lib/utils/include/utils/full_binary_tree/as_dot.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_AS_DOT_H #include "utils/containers/get_only.h" -#include "utils/dot_file.h" +#include "utils/dot/dot_file.h" #include "utils/full_binary_tree/full_binary_tree_implementation.dtg.h" #include "utils/full_binary_tree/full_binary_tree_visitor.dtg.h" #include "utils/full_binary_tree/visit.h" @@ -14,7 +14,7 @@ #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" #include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h" #include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h" -#include "utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_as_dot.h" #include #include #include @@ -67,13 +67,29 @@ std::string LabelledDataflowGraphView g = as_labelled_dataflow_graph(tree, impl, get_parent_label, get_leaf_label); - std::function get_node_label = + std::function render_node_label = [](std::string const &s) { return s; }; - std::function get_input_label = + + std::function render_value_label = [](std::monostate const &) { return ""; }; - return as_dot( - view_as_labelled_open_dataflow_graph(g), get_node_label, get_input_label); + std::function + render_dataflow_graph_input = + [](DataflowGraphInput const &) { return ""; }; + + std::function render_dataflow_input = + [](DataflowInput const &) { return ""; }; + + std::function render_dataflow_output = + [](DataflowOutput const &) { return ""; }; + + return labelled_open_dataflow_graph_as_dot( + view_as_labelled_open_dataflow_graph(g), + render_node_label, + render_value_label, + render_dataflow_graph_input, + render_dataflow_input, + render_dataflow_output); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index 2d466a5d9f..f78a0225d5 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_UTILS_GRAPH_ALGORITHMS_H #define _FLEXFLOW_UTILS_GRAPH_ALGORITHMS_H -#include "utils/dot_file.h" +#include "utils/dot/dot_file.h" #include "utils/graph/digraph/digraph.h" #include "utils/graph/graph_split.dtg.h" #include "utils/graph/node/graph.h" diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/as_dot.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/as_dot.h deleted file mode 100644 index 6c9626ce00..0000000000 --- a/lib/utils/include/utils/graph/dataflow_graph/algorithms/as_dot.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_AS_DOT_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_AS_DOT_H - -#include "utils/dot_file.h" -#include "utils/graph/dataflow_graph/dataflow_graph_view.h" - -namespace FlexFlow { - -std::string as_dot(DataflowGraphView const &); -void as_dot(DotFile &, - DataflowGraphView const &, - std::function const &get_node_label); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.h new file mode 100644 index 0000000000..1c2fbcf671 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.h @@ -0,0 +1,30 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_DATAFLOW_GRAPH_AS_DOT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_DATAFLOW_GRAPH_AS_DOT_H + +#include "utils/dot/dot_file.h" +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" + +namespace FlexFlow { + +std::string dataflow_graph_as_dot( + DataflowGraphView const &, + std::optional> const + &get_node_label = std::nullopt, + std::optional> const + &get_input_label = std::nullopt, + std::optional> const + &get_output_label = std::nullopt); + +void dataflow_graph_as_dot( + DotFile &, + DataflowGraphView const &, + std::optional> const + &get_node_label = std::nullopt, + std::optional> const + &get_input_label = std::nullopt, + std::optional> const + &get_output_label = std::nullopt); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graph_data.dtg.toml b/lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graph_data.dtg.toml new file mode 100644 index 0000000000..a957ad82c7 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graph_data.dtg.toml @@ -0,0 +1,33 @@ +namespace = "FlexFlow" +name = "DataflowGraphData" +type = "struct" +features = [ + "eq", + "hash", + "fmt", + "json", +] + +includes = [ + "utils/graph/node/node.dtg.h", + "utils/graph/dataflow_graph/dataflow_edge.dtg.h", + "utils/graph/dataflow_graph/dataflow_output.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "nodes" +type = "std::unordered_set<::FlexFlow::Node>" + +[[fields]] +name = "edges" +type = "std::unordered_set<::FlexFlow::DataflowEdge>" + +[[fields]] +name = "outputs" +type = "std::unordered_set<::FlexFlow::DataflowOutput>" diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/view_from_dataflow_graph_data.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/view_from_dataflow_graph_data.h new file mode 100644 index 0000000000..39a38e88aa --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/view_from_dataflow_graph_data.h @@ -0,0 +1,30 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_VIEW_FROM_DATAFLOW_GRAPH_DATA_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_VIEW_FROM_DATAFLOW_GRAPH_DATA_H + +#include "utils/graph/dataflow_graph/algorithms/dataflow_graph_data.dtg.h" +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" +#include "utils/graph/dataflow_graph/i_dataflow_graph_view.h" + +namespace FlexFlow { + +struct ViewFromDataflowGraphData final : virtual public IDataflowGraphView { + +public: + explicit ViewFromDataflowGraphData(DataflowGraphData const &); + + std::unordered_set query_nodes(NodeQuery const &query) const override; + std::unordered_set + query_edges(DataflowEdgeQuery const &query) const override; + std::unordered_set + query_outputs(DataflowOutputQuery const &query) const override; + ViewFromDataflowGraphData *clone() const override; + +private: + DataflowGraphData data; +}; + +DataflowGraphView view_from_dataflow_graph_data(DataflowGraphData const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge.dtg.toml b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge.dtg.toml index 7a73c1a8aa..b8aec00152 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge.dtg.toml +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge.dtg.toml @@ -6,6 +6,7 @@ features = [ "ord", "hash", "fmt", + "json", ] includes = [ diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_input.dtg.toml b/lib/utils/include/utils/graph/dataflow_graph/dataflow_input.dtg.toml index 8169d1f736..3fb0af86d0 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_input.dtg.toml +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_input.dtg.toml @@ -6,6 +6,7 @@ features = [ "ord", "hash", "fmt", + "json", ] includes = [ diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output.dtg.toml b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output.dtg.toml index dee7152aa2..69aae6d17e 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output.dtg.toml +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output.dtg.toml @@ -6,6 +6,7 @@ features = [ "ord", "hash", "fmt", + "json", ] includes = [ diff --git a/lib/utils/include/utils/graph/digraph/algorithms/digraph_as_dot.h b/lib/utils/include/utils/graph/digraph/algorithms/digraph_as_dot.h index ee533a1180..443147a709 100644 --- a/lib/utils/include/utils/graph/digraph/algorithms/digraph_as_dot.h +++ b/lib/utils/include/utils/graph/digraph/algorithms/digraph_as_dot.h @@ -2,12 +2,13 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_DIGRAPH_AS_DOT_H #include "utils/graph/digraph/digraph_view.h" +#include namespace FlexFlow { std::string digraph_as_dot( DiGraphView const &, - std::function const &get_node_label); + std::function const &get_node_label); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/instances/unordered_set_kwarg_dataflow_graph.h b/lib/utils/include/utils/graph/instances/unordered_set_kwarg_dataflow_graph.h new file mode 100644 index 0000000000..418346bb36 --- /dev/null +++ b/lib/utils/include/utils/graph/instances/unordered_set_kwarg_dataflow_graph.h @@ -0,0 +1,130 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_UNORDERED_SET_KWARG_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_UNORDERED_SET_KWARG_DATAFLOW_GRAPH_H + +#include "utils/containers/generate_map.h" +#include "utils/containers/set_union.h" +#include "utils/containers/values.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_edges.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_outputs.h" +#include "utils/graph/kwarg_dataflow_graph/i_kwarg_dataflow_graph.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge_query.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output_query.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/node/node_source.h" + +namespace FlexFlow { + +template +struct UnorderedSetKwargDataflowGraph final + : public IKwargDataflowGraph { + UnorderedSetKwargDataflowGraph() = default; + + KwargNodeAddedResult add_node( + std::unordered_map> const &inputs, + std::unordered_set const &output_slots) override { + + Node new_node = this->node_source.new_node(); + + std::unordered_map> outputs = + generate_map( + output_slots, + [&](SlotName const &output_slot) -> KwargDataflowOutput { + KwargDataflowOutput output = + KwargDataflowOutput{ + /*node=*/new_node, + /*slot_name=*/output_slot, + }; + + this->outputs.insert(output); + + return output; + }); + + this->add_node_unsafe(new_node, inputs, outputs); + + return KwargNodeAddedResult{ + /*node=*/new_node, + /*outputs=*/outputs, + }; + } + + void add_node_unsafe( + Node const &node, + std::unordered_map> const &inputs, + std::unordered_map> const + &outputs) override { + this->nodes.insert(node); + + for (auto const &[input_slot_name, src] : inputs) { + KwargDataflowInput dst = KwargDataflowInput{ + node, + input_slot_name, + }; + + KwargDataflowEdge in_edge = KwargDataflowEdge{ + /*src=*/src, + /*dst=*/dst, + }; + + this->edges.insert(in_edge); + } + + this->outputs = set_union(this->outputs, unordered_set_of(values(outputs))); + } + + std::unordered_set query_nodes(NodeQuery const &q) const override { + return filter(this->nodes, + [&](Node const &n) { return includes(q.nodes, n); }); + } + + std::unordered_set> + query_edges(KwargDataflowEdgeQuery const &q) const override { + return filter(this->edges, [&](KwargDataflowEdge const &e) { + return kwarg_dataflow_edge_query_includes(q, e); + }); + } + + std::unordered_set> query_outputs( + KwargDataflowOutputQuery const &q) const override { + return filter(this->outputs, + [&](KwargDataflowOutput const &output) { + return kwarg_dataflow_output_query_includes(q, output); + }); + } + + void inplace_materialize_from( + KwargDataflowGraphView const &v) override { + this->nodes = get_nodes(v); + this->edges = get_all_kwarg_dataflow_edges(v); + this->outputs = get_all_kwarg_dataflow_outputs(v); + }; + + UnorderedSetKwargDataflowGraph *clone() const override { + return new UnorderedSetKwargDataflowGraph{ + this->node_source, + this->nodes, + this->edges, + this->outputs, + }; + } + +private: + UnorderedSetKwargDataflowGraph( + NodeSource const &node_source, + std::unordered_set const &nodes, + std::unordered_set> const &edges, + std::unordered_set> const &outputs) + : node_source(node_source), nodes(nodes), edges(edges), outputs(outputs) { + } + +private: + NodeSource node_source; + + std::unordered_set nodes; + std::unordered_set> edges; + std::unordered_set> outputs; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_data_from_kwarg_dataflow_graph_data.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_data_from_kwarg_dataflow_graph_data.h new file mode 100644 index 0000000000..c739bf0e82 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_data_from_kwarg_dataflow_graph_data.h @@ -0,0 +1,98 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_DATAFLOW_GRAPH_DATA_FROM_KWARG_DATAFLOW_GRAPH_DATA_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_DATAFLOW_GRAPH_DATA_FROM_KWARG_DATAFLOW_GRAPH_DATA_H + +#include "utils/containers/group_by.h" +#include "utils/containers/index_of.h" +#include "utils/containers/map_values.h" +#include "utils/containers/transform.h" +#include "utils/graph/dataflow_graph/algorithms/dataflow_graph_data.dtg.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_data.dtg.h" +#include "utils/nonempty_unordered_set/nonempty_unordered_set.h" +#include "utils/one_to_many/one_to_many_transform_values.h" + +namespace FlexFlow { + +template +DataflowGraphData dataflow_graph_data_from_kwarg_dataflow_graph_data( + KwargDataflowGraphData const &kwarg_data, + std::function( + std::unordered_set const &)> const &order_slots) { + std::unordered_set> all_inputs = transform( + kwarg_data.edges, + [](KwargDataflowEdge const &e) -> KwargDataflowInput { + return e.dst; + }); + + std::unordered_set> all_outputs = + kwarg_data.outputs; + + OneToMany incoming_slots_by_node = + one_to_many_transform_values( + group_by(all_inputs, + [](KwargDataflowInput const &i) -> Node { + return i.node; + }), + [](KwargDataflowInput const &i) -> SlotName { + return i.slot_name; + }); + + OneToMany outgoing_slots_by_node = + one_to_many_transform_values( + group_by(all_outputs, + [](KwargDataflowOutput const &o) -> Node { + return o.node; + }), + [](KwargDataflowOutput const &o) -> SlotName { + return o.slot_name; + }); + + auto dataflow_input_from_kwarg_input = + [&](KwargDataflowInput const &i) -> DataflowInput { + std::vector slot_ordering = order_slots( + incoming_slots_by_node.at_l(i.node).unwrap_as_unordered_set()); + + return DataflowInput{ + /*node=*/i.node, + /*idx=*/ + nonnegative_int{ + index_of(slot_ordering, i.slot_name).value(), + }, + }; + }; + + auto dataflow_output_from_kwarg_output = + [&](KwargDataflowOutput const &o) -> DataflowOutput { + std::vector slot_ordering = order_slots( + outgoing_slots_by_node.at_l(o.node).unwrap_as_unordered_set()); + + return DataflowOutput{ + /*node=*/o.node, + /*idx=*/ + nonnegative_int{ + index_of(slot_ordering, o.slot_name).value(), + }, + }; + }; + + auto dataflow_edge_from_kwarg_dataflow_edge = + [&](KwargDataflowEdge const &kwarg_edge) -> DataflowEdge { + return DataflowEdge{ + /*src=*/dataflow_output_from_kwarg_output(kwarg_edge.src), + /*dst=*/dataflow_input_from_kwarg_input(kwarg_edge.dst), + }; + }; + + DataflowGraphData result_data = DataflowGraphData{ + /*nodes=*/kwarg_data.nodes, + /*edges=*/ + transform(kwarg_data.edges, dataflow_edge_from_kwarg_dataflow_edge), + /*outputs=*/ + transform(kwarg_data.outputs, dataflow_output_from_kwarg_output), + }; + + return result_data; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_from_kwarg_dataflow_graph.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_from_kwarg_dataflow_graph.h new file mode 100644 index 0000000000..d07250ae5b --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_from_kwarg_dataflow_graph.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_DATAFLOW_GRAPH_FROM_KWARG_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_DATAFLOW_GRAPH_FROM_KWARG_DATAFLOW_GRAPH_H + +#include "utils/graph/dataflow_graph/algorithms/view_from_dataflow_graph_data.h" +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_data_from_kwarg_dataflow_graph_data.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_data.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +DataflowGraphView dataflow_graph_from_kwarg_dataflow_graph( + KwargDataflowGraphView const &kwarg_dg, + std::function( + std::unordered_set const &)> const &order_slots) { + KwargDataflowGraphData kwarg_data = + get_kwarg_dataflow_graph_data(kwarg_dg); + + DataflowGraphData result_data = + dataflow_graph_data_from_kwarg_dataflow_graph_data(kwarg_data, + order_slots); + + return view_from_dataflow_graph_data(result_data); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_inputs.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_inputs.h new file mode 100644 index 0000000000..97d13b498f --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_inputs.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_ALL_KWARG_DATAFLOW_INPUTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_ALL_KWARG_DATAFLOW_INPUTS_H + +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge_query.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +std::unordered_set> + get_all_kwarg_dataflow_inputs(KwargDataflowGraphView const &v) { + return transform( + v.query_edges(kwarg_dataflow_edge_query_all()), + [](KwargDataflowEdge const &e) -> KwargDataflowInput { + return e.dst; + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_kwarg_dataflow_edges_for_node.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_kwarg_dataflow_edges_for_node.h index 2c57970736..8e7d111ecf 100644 --- a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_kwarg_dataflow_edges_for_node.h +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_kwarg_dataflow_edges_for_node.h @@ -13,7 +13,7 @@ std::unordered_map> KwargDataflowEdgeQuery query = KwargDataflowEdgeQuery{ /*src_nodes=*/query_set::matchall(), /*src_slots=*/query_set::matchall(), - /*dst_nodes=*/query_set{n}, + /*dst_nodes=*/query_set::match_single_value(n), /*dst_slots=*/query_set::matchall(), }; diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_slots_for_node.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_slots_for_node.h new file mode 100644 index 0000000000..32848f38a6 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_slots_for_node.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_INCOMING_SLOTS_FOR_NODE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_INCOMING_SLOTS_FOR_NODE_H + +#include "utils/containers/keys.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_kwarg_dataflow_edges_for_node.h" + +namespace FlexFlow { + +template +std::unordered_set + get_incoming_slots_for_node(KwargDataflowGraphView const &g, + Node n) { + return keys(get_incoming_kwarg_dataflow_edges_for_node(g, n)); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_edges_from_node_to_node.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_edges_from_node_to_node.h index 3fe2d48c6a..45a9fddc5b 100644 --- a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_edges_from_node_to_node.h +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_edges_from_node_to_node.h @@ -12,9 +12,9 @@ std::unordered_set> Node const &src, Node const &dst) { KwargDataflowEdgeQuery query = KwargDataflowEdgeQuery{ - /*src_nodes=*/query_set{src}, + /*src_nodes=*/query_set::match_single_value(src), /*src_slots=*/query_set::matchall(), - /*dst_nodes=*/query_set{dst}, + /*dst_nodes=*/query_set::match_single_value(dst), /*dst_slots=*/query_set::matchall(), }; diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_data.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_data.h new file mode 100644 index 0000000000..365c09486c --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_data.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_KWARG_DATAFLOW_GRAPH_DATA_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_KWARG_DATAFLOW_GRAPH_DATA_H + +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_edges.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_outputs.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_data.dtg.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +template +KwargDataflowGraphData + get_kwarg_dataflow_graph_data(KwargDataflowGraphView const &g) { + return KwargDataflowGraphData{ + /*nodes=*/get_nodes(g), + /*edges=*/get_all_kwarg_dataflow_edges(g), + /*outputs=*/get_all_kwarg_dataflow_outputs(g), + }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_subgraph_incoming_edges.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_subgraph_incoming_edges.h index 8e9feaf3b5..908b805e58 100644 --- a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_subgraph_incoming_edges.h +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_subgraph_incoming_edges.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_KWARG_DATAFLOW_SUBGRAPH_INCOMING_EDGES_H #include "utils/containers/set_minus.h" +#include "utils/containers/set_of.h" #include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" #include "utils/graph/node/algorithms.h" @@ -13,12 +14,13 @@ std::unordered_set> KwargDataflowGraphView const &g, std::unordered_set const &subgraph) { std::unordered_set all_nodes = get_nodes(g); - query_set src_query = query_set{set_minus(all_nodes, subgraph)}; + query_set src_query = + query_set::match_values_in(set_of(set_minus(all_nodes, subgraph))); KwargDataflowEdgeQuery query = KwargDataflowEdgeQuery{ /*src_nodes=*/src_query, /*src_slots=*/query_set::matchall(), - /*dst_nodes=*/query_set{subgraph}, + /*dst_nodes=*/query_set::match_values_in(set_of(subgraph)), /*dst_slots=*/query_set::matchall(), }; diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_subgraph_outgoing_edges.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_subgraph_outgoing_edges.h index 532e37f8ec..5b86f9492f 100644 --- a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_subgraph_outgoing_edges.h +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_subgraph_outgoing_edges.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_KWARG_DATAFLOW_SUBGRAPH_OUTGOING_EDGES_H #include "utils/containers/set_minus.h" +#include "utils/containers/set_of.h" #include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" #include "utils/graph/node/algorithms.h" @@ -13,10 +14,11 @@ std::unordered_set> KwargDataflowGraphView const &g, std::unordered_set const &subgraph) { std::unordered_set all_nodes = get_nodes(g); - query_set dst_query = query_set{set_minus(all_nodes, subgraph)}; + query_set dst_query = + query_set::match_values_in(set_of(set_minus(all_nodes, subgraph))); KwargDataflowEdgeQuery query = KwargDataflowEdgeQuery{ - /*src_nodes=*/query_set{subgraph}, + /*src_nodes=*/query_set::match_values_in(set_of(subgraph)), /*src_slots=*/query_set::matchall(), /*dst_nodes=*/dst_query, /*dst_slots=*/query_set::matchall(), diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_edges_for_node.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_edges_for_node.h index 7ab7b80199..3a1b90d06f 100644 --- a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_edges_for_node.h +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_edges_for_node.h @@ -12,7 +12,7 @@ OneToMany> get_outgoing_kwarg_dataflow_edges_for_node( KwargDataflowGraphView const &g, Node const &n) { KwargDataflowEdgeQuery query = KwargDataflowEdgeQuery{ - /*src_nodes=*/query_set{n}, + /*src_nodes=*/query_set::match_single_value(n), /*src_slots=*/query_set::matchall(), /*dst_nodes=*/query_set::matchall(), /*dst_slots=*/query_set::matchall(), diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h index 72eb9c810d..8b70dd80ff 100644 --- a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h @@ -10,7 +10,7 @@ std::unordered_map> get_outgoing_kwarg_dataflow_outputs_for_node( KwargDataflowGraphView const &g, Node const &n) { KwargDataflowOutputQuery query = KwargDataflowOutputQuery{ - /*nodes=*/query_set{n}, + /*nodes=*/query_set::match_single_value(n), /*output_idxs=*/query_set::matchall(), }; diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_slots_for_node.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_slots_for_node.h new file mode 100644 index 0000000000..372dfed1e8 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_slots_for_node.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_OUTGOING_SLOTS_FOR_NODE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_OUTGOING_SLOTS_FOR_NODE_H + +#include "utils/containers/keys.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h" + +namespace FlexFlow { + +template +std::unordered_set + get_outgoing_slots_for_node(KwargDataflowGraphView const &g, + Node n) { + return keys(get_outgoing_kwarg_dataflow_outputs_for_node(g, n)); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_as_dot.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_as_dot.h new file mode 100644 index 0000000000..aacdadee4f --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_as_dot.h @@ -0,0 +1,60 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_KWARG_DATAFLOW_GRAPH_AS_DOT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_KWARG_DATAFLOW_GRAPH_AS_DOT_H + +#include "utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_from_kwarg_dataflow_graph.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_slots_for_node.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_slots_for_node.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +std::string kwarg_dataflow_graph_as_dot( + KwargDataflowGraphView const &g, + std::function const &render_node, + std::function const &)> const + &render_value, + std::function const &render_slot_name, + std::function( + std::unordered_set const &)> const &order_slots) { + std::function get_input_label = + [&](DataflowInput const &i) -> nlohmann::json { + std::vector slot_ordering = + order_slots(get_incoming_slots_for_node(g, i.node)); + + SlotName slot_name = slot_ordering.at(i.idx.unwrap_nonnegative()); + + return render_slot_name(slot_name); + }; + + std::function get_output_label = + [&](DataflowOutput const &o) -> nlohmann::json { + std::vector slot_ordering = + order_slots(get_outgoing_slots_for_node(g, o.node)); + + SlotName slot_name = slot_ordering.at(o.idx.unwrap_nonnegative()); + + nlohmann::json result; + + result["slot"] = render_slot_name(slot_name); + + KwargDataflowOutput kwarg_o = KwargDataflowOutput{ + /*node=*/o.node, + /*slot_name=*/slot_name, + }; + result["value"] = render_value(kwarg_o); + + return result; + }; + + return dataflow_graph_as_dot( + dataflow_graph_from_kwarg_dataflow_graph(g, order_slots), + render_node, + /*get_input_label=*/get_input_label, + /*get_output_label=*/get_output_label); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_data.dtg.toml b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_data.dtg.toml new file mode 100644 index 0000000000..e3429f16d0 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_data.dtg.toml @@ -0,0 +1,36 @@ +namespace = "FlexFlow" +name = "KwargDataflowGraphData" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = [ + "SlotName", +] + +includes = [ + "utils/graph/node/node.dtg.h", + "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge.dtg.h", + "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "nodes" +type = "std::unordered_set<::FlexFlow::Node>" + +[[fields]] +name = "edges" +type = "std::unordered_set<::FlexFlow::KwargDataflowEdge>" + +[[fields]] +name = "outputs" +type = "std::unordered_set<::FlexFlow::KwargDataflowOutput>" diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_data.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_data.h new file mode 100644 index 0000000000..62644a5c6b --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_data.h @@ -0,0 +1,43 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_KWARG_DATAFLOW_GRAPH_DATA_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_KWARG_DATAFLOW_GRAPH_DATA_H + +#include "utils/containers/flatmap.h" +#include "utils/containers/is_subseteq_of.h" +#include "utils/containers/transform.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_data.dtg.h" +#include + +namespace FlexFlow { + +template +void require_kwarg_dataflow_graph_data_is_valid( + KwargDataflowGraphData const &data) { + + std::unordered_set nodes_from_edges = flatmap( + data.edges, + [](KwargDataflowEdge const &e) -> std::unordered_set { + return std::unordered_set{ + e.src.node, + e.dst.node, + }; + }); + + ASSERT(is_subseteq_of(nodes_from_edges, data.nodes)); + + std::unordered_set nodes_from_outputs = transform( + data.outputs, + [](KwargDataflowOutput const &o) -> Node { return o.node; }); + + ASSERT(is_subseteq_of(nodes_from_outputs, data.nodes)); + + std::unordered_set> outputs_from_edges = + transform(data.edges, + [](KwargDataflowEdge const &e) + -> KwargDataflowOutput { return e.src; }); + + ASSERT(is_subseteq_of(outputs_from_edges, data.outputs)); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graphs_are_isomorphic.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graphs_are_isomorphic.h new file mode 100644 index 0000000000..e3ba0c4609 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graphs_are_isomorphic.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_KWARG_DATAFLOW_GRAPHS_ARE_ISOMORPHIC_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_KWARG_DATAFLOW_GRAPHS_ARE_ISOMORPHIC_H + +#include "utils/graph/kwarg_dataflow_graph/algorithms/find_isomorphism_between_kwarg_dataflow_graphs.h" + +namespace FlexFlow { + +template +bool kwarg_dataflow_graphs_are_isomorphic( + KwargDataflowGraphView const &lhs, + KwargDataflowGraphView const &rhs) { + std::optional> found = + find_isomorphism_between_kwarg_dataflow_graphs(lhs, rhs); + + return found.has_value(); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/view_from_kwarg_dataflow_graph_data.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/view_from_kwarg_dataflow_graph_data.h new file mode 100644 index 0000000000..8e6daadd3c --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/view_from_kwarg_dataflow_graph_data.h @@ -0,0 +1,56 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_VIEW_FROM_KWARG_DATAFLOW_GRAPH_DATA_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_VIEW_FROM_KWARG_DATAFLOW_GRAPH_DATA_H + +#include "utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_data.dtg.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_data.h" +#include "utils/graph/kwarg_dataflow_graph/i_kwarg_dataflow_graph_view.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/view_from_open_kwarg_dataflow_graph_data.h" + +namespace FlexFlow { + +template +struct ViewFromKwargDataflowGraphData final + : virtual public IKwargDataflowGraphView { + explicit ViewFromKwargDataflowGraphData( + KwargDataflowGraphData const &data) + : data(data) {} + + std::unordered_set query_nodes(NodeQuery const &query) const override { + return apply_node_query(query, this->data.nodes); + } + + std::unordered_set> query_edges( + KwargDataflowEdgeQuery const &query) const override { + return filter(this->data.edges, [&](KwargDataflowEdge const &e) { + return kwarg_dataflow_edge_query_includes(query, e); + }); + } + + std::unordered_set> query_outputs( + KwargDataflowOutputQuery const &query) const override { + return filter(this->data.outputs, + [&](KwargDataflowOutput const &o) { + return kwarg_dataflow_output_query_includes(query, o); + }); + } + + ViewFromKwargDataflowGraphData *clone() const override { + return new ViewFromKwargDataflowGraphData{this->data}; + } + +private: + KwargDataflowGraphData data; +}; + +template +KwargDataflowGraphView view_from_kwarg_dataflow_graph_data( + KwargDataflowGraphData const &data) { + require_kwarg_dataflow_graph_data_is_valid(data); + + return KwargDataflowGraphView::template create< + ViewFromKwargDataflowGraphData>(data); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h index 9726b0eb34..70edc3d9dd 100644 --- a/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h @@ -32,7 +32,7 @@ struct KwargDataflowGraphView : virtual public DiGraphView { std::is_base_of, T>::value, KwargDataflowGraphView>::type create(Args &&...args) { - return DataflowGraphView(make_cow_ptr(std::forward(args)...)); + return KwargDataflowGraphView(make_cow_ptr(std::forward(args)...)); } protected: diff --git a/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_view_as_dot.h b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_view_as_dot.h new file mode 100644 index 0000000000..2bbd88103e --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_view_as_dot.h @@ -0,0 +1,39 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_LABELLED_KWARG_DATAFLOW_GRAPH_VIEW_AS_DOT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_LABELLED_KWARG_DATAFLOW_GRAPH_VIEW_AS_DOT_H + +#include "utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_as_dot.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph_view.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.h" + +namespace FlexFlow { + +template +std::string labelled_kwarg_dataflow_graph_view_as_dot( + LabelledKwargDataflowGraphView const &g, + std::function const &render_node_label, + std::function const &render_value_label, + std::function const &render_slot_name, + std::function( + std::unordered_set const &)> const &order_slots) { + std::function render_node = + [&](Node const &n) -> nlohmann::json { + return render_node_label(g.at(n)); + }; + + std::function const &)> + render_value = + [&](KwargDataflowOutput const &v) -> nlohmann::json { + return render_value_label(g.at(v)); + }; + + return kwarg_dataflow_graph_as_dot( + static_cast>(g), + render_node, + render_value, + render_slot_name, + order_slots); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.h b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.h deleted file mode 100644 index 1364e9ceb0..0000000000 --- a/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.h +++ /dev/null @@ -1,24 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_VIEW_AS_DOT_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_VIEW_AS_DOT_H - -#include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph_view.h" - -namespace FlexFlow { - -template -std::string labelled_open_kwarg_dataflow_graph_view_as_dot( - LabelledOpenKwargDataflowGraphView const &g, - std::function const &, - std::function const &) { - NOT_IMPLEMENTED(); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/materialize_labelled_kwarg_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/materialize_labelled_kwarg_dataflow_graph_view.h new file mode 100644 index 0000000000..d6fa70fd0f --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_kwarg_dataflow_graph/algorithms/materialize_labelled_kwarg_dataflow_graph_view.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_MATERIALIZE_LABELLED_KWARG_DATAFLOW_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_KWARG_DATAFLOW_GRAPH_ALGORITHMS_MATERIALIZE_LABELLED_KWARG_DATAFLOW_GRAPH_VIEW_H + +#include "utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h" +#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph.h" +namespace FlexFlow { + +template +LabelledKwargDataflowGraph + materialize_labelled_kwarg_dataflow_graph_view( + LabelledKwargDataflowGraphView const + &view) { + return LabelledKwargDataflowGraph:: + template create_copy_of< + UnorderedSetLabelledOpenKwargDataflowGraph>(view); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h deleted file mode 100644 index 6faddcdfcb..0000000000 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h +++ /dev/null @@ -1,29 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_AS_DOT_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_AS_DOT_H - -#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" -#include "utils/graph/open_dataflow_graph/algorithms/as_dot.h" - -namespace FlexFlow { - -template -std::string as_dot( - LabelledOpenDataflowGraphView const &g, - std::function const &get_node_label, - std::function const &get_input_label) { - std::function unlabelled_get_node_label = - [&](Node const &n) -> std::string { return get_node_label(g.at(n)); }; - - std::function - unlabelled_get_input_label = [&](DataflowGraphInput const &i) { - return get_input_label(g.at(OpenDataflowValue{i})); - }; - - return as_dot(static_cast(g), - unlabelled_get_node_label, - unlabelled_get_input_label); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_as_dot.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_as_dot.h new file mode 100644 index 0000000000..b1177c51a0 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_as_dot.h @@ -0,0 +1,38 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_LABELLED_OPEN_DATAFLOW_GRAPH_AS_DOT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_LABELLED_OPEN_DATAFLOW_GRAPH_AS_DOT_H + +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_as_dot.h" + +namespace FlexFlow { + +template +std::string labelled_open_dataflow_graph_as_dot( + LabelledOpenDataflowGraphView const &g, + std::function const &render_node_label, + std::function const &render_value_label, + std::function const + &render_dataflow_graph_input, + std::function const + &render_dataflow_input, + std::function const + &render_dataflow_output) { + std::function render_node = + [&](Node const &n) -> std::string { return render_node_label(g.at(n)); }; + + std::function + render_unlabelled_dataflow_graph_input = + [&](DataflowGraphInput const &i) { + return render_value_label(g.at(OpenDataflowValue{i})); + }; + + return open_dataflow_graph_as_dot(static_cast(g), + render_node, + render_unlabelled_dataflow_graph_input, + render_dataflow_input, + render_dataflow_output); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.h b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.h new file mode 100644 index 0000000000..120833020f --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.h @@ -0,0 +1,44 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_VIEW_AS_DOT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_VIEW_AS_DOT_H + +#include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph_view.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_as_dot.h" + +namespace FlexFlow { + +template +std::string labelled_open_kwarg_dataflow_graph_view_as_dot( + LabelledOpenKwargDataflowGraphView const &g, + std::function const &render_node_label, + std::function const &render_value_label, + std::function const &render_slot_name, + std::function( + std::unordered_set const &)> const &order_slots) { + std::function render_node = + [&](Node const &n) -> nlohmann::json { + return render_node_label(g.at(n)); + }; + + std::function const &)> + render_value = + [&](OpenKwargDataflowValue const &v) + -> nlohmann::json { return render_value_label(g.at(v)); }; + + return open_kwarg_dataflow_graph_as_dot( + static_cast>(g), + render_node, + render_value, + render_slot_name, + order_slots); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/as_dot.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/as_dot.h deleted file mode 100644 index 4c600637aa..0000000000 --- a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/as_dot.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_AS_DOT_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_AS_DOT_H - -#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" - -namespace FlexFlow { - -std::string as_dot(OpenDataflowGraphView const &); -std::string - as_dot(OpenDataflowGraphView const &, - std::function const &get_node_label, - std::function const - &get_input_label); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_as_dot.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_as_dot.h new file mode 100644 index 0000000000..7cdaeb35cb --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_as_dot.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_OPEN_DATAFLOW_GRAPH_AS_DOT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_OPEN_DATAFLOW_GRAPH_AS_DOT_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +std::string open_dataflow_graph_as_dot(OpenDataflowGraphView const &); +std::string open_dataflow_graph_as_dot( + OpenDataflowGraphView const &, + std::function const &render_node, + std::function const + &render_dataflow_graph_input, + std::function const + &render_dataflow_input, + std::function const + &render_dataflow_output); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/find_isomorphisms_between_open_kwarg_dataflow_graphs.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/find_isomorphisms_between_open_kwarg_dataflow_graphs.h index 72eca240ae..32710e75bf 100644 --- a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/find_isomorphisms_between_open_kwarg_dataflow_graphs.h +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/find_isomorphisms_between_open_kwarg_dataflow_graphs.h @@ -11,6 +11,7 @@ #include "utils/graph/digraph/algorithms/get_terminal_nodes.h" #include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_incoming_open_kwarg_dataflow_edges_for_node.h" #include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_unused_open_kwarg_dataflow_graph_inputs.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_as_dot.h" #include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_isomorphism.dtg.h" #include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graphs_are_isomorphic_under.h" #include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.h" @@ -186,7 +187,7 @@ std::optional> dst_incoming_edges = get_incoming_open_kwarg_dataflow_edges_for_node(dst_g, dst_node); - if (src_incoming_edges.size() != dst_incoming_edges.size()) { + if (keys(src_incoming_edges) != keys(dst_incoming_edges)) { fail(); return; } @@ -249,7 +250,13 @@ std::unordered_set> if (found.has_value()) { ASSERT(open_kwarg_dataflow_graphs_are_isomorphic_under( - src, dst, found.value())); + src, dst, found.value()), + fmt::format("src=\n{}\ndst=\n{}\n", + open_kwarg_dataflow_graph_as_dot(src), + open_kwarg_dataflow_graph_as_dot(dst)), + found, + get_open_kwarg_dataflow_graph_data(src), + get_open_kwarg_dataflow_graph_data(dst)); result.insert(found.value()); } diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_incoming_open_kwarg_dataflow_edges_for_node.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_incoming_open_kwarg_dataflow_edges_for_node.h index 661105b3d6..20e078fc52 100644 --- a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_incoming_open_kwarg_dataflow_edges_for_node.h +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_incoming_open_kwarg_dataflow_edges_for_node.h @@ -17,14 +17,14 @@ std::unordered_map> /*input_edge_query=*/ KwargDataflowInputEdgeQuery{ /*srcs=*/query_set::matchall(), - /*dst_nodes=*/query_set{n}, + /*dst_nodes=*/query_set::match_single_value(n), /*dst_slots=*/query_set::matchall(), }, /*standard_edge_query=*/ KwargDataflowEdgeQuery{ /*src_nodes=*/query_set::matchall(), /*src_slots=*/query_set::matchall(), - /*dst_nodes=*/query_set{n}, + /*dst_nodes=*/query_set::match_single_value(n), /*dst_slots=*/query_set::matchall(), }, }; diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_graph_subgraph.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_graph_subgraph.h index d45fdfa640..a6a5391ed6 100644 --- a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_graph_subgraph.h +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_graph_subgraph.h @@ -97,20 +97,20 @@ OpenKwargDataflowGraphData }); OpenKwargDataflowEdgeQuery - subgraph_interior_edges_query = - OpenKwargDataflowEdgeQuery{ - KwargDataflowInputEdgeQuery{ - /*srcs=*/query_set::match_none(), - /*dst_nodes=*/query_set::match_none(), - /*dst_slots=*/query_set::match_none(), - }, - KwargDataflowEdgeQuery{ - /*srcs=*/query_set{subgraph_nodes}, - /*src_slots=*/query_set::matchall(), - /*dsts=*/query_set{subgraph_nodes}, - /*dst_slots=*/query_set::matchall(), - }, - }; + subgraph_interior_edges_query = OpenKwargDataflowEdgeQuery{ + KwargDataflowInputEdgeQuery{ + /*srcs=*/query_set::match_none(), + /*dst_nodes=*/query_set::match_none(), + /*dst_slots=*/query_set::match_none(), + }, + KwargDataflowEdgeQuery{ + /*srcs=*/query_set::match_values_in(set_of(subgraph_nodes)), + /*src_slots=*/query_set::matchall(), + /*dsts=*/query_set::match_values_in(set_of(subgraph_nodes)), + /*dst_slots=*/query_set::matchall(), + }, + }; std::unordered_set> subgraph_interior_edges = g.query_edges(subgraph_interior_edges_query); diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_subgraph_incoming_edges.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_subgraph_incoming_edges.h index c711f79100..975180a12e 100644 --- a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_subgraph_incoming_edges.h +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_subgraph_incoming_edges.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_OPEN_KWARG_DATAFLOW_SUBGRAPH_INCOMING_EDGES_H #include "utils/containers/set_minus.h" +#include "utils/containers/set_of.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h" @@ -13,21 +14,22 @@ std::unordered_set> OpenKwargDataflowGraphView const &g, std::unordered_set const &subgraph) { std::unordered_set all_nodes = get_nodes(g); - query_set src_query = query_set{set_minus(all_nodes, subgraph)}; + query_set src_query = + query_set::match_values_in(set_of(set_minus(all_nodes, subgraph))); OpenKwargDataflowEdgeQuery query = OpenKwargDataflowEdgeQuery{ /*input_edge_query=*/KwargDataflowInputEdgeQuery{ /*srcs=*/query_set::matchall(), - /*dst_nodes=*/query_set{subgraph}, + /*dst_nodes=*/query_set::match_values_in(set_of(subgraph)), /*dst_slots=*/query_set::matchall(), }, /*standard_edge_query=*/ KwargDataflowEdgeQuery{ /*src_nodes=*/src_query, /*src_slots=*/query_set::matchall(), - /*dst_nodes=*/query_set{subgraph}, + /*dst_nodes=*/query_set::match_values_in(set_of(subgraph)), /*dst_slots=*/query_set::matchall(), }, }; diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_value_uses.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_value_uses.h index 94f729abbf..2c80a5ab8d 100644 --- a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_value_uses.h +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_value_uses.h @@ -22,8 +22,9 @@ std::unordered_set> return OpenKwargDataflowEdgeQuery{ kwarg_dataflow_input_edge_query_none(), KwargDataflowEdgeQuery{ - /*src_nodes=*/query_set{o.node}, - /*src_slots=*/query_set{o.slot_name}, + /*src_nodes=*/query_set::match_single_value(o.node), + /*src_slots=*/ + query_set::match_single_value(o.slot_name), /*dst_nodes=*/query_set::matchall(), /*dst_slots=*/query_set::matchall(), }, @@ -32,7 +33,7 @@ std::unordered_set> [&](KwargDataflowGraphInput const &i) { return OpenKwargDataflowEdgeQuery{ KwargDataflowInputEdgeQuery{ - /*srcs=*/query_set{i.name}, + /*srcs=*/query_set::match_single_value(i.name), /*dst_nodes=*/query_set::matchall(), /*dst_slots=*/query_set::matchall(), }, diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_as_dot.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_as_dot.h new file mode 100644 index 0000000000..423f7a9a2c --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_as_dot.h @@ -0,0 +1,140 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_OPEN_KWARG_DATAFLOW_GRAPH_AS_DOT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_OPEN_KWARG_DATAFLOW_GRAPH_AS_DOT_H + +#include "utils/containers/filtrans.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_as_dot.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_as_dot.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/view_as_closed_kwarg_dataflow_graph_by_materializing_inputs.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_value.dtg.h" + +namespace FlexFlow { + +template +std::string open_kwarg_dataflow_graph_as_dot( + OpenKwargDataflowGraphView const &g) { + std::function render_node = [](Node const &n) { + nlohmann::json j = fmt::to_string(n); + return j; + }; + + std::function const &)> + render_value = + [](OpenKwargDataflowValue const &v) { + nlohmann::json j = fmt::to_string(v); + return j; + }; + + std::function render_slot_name = + [](SlotName const &s) { + nlohmann::json j = fmt::to_string(s); + return j; + }; + + std::function(std::unordered_set const &)> + order_slots = [](std::unordered_set const &unordered) { + return sorted(unordered); + }; + + return open_kwarg_dataflow_graph_as_dot( + g, render_node, render_value, render_slot_name, order_slots); +} + +template +std::string open_kwarg_dataflow_graph_as_dot( + OpenKwargDataflowGraphView const &g, + std::function const &render_node, + std::function const &)> const + &render_value, + std::function const &render_slot_name, + std::function( + std::unordered_set const &)> const &order_slots) { + std::pair>, + bidict, Node>> + closed_g_and_mapping = + view_as_closed_kwarg_dataflow_graph_by_materializing_inputs(g); + + KwargDataflowGraphView> closed_g = + closed_g_and_mapping.first; + bidict, Node> closed_mapping = + closed_g_and_mapping.second; + + std::function closed_render_node = + [&](Node const &n) -> nlohmann::json { + if (closed_mapping.contains_r(n)) { + return render_value(OpenKwargDataflowValue{ + closed_mapping.at_r(n)}); + } else { + return render_node(n); + } + }; + + std::function> const &)> + closed_render_value = + [&](KwargDataflowOutput> const &o) + -> nlohmann::json { + if (closed_mapping.contains_r(o.node)) { + ASSERT(!o.slot_name.has_value()); + + nlohmann::json j = "graph_input"; + return j; + } else { + KwargDataflowOutput open_o = KwargDataflowOutput{ + /*node=*/o.node, + /*slot_name=*/assert_unwrap(o.slot_name), + }; + + return render_value( + OpenKwargDataflowValue{open_o}); + } + }; + + std::function const &)> + closed_render_slot_name = + [&](std::optional const &s) -> nlohmann::json { + if (s.has_value()) { + return render_slot_name(s.value()); + } else { + nlohmann::json j = "graph_input"; + return j; + } + }; + + std::function>( + std::unordered_set> const &)> + closed_order_slots = + [&](std::unordered_set> const &unsorted) + -> std::vector> { + std::unordered_set not_nullopt = filtrans( + unsorted, + [](std::optional const &s) -> std::optional { + return s; + }); + + std::vector sorted_not_nullopt = order_slots(not_nullopt); + + std::vector> result = transform( + sorted_not_nullopt, + [](SlotName const &s) -> std::optional { return s; }); + + if (contains(unsorted, std::nullopt)) { + result.push_back(std::nullopt); + } + + return result; + }; + + return kwarg_dataflow_graph_as_dot( + /*g=*/closed_g, + /*render_node=*/closed_render_node, + /*render_value=*/closed_render_value, + /*render_slot_name=*/closed_render_slot_name, + /*order_slots=*/closed_order_slots); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.h index 46a3577a06..c7328a8f2a 100644 --- a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.h +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.h @@ -4,6 +4,8 @@ #include "utils/containers/filtrans.h" #include "utils/containers/is_subseteq_of.h" #include "utils/containers/transform.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_data.dtg.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_data.h" #include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.dtg.h" namespace FlexFlow { @@ -25,6 +27,26 @@ void require_open_kwarg_dataflow_graph_data_is_valid( }); ASSERT(is_subseteq_of(inputs_from_edges, data.inputs)); + + require_kwarg_dataflow_graph_data_is_valid( + kwarg_dataflow_graph_data_from_open(data)); +} + +template +KwargDataflowGraphData kwarg_dataflow_graph_data_from_open( + OpenKwargDataflowGraphData const &open_data) { + + return KwargDataflowGraphData{ + /*nodes=*/open_data.nodes, + /*edges=*/ + filtrans( + open_data.edges, + [](OpenKwargDataflowEdge const &open_edge) + -> std::optional> { + return open_edge.try_require_internal_edge(); + }), + /*outputs=*/open_data.outputs, + }; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/view_as_closed_kwarg_dataflow_graph_by_materializing_inputs.h b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/view_as_closed_kwarg_dataflow_graph_by_materializing_inputs.h new file mode 100644 index 0000000000..1626c5833b --- /dev/null +++ b/lib/utils/include/utils/graph/open_kwarg_dataflow_graph/algorithms/view_as_closed_kwarg_dataflow_graph_by_materializing_inputs.h @@ -0,0 +1,116 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_VIEW_AS_CLOSED_KWARG_DATAFLOW_GRAPH_BY_MATERIALIZING_INPUTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_VIEW_AS_CLOSED_KWARG_DATAFLOW_GRAPH_BY_MATERIALIZING_INPUTS_H + +#include "utils/bidict/algorithms/right_entries.h" +#include "utils/bidict/generate_bidict.h" +#include "utils/containers/set_union.h" +#include "utils/containers/transform.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_edges.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/view_from_kwarg_dataflow_graph_data.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge.dtg.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" +#include "utils/graph/node/node_source.h" +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_open_kwarg_dataflow_graph_data.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph_view.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +std::pair>, + bidict, Node>> + view_as_closed_kwarg_dataflow_graph_by_materializing_inputs( + OpenKwargDataflowGraphView const &g) { + OpenKwargDataflowGraphData open_g_data = + get_open_kwarg_dataflow_graph_data(g); + + NodeSource n; + + bidict, Node> graph_input_nodes = + generate_bidict( + open_g_data.inputs, + [&](KwargDataflowGraphInput const &) -> Node { + return n.new_node(); + }); + + auto kwarg_dataflow_output_for_graph_input = + [&](KwargDataflowGraphInput const &i) + -> KwargDataflowOutput> { + return KwargDataflowOutput>{ + /*node=*/graph_input_nodes.at_l(i), + /*slot_name=*/std::nullopt, + }; + }; + + auto convert_kwarg_dataflow_output = + [&](KwargDataflowOutput const &o) + -> KwargDataflowOutput> { + return KwargDataflowOutput>{ + /*node=*/o.node, + /*slot_name=*/o.slot_name, + }; + }; + + auto convert_kwarg_dataflow_input = [&](KwargDataflowInput const &i) + -> KwargDataflowInput> { + return KwargDataflowInput>{ + /*node=*/i.node, + /*slot_name=*/i.slot_name, + }; + }; + + auto convert_standard_edge = [&](KwargDataflowEdge const &e) + -> KwargDataflowEdge> { + return KwargDataflowEdge>{ + /*src=*/convert_kwarg_dataflow_output(e.src), + /*dst=*/convert_kwarg_dataflow_input(e.dst), + }; + }; + + auto convert_input_edge = + [&](KwargDataflowInputEdge const &e) + -> KwargDataflowEdge> { + return KwargDataflowEdge>{ + /*src=*/kwarg_dataflow_output_for_graph_input(e.src), + /*dst=*/convert_kwarg_dataflow_input(e.dst), + }; + }; + + auto convert_edge = + [&](OpenKwargDataflowEdge const &open_edge) + -> KwargDataflowEdge> { + return open_edge.template visit>>( + overload{ + convert_standard_edge, + convert_input_edge, + }); + }; + + KwargDataflowGraphData> closed_g_data = + KwargDataflowGraphData>{ + /*nodes=*/set_union(open_g_data.nodes, + right_entries(graph_input_nodes)), + /*edges=*/transform(open_g_data.edges, convert_edge), + /*outputs=*/ + set_union( + transform(open_g_data.outputs, convert_kwarg_dataflow_output), + transform(open_g_data.inputs, + kwarg_dataflow_output_for_graph_input)), + }; + + ASSERT(closed_g_data.edges.size() == open_g_data.edges.size()); + + KwargDataflowGraphView> closed_g = + view_from_kwarg_dataflow_graph_data(closed_g_data); + + ASSERT(closed_g_data.edges == get_all_kwarg_dataflow_edges(closed_g)); + + return std::pair{ + closed_g, + graph_input_nodes, + }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/query_set.h b/lib/utils/include/utils/graph/query_set.h index 4dc1f037e4..c4234888cb 100644 --- a/lib/utils/include/utils/graph/query_set.h +++ b/lib/utils/include/utils/graph/query_set.h @@ -6,6 +6,7 @@ #include "utils/containers/filter.h" #include "utils/containers/filter_keys.h" #include "utils/containers/intersection.h" +#include "utils/containers/set_of.h" #include "utils/containers/set_union.h" #include "utils/containers/transform.h" #include "utils/containers/unordered_set_of.h" @@ -23,18 +24,33 @@ namespace FlexFlow { template struct query_set { query_set() = delete; - query_set(T const &t) : query(std::set{t}) {} - query_set(std::unordered_set const &query) - : query(std::set{query.cbegin(), query.cend()}) {} + static query_set matchall() { + std::optional> query_val = std::nullopt; + return query_set{ + query_val, + }; + } - query_set(std::optional> const &query) - : query(transform(query, [](std::unordered_set const &s) { - return std::set{s.cbegin(), s.cend()}; - })) {} + static query_set match_none() { + std::set to_match = {}; + + return query_set{ + std::optional>{to_match}, + }; + } - query_set(std::initializer_list const &l) - : query_set(std::unordered_set{l}) {} + static query_set match_values_in(std::set const &values) { + return query_set{ + std::optional>{values}, + }; + } + + static query_set match_single_value(T const &val) { + std::set vals = {val}; + + return query_set::match_values_in(vals); + } friend bool operator==(query_set const &lhs, query_set const &rhs) { return lhs.query == rhs.query; @@ -58,18 +74,13 @@ struct query_set { return std::unordered_set{query_value.begin(), query_value.end()}; } - static query_set matchall() { - return {std::nullopt}; - } - - static query_set match_none() { - return {std::unordered_set{}}; - } - std::optional> const &value() const { return this->query; } +private: + explicit query_set(std::optional> const &query) : query(query) {} + private: std::optional> query; }; @@ -134,7 +145,8 @@ query_set query_intersection(query_set const &lhs, } else if (is_matchall(rhs)) { return lhs; } else { - return intersection(allowed_values(lhs), allowed_values(rhs)); + return query_set::match_values_in( + set_of(intersection(allowed_values(lhs), allowed_values(rhs)))); } } @@ -143,7 +155,8 @@ query_set query_union(query_set const &lhs, query_set const &rhs) { if (is_matchall(lhs) || is_matchall(rhs)) { return query_set::matchall(); } else { - return set_union(allowed_values(lhs), allowed_values(rhs)); + return query_set::match_values_in( + set_of(set_union(allowed_values(lhs), allowed_values(rhs)))); } } diff --git a/lib/utils/include/utils/nonempty_unordered_set/nonempty_unordered_set.h b/lib/utils/include/utils/nonempty_unordered_set/nonempty_unordered_set.h new file mode 100644 index 0000000000..2bc070fb5e --- /dev/null +++ b/lib/utils/include/utils/nonempty_unordered_set/nonempty_unordered_set.h @@ -0,0 +1,115 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_NONEMPTY_UNORDERED_SET_NONEMPTY_UNORDERED_SET_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_NONEMPTY_UNORDERED_SET_NONEMPTY_UNORDERED_SET_H + +#include "utils/fmt/unordered_set.h" +#include "utils/hash-utils.h" +#include "utils/hash/unordered_set.h" +#include "utils/positive_int/positive_int.h" +#include +#include + +namespace FlexFlow { + +template +struct nonempty_unordered_set { +public: + nonempty_unordered_set() = delete; + + nonempty_unordered_set(std::initializer_list const &vs) : raw(vs) { + ASSERT(this->raw.size() > 0); + } + + explicit nonempty_unordered_set(std::unordered_set const &s) : raw(s) { + ASSERT(this->raw.size() > 0); + } + + bool operator==(nonempty_unordered_set const &other) const { + return this->unwrap_as_unordered_set() == other.unwrap_as_unordered_set(); + } + + bool operator!=(nonempty_unordered_set const &other) const { + return this->unwrap_as_unordered_set() != other.unwrap_as_unordered_set(); + } + + bool operator==(std::unordered_set const &other) const { + return this->unwrap_as_unordered_set() == other; + } + + bool operator!=(std::unordered_set const &other) const { + return this->unwrap_as_unordered_set() != other; + } + + void insert(T const &t) { + this->raw.insert(t); + } + + size_t size() const { + return this->raw.size(); + }; + + positive_int num_elements() const { + return positive_int{this->raw.size()}; + }; + + std::unordered_set const &unwrap_as_unordered_set() const { + return this->raw; + } + + using value_type = T; + + typename std::unordered_set::const_iterator begin() const { + return this->raw.cbegin(); + } + + typename std::unordered_set::const_iterator cbegin() const { + return this->raw.cbegin(); + } + + typename std::unordered_set::const_iterator end() const { + return this->raw.cend(); + } + + typename std::unordered_set::const_iterator cend() const { + return this->raw.cend(); + } + +private: + std::unordered_set raw; +}; + +template +bool operator==(std::unordered_set const &lhs, + nonempty_unordered_set const &rhs) { + return lhs == rhs.unwrap_as_unordered_set(); +} + +template +bool operator!=(std::unordered_set const &lhs, + nonempty_unordered_set const &rhs) { + return lhs != rhs.unwrap_as_unordered_set(); +} + +template +std::unordered_set format_as(nonempty_unordered_set const &s) { + return s.unwrap_as_unordered_set(); +} + +template +std::ostream &operator<<(std::ostream &s, nonempty_unordered_set const &m) { + return (s << fmt::to_string(m)); +} + +} // namespace FlexFlow + +namespace std { + +template +struct hash<::FlexFlow::nonempty_unordered_set> { + size_t operator()(::FlexFlow::nonempty_unordered_set const &x) const { + return ::FlexFlow::get_std_hash(x.unwrap_as_unordered_set()); + }; +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/one_to_many/one_to_many.h b/lib/utils/include/utils/one_to_many/one_to_many.h index 798ae2fb87..30d84d34c3 100644 --- a/lib/utils/include/utils/one_to_many/one_to_many.h +++ b/lib/utils/include/utils/one_to_many/one_to_many.h @@ -2,7 +2,9 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_H #include "utils/containers/generate_map.h" +#include "utils/containers/items.h" #include "utils/containers/keys.h" +#include "utils/containers/transform.h" #include "utils/containers/try_at.h" #include "utils/containers/unordered_set_of.h" #include "utils/containers/values.h" @@ -15,6 +17,7 @@ #include "utils/hash/unordered_set.h" #include "utils/json/check_is_json_deserializable.h" #include "utils/json/check_is_json_serializable.h" +#include "utils/nonempty_unordered_set/nonempty_unordered_set.h" #include #include #include @@ -58,7 +61,12 @@ struct OneToMany { if (!found_l.has_value()) { this->m_r_to_l.insert({r, l}); - this->m_l_to_r[l].insert(r); + + if (contains_key(this->m_l_to_r, l)) { + this->m_l_to_r.at(l).insert(r); + } else { + this->m_l_to_r.insert({l, nonempty_unordered_set{{r}}}); + } } else if (found_l.value() == l) { return; } else { @@ -71,7 +79,14 @@ struct OneToMany { } } - std::unordered_set const &at_l(L const &l) const { + std::unordered_set> relation() const { + return transform(items(this->m_r_to_l), + [](std::pair const &p) -> std::pair { + return {p.second, p.first}; + }); + } + + nonempty_unordered_set const &at_l(L const &l) const { return this->m_l_to_r.at(l); } @@ -87,11 +102,11 @@ struct OneToMany { return keys(this->m_r_to_l); } - std::unordered_set> right_groups() const { + std::unordered_set> right_groups() const { return unordered_set_of(values(this->m_l_to_r)); } - std::unordered_map> const &l_to_r() const { + std::unordered_map> const &l_to_r() const { return this->m_l_to_r; } @@ -100,7 +115,7 @@ struct OneToMany { } private: - std::unordered_map> m_l_to_r; + std::unordered_map> m_l_to_r; std::unordered_map m_r_to_l; private: @@ -113,7 +128,7 @@ struct OneToMany { }; template -std::unordered_map> +std::unordered_map> format_as(OneToMany const &m) { return generate_map(m.left_values(), [&](L const &l) { return m.at_l(l); }); } diff --git a/lib/utils/include/utils/one_to_many/one_to_many_transform_values.h b/lib/utils/include/utils/one_to_many/one_to_many_transform_values.h new file mode 100644 index 0000000000..a9afe98988 --- /dev/null +++ b/lib/utils/include/utils/one_to_many/one_to_many_transform_values.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_TRANSFORM_VALUES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_TRANSFORM_VALUES_H + +#include "utils/containers/transform.h" +#include "utils/one_to_many/one_to_many_from_unstructured_relation.h" + +namespace FlexFlow { + +template > +OneToMany one_to_many_transform_values(OneToMany const &input, + F f) { + return one_to_many_from_unstructured_relation(transform( + input.relation(), [&](std::pair const &p) -> std::pair { + return {p.first, f(p.second)}; + })); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orientation.dtg.toml b/lib/utils/include/utils/orientation.dtg.toml new file mode 100644 index 0000000000..e19b5b53b5 --- /dev/null +++ b/lib/utils/include/utils/orientation.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "Orientation" +type = "enum" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "HORIZONTAL" + +[[values]] +name = "VERTICAL" diff --git a/lib/utils/include/utils/orthotope/up_projection.h b/lib/utils/include/utils/orthotope/up_projection.h index 7fa7c0339c..e485419fbb 100644 --- a/lib/utils/include/utils/orthotope/up_projection.h +++ b/lib/utils/include/utils/orthotope/up_projection.h @@ -52,7 +52,8 @@ DimCoord compute_up_projection(UpProjection const &projection, flatmap(coord.raw, [&](L const &input_dim, nonnegative_int input_dim_val) { std::unordered_set dst_dims = - projection.dim_mapping.at_l(input_dim); + projection.dim_mapping.at_l(input_dim) + .unwrap_as_unordered_set(); DimDomain dst_domain = restrict_domain_to_dims(output_domain, dst_dims); diff --git a/lib/utils/include/utils/record_formatter.h b/lib/utils/include/utils/record_formatter.h index c1dab45e67..9d2527ab3f 100644 --- a/lib/utils/include/utils/record_formatter.h +++ b/lib/utils/include/utils/record_formatter.h @@ -1,10 +1,29 @@ -#ifndef _RECORD_FORMATTER_H -#define _RECORD_FORMATTER_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_RECORD_FORMATTER_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_RECORD_FORMATTER_H +#include "utils/containers/keys.h" +#include "utils/containers/sorted.h" +#include "utils/orientation.dtg.h" +#include #include #include +namespace FlexFlow { + +/** + * \brief Helper interface for generating + * DOT/graphviz records. + * + * \note This is very old code and should not be emulated stylistically. + * + * \see \ref DotFile + * \see \ref mk_empty_record + */ class RecordFormatter { +public: + RecordFormatter() = delete; + explicit RecordFormatter(Orientation, std::vector const &pieces); + friend RecordFormatter &operator<<(RecordFormatter &r, std::string const &tok); friend RecordFormatter &operator<<(RecordFormatter &r, int tok); @@ -15,8 +34,45 @@ class RecordFormatter { std::ostringstream &oss); friend std::ostream &operator<<(std::ostream &s, RecordFormatter const &r); -private: +public: + Orientation orientation; std::vector pieces; }; -#endif // _RECORD_FORMATTER_H +RecordFormatter mk_empty_record(Orientation); + +template +RecordFormatter mk_kv_record(std::string const &k, T const &v) { + RecordFormatter rr = mk_empty_record(Orientation::HORIZONTAL); + rr << k << fmt::to_string(v); + return rr; +} + +template <> +RecordFormatter mk_kv_record(std::string const &, RecordFormatter const &); + +template +RecordFormatter mk_kv_record(std::string const &k, std::optional const &v) { + if (v.has_value()) { + return mk_kv_record(k, v.value()); + } else { + RecordFormatter rr = mk_empty_record(Orientation::HORIZONTAL); + rr << k << "(none)"; + return rr; + } +} + +template +RecordFormatter mk_record_for_map(std::unordered_map const &m) { + RecordFormatter result = mk_empty_record(Orientation::VERTICAL); + + for (K const &k : sorted(keys(m))) { + result << mk_kv_record(fmt::to_string(k), m.at(k)); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/src/utils/dot/dot_file.cc b/lib/utils/src/utils/dot/dot_file.cc new file mode 100644 index 0000000000..e93e47f1e3 --- /dev/null +++ b/lib/utils/src/utils/dot/dot_file.cc @@ -0,0 +1,10 @@ +#include "utils/dot/dot_file.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using T = ordered_value_type<0>; + +template class DotFile; + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/dot/dot_html_from_json.cc b/lib/utils/src/utils/dot/dot_html_from_json.cc new file mode 100644 index 0000000000..476ac59749 --- /dev/null +++ b/lib/utils/src/utils/dot/dot_html_from_json.cc @@ -0,0 +1,137 @@ +#include "utils/dot/dot_html_from_json.h" +#include + +namespace FlexFlow { + +DotHtmlTable dot_html_table_from_json(nlohmann::json const &j) { + auto mk_table_from_rows = + [](std::vector const &rows) -> DotHtmlTable { + return DotHtmlTable{ + /*border=*/0_n, + /*cellborder=*/1_n, + /*cellspacing=*/0_n, + /*rows=*/rows, + }; + }; + + auto mk_singleton_table = [&](std::string const &s) -> DotHtmlTable { + return mk_table_from_rows({ + DotHtmlTableRow{ + /*cells=*/{ + DotHtmlTableCell{ + /*content=*/DotHtmlTableCellContents{ + s, + }, + /*port=*/std::nullopt, + /*colspan=*/std::nullopt, + }, + }, + }, + }); + }; + + auto mk_singleton_row = [&](nlohmann::json const &v) -> DotHtmlTableRow { + return DotHtmlTableRow{ + /*cells=*/{ + dot_html_cell_from_json(v), + }, + }; + }; + + auto mk_kv_row = [&](std::string const &k, + nlohmann::json const &v) -> DotHtmlTableRow { + return DotHtmlTableRow{ + /*cells=*/{ + DotHtmlTableCell{ + /*content=*/DotHtmlTableCellContents{ + k, + }, + /*port=*/std::nullopt, + /*colspan=*/std::nullopt, + }, + dot_html_cell_from_json(v), + }, + }; + }; + + switch (j.type()) { + case nlohmann::json::value_t::null: + return mk_singleton_table("(none)"); + case nlohmann::json::value_t::boolean: + return mk_singleton_table(fmt::to_string(j.get())); + case nlohmann::json::value_t::string: + return mk_singleton_table(j.get()); + case nlohmann::json::value_t::number_integer: + return mk_singleton_table(fmt::to_string(j.get())); + case nlohmann::json::value_t::number_unsigned: + return mk_singleton_table(fmt::to_string(j.get())); + case nlohmann::json::value_t::number_float: + return mk_singleton_table(fmt::to_string(j.get())); + case nlohmann::json::value_t::object: { + std::vector rows; + for (auto const &[k, v] : j.items()) { + rows.push_back(mk_kv_row(k, v)); + } + return mk_table_from_rows(rows); + } + case nlohmann::json::value_t::array: { + std::vector rows; + for (auto const &v : j) { + rows.push_back(mk_singleton_row(v)); + } + return mk_table_from_rows(rows); + } + + case nlohmann::json::value_t::binary: + case nlohmann::json::value_t::discarded: + default: + PANIC("Unhandled value_t", j.type()); + } +} + +DotHtmlTableCell dot_html_cell_from_json(nlohmann::json const &j) { + + auto mk_cell_from_string = [](std::string const &s) -> DotHtmlTableCell { + return DotHtmlTableCell{ + DotHtmlTableCellContents{ + s, + }, + /*port=*/std::nullopt, + /*colspan=*/std::nullopt, + }; + }; + + auto mk_cell_from_table = [](DotHtmlTable const &t) -> DotHtmlTableCell { + return DotHtmlTableCell{ + DotHtmlTableCellContents{ + t, + }, + /*port=*/std::nullopt, + /*colspan=*/std::nullopt, + }; + }; + + switch (j.type()) { + case nlohmann::json::value_t::null: + return mk_cell_from_string("(none)"); + case nlohmann::json::value_t::boolean: + return mk_cell_from_string(fmt::to_string(j.get())); + case nlohmann::json::value_t::string: + return mk_cell_from_string(j.get()); + case nlohmann::json::value_t::number_integer: + return mk_cell_from_string(fmt::to_string(j.get())); + case nlohmann::json::value_t::number_unsigned: + return mk_cell_from_string(fmt::to_string(j.get())); + case nlohmann::json::value_t::number_float: + return mk_cell_from_string(fmt::to_string(j.get())); + case nlohmann::json::value_t::object: + case nlohmann::json::value_t::array: + return mk_cell_from_table(dot_html_table_from_json(j)); + case nlohmann::json::value_t::binary: + case nlohmann::json::value_t::discarded: + default: + PANIC("Unhandled value_t", j.type()); + } +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/dot/dot_html_table_cell_contents.cc b/lib/utils/src/utils/dot/dot_html_table_cell_contents.cc new file mode 100644 index 0000000000..aba3c5de65 --- /dev/null +++ b/lib/utils/src/utils/dot/dot_html_table_cell_contents.cc @@ -0,0 +1,62 @@ +#include "utils/dot/dot_html_table_cell_contents.h" +#include "utils/dot/dot_html_table.dtg.h" + +namespace FlexFlow { + +DotHtmlTableCellContents::DotHtmlTableCellContents(std::string const &s) + : value(s) {} + +DotHtmlTableCellContents::DotHtmlTableCellContents(DotHtmlTable const &t) + : value(std::make_shared(t)) {} + +DotHtmlTable const &DotHtmlTableCellContents::require_nested() const { + return *std::get>(this->value); +} + +std::string const &DotHtmlTableCellContents::require_simple() const { + return std::get(this->value); +} + +bool DotHtmlTableCellContents::is_simple() const { + return std::holds_alternative(this->value); +} + +bool DotHtmlTableCellContents::is_nested() const { + return std::holds_alternative>(this->value); +} + +bool DotHtmlTableCellContents::operator==( + DotHtmlTableCellContents const &other) const { + if (this->is_simple() && other.is_simple()) { + return this->require_simple() == other.require_simple(); + } else if (this->is_nested() && other.is_nested()) { + return this->require_nested() == other.require_nested(); + } else { + return false; + } +} + +bool DotHtmlTableCellContents::operator!=( + DotHtmlTableCellContents const &other) const { + if (this->is_simple() && other.is_simple()) { + return this->require_simple() != other.require_simple(); + } else if (this->is_nested() && other.is_nested()) { + return this->require_nested() != other.require_nested(); + } else { + return true; + } +} + +std::string format_as(DotHtmlTableCellContents const &c) { + if (c.is_simple()) { + return c.require_simple(); + } else { + return fmt::to_string(c.require_nested()); + } +} + +std::ostream &operator<<(std::ostream &s, DotHtmlTableCellContents const &x) { + return (s << fmt::to_string(x)); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/dot/render_dot_html_table_to_string.cc b/lib/utils/src/utils/dot/render_dot_html_table_to_string.cc new file mode 100644 index 0000000000..135542f3b0 --- /dev/null +++ b/lib/utils/src/utils/dot/render_dot_html_table_to_string.cc @@ -0,0 +1,66 @@ +#include "utils/dot/render_dot_html_table_to_string.h" +#include "utils/containers/flatmap.h" +#include "utils/containers/transform.h" +#include "utils/join_strings.h" + +namespace FlexFlow { + +static std::string escape_html_string(std::string const &s) { + auto escape_dot_char = [](char c) -> std::string { + switch (c) { + case '<': + return std::string{"<"}; + case '>': + return std::string{">"}; + default: + return std::string{c}; + } + }; + + return flatmap(s, escape_dot_char); +} + +static std::string render_dot_html_cell_contents_to_string( + DotHtmlTableCellContents const &cell_contents) { + if (cell_contents.is_simple()) { + return escape_html_string(cell_contents.require_simple()); + } else { + return render_dot_html_table_to_string(cell_contents.require_nested()); + } +} + +static std::string + render_dot_html_cell_to_string(DotHtmlTableCell const &cell) { + std::ostringstream oss; + + oss << "" << render_dot_html_cell_contents_to_string(cell.content) + << ""; + + return oss.str(); +} + +static std::string render_dot_html_row_to_string(DotHtmlTableRow const &row) { + return fmt::format( + "{}", + join_strings(transform(row.cells, render_dot_html_cell_to_string), + std::string{"\n"})); +} + +std::string render_dot_html_table_to_string(DotHtmlTable const &table) { + return fmt::format( + "{}
", + table.border, + table.cellborder, + table.cellspacing, + join_strings(transform(table.rows, render_dot_html_row_to_string), + std::string{"\n"})); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/dot_file.cc b/lib/utils/src/utils/dot_file.cc deleted file mode 100644 index 297ba484e3..0000000000 --- a/lib/utils/src/utils/dot_file.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/dot_file.h" diff --git a/lib/utils/src/utils/graph/algorithms.cc b/lib/utils/src/utils/graph/algorithms.cc index ae93fa38b5..f09db77282 100644 --- a/lib/utils/src/utils/graph/algorithms.cc +++ b/lib/utils/src/utils/graph/algorithms.cc @@ -4,6 +4,7 @@ #include "utils/containers/intersection.h" #include "utils/containers/restrict_keys.h" #include "utils/containers/set_difference.h" +#include "utils/containers/set_of.h" #include "utils/containers/transform.h" #include "utils/containers/unordered_set_of.h" #include "utils/containers/values.h" @@ -57,7 +58,11 @@ struct GetNodesFunctor { std::unordered_set query_nodes(GraphView const &g, std::unordered_set const &nodes) { - return g.query_nodes(NodeQuery{nodes}); + NodeQuery query = NodeQuery{ + query_set::match_values_in(set_of(nodes)), + }; + + return g.query_nodes(query); } void remove_node(DiGraph &g, Node const &n) { @@ -131,12 +136,17 @@ void add_edges(UndirectedGraph &g, } bool contains_edge(DiGraphView const &g, DirectedEdge const &e) { - return contains(g.query_edges(DirectedEdgeQuery{e.src, e.dst}), e); + DirectedEdgeQuery query = DirectedEdgeQuery{ + query_set::match_single_value(e.src), + query_set::match_single_value(e.dst), + }; + + return contains(g.query_edges(query), e); } bool contains_edge(UndirectedGraphView const &g, UndirectedEdge const &e) { - UndirectedEdgeQuery q = - UndirectedEdgeQuery{{e.endpoints.max(), e.endpoints.min()}}; + UndirectedEdgeQuery q = UndirectedEdgeQuery{ + query_set::match_values_in({e.endpoints.max(), e.endpoints.min()})}; return contains(g.query_edges(q), e); } @@ -159,7 +169,11 @@ void remove_edges(UndirectedGraph &g, std::unordered_set get_node_edges(UndirectedGraphView const &g, Node const &n) { - return g.query_edges(UndirectedEdgeQuery{n}); + UndirectedEdgeQuery query = UndirectedEdgeQuery{ + query_set::match_single_value(n), + }; + + return g.query_edges(query); } std::vector get_unchecked_dfs_ordering( diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms.cc index 7069146057..072104ae35 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/algorithms.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms.cc @@ -26,7 +26,7 @@ std::vector get_dataflow_inputs(DataflowGraphView const &g, std::vector get_outputs(DataflowGraphView const &g, Node const &n) { return sorted_by(g.query_outputs(DataflowOutputQuery{ - query_set{n}, + query_set::match_single_value(n), query_set::matchall(), }), [](DataflowOutput const &l, DataflowOutput const &r) { diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/as_dot.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/as_dot.cc deleted file mode 100644 index 2ae903fa0b..0000000000 --- a/lib/utils/src/utils/graph/dataflow_graph/algorithms/as_dot.cc +++ /dev/null @@ -1,77 +0,0 @@ -#include "utils/graph/dataflow_graph/algorithms/as_dot.h" -#include "utils/containers/generate_map.h" -#include "utils/containers/map_keys.h" -#include "utils/dot_file.h" -#include "utils/graph/dataflow_graph/algorithms.h" -#include "utils/graph/dataflow_graph/algorithms/view_as_open_dataflow_graph.h" -#include "utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h" -#include "utils/graph/node/algorithms.h" -#include "utils/graph/render_dot.h" -#include "utils/record_formatter.h" - -namespace FlexFlow { - -std::string as_dot(DataflowGraphView const &g) { - auto get_node_attrs = [](Node const &) { - return std::unordered_map{}; - }; - - std::unordered_map> - node_labels = generate_map(get_nodes(g), get_node_attrs); - - auto get_output_label = [](DataflowOutput const &o) { - return fmt::to_string(o.idx); - }; - - std::unordered_map output_labels = - generate_map(get_all_dataflow_outputs(g), get_output_label); - std::unordered_map value_labels = - map_keys(output_labels, - [](DataflowOutput const &o) { return OpenDataflowValue{o}; }); - - return render_dot(with_labelling( - view_as_open_dataflow_graph(g), node_labels, value_labels)); -} - -void as_dot(DotFile &dot, - DataflowGraphView const &g, - std::function const &get_node_label) { - auto get_node_name = [](Node n) { return fmt::format("n{}", n.raw_uid); }; - - auto get_input_field = [](nonnegative_int idx) { - return fmt::format("i{}", idx); - }; - - auto get_output_field = [](nonnegative_int idx) { - return fmt::format("o{}", idx); - }; - - for (Node const &n : get_nodes(g)) { - std::vector n_inputs = get_dataflow_inputs(g, n); - std::vector n_outputs = get_outputs(g, n); - - RecordFormatter inputs_record; - for (DataflowInput const &i : n_inputs) { - inputs_record << fmt::format("<{}>{}", get_input_field(i.idx), i.idx); - } - - RecordFormatter outputs_record; - for (DataflowOutput const &o : n_outputs) { - outputs_record << fmt::format("<{}>{}", get_output_field(o.idx), o.idx); - } - - RecordFormatter rec; - rec << inputs_record << get_node_label(n) << outputs_record; - - dot.add_record_node(get_node_name(n), rec); - } - - for (DataflowEdge const &e : get_edges(g)) { - dot.add_edge(get_node_name(e.src.node), - get_node_name(e.dst.node), - get_output_field(e.src.idx), - get_input_field(e.dst.idx)); - } -} - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.cc new file mode 100644 index 0000000000..f617c52593 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.cc @@ -0,0 +1,152 @@ +#include "utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/map_keys.h" +#include "utils/dot/dot_file.h" +#include "utils/dot/dot_html_from_json.h" +#include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/dataflow_graph/algorithms/view_as_open_dataflow_graph.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/render_dot.h" +#include "utils/nonnegative_int/num_elements.h" +#include "utils/record_formatter.h" + +namespace FlexFlow { + +std::string dataflow_graph_as_dot( + DataflowGraphView const &g, + std::optional> const + &get_node_label, + std::optional> const + &get_input_label, + std::optional> const + &get_output_label) { + std::ostringstream oss; + DotFile dot{oss}; + + dataflow_graph_as_dot( + dot, g, get_node_label, get_input_label, get_output_label); + + dot.close(); + + return oss.str(); +} + +void dataflow_graph_as_dot( + DotFile &dot, + DataflowGraphView const &g, + std::optional> const + &get_node_label, + std::optional> const + &get_input_label, + std::optional> const + &get_output_label) { + + auto get_node_name = [](Node n) { return fmt::format("n{}", n.raw_uid); }; + + std::function resolved_get_node_label = + get_node_label.value_or(get_node_name); + + std::function + resolved_get_input_label = get_input_label.value_or( + [](DataflowInput const &i) { return fmt::to_string(i.idx); }); + + std::function + resolved_get_output_label = get_output_label.value_or( + [](DataflowOutput const &o) { return fmt::to_string(o.idx); }); + + auto get_input_field = [](nonnegative_int idx) { + return fmt::format("i{}", idx); + }; + + auto get_output_field = [](nonnegative_int idx) { + return fmt::format("o{}", idx); + }; + + for (Node const &n : get_nodes(g)) { + std::vector n_inputs = get_dataflow_inputs(g, n); + std::vector n_outputs = get_outputs(g, n); + + auto make_io_cell = [](nlohmann::json const &j, + std::string const &port, + positive_int colspan) -> DotHtmlTableCell { + DotHtmlTableCell cell = dot_html_cell_from_json(j); + cell.port = port; + cell.colspan = colspan; + return cell; + }; + + positive_int num_input_columns = + positive_int{std::max(num_elements(n_inputs), 1_n)}; + positive_int num_output_columns = + positive_int{std::max(num_elements(n_outputs), 1_n)}; + + std::vector inputs = + transform(n_inputs, [&](DataflowInput const &i) -> DotHtmlTableCell { + return make_io_cell(resolved_get_input_label(i), + get_input_field(i.idx), + num_output_columns); + }); + + if (inputs.size() == 0) { + inputs.push_back(DotHtmlTableCell{ + /*content=*/DotHtmlTableCellContents{ + "(no inputs)", + }, + /*port=*/std::nullopt, + /*colspan=*/num_output_columns, + }); + } + + DotHtmlTableCell body = dot_html_cell_from_json(resolved_get_node_label(n)); + body.colspan = num_input_columns * num_output_columns; + + std::vector outputs = + transform(n_outputs, [&](DataflowOutput const &o) -> DotHtmlTableCell { + return make_io_cell(resolved_get_output_label(o), + get_output_field(o.idx), + num_input_columns); + }); + + if (outputs.size() == 0) { + outputs.push_back(DotHtmlTableCell{ + /*content=*/DotHtmlTableCellContents{ + "(no outputs)", + }, + /*port=*/std::nullopt, + /*colspan=*/num_input_columns, + }); + } + + DotHtmlTable table = DotHtmlTable{ + /*border=*/0_n, + /*cellborder=*/1_n, + /*cellspacing=*/0_n, + /*rows=*/ + { + DotHtmlTableRow{ + inputs, + }, + DotHtmlTableRow{ + /*cells=*/{ + body, + }, + }, + DotHtmlTableRow{ + outputs, + }, + }, + }; + + dot.add_html_node(get_node_name(n), table); + } + + for (DataflowEdge const &e : get_edges(g)) { + dot.add_edge(get_node_name(e.src.node), + get_node_name(e.dst.node), + get_output_field(e.src.idx), + get_input_field(e.dst.idx)); + } +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc index 73afc11acc..ceca6982a6 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc @@ -5,9 +5,9 @@ namespace FlexFlow { std::unordered_set get_dataflow_edges_from_node_to_node( DataflowGraphView const &g, Node const &src, Node const &dst) { return g.query_edges(DataflowEdgeQuery{ - /*src_nodes=*/query_set{src}, + /*src_nodes=*/query_set::match_single_value(src), /*src_idxs=*/query_set::matchall(), - /*dst_nodes=*/query_set{dst}, + /*dst_nodes=*/query_set::match_single_value(dst), /*dst_idxs=*/query_set::matchall(), }); } diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_incoming_edges.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_incoming_edges.cc index c4947f967a..42bba1892f 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_incoming_edges.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_incoming_edges.cc @@ -1,4 +1,5 @@ #include "utils/graph/dataflow_graph/algorithms/get_incoming_edges.h" +#include "utils/containers/set_of.h" #include "utils/containers/sorted_by.h" namespace FlexFlow { @@ -8,7 +9,7 @@ std::vector get_incoming_edges(DataflowGraphView const &g, return sorted_by(g.query_edges(DataflowEdgeQuery{ query_set::matchall(), query_set::matchall(), - {n}, + query_set::match_single_value(n), query_set::matchall(), }), [](DataflowEdge const &l, DataflowEdge const &r) { @@ -22,7 +23,7 @@ std::unordered_set DataflowEdgeQuery query = DataflowEdgeQuery{ query_set::matchall(), query_set::matchall(), - query_set{ns}, + query_set::match_values_in(set_of(ns)), query_set::matchall(), }; return g.query_edges(query); diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_outgoing_edges.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_outgoing_edges.cc index 16b2b82b2d..f958b8e085 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_outgoing_edges.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_outgoing_edges.cc @@ -1,4 +1,5 @@ #include "utils/graph/dataflow_graph/algorithms/get_outgoing_edges.h" +#include "utils/containers/set_of.h" #include "utils/containers/sorted_by.h" namespace FlexFlow { @@ -6,7 +7,7 @@ namespace FlexFlow { std::unordered_set get_outgoing_edges(DataflowGraphView const &g, Node const &n) { return g.query_edges(DataflowEdgeQuery{ - {n}, + query_set::match_single_value(n), query_set::matchall(), query_set::matchall(), query_set::matchall(), @@ -17,7 +18,7 @@ std::unordered_set get_outgoing_edges(DataflowGraphView const &g, std::unordered_set const &ns) { DataflowEdgeQuery query = DataflowEdgeQuery{ - query_set{ns}, + query_set::match_values_in(set_of(ns)), query_set::matchall(), query_set::matchall(), query_set::matchall(), diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc index a06ec1ab31..e89189776e 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc @@ -1,5 +1,6 @@ #include "utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h" #include "utils/containers/set_minus.h" +#include "utils/containers/set_of.h" #include "utils/graph/node/algorithms.h" namespace FlexFlow { @@ -9,12 +10,13 @@ std::unordered_set std::unordered_set const &ns) { std::unordered_set all_nodes = get_nodes(g); - query_set src_query = query_set{set_minus(all_nodes, ns)}; + query_set src_query = + query_set::match_values_in(set_of(set_minus(all_nodes, ns))); DataflowEdgeQuery query = DataflowEdgeQuery{ src_query, query_set::matchall(), - query_set{ns}, + query_set::match_values_in(set_of(ns)), query_set::matchall(), }; diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc index f94dd94e11..c958a2e248 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc @@ -1,5 +1,6 @@ #include "utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.h" #include "utils/containers/set_minus.h" +#include "utils/containers/set_of.h" #include "utils/graph/node/algorithms.h" namespace FlexFlow { @@ -9,10 +10,11 @@ std::unordered_set std::unordered_set const &ns) { std::unordered_set all_nodes = get_nodes(g); - query_set dst_query = query_set{set_minus(all_nodes, ns)}; + query_set dst_query = + query_set::match_values_in(set_of(set_minus(all_nodes, ns))); DataflowEdgeQuery query = DataflowEdgeQuery{ - query_set{ns}, + query_set::match_values_in(set_of(ns)), query_set::matchall(), dst_query, query_set::matchall(), diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/view_from_dataflow_graph_data.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/view_from_dataflow_graph_data.cc new file mode 100644 index 0000000000..90f2f5134c --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/view_from_dataflow_graph_data.cc @@ -0,0 +1,40 @@ +#include "utils/graph/dataflow_graph/algorithms/view_from_dataflow_graph_data.h" +#include "utils/containers/filter.h" +#include "utils/graph/dataflow_graph/dataflow_edge_query.h" +#include "utils/graph/dataflow_graph/dataflow_output_query.h" +#include "utils/graph/node/node_query.h" + +namespace FlexFlow { + +ViewFromDataflowGraphData::ViewFromDataflowGraphData( + DataflowGraphData const &data) + : data(data) {} + +std::unordered_set + ViewFromDataflowGraphData::query_nodes(NodeQuery const &query) const { + return apply_node_query(query, this->data.nodes); +} + +std::unordered_set ViewFromDataflowGraphData::query_edges( + DataflowEdgeQuery const &query) const { + return filter(this->data.edges, [&](DataflowEdge const &e) { + return dataflow_edge_query_includes_dataflow_edge(query, e); + }); +} + +std::unordered_set ViewFromDataflowGraphData::query_outputs( + DataflowOutputQuery const &query) const { + return filter(this->data.outputs, [&](DataflowOutput const &o) { + return dataflow_output_query_includes_dataflow_output(query, o); + }); +} + +ViewFromDataflowGraphData *ViewFromDataflowGraphData::clone() const { + return new ViewFromDataflowGraphData{this->data}; +} + +DataflowGraphView view_from_dataflow_graph_data(DataflowGraphData const &data) { + return DataflowGraphView::create(data); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.cc b/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.cc index 982969f3a5..88bbd887d5 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.cc @@ -29,18 +29,18 @@ bool dataflow_edge_query_includes_dataflow_edge(DataflowEdgeQuery const &q, DataflowEdgeQuery dataflow_edge_query_for_edge(DataflowEdge const &e) { return DataflowEdgeQuery{ - query_set{e.src.node}, - query_set{e.src.idx}, - query_set{e.dst.node}, - query_set{e.dst.idx}, + query_set::match_single_value(e.src.node), + query_set::match_single_value(e.src.idx), + query_set::match_single_value(e.dst.node), + query_set::match_single_value(e.dst.idx), }; } DataflowEdgeQuery dataflow_edge_query_all_outgoing_from(DataflowOutput const &src) { return DataflowEdgeQuery{ - query_set{src.node}, - query_set{src.idx}, + query_set::match_single_value(src.node), + query_set::match_single_value(src.idx), query_set::matchall(), query_set::matchall(), }; @@ -51,8 +51,8 @@ DataflowEdgeQuery return DataflowEdgeQuery{ query_set::matchall(), query_set::matchall(), - query_set{dst.node}, - query_set{dst.idx}, + query_set::match_single_value(dst.node), + query_set::match_single_value(dst.idx), }; } diff --git a/lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.cc b/lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.cc index ceaad2bfdf..eb1dfacd8f 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.cc @@ -23,8 +23,8 @@ bool dataflow_output_query_includes_dataflow_output( DataflowOutputQuery dataflow_output_query_for_output(DataflowOutput const &o) { return DataflowOutputQuery{ - query_set{o.node}, - query_set{o.idx}, + query_set::match_single_value(o.node), + query_set::match_single_value(o.idx), }; } diff --git a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc index 75b0e09891..4900cdaa10 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc @@ -2,6 +2,7 @@ #include "utils/containers/are_disjoint.h" #include "utils/containers/extend.h" #include "utils/containers/set_minus.h" +#include "utils/containers/set_of.h" #include "utils/containers/values.h" #include "utils/containers/vector_of.h" #include "utils/graph/algorithms.h" @@ -53,7 +54,10 @@ std::optional } std::unordered_set from_head_to_tail = - g.query_edges(DirectedEdgeQuery{head, tail}); + g.query_edges(DirectedEdgeQuery{ + query_set::match_values_in(set_of(head)), + query_set::match_values_in(set_of(tail)), + }); DiGraphView subgraph = get_subgraph(g, set_union(head, tail)); if (!is_complete_bipartite_digraph(subgraph, head)) { diff --git a/lib/utils/src/utils/graph/digraph/algorithms/digraph_as_dot.cc b/lib/utils/src/utils/graph/digraph/algorithms/digraph_as_dot.cc index d019c59e23..57787b8ce7 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/digraph_as_dot.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/digraph_as_dot.cc @@ -1,5 +1,6 @@ #include "utils/graph/digraph/algorithms/digraph_as_dot.h" -#include "utils/dot_file.h" +#include "utils/dot/dot_file.h" +#include "utils/dot/dot_html_from_json.h" #include "utils/graph/digraph/algorithms/get_edges.h" #include "utils/graph/node/algorithms.h" @@ -7,7 +8,7 @@ namespace FlexFlow { std::string digraph_as_dot( DiGraphView const &g, - std::function const &get_node_label) { + std::function const &get_node_label) { std::ostringstream oss; DotFile dot = DotFile{oss}; @@ -16,9 +17,8 @@ std::string digraph_as_dot( }; for (Node const &n : get_nodes(g)) { - RecordFormatter rec; - rec << get_node_label(n); - dot.add_record_node(get_node_name(n), rec); + dot.add_html_node(get_node_name(n), + dot_html_table_from_json(get_node_label(n))); } for (DirectedEdge const &e : get_edges(g)) { diff --git a/lib/utils/src/utils/graph/digraph/algorithms/digraph_has_edge.cc b/lib/utils/src/utils/graph/digraph/algorithms/digraph_has_edge.cc index 5c790abb8c..a9c255309d 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/digraph_has_edge.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/digraph_has_edge.cc @@ -4,8 +4,8 @@ namespace FlexFlow { bool digraph_has_edge(DiGraphView const &g, DirectedEdge const &e) { return !g.query_edges(DirectedEdgeQuery{ - query_set{e.src}, - query_set{e.dst}, + query_set::match_single_value(e.src), + query_set::match_single_value(e.dst), }) .empty(); } diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc index 2c6606a06b..e8d450d3a2 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc @@ -1,5 +1,6 @@ #include "utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h" #include "utils/containers/are_disjoint.h" +#include "utils/containers/set_of.h" namespace FlexFlow { @@ -17,8 +18,8 @@ std::unordered_set get_edges_from_subgraph_to_subgraph( } return g.query_edges(DirectedEdgeQuery{ - /*srcs=*/query_set{src_subgraph}, - /*dsts=*/query_set{dst_subgraph}, + /*srcs=*/query_set::match_values_in(set_of(src_subgraph)), + /*dsts=*/query_set::match_values_in(set_of(dst_subgraph)), }); } diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_incoming_edges.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_incoming_edges.cc index 1b021e7e79..db09dd07d6 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_incoming_edges.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_incoming_edges.cc @@ -1,5 +1,7 @@ #include "utils/graph/digraph/algorithms/get_incoming_edges.h" #include "utils/containers/group_by.h" +#include "utils/containers/map_values.h" +#include "utils/containers/set_of.h" namespace FlexFlow { @@ -7,7 +9,7 @@ std::unordered_set get_incoming_edges(DiGraphView const &g, Node const &n) { return g.query_edges(DirectedEdgeQuery{ query_set::matchall(), - query_set{n}, + query_set::match_single_value(n), }); } @@ -15,12 +17,16 @@ std::unordered_map> get_incoming_edges(DiGraphView const &g, std::unordered_set const &ns) { std::unordered_map> result = - group_by(g.query_edges(DirectedEdgeQuery{ - query_set::matchall(), - query_set{ns}, - }), - [](DirectedEdge const &e) { return e.dst; }) - .l_to_r(); + map_values(group_by(g.query_edges(DirectedEdgeQuery{ + query_set::matchall(), + query_set::match_values_in(set_of(ns)), + }), + [](DirectedEdge const &e) { return e.dst; }) + .l_to_r(), + [](nonempty_unordered_set const &s) + -> std::unordered_set { + return s.unwrap_as_unordered_set(); + }); for (Node const &n : ns) { result[n]; diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_outgoing_edges.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_outgoing_edges.cc index 9569fb1ae3..c2057472cf 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_outgoing_edges.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_outgoing_edges.cc @@ -1,5 +1,7 @@ #include "utils/graph/digraph/algorithms/get_outgoing_edges.h" #include "utils/containers/group_by.h" +#include "utils/containers/map_values.h" +#include "utils/containers/set_of.h" namespace FlexFlow { @@ -7,12 +9,16 @@ std::unordered_map> get_outgoing_edges(DiGraphView const &g, std::unordered_set const &ns) { std::unordered_map> result = - group_by(g.query_edges(DirectedEdgeQuery{ - query_set{ns}, - query_set::matchall(), - }), - [](DirectedEdge const &e) { return e.src; }) - .l_to_r(); + map_values(group_by(g.query_edges(DirectedEdgeQuery{ + query_set::match_values_in(set_of(ns)), + query_set::matchall(), + }), + [](DirectedEdge const &e) { return e.src; }) + .l_to_r(), + [](nonempty_unordered_set const &s) + -> std::unordered_set { + return s.unwrap_as_unordered_set(); + }); for (Node const &n : ns) { result[n]; @@ -24,7 +30,7 @@ std::unordered_map> std::unordered_set get_outgoing_edges(DiGraphView const &g, Node const &n) { return g.query_edges(DirectedEdgeQuery{ - query_set{n}, + query_set::match_single_value(n), query_set::matchall(), }); } diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.cc index f19deb3046..3067e25d2a 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.cc @@ -1,5 +1,6 @@ #include "utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.h" #include "utils/containers/set_minus.h" +#include "utils/containers/set_of.h" #include "utils/graph/node/algorithms.h" namespace FlexFlow { @@ -8,8 +9,10 @@ std::unordered_set get_subgraph_outgoing_edges( DiGraphView const &g, std::unordered_set const &subgraph_nodes) { std::unordered_set external_nodes = set_minus(get_nodes(g), subgraph_nodes); - DirectedEdgeQuery query = DirectedEdgeQuery{query_set{subgraph_nodes}, - query_set{external_nodes}}; + DirectedEdgeQuery query = DirectedEdgeQuery{ + query_set::match_values_in(set_of(subgraph_nodes)), + query_set::match_values_in(set_of(external_nodes)), + }; return g.query_edges(query); } diff --git a/lib/utils/src/utils/graph/digraph/directed_edge_query.cc b/lib/utils/src/utils/graph/digraph/directed_edge_query.cc index b12098bd96..b7aabb14be 100644 --- a/lib/utils/src/utils/graph/digraph/directed_edge_query.cc +++ b/lib/utils/src/utils/graph/digraph/directed_edge_query.cc @@ -1,4 +1,5 @@ #include "utils/graph/digraph/directed_edge_query.h" +#include "utils/containers/set_of.h" namespace FlexFlow { @@ -30,7 +31,10 @@ DirectedEdgeQuery query_intersection(DirectedEdgeQuery const &lhs, result_dsts = allowed_values(query_intersection(lhs.dsts, rhs.dsts)); } - return DirectedEdgeQuery{result_srcs, result_dsts}; + return DirectedEdgeQuery{ + query_set::match_values_in(set_of(result_srcs)), + query_set::match_values_in(set_of(result_dsts)), + }; } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/instances/unordered_set_kwarg_dataflow_graph.cc b/lib/utils/src/utils/graph/instances/unordered_set_kwarg_dataflow_graph.cc new file mode 100644 index 0000000000..fc0ed61fbf --- /dev/null +++ b/lib/utils/src/utils/graph/instances/unordered_set_kwarg_dataflow_graph.cc @@ -0,0 +1,10 @@ +#include "utils/graph/instances/unordered_set_kwarg_dataflow_graph.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using SlotName = ordered_value_type<0>; + +template struct UnorderedSetKwargDataflowGraph; + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_data_from_kwarg_dataflow_graph_data.cc b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_data_from_kwarg_dataflow_graph_data.cc new file mode 100644 index 0000000000..1e55bf407e --- /dev/null +++ b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_data_from_kwarg_dataflow_graph_data.cc @@ -0,0 +1,13 @@ +#include "utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_data_from_kwarg_dataflow_graph_data.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using SlotName = ordered_value_type<0>; + +template DataflowGraphData dataflow_graph_data_from_kwarg_dataflow_graph_data( + KwargDataflowGraphData const &, + std::function< + std::vector(std::unordered_set const &)> const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_from_kwarg_dataflow_graph.cc b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_from_kwarg_dataflow_graph.cc new file mode 100644 index 0000000000..bdbb6c1d6a --- /dev/null +++ b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_from_kwarg_dataflow_graph.cc @@ -0,0 +1,13 @@ +#include "utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_from_kwarg_dataflow_graph.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using SlotName = ordered_value_type<0>; + +template DataflowGraphView dataflow_graph_from_kwarg_dataflow_graph( + KwargDataflowGraphView const &, + std::function< + std::vector(std::unordered_set const &)> const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_inputs.cc b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_inputs.cc new file mode 100644 index 0000000000..343bc1c228 --- /dev/null +++ b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_inputs.cc @@ -0,0 +1,11 @@ +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_inputs.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using SlotName = ordered_value_type<0>; + +template std::unordered_set> + get_all_kwarg_dataflow_inputs(KwargDataflowGraphView const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_slots_for_node.cc b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_slots_for_node.cc new file mode 100644 index 0000000000..816e529be4 --- /dev/null +++ b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_slots_for_node.cc @@ -0,0 +1,11 @@ +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_slots_for_node.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using SlotName = ordered_value_type<0>; + +template std::unordered_set + get_incoming_slots_for_node(KwargDataflowGraphView const &, Node); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_data.cc b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_data.cc new file mode 100644 index 0000000000..1d9e1d395f --- /dev/null +++ b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_data.cc @@ -0,0 +1,11 @@ +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_graph_data.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using SlotName = ordered_value_type<0>; + +template KwargDataflowGraphData + get_kwarg_dataflow_graph_data(KwargDataflowGraphView const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_slots_for_node.cc b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_slots_for_node.cc new file mode 100644 index 0000000000..98c2e3895e --- /dev/null +++ b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_slots_for_node.cc @@ -0,0 +1,11 @@ +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_slots_for_node.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using SlotName = ordered_value_type<0>; + +template std::unordered_set + get_outgoing_slots_for_node(KwargDataflowGraphView const &, Node); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_as_dot.cc b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_as_dot.cc new file mode 100644 index 0000000000..b9585b562a --- /dev/null +++ b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_as_dot.cc @@ -0,0 +1,17 @@ +#include "utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_as_dot.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using SlotName = ordered_value_type<0>; + +template std::string kwarg_dataflow_graph_as_dot( + KwargDataflowGraphView const &, + std::function const &, + std::function const &)> const + &, + std::function const &, + std::function< + std::vector(std::unordered_set const &)> const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_data.cc b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_data.cc new file mode 100644 index 0000000000..7c6577c5ea --- /dev/null +++ b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_data.cc @@ -0,0 +1,11 @@ +#include "utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_data.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using SlotName = ordered_value_type<0>; + +template void require_kwarg_dataflow_graph_data_is_valid( + KwargDataflowGraphData const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graphs_are_isomorphic.cc b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graphs_are_isomorphic.cc new file mode 100644 index 0000000000..528d830261 --- /dev/null +++ b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graphs_are_isomorphic.cc @@ -0,0 +1,12 @@ +#include "utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graphs_are_isomorphic.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using SlotName = ordered_value_type<0>; + +template bool kwarg_dataflow_graphs_are_isomorphic( + KwargDataflowGraphView const &, + KwargDataflowGraphView const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/view_from_kwarg_dataflow_graph_data.cc b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/view_from_kwarg_dataflow_graph_data.cc new file mode 100644 index 0000000000..9d1fa0a31e --- /dev/null +++ b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/view_from_kwarg_dataflow_graph_data.cc @@ -0,0 +1,11 @@ +#include "utils/graph/kwarg_dataflow_graph/algorithms/view_from_kwarg_dataflow_graph_data.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using SlotName = ordered_value_type<0>; + +template KwargDataflowGraphView view_from_kwarg_dataflow_graph_data( + KwargDataflowGraphData const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_view_as_dot.cc b/lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_view_as_dot.cc new file mode 100644 index 0000000000..f1b9c13e17 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_view_as_dot.cc @@ -0,0 +1,19 @@ +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_view_as_dot.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using NodeLabel = value_type<0>; +using ValueLabel = value_type<1>; +using SlotName = ordered_value_type<2>; + +template std::string labelled_kwarg_dataflow_graph_view_as_dot( + LabelledKwargDataflowGraphView const &, + std::function const &, + std::function const &, + std::function const &, + std::function< + std::vector(std::unordered_set const &)> const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/materialize_labelled_kwarg_dataflow_graph_view.cc b/lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/materialize_labelled_kwarg_dataflow_graph_view.cc new file mode 100644 index 0000000000..1de29d648c --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/materialize_labelled_kwarg_dataflow_graph_view.cc @@ -0,0 +1,16 @@ +#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/materialize_labelled_kwarg_dataflow_graph_view.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using NodeLabel = value_type<0>; +using ValueLabel = value_type<1>; +using SlotName = ordered_value_type<2>; + +template LabelledKwargDataflowGraph + materialize_labelled_kwarg_dataflow_graph_view( + LabelledKwargDataflowGraphView const + &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.cc deleted file mode 100644 index 78dbed5262..0000000000 --- a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h" diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_as_dot.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_as_dot.cc new file mode 100644 index 0000000000..b8f5b15417 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_as_dot.cc @@ -0,0 +1,17 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_as_dot.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using NodeLabel = value_type<0>; +using ValueLabel = value_type<1>; + +template std::string labelled_open_dataflow_graph_as_dot( + LabelledOpenDataflowGraphView const &, + std::function const &, + std::function const &, + std::function const &, + std::function const &, + std::function const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.cc b/lib/utils/src/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.cc similarity index 52% rename from lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.cc rename to lib/utils/src/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.cc index ef35eaa96a..d4a580eaab 100644 --- a/lib/utils/src/utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.cc +++ b/lib/utils/src/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.cc @@ -1,4 +1,4 @@ -#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.h" +#include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.h" #include "utils/archetypes/ordered_value_type.h" #include "utils/archetypes/value_type.h" @@ -13,8 +13,11 @@ template std::string labelled_open_kwarg_dataflow_graph_view_as_dot( LabelledOpenKwargDataflowGraphView const &g, - std::function const &, - std::function const &); + SlotName> const &, + std::function const &, + std::function const &, + std::function const &, + std::function< + std::vector(std::unordered_set const &)> const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc b/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc index 4acefbe3f5..db181fbe73 100644 --- a/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc +++ b/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc @@ -1,5 +1,7 @@ #include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" #include "utils/containers/group_by.h" +#include "utils/containers/map_values.h" +#include "utils/containers/set_of.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/multidigraph/multidiedge.dtg.h" #include "utils/graph/multidigraph/multidiedge_query.dtg.h" @@ -10,17 +12,30 @@ namespace FlexFlow { std::unordered_set get_incoming_edges(MultiDiGraphView const &g, Node const &n) { - return g.query_edges(MultiDiEdgeQuery{query_set::matchall(), {n}}); + MultiDiEdgeQuery query = MultiDiEdgeQuery{ + query_set::matchall(), + query_set::match_single_value(n), + }; + + return g.query_edges(query); } std::unordered_map> get_incoming_edges(MultiDiGraphView const &g, std::unordered_set const &ns) { - std::unordered_map> result = - group_by(g.query_edges(MultiDiEdgeQuery{query_set::matchall(), - query_set{ns}}), + MultiDiEdgeQuery query = MultiDiEdgeQuery{ + query_set::matchall(), + query_set::match_values_in(set_of(ns)), + }; + + std::unordered_map> result = map_values( + group_by(g.query_edges(query), [&](MultiDiEdge const &e) { return g.get_multidiedge_dst(e); }) - .l_to_r(); + .l_to_r(), + [](nonempty_unordered_set const &s) + -> std::unordered_set { + return s.unwrap_as_unordered_set(); + }); for (Node const &n : ns) { result[n]; diff --git a/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc b/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc index 438bfb0b93..28e181ebb9 100644 --- a/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc +++ b/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc @@ -1,23 +1,39 @@ #include "utils/graph/multidigraph/algorithms/get_outgoing_edges.h" #include "utils/containers/group_by.h" +#include "utils/containers/map_values.h" +#include "utils/containers/set_of.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/node/algorithms.h" #include + namespace FlexFlow { std::unordered_set get_outgoing_edges(MultiDiGraphView const &g, Node const &n) { - return g.query_edges(MultiDiEdgeQuery{{n}, query_set::matchall()}); + MultiDiEdgeQuery query = MultiDiEdgeQuery{ + query_set::match_single_value(n), + query_set::matchall(), + }; + + return g.query_edges(query); } std::unordered_map> get_outgoing_edges(MultiDiGraphView const &g, std::unordered_set const &ns) { - std::unordered_map> result = - group_by(g.query_edges(MultiDiEdgeQuery{query_set{ns}, - query_set::matchall()}), + MultiDiEdgeQuery query = MultiDiEdgeQuery{ + query_set::match_values_in(set_of(ns)), + query_set::matchall(), + }; + + std::unordered_map> result = map_values( + group_by(g.query_edges(query), [&](MultiDiEdge const &e) { return g.get_multidiedge_src(e); }) - .l_to_r(); + .l_to_r(), + [](nonempty_unordered_set const &s) + -> std::unordered_set { + return s.unwrap_as_unordered_set(); + }); for (Node const &n : ns) { result[n]; diff --git a/lib/utils/src/utils/graph/node/algorithms.cc b/lib/utils/src/utils/graph/node/algorithms.cc index 1d2be55e5e..61a4d9d9af 100644 --- a/lib/utils/src/utils/graph/node/algorithms.cc +++ b/lib/utils/src/utils/graph/node/algorithms.cc @@ -8,7 +8,11 @@ std::unordered_set get_nodes(GraphView const &g) { } bool has_node(GraphView const &g, Node const &n) { - return !g.query_nodes(NodeQuery{{n}}).empty(); + NodeQuery query = NodeQuery{ + query_set::match_single_value(n), + }; + + return !g.query_nodes(query).empty(); } size_t num_nodes(GraphView const &g) { diff --git a/lib/utils/src/utils/graph/node/node_query.cc b/lib/utils/src/utils/graph/node/node_query.cc index 834086a733..aa24da42ae 100644 --- a/lib/utils/src/utils/graph/node/node_query.cc +++ b/lib/utils/src/utils/graph/node/node_query.cc @@ -1,4 +1,5 @@ #include "utils/graph/node/node_query.h" +#include "utils/containers/set_of.h" namespace FlexFlow { @@ -18,10 +19,9 @@ NodeQuery query_intersection(NodeQuery const &lhs, NodeQuery const &rhs) { nodes = allowed_values(query_intersection(lhs.nodes, rhs.nodes)); } - NodeQuery intersection_result = node_query_all(); - intersection_result.nodes = nodes; - - return intersection_result; + return NodeQuery{ + query_set::match_values_in(set_of(nodes)), + }; } NodeQuery query_union(NodeQuery const &lhs, NodeQuery const &rhs) { diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/as_dot.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/as_dot.cc deleted file mode 100644 index 72c2d9d3c7..0000000000 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/as_dot.cc +++ /dev/null @@ -1,71 +0,0 @@ -#include "utils/graph/open_dataflow_graph/algorithms/as_dot.h" -#include "utils/dot_file.h" -#include "utils/graph/dataflow_graph/algorithms.h" -#include "utils/graph/dataflow_graph/algorithms/as_dot.h" -#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h" -#include "utils/graph/node/algorithms.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h" - -namespace FlexFlow { - -std::string as_dot(OpenDataflowGraphView const &g) { - - std::function get_node_label = [](Node const &n) { - return fmt::format("n{}", n.raw_uid); - }; - - std::function get_input_label = - [](DataflowGraphInput const &i) { return fmt::format("i{}", i.idx); }; - - return as_dot(g, get_node_label, get_input_label); -} - -/* WARN(@lockshaw): doing this all with string ids is ugly and error prone, - * as it requires duplicating the stringification logic across functions. - * - * Fixing this is tracked in issue - * https://github.com/flexflow/FlexFlow/issues/1476 - */ -std::string - as_dot(OpenDataflowGraphView const &g, - std::function const &get_node_label, - std::function const - &get_input_label) { - std::ostringstream oss; - DotFile dot = DotFile{oss}; - - as_dot(dot, static_cast(g), get_node_label); - - auto get_node_name = [](Node n) { return fmt::format("n{}", n.raw_uid); }; - - auto get_input_field = [](nonnegative_int idx) { - return fmt::format("i{}", idx); - }; - - auto get_output_field = [](nonnegative_int idx) { - return fmt::format("o{}", idx); - }; - - auto get_graph_input_name = [](DataflowGraphInput i) { - return fmt::format("gi{}", i.idx); - }; - - for (DataflowGraphInput const &i : get_open_dataflow_graph_inputs(g)) { - dot.add_node(get_graph_input_name(i), - {{"style", "dashed"}, {"label", get_input_label(i)}}); - } - - for (DataflowInputEdge const &e : get_incoming_edges(g)) { - dot.add_edge(get_graph_input_name(e.src), - get_node_name(e.dst.node), - std::nullopt, - get_input_field(e.dst.idx)); - } - - dot.close(); - return oss.str(); -} - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.cc index 728dc75678..0228fdd8e9 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.cc @@ -26,13 +26,13 @@ std::vector get_incoming_edges(OpenDataflowGraphView const &g, return sorted_by(g.query_edges(OpenDataflowEdgeQuery{ DataflowInputEdgeQuery{ query_set::matchall(), - {n}, + query_set::match_single_value(n), query_set::matchall(), }, DataflowEdgeQuery{ query_set::matchall(), query_set::matchall(), - {n}, + query_set::match_single_value(n), query_set::matchall(), }, }), diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc index 36f027f792..e8989c9ee1 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc @@ -78,9 +78,9 @@ OpenDataflowGraphData query_set::match_none(), }, DataflowEdgeQuery{ - query_set{subgraph_nodes}, + query_set::match_values_in(set_of(subgraph_nodes)), query_set::matchall(), - query_set{subgraph_nodes}, + query_set::match_values_in(set_of(subgraph_nodes)), query_set::matchall(), }, }; diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_incoming_edges.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_incoming_edges.cc index 6448da9c73..fe2b46cd64 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_incoming_edges.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_incoming_edges.cc @@ -1,5 +1,6 @@ #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_incoming_edges.h" #include "utils/containers/set_minus.h" +#include "utils/containers/set_of.h" #include "utils/graph/node/algorithms.h" namespace FlexFlow { @@ -12,13 +13,13 @@ std::unordered_set OpenDataflowEdgeQuery query = OpenDataflowEdgeQuery{ DataflowInputEdgeQuery{ query_set::matchall(), - query_set{ns}, + query_set::match_values_in(set_of(ns)), query_set::matchall(), }, DataflowEdgeQuery{ - query_set{nodes_not_in_ns}, + query_set::match_values_in(set_of(nodes_not_in_ns)), query_set::matchall(), - query_set{ns}, + query_set::match_values_in(set_of(ns)), query_set::matchall(), }, }; diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_as_dot.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_as_dot.cc new file mode 100644 index 0000000000..2227b8ef8d --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_as_dot.cc @@ -0,0 +1,86 @@ +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_as_dot.h" +#include "utils/dot/dot_file.h" +#include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.h" +#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h" + +namespace FlexFlow { + +std::string open_dataflow_graph_as_dot(OpenDataflowGraphView const &g) { + + std::function get_node_label = [](Node const &n) { + return fmt::format("n{}", n.raw_uid); + }; + + std::function get_graph_input_label = + [](DataflowGraphInput const &i) { return fmt::format("gi{}", i.idx); }; + + std::function get_input_label = + [](DataflowInput const &i) { return fmt::format("i{}", i.idx); }; + + std::function get_output_label = + [](DataflowOutput const &o) { return fmt::format("o{}", o.idx); }; + + return open_dataflow_graph_as_dot(g, + get_node_label, + get_graph_input_label, + get_input_label, + get_output_label); +} + +/* WARN(@lockshaw): doing this all with string ids is ugly and error prone, + * as it requires duplicating the stringification logic across functions. + * + * Fixing this is tracked in issue + * https://github.com/flexflow/FlexFlow/issues/1476 + */ +std::string open_dataflow_graph_as_dot( + OpenDataflowGraphView const &g, + std::function const &get_node_label, + std::function const + &get_graph_input_label, + std::function const &get_input_label, + std::function const + &get_output_label) { + std::ostringstream oss; + DotFile dot = DotFile{oss}; + + dataflow_graph_as_dot(dot, static_cast(g), get_node_label); + + auto get_node_name = [](Node n) -> std::string { + return fmt::format("n{}", n.raw_uid); + }; + + auto get_input_field = [](nonnegative_int idx) -> std::string { + return fmt::format("i{}", idx); + }; + + auto get_output_field = [](nonnegative_int idx) -> std::string { + return fmt::format("o{}", idx); + }; + + auto get_graph_input_name = [](DataflowGraphInput i) -> std::string { + return fmt::format("gi{}", i.idx); + }; + + for (DataflowGraphInput const &i : get_open_dataflow_graph_inputs(g)) { + dot.add_node(get_graph_input_name(i), + {{"style", "dashed"}, {"label", get_graph_input_label(i)}}); + } + + for (DataflowInputEdge const &e : get_incoming_edges(g)) { + dot.add_edge(get_graph_input_name(e.src), + get_node_name(e.dst.node), + std::nullopt, + get_input_field(e.dst.idx)); + } + + dot.close(); + return oss.str(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge_query.cc b/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge_query.cc index 34adea6b09..71cf98e93b 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge_query.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge_query.cc @@ -26,16 +26,16 @@ bool dataflow_input_edge_query_includes(DataflowInputEdgeQuery const &q, DataflowInputEdgeQuery dataflow_input_edge_query_for_edge(DataflowInputEdge const &e) { return DataflowInputEdgeQuery{ - query_set{e.src}, - query_set{e.dst.node}, - query_set{e.dst.idx}, + query_set::match_single_value(e.src), + query_set::match_single_value(e.dst.node), + query_set::match_single_value(e.dst.idx), }; } DataflowInputEdgeQuery dataflow_input_edge_query_all_outgoing_from(DataflowGraphInput const &src) { return DataflowInputEdgeQuery{ - query_set{src}, + query_set::match_single_value(src), query_set::matchall(), query_set::matchall(), }; @@ -45,8 +45,8 @@ DataflowInputEdgeQuery dataflow_input_edge_query_all_incoming_to(DataflowInput const &dst) { return DataflowInputEdgeQuery{ query_set::matchall(), - query_set{dst.node}, - query_set{dst.idx}, + query_set::match_single_value(dst.node), + query_set::match_single_value(dst.idx), }; } diff --git a/lib/utils/src/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_as_dot.cc b/lib/utils/src/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_as_dot.cc new file mode 100644 index 0000000000..27113566d5 --- /dev/null +++ b/lib/utils/src/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_as_dot.cc @@ -0,0 +1,19 @@ +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_as_dot.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using GraphInputName = ordered_value_type<0>; +using SlotName = ordered_value_type<1>; + +template std::string open_kwarg_dataflow_graph_as_dot( + OpenKwargDataflowGraphView const &, + std::function const &, + std::function const &)> const &, + std::function const &, + std::function< + std::vector(std::unordered_set const &)> const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.cc b/lib/utils/src/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.cc index 3ea7dce3da..1290bb698f 100644 --- a/lib/utils/src/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.cc +++ b/lib/utils/src/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_data.cc @@ -9,4 +9,7 @@ using SlotName = ordered_value_type<1>; template void require_open_kwarg_dataflow_graph_data_is_valid( OpenKwargDataflowGraphData const &); +template KwargDataflowGraphData kwarg_dataflow_graph_data_from_open( + OpenKwargDataflowGraphData const &); + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_kwarg_dataflow_graph/algorithms/view_as_closed_kwarg_dataflow_graph_by_materializing_inputs.cc b/lib/utils/src/utils/graph/open_kwarg_dataflow_graph/algorithms/view_as_closed_kwarg_dataflow_graph_by_materializing_inputs.cc new file mode 100644 index 0000000000..491ff1b600 --- /dev/null +++ b/lib/utils/src/utils/graph/open_kwarg_dataflow_graph/algorithms/view_as_closed_kwarg_dataflow_graph_by_materializing_inputs.cc @@ -0,0 +1,14 @@ +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/view_as_closed_kwarg_dataflow_graph_by_materializing_inputs.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using GraphInputName = ordered_value_type<0>; +using SlotName = ordered_value_type<1>; + +template std::pair>, + bidict, Node>> + view_as_closed_kwarg_dataflow_graph_by_materializing_inputs( + OpenKwargDataflowGraphView const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/render_dot.cc b/lib/utils/src/utils/graph/render_dot.cc index 8bdc001c80..d5a1562ac1 100644 --- a/lib/utils/src/utils/graph/render_dot.cc +++ b/lib/utils/src/utils/graph/render_dot.cc @@ -38,17 +38,17 @@ std::string render_node_label( std::vector n_inputs = get_dataflow_inputs(g, n); std::vector n_outputs = get_outputs(g, n); - RecordFormatter inputs_record; + RecordFormatter inputs_record = mk_empty_record(Orientation::HORIZONTAL); for (DataflowInput const &i : n_inputs) { inputs_record << fmt::format("{}", i.idx, i.idx); } - RecordFormatter outputs_record; + RecordFormatter outputs_record = mk_empty_record(Orientation::HORIZONTAL); for (DataflowOutput const &o : n_outputs) { outputs_record << fmt::format("{}", o.idx, g.at(o)); } - RecordFormatter rec; + RecordFormatter rec = mk_empty_record(Orientation::VERTICAL); rec << inputs_record << try_at(g.at(n), std::string{"label"}) .value_or(fmt::to_string(n.raw_uid)) diff --git a/lib/utils/src/utils/graph/series_parallel/sp_ization/escribano_algo.cc b/lib/utils/src/utils/graph/series_parallel/sp_ization/escribano_algo.cc index 378bb15bb0..36b8e7294b 100644 --- a/lib/utils/src/utils/graph/series_parallel/sp_ization/escribano_algo.cc +++ b/lib/utils/src/utils/graph/series_parallel/sp_ization/escribano_algo.cc @@ -147,7 +147,7 @@ static std::unordered_set return filter_out_sync_nodes(forest, node_roles); } -static std::pair, std::unordered_set> +static std::pair, nonempty_unordered_set> get_up_and_down_sets( DiGraph const &g, std::unordered_set const &forest, @@ -155,12 +155,12 @@ static std::pair, std::unordered_set> nonnegative_int max_depth = get_max_depth(g, depth_map); - auto grouped_by_depth = + OneToMany grouped_by_depth = group_by(forest, [&](Node const &n) { return depth_map.at(n); }); - return make_pair(grouped_by_depth.at_l( - nonnegative_int{max_depth.unwrap_nonnegative() - 1}), - grouped_by_depth.at_l(max_depth)); + return std::make_pair(grouped_by_depth.at_l(nonnegative_int{ + max_depth.unwrap_nonnegative() - 1}), + grouped_by_depth.at_l(max_depth)); } static std::unordered_set @@ -228,7 +228,13 @@ SeriesParallelDecomposition escribano_sp_ization(DiGraph g) { Node handle = get_only(get_lowest_common_ancestors(sp, component).value()); std::unordered_set forest = get_forest_escribano(sp, handle, component, node_roles); - auto [up, down] = get_up_and_down_sets(sp, forest, depth_map); + + std::pair, nonempty_unordered_set> + up_down_sets = get_up_and_down_sets(sp, forest, depth_map); + + std::unordered_set up = up_down_sets.first.unwrap_as_unordered_set(); + std::unordered_set down = + up_down_sets.second.unwrap_as_unordered_set(); remove_edges(sp, edges_to_remove(sp, up, down)); diff --git a/lib/utils/src/utils/graph/series_parallel/sp_ization/work_duplicating_sp_ization.cc b/lib/utils/src/utils/graph/series_parallel/sp_ization/work_duplicating_sp_ization.cc index ca21427d9e..7423437b1c 100644 --- a/lib/utils/src/utils/graph/series_parallel/sp_ization/work_duplicating_sp_ization.cc +++ b/lib/utils/src/utils/graph/series_parallel/sp_ization/work_duplicating_sp_ization.cc @@ -59,7 +59,7 @@ static NonNormalSPDecomposition parallel_composition_with_coalescing( for (auto const &[head, strands_with_head] : strands_grouped_by_head.l_to_r()) { std::unordered_set tails = - transform(strands_with_head, cut_off_head); + transform(strands_with_head.unwrap_as_unordered_set(), cut_off_head); NonNormalSPDecomposition parallel_comp = parallel_composition_with_coalescing(tails); diff --git a/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc b/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc index 726fda8af7..d28818c5b4 100644 --- a/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc +++ b/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc @@ -5,8 +5,8 @@ namespace FlexFlow { std::unordered_set get_neighboring_nodes(UndirectedGraphView const &g, Node const &n) { - std::unordered_set edges = - g.query_edges(UndirectedEdgeQuery{query_set{n}}); + std::unordered_set edges = g.query_edges( + UndirectedEdgeQuery{query_set::match_single_value(n)}); std::unordered_set result = set_union(transform(vector_of(edges), [](UndirectedEdge const &e) { diff --git a/lib/utils/src/utils/graph/views/views.cc b/lib/utils/src/utils/graph/views/views.cc index 8248328d74..6efbb45a17 100644 --- a/lib/utils/src/utils/graph/views/views.cc +++ b/lib/utils/src/utils/graph/views/views.cc @@ -1,5 +1,6 @@ #include "utils/graph/views/views.h" #include "utils/containers/flatmap.h" +#include "utils/containers/set_of.h" #include "utils/containers/transform.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/directed_edge_query.h" @@ -20,15 +21,19 @@ UndirectedSubgraphView *UndirectedSubgraphView::clone() const { std::unordered_set UndirectedSubgraphView::query_edges( UndirectedEdgeQuery const &query) const { - UndirectedEdgeQuery subgraph_query = - UndirectedEdgeQuery{this->subgraph_nodes}; + UndirectedEdgeQuery subgraph_query = UndirectedEdgeQuery{ + query_set::match_values_in(set_of(this->subgraph_nodes)), + }; return this->g.query_edges(query_intersection(query, subgraph_query)); } std::unordered_set UndirectedSubgraphView::query_nodes(NodeQuery const &query) const { - return this->g.query_nodes( - query_intersection(query, NodeQuery{this->subgraph_nodes})); + NodeQuery subgraph_query = NodeQuery{ + query_set::match_values_in(set_of(this->subgraph_nodes)), + }; + + return this->g.query_nodes(query_intersection(query, subgraph_query)); } DiSubgraphView::DiSubgraphView(DiGraphView const &g, @@ -37,15 +42,20 @@ DiSubgraphView::DiSubgraphView(DiGraphView const &g, std::unordered_set DiSubgraphView::query_edges(DirectedEdgeQuery const &query) const { - DirectedEdgeQuery subgraph_query = - DirectedEdgeQuery{this->subgraph_nodes, this->subgraph_nodes}; + DirectedEdgeQuery subgraph_query = DirectedEdgeQuery{ + query_set::match_values_in(set_of(this->subgraph_nodes)), + query_set::match_values_in(set_of(this->subgraph_nodes)), + }; return this->g.query_edges(query_intersection(query, subgraph_query)); } std::unordered_set DiSubgraphView::query_nodes(NodeQuery const &query) const { - return this->g.query_nodes( - query_intersection(query, NodeQuery{this->subgraph_nodes})); + NodeQuery subgraph_query = NodeQuery{ + query_set::match_values_in(set_of(this->subgraph_nodes)), + }; + + return this->g.query_nodes(query_intersection(query, subgraph_query)); } DiSubgraphView *DiSubgraphView::clone() const { diff --git a/lib/utils/src/utils/nonempty_unordered_set/nonempty_unordered_set.cc b/lib/utils/src/utils/nonempty_unordered_set/nonempty_unordered_set.cc new file mode 100644 index 0000000000..3738faaaa0 --- /dev/null +++ b/lib/utils/src/utils/nonempty_unordered_set/nonempty_unordered_set.cc @@ -0,0 +1,10 @@ +#include "utils/nonempty_unordered_set/nonempty_unordered_set.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T = value_type<0>; + +template struct nonempty_unordered_set; + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/one_to_many/one_to_many.cc b/lib/utils/src/utils/one_to_many/one_to_many.cc index 7cc939585d..158d2e10c9 100644 --- a/lib/utils/src/utils/one_to_many/one_to_many.cc +++ b/lib/utils/src/utils/one_to_many/one_to_many.cc @@ -12,7 +12,7 @@ using R = value_type<1>; template struct OneToMany; -template std::unordered_map> +template std::unordered_map> format_as(OneToMany const &); template std::ostream &operator<<(std::ostream &, OneToMany const &); diff --git a/lib/utils/src/utils/one_to_many/one_to_many_transform_values.cc b/lib/utils/src/utils/one_to_many/one_to_many_transform_values.cc new file mode 100644 index 0000000000..141db7f1da --- /dev/null +++ b/lib/utils/src/utils/one_to_many/one_to_many_transform_values.cc @@ -0,0 +1,14 @@ +#include "utils/one_to_many/one_to_many_transform_values.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using L = value_type<0>; +using R1 = value_type<1>; +using R2 = value_type<2>; +using F = std::function; + +template OneToMany one_to_many_transform_values(OneToMany const &, + F); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/record_formatter.cc b/lib/utils/src/utils/record_formatter.cc index c98bea0ab8..b44ed87338 100644 --- a/lib/utils/src/utils/record_formatter.cc +++ b/lib/utils/src/utils/record_formatter.cc @@ -1,4 +1,16 @@ #include "utils/record_formatter.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +RecordFormatter::RecordFormatter(Orientation orientation, + std::vector const &pieces) + : orientation(orientation), pieces(pieces) {} + +RecordFormatter mk_empty_record(Orientation o) { + return RecordFormatter{o, std::vector{}}; +} RecordFormatter &operator<<(RecordFormatter &r, std::string const &tok) { r.pieces.push_back(tok); @@ -27,7 +39,12 @@ RecordFormatter &operator<<(RecordFormatter &r, float tok) { RecordFormatter &operator<<(RecordFormatter &r, RecordFormatter const &sub_r) { std::ostringstream oss; - oss << sub_r; + + if (r.orientation == sub_r.orientation) { + oss << "{ " << sub_r << " }"; + } else { + oss << sub_r; + } r << oss.str(); return r; @@ -51,3 +68,28 @@ std::ostream &operator<<(std::ostream &s, RecordFormatter const &r) { return s; } + +template <> +RecordFormatter mk_kv_record(std::string const &k, RecordFormatter const &v) { + RecordFormatter rr = mk_empty_record(Orientation::HORIZONTAL); + rr << k << v; + return rr; +} + +} // namespace FlexFlow + +namespace FlexFlow { + +using T = value_type<0>; + +template RecordFormatter mk_kv_record(std::string const &, T const &); + +template RecordFormatter mk_kv_record(std::string const &, + std::optional const &); + +using K = ordered_value_type<0>; +using V = value_type<0>; + +template RecordFormatter mk_record_for_map(std::unordered_map const &); + +} // namespace FlexFlow diff --git a/lib/utils/test/src/utils/dot_file.cc b/lib/utils/test/src/utils/dot/dot_file.cc similarity index 93% rename from lib/utils/test/src/utils/dot_file.cc rename to lib/utils/test/src/utils/dot/dot_file.cc index e409572511..05720593dd 100644 --- a/lib/utils/test/src/utils/dot_file.cc +++ b/lib/utils/test/src/utils/dot/dot_file.cc @@ -1,7 +1,9 @@ -#include "utils/dot_file.h" +#include "utils/dot/dot_file.h" #include #include +using namespace ::FlexFlow; + TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("DotFile") { std::ostringstream oss; @@ -35,7 +37,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("add_record_node") { - RecordFormatter rf; + RecordFormatter rf = mk_empty_record(Orientation::VERTICAL); rf << "Field1"; rf << 42; diff --git a/lib/utils/test/src/utils/dot/dot_html_from_json.cc b/lib/utils/test/src/utils/dot/dot_html_from_json.cc new file mode 100644 index 0000000000..83888eed90 --- /dev/null +++ b/lib/utils/test/src/utils/dot/dot_html_from_json.cc @@ -0,0 +1,196 @@ +#include "utils/dot/dot_html_from_json.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("dot_html_table_from_json") { + SUBCASE("json is int") { + nlohmann::json j = 5; + + DotHtmlTable result = dot_html_table_from_json(j); + DotHtmlTable correct = DotHtmlTable{ + /*border=*/0_n, + /*cellborder=*/1_n, + /*cellspacing=*/0_n, + /*rows=*/ + { + DotHtmlTableRow{ + /*cells=*/{ + DotHtmlTableCell{ + /*contents=*/DotHtmlTableCellContents{ + std::string{"5"}, + }, + /*port=*/std::nullopt, + /*colspan=*/std::nullopt, + }, + }, + }, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("json is array") { + nlohmann::json j = std::vector{ + 3, + 5, + 4, + 3, + 2, + }; + + auto mk_row = [](std::string const &x) -> DotHtmlTableRow { + return DotHtmlTableRow{ + /*cells=*/{ + DotHtmlTableCell{ + /*contents=*/DotHtmlTableCellContents{ + x, + }, + /*port=*/std::nullopt, + /*colspan=*/std::nullopt, + }, + }, + }; + }; + + DotHtmlTable result = dot_html_table_from_json(j); + DotHtmlTable correct = DotHtmlTable{ + /*border=*/0_n, + /*cellborder=*/1_n, + /*cellspacing=*/0_n, + /*rows=*/ + { + mk_row("3"), + mk_row("5"), + mk_row("4"), + mk_row("3"), + mk_row("2"), + }, + }; + + CHECK(result == correct); + } + + SUBCASE("json is object") { + nlohmann::json j; + + j["hello"] = 3; + j["world"] = "yes"; + + auto mk_kv_row = [](std::string const &k, + std::string const &v) -> DotHtmlTableRow { + return DotHtmlTableRow{ + /*cells=*/{ + DotHtmlTableCell{ + /*contents=*/DotHtmlTableCellContents{ + k, + }, + /*port=*/std::nullopt, + /*colspan=*/std::nullopt, + }, + DotHtmlTableCell{ + /*contents=*/DotHtmlTableCellContents{ + v, + }, + /*port=*/std::nullopt, + /*colspan=*/std::nullopt, + }, + }, + }; + }; + + DotHtmlTable result = dot_html_table_from_json(j); + DotHtmlTable correct = DotHtmlTable{ + /*border=*/0_n, + /*cellborder=*/1_n, + /*cellspacing=*/0_n, + /*rows=*/ + { + mk_kv_row("hello", "3"), + mk_kv_row("world", "yes"), + }, + }; + + CHECK(result == correct); + } + + SUBCASE("json is nested objects") { + nlohmann::json j; + + j["hello"] = 3; + j["world"] = "yes"; + j["two"] = nlohmann::json{ + {"abc", 5}, + {"def", "no"}, + }; + j["red"] = nlohmann::json{ + {"blue", "green"}, + }; + + auto mk_kv_row = + [](std::string const &k, + DotHtmlTableCellContents const &v) -> DotHtmlTableRow { + return DotHtmlTableRow{ + /*cells=*/{ + DotHtmlTableCell{ + /*contents=*/DotHtmlTableCellContents{ + k, + }, + /*port=*/std::nullopt, + /*colspan=*/std::nullopt, + }, + DotHtmlTableCell{ + /*contents=*/v, + /*port=*/std::nullopt, + /*colspan=*/std::nullopt, + }, + }, + }; + }; + + DotHtmlTable result = dot_html_table_from_json(j); + DotHtmlTable correct = DotHtmlTable{ + /*border=*/0_n, + /*cellborder=*/1_n, + /*cellspacing=*/0_n, + /*rows=*/ + { + mk_kv_row("hello", DotHtmlTableCellContents{"3"}), + mk_kv_row( + "red", + DotHtmlTableCellContents{ + DotHtmlTable{ + /*border=*/0_n, + /*cellborder=*/1_n, + /*cellspacing=*/0_n, + /*rows=*/ + { + mk_kv_row("blue", + DotHtmlTableCellContents{"green"}), + }, + }, + }), + mk_kv_row( + "two", + DotHtmlTableCellContents{ + DotHtmlTable{ + /*border=*/0_n, + /*cellborder=*/1_n, + /*cellspacing=*/0_n, + /*rows=*/ + { + mk_kv_row("abc", DotHtmlTableCellContents{"5"}), + mk_kv_row("def", DotHtmlTableCellContents{"no"}), + }, + }, + }), + mk_kv_row("world", DotHtmlTableCellContents{"yes"}), + }, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/dot/render_dot_html_table_to_string.cc b/lib/utils/test/src/utils/dot/render_dot_html_table_to_string.cc new file mode 100644 index 0000000000..77e58ef6a0 --- /dev/null +++ b/lib/utils/test/src/utils/dot/render_dot_html_table_to_string.cc @@ -0,0 +1,34 @@ +#include "utils/dot/render_dot_html_table_to_string.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("render_dot_html_table_to_string") { + DotHtmlTable input = DotHtmlTable{ + /*border=*/0_n, + /*cellborder=*/1_n, + /*cellspacing=*/0_n, + /*rows=*/ + { + DotHtmlTableRow{ + /*cells=*/{ + DotHtmlTableCell{ + /*contents=*/DotHtmlTableCellContents{ + std::string{"5"}, + }, + /*port=*/std::nullopt, + /*colspan=*/std::nullopt, + }, + }, + }, + }, + }; + + std::string result = render_dot_html_table_to_string(input); + std::string correct = "
5
"; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.cc new file mode 100644 index 0000000000..37d7dcd712 --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.cc @@ -0,0 +1,75 @@ +#include "utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.h" +#include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("dataflow_graph_as_dot") { + DataflowGraph g = DataflowGraph::create(); + + NodeAddedResult n1_added = g.add_node({}, 1_n); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({}, 1_n); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({}, 1_n); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + NodeAddedResult n4_added = g.add_node({o1, o2, o3}, 1_n); + Node n4 = n4_added.node; + DataflowOutput o4 = get_only(n4_added.outputs); + + auto get_node_label = [&](Node n) -> std::string { + if (n == n1) { + return "n1"; + } else if (n == n2) { + return "n2"; + } else if (n == n3) { + return "n3"; + } else { + ASSERT(n == n4); + return "n4"; + } + }; + + auto get_input_label = [&](DataflowInput const &i) -> std::string { + return fmt::format("{}_{}", get_node_label(i.node), i.idx); + }; + + auto get_output_label = [&](DataflowOutput const &o) -> std::string { + return fmt::format("{}_{}", get_node_label(o.node), o.idx); + }; + + std::string result = dataflow_graph_as_dot( + g, get_node_label, get_input_label, get_output_label); + + std::string correct = R"EXPECTED_OUTPUT(digraph taskgraph { + node0 [label=< + +
(no inputs)
n1
n1_0
>,shape=plaintext]; + node1 [label=< + +
(no inputs)
n2
n2_0
>,shape=plaintext]; + node2 [label=< + +
(no inputs)
n3
n3_0
>,shape=plaintext]; + node3 [label=< + + + +
n4_0n4_1n4_2
n4
n4_0
>,shape=plaintext]; + node1:o0 -> node3:i1; + node2:o0 -> node3:i2; + node0:o0 -> node3:i0; +})EXPECTED_OUTPUT"; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/directed_edge_query.cc b/lib/utils/test/src/utils/graph/digraph/directed_edge_query.cc index ee7ead009e..87a734a05a 100644 --- a/lib/utils/test/src/utils/graph/digraph/directed_edge_query.cc +++ b/lib/utils/test/src/utils/graph/digraph/directed_edge_query.cc @@ -24,7 +24,10 @@ TEST_SUITE(FF_TEST_SUITE) { DirectedEdge e1 = DirectedEdge{n1, n2}; DirectedEdge e2 = DirectedEdge{n2, n3}; - DirectedEdgeQuery query = DirectedEdgeQuery{query_set{n1}, query_set{n2}}; + DirectedEdgeQuery query = DirectedEdgeQuery{ + query_set::match_single_value(n1), + query_set::match_single_value(n2), + }; CHECK(matches_edge(query, e1)); CHECK_FALSE(matches_edge(query, e2)); @@ -40,27 +43,39 @@ TEST_SUITE(FF_TEST_SUITE) { DirectedEdge e3 = DirectedEdge{n3, n4}; SUBCASE("standard intersection") { - DirectedEdgeQuery q1 = - DirectedEdgeQuery{query_set{n1, n2}, query_set{n2, n3}}; - DirectedEdgeQuery q2 = - DirectedEdgeQuery{query_set{n2, n3}, query_set{n3, n4}}; + DirectedEdgeQuery q1 = DirectedEdgeQuery{ + query_set::match_values_in(std::set{n1, n2}), + query_set::match_values_in(std::set{n2, n3}), + }; + DirectedEdgeQuery q2 = DirectedEdgeQuery{ + query_set::match_values_in(std::set{n2, n3}), + query_set::match_values_in(std::set{n3, n4}), + }; DirectedEdgeQuery result = query_intersection(q1, q2); - DirectedEdgeQuery expected = - DirectedEdgeQuery{query_set{n2}, query_set{n3}}; + DirectedEdgeQuery expected = DirectedEdgeQuery{ + query_set::match_single_value(n2), + query_set::match_single_value(n3), + }; CHECK(result == expected); } SUBCASE("intersection with matchall") { - DirectedEdgeQuery q1 = - DirectedEdgeQuery{query_set{n1, n2}, matchall()}; - DirectedEdgeQuery q2 = - DirectedEdgeQuery{matchall(), query_set{n3, n4}}; + DirectedEdgeQuery q1 = DirectedEdgeQuery{ + query_set::match_values_in(std::set{n1, n2}), + query_set::matchall(), + }; + DirectedEdgeQuery q2 = DirectedEdgeQuery{ + query_set::matchall(), + query_set::match_values_in(std::set{n3, n4}), + }; DirectedEdgeQuery result = query_intersection(q1, q2); - DirectedEdgeQuery expected = - DirectedEdgeQuery{query_set{n1, n2}, query_set{n3, n4}}; + DirectedEdgeQuery expected = DirectedEdgeQuery{ + query_set::match_values_in(std::set{n1, n2}), + query_set::match_values_in(std::set{n3, n4}), + }; CHECK(result == expected); } diff --git a/lib/utils/test/src/utils/graph/instances/adjacency_digraph.cc b/lib/utils/test/src/utils/graph/instances/adjacency_digraph.cc index 37c957df18..d07e5d5703 100644 --- a/lib/utils/test/src/utils/graph/instances/adjacency_digraph.cc +++ b/lib/utils/test/src/utils/graph/instances/adjacency_digraph.cc @@ -31,26 +31,56 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("query_nodes") { - CHECK(g.query_nodes(node_query_all()) == - std::unordered_set{n[0], n[1], n[2], n[3], n[4]}); + SUBCASE("query_all") { + std::unordered_set result = g.query_nodes(node_query_all()); + std::unordered_set correct = {n[0], n[1], n[2], n[3], n[4]}; + + CHECK(result == correct); + } - CHECK(g.query_nodes(NodeQuery{query_set{{n[0], n[2]}}}) == - std::unordered_set{n[0], n[2]}); + SUBCASE("set of nodes") { + NodeQuery query = NodeQuery{ + query_set::match_values_in(std::set{n[0], n[2]}), + }; + + std::unordered_set result = g.query_nodes(query); + std::unordered_set correct = {n[0], n[2]}; + + CHECK(result == correct); + } + } - SUBCASE("query_edges") { + SUBCASE("query_edges") { - std::unordered_set queried_edges = + SUBCASE("query_all") { + std::unordered_set result = g.query_edges(directed_edge_query_all()); - std::unordered_set expected = { - e[0], e[1], e[2], e[3], e[4]}; - CHECK(queried_edges == expected); - - queried_edges = g.query_edges(DirectedEdgeQuery{ - query_set{{n[0]}}, query_set{{n[1]}}}); - expected = std::unordered_set{e[0]}; - CHECK(queried_edges == expected); + + std::unordered_set correct = { + e.at(0), + e.at(1), + e.at(2), + e.at(3), + e.at(4), + }; + + CHECK(result == correct); + } + + SUBCASE("query for specific edge") { + DirectedEdgeQuery query = DirectedEdgeQuery{ + query_set::match_single_value(n.at(0)), + query_set::match_single_value(n.at(1)), + + }; + + std::unordered_set result = g.query_edges(query); + std::unordered_set correct = + std::unordered_set{e[0]}; + CHECK(result == correct); } } + SUBCASE("remove_node_unsafe") { g.remove_node_unsafe(n[0]); diff --git a/lib/utils/test/src/utils/graph/instances/adjacency_multidigraph.cc b/lib/utils/test/src/utils/graph/instances/adjacency_multidigraph.cc index 116b8ffbd3..d69e1ee71e 100644 --- a/lib/utils/test/src/utils/graph/instances/adjacency_multidigraph.cc +++ b/lib/utils/test/src/utils/graph/instances/adjacency_multidigraph.cc @@ -74,30 +74,43 @@ TEST_SUITE(FF_TEST_SUITE) { check_state({n1, n2}, {e1, e2, e3, e4}); { - MultiDiEdgeQuery input = - MultiDiEdgeQuery{{n1}, query_set::matchall()}; + MultiDiEdgeQuery input = MultiDiEdgeQuery{ + query_set::match_single_value(n1), + query_set::matchall(), + }; + std::unordered_set result = g.query_edges(input); std::unordered_set correct = {e1, e2, e3}; CHECK(result == correct); } { - MultiDiEdgeQuery input = - MultiDiEdgeQuery{query_set::matchall(), {n1}}; + MultiDiEdgeQuery input = MultiDiEdgeQuery{ + query_set::matchall(), + query_set::match_single_value(n1), + }; + std::unordered_set result = g.query_edges(input); std::unordered_set correct = {e1, e2, e4}; CHECK(result == correct); } { - MultiDiEdgeQuery input = MultiDiEdgeQuery{{n1}, {n2}}; + MultiDiEdgeQuery input = MultiDiEdgeQuery{ + query_set::match_single_value(n1), + query_set::match_single_value(n2), + }; std::unordered_set result = g.query_edges(input); std::unordered_set correct = {e3}; CHECK(result == correct); } { - MultiDiEdgeQuery input = MultiDiEdgeQuery{{n1}, {n1}}; + MultiDiEdgeQuery input = MultiDiEdgeQuery{ + query_set::match_single_value(n1), + query_set::match_single_value(n1), + }; + std::unordered_set result = g.query_edges(input); std::unordered_set correct = {e1, e2}; CHECK(result == correct); diff --git a/lib/utils/test/src/utils/graph/instances/unordered_set_kwarg_dataflow_graph.cc b/lib/utils/test/src/utils/graph/instances/unordered_set_kwarg_dataflow_graph.cc new file mode 100644 index 0000000000..1709940262 --- /dev/null +++ b/lib/utils/test/src/utils/graph/instances/unordered_set_kwarg_dataflow_graph.cc @@ -0,0 +1,140 @@ +#include "utils/graph/instances/unordered_set_kwarg_dataflow_graph.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_edge_query.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output_query.h" +#include "utils/graph/node/node_query.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("UnorderedSetKwargDataflowGraph") { + KwargDataflowGraph g = KwargDataflowGraph::create< + UnorderedSetKwargDataflowGraph>(); + + { + std::unordered_set result = g.query_nodes(node_query_all()); + std::unordered_set correct = {}; + REQUIRE(result == correct); + } + + { + std::unordered_set> result = + g.query_edges(kwarg_dataflow_edge_query_all()); + std::unordered_set> correct = {}; + REQUIRE(result == correct); + } + + { + std::unordered_set> result = + g.query_outputs(kwarg_dataflow_output_query_all()); + std::unordered_set> correct = {}; + REQUIRE(result == correct); + } + + KwargNodeAddedResult added = g.add_node( + /*inputs=*/{}, + /*output_slots=*/{ + "output_1", + "output_2", + "output_3", + }); + + KwargDataflowOutput added_output_1 = + added.outputs.at("output_1"); + + KwargDataflowOutput added_output_2 = + added.outputs.at("output_2"); + + KwargDataflowOutput added_output_3 = + added.outputs.at("output_3"); + + { + std::unordered_set result = g.query_nodes(node_query_all()); + std::unordered_set correct = {added.node}; + REQUIRE(result == correct); + } + + { + std::unordered_set> result = + g.query_edges(kwarg_dataflow_edge_query_all()); + std::unordered_set> correct = {}; + REQUIRE(result == correct); + } + + { + std::unordered_set> result = + g.query_outputs(kwarg_dataflow_output_query_all()); + std::unordered_set> correct = + unordered_set_of(values(added.outputs)); + REQUIRE(result == correct); + } + + KwargNodeAddedResult added2 = g.add_node( + /*inputs=*/ + { + { + "input_1", + added_output_1, + }, + { + "input_2", + added_output_3, + }, + }, + /*output_slots=*/{ + "output_1", + }); + + KwargDataflowOutput added2_output_1 = + KwargDataflowOutput{ + added2.outputs.at("output_1"), + }; + + { + std::unordered_set result = g.query_nodes(node_query_all()); + std::unordered_set correct = {added.node, added2.node}; + REQUIRE(result == correct); + } + + { + std::unordered_set> result = + g.query_edges(kwarg_dataflow_edge_query_all()); + + auto mk_edge = + [](KwargDataflowOutput const &src, + Node const &dst_node, + std::string const &dst_slot) -> KwargDataflowEdge { + return KwargDataflowEdge{ + /*src=*/src, + /*dst=*/ + KwargDataflowInput{ + dst_node, + dst_slot, + }, + }; + }; + + std::unordered_set> correct = { + mk_edge(added_output_1, added2.node, "input_1"), + mk_edge(added_output_3, added2.node, "input_2"), + }; + + REQUIRE(result == correct); + } + + { + std::unordered_set> result = + g.query_outputs(kwarg_dataflow_output_query_all()); + + auto get_output_set = [](KwargNodeAddedResult const &r) { + return unordered_set_of(values(r.outputs)); + }; + + std::unordered_set> correct = + set_union(get_output_set(added), get_output_set(added2)); + + REQUIRE(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_data_from_kwarg_dataflow_graph_data.cc b/lib/utils/test/src/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_data_from_kwarg_dataflow_graph_data.cc new file mode 100644 index 0000000000..c2fb348075 --- /dev/null +++ b/lib/utils/test/src/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_data_from_kwarg_dataflow_graph_data.cc @@ -0,0 +1,135 @@ +#include "utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_data_from_kwarg_dataflow_graph_data.h" +#include "utils/containers/reversed.h" +#include "utils/containers/sorted.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("dataflow_graph_data_from_kwarg_dataflow_graph_data") { + Node n0 = Node{0}; + Node n1 = Node{1}; + Node n2 = Node{2}; + + KwargDataflowOutput o0 = KwargDataflowOutput{ + /*node=*/n0, + /*slot_name=*/"a", + }; + + KwargDataflowOutput o1 = KwargDataflowOutput{ + /*node=*/n1, + /*slot_name=*/"b", + }; + + KwargDataflowOutput o2 = KwargDataflowOutput{ + /*node=*/n1, + /*slot_name=*/"c", + }; + + KwargDataflowOutput o3 = KwargDataflowOutput{ + /*node=*/n2, + /*slot_name=*/"d", + }; + + auto mk_kwarg_edge = + [](KwargDataflowOutput const &src, + Node dst_node, + std::string dst_slot) -> KwargDataflowEdge { + return KwargDataflowEdge{ + src, + KwargDataflowInput{ + dst_node, + dst_slot, + }, + }; + }; + + KwargDataflowGraphData input = + KwargDataflowGraphData{ + /*nodes=*/{ + n0, + n1, + n2, + }, + /*edges=*/ + { + mk_kwarg_edge(o1, n2, "z"), + mk_kwarg_edge(o2, n2, "y"), + mk_kwarg_edge(o0, n2, "x"), + }, + /*outputs=*/ + { + o0, + o1, + o2, + o3, + }, + }; + + std::function( + std::unordered_set const &)> + slot_ordering = [](std::unordered_set const &slots) + -> std::vector { return reversed(sorted(slots)); }; + + DataflowGraphData result = + dataflow_graph_data_from_kwarg_dataflow_graph_data(input, + slot_ordering); + + DataflowGraphData correct = [&]() { + DataflowOutput correct_o0 = DataflowOutput{ + /*node=*/n0, + /*idx=*/0_n, + }; + + DataflowOutput correct_o1 = DataflowOutput{ + /*node=*/n1, + /*idx=*/1_n, + }; + + DataflowOutput correct_o2 = DataflowOutput{ + /*node=*/n1, + /*idx=*/0_n, + }; + + DataflowOutput correct_o3 = DataflowOutput{ + /*node=*/n2, + /*idx=*/0_n, + }; + + auto mk_edge = [](DataflowOutput const &src, + Node dst_node, + nonnegative_int dst_idx) -> DataflowEdge { + return DataflowEdge{ + src, + DataflowInput{ + dst_node, + dst_idx, + }, + }; + }; + + return DataflowGraphData{ + /*nodes=*/{ + n0, + n1, + n2, + }, + /*edges=*/ + { + mk_edge(correct_o0, n2, 2_n), + mk_edge(correct_o1, n2, 0_n), + mk_edge(correct_o2, n2, 1_n), + }, + /*outputs=*/ + { + correct_o0, + correct_o1, + correct_o2, + correct_o3, + }, + }; + }(); + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_from_kwarg_dataflow_graph.cc b/lib/utils/test/src/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_from_kwarg_dataflow_graph.cc new file mode 100644 index 0000000000..9ac43c2ee2 --- /dev/null +++ b/lib/utils/test/src/utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_from_kwarg_dataflow_graph.cc @@ -0,0 +1,92 @@ +#include "utils/graph/kwarg_dataflow_graph/algorithms/dataflow_graph_from_kwarg_dataflow_graph.h" +#include "utils/containers/get_only.h" +#include "utils/containers/require_only_key.h" +#include "utils/containers/reversed.h" +#include "utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.h" +#include "utils/graph/dataflow_graph/algorithms/dataflow_graphs_are_isomorphic.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/instances/unordered_set_kwarg_dataflow_graph.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_as_dot.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("dataflow_graph_from_kwarg_dataflow_graph") { + + KwargDataflowGraphView input = [] { + KwargDataflowGraph g = + KwargDataflowGraph::template create< + UnorderedSetKwargDataflowGraph>(); + + KwargNodeAddedResult n0_added = g.add_node( + /*inputs=*/std::unordered_map>{}, + /*outputs=*/std::unordered_set{ + "a", + }); + + KwargDataflowOutput o0 = + require_only_key(n0_added.outputs, std::string{"a"}); + + KwargNodeAddedResult n1_added = g.add_node( + /*inputs=*/std::unordered_map>{}, + /*outputs=*/std::unordered_set{ + "b", + "c", + }); + + KwargDataflowOutput o1 = n1_added.outputs.at("b"); + KwargDataflowOutput o2 = n1_added.outputs.at("c"); + + KwargNodeAddedResult n2_added = g.add_node( + /*inputs=*/ + std::unordered_map>{ + {"z", o1}, + {"y", o2}, + {"x", o0}, + }, + /*outputs=*/std::unordered_set{ + "d", + }); + + return g; + }(); + + std::function( + std::unordered_set const &)> + slot_ordering = [](std::unordered_set const &slots) + -> std::vector { return reversed(sorted(slots)); }; + + DataflowGraphView result = + dataflow_graph_from_kwarg_dataflow_graph(input, slot_ordering); + + DataflowGraphView correct = [] { + DataflowGraph g = DataflowGraph::create(); + + NodeAddedResult n0_added = g.add_node( + /*inputs=*/{}, + /*num_outputs=*/1_n); + + DataflowOutput o0 = get_only(n0_added.outputs); + + NodeAddedResult n1_added = g.add_node( + /*inputs=*/{}, + /*num_outputs=*/2_n); + + DataflowOutput o1 = n1_added.outputs.at(0); + DataflowOutput o2 = n1_added.outputs.at(1); + + NodeAddedResult n2_added = g.add_node( + /*inputs=*/{o2, o1, o0}, + /*num_outputs=*/1_n); + + return g; + }(); + + CHECK(dataflow_graphs_are_isomorphic(result, correct)); + } +} diff --git a/lib/utils/test/src/utils/graph/kwarg_dataflow_graph/algorithms/view_from_kwarg_dataflow_graph_data.cc b/lib/utils/test/src/utils/graph/kwarg_dataflow_graph/algorithms/view_from_kwarg_dataflow_graph_data.cc new file mode 100644 index 0000000000..c4667bf746 --- /dev/null +++ b/lib/utils/test/src/utils/graph/kwarg_dataflow_graph/algorithms/view_from_kwarg_dataflow_graph_data.cc @@ -0,0 +1,90 @@ +#include "utils/graph/kwarg_dataflow_graph/algorithms/view_from_kwarg_dataflow_graph_data.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_edges.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_outputs.h" +#include "utils/graph/node/algorithms.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("view_from_kwarg_dataflow_graph_data") { + auto mk_edge = [](Node src, + std::optional src_idx, + Node dst, + std::optional dst_idx) + -> KwargDataflowEdge> { + return KwargDataflowEdge>{ + /*src=*/KwargDataflowOutput>{ + /*node=*/src, + /*slot_name=*/src_idx, + }, + /*dst=*/ + KwargDataflowInput>{ + /*node=*/dst, + /*slot_name=*/dst_idx, + }, + }; + }; + + auto mk_output = [](Node src, std::optional src_idx) + -> KwargDataflowOutput> { + return KwargDataflowOutput>{ + /*node=*/src, + /*slot_name=*/src_idx, + }; + }; + + Node n0 = Node{0}; + Node n1 = Node{1}; + Node n2 = Node{2}; + + std::unordered_set all_nodes = {n0, n1, n2}; + + std::unordered_set>> all_edges = { + mk_edge(n0, 1, n1, 0), + mk_edge(n0, 1, n1, std::nullopt), + mk_edge(n1, 2, n2, 3), + mk_edge(n0, std::nullopt, n2, 1), + }; + + std::unordered_set>> all_outputs = { + mk_output(n0, 1), + mk_output(n0, std::nullopt), + mk_output(n0, 4), + mk_output(n1, 2), + mk_output(n2, 4), + }; + + KwargDataflowGraphData> data = + KwargDataflowGraphData>{ + /*nodes=*/{n0, n1, n2}, + /*edges=*/all_edges, + /*outputs=*/all_outputs, + }; + + KwargDataflowGraphView> g = + view_from_kwarg_dataflow_graph_data(data); + + SUBCASE("get_nodes") { + std::unordered_set result = get_nodes(g); + std::unordered_set correct = all_nodes; + ASSERT(result == correct); + } + + SUBCASE("get_all_kwarg_dataflow_edges") { + std::unordered_set>> result = + get_all_kwarg_dataflow_edges(g); + std::unordered_set>> correct = + all_edges; + ASSERT(result == correct); + } + + SUBCASE("get_all_kwarg_dataflow_outputs") { + std::unordered_set>> result = + get_all_kwarg_dataflow_outputs(g); + std::unordered_set>> correct = + all_outputs; + ASSERT(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc b/lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc index ce4d7a373b..f427a78b64 100644 --- a/lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc +++ b/lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc @@ -26,7 +26,12 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("add_node") { Node n3 = g.add_node(); - std::unordered_set result = g.query_nodes(NodeQuery{{n3}}); + + NodeQuery query = NodeQuery{ + query_set::match_single_value(n3), + }; + + std::unordered_set result = g.query_nodes(query); std::unordered_set correct = {n3}; CHECK(result == correct); } @@ -34,8 +39,12 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("add_edge") { SUBCASE("non-duplicate edge") { MultiDiEdge e7 = g.add_edge(n2, n1); - std::unordered_set result = - g.query_edges(MultiDiEdgeQuery({n2}, {n1})); + MultiDiEdgeQuery query = MultiDiEdgeQuery{ + query_set::match_single_value(n2), + query_set::match_single_value(n1), + }; + + std::unordered_set result = g.query_edges(query); std::unordered_set correct = {e7}; CHECK(result == correct); } @@ -44,8 +53,12 @@ TEST_SUITE(FF_TEST_SUITE) { MultiDiEdge e7 = g.add_edge(n2, n1); MultiDiEdge e8 = g.add_edge(n2, n1); - std::unordered_set result = - g.query_edges(MultiDiEdgeQuery({n2}, {n1})); + MultiDiEdgeQuery query = MultiDiEdgeQuery{ + query_set::match_single_value(n2), + query_set::match_single_value(n1), + }; + + std::unordered_set result = g.query_edges(query); std::unordered_set correct = {e7, e8}; CHECK(result == correct); } @@ -54,35 +67,47 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("remove_node") { g.remove_node(n0); - std::unordered_set node_result = g.query_nodes(NodeQuery{{n0}}); + NodeQuery node_query = NodeQuery{ + query_set::match_single_value(n0), + }; + std::unordered_set node_result = g.query_nodes(node_query); std::unordered_set node_correct = {}; CHECK(node_result == node_correct); - std::unordered_set edge_result = - g.query_edges(MultiDiEdgeQuery({n0}, {n1, n2})); + MultiDiEdgeQuery edge_query = MultiDiEdgeQuery{ + query_set::match_single_value(n0), + query_set::match_values_in(std::set{n1, n2}), + }; + std::unordered_set edge_result = g.query_edges(edge_query); std::unordered_set edge_correct = {}; CHECK(edge_result == edge_correct); } SUBCASE("remove_edge") { g.remove_edge(e3); - std::unordered_set result = - g.query_edges(MultiDiEdgeQuery({n1}, {n2})); + std::unordered_set result = g.query_edges(MultiDiEdgeQuery{ + query_set::match_single_value(n1), + query_set::match_single_value(n2), + }); std::unordered_set correct = {e4}; CHECK(result == correct); SUBCASE("remove non-duplicate edge") { g.remove_edge(e0); - std::unordered_set result = - g.query_edges(MultiDiEdgeQuery({n0}, {n2})); + std::unordered_set result = g.query_edges(MultiDiEdgeQuery{ + query_set::match_single_value(n0), + query_set::match_single_value(n2), + }); std::unordered_set correct = {}; CHECK(result == correct); } SUBCASE("remove duplicate edge") { g.remove_edge(e1); - std::unordered_set result = - g.query_edges(MultiDiEdgeQuery({n1}, {n0})); + std::unordered_set result = g.query_edges(MultiDiEdgeQuery{ + query_set::match_single_value(n1), + query_set::match_single_value(n0), + }); std::unordered_set correct = {e2}; CHECK(result == correct); } @@ -90,29 +115,44 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("query_nodes") { SUBCASE("all nodes") { - std::unordered_set result = - g.query_nodes(NodeQuery{{n0, n1, n2}}); + NodeQuery query = NodeQuery{ + query_set::match_values_in(std::set{n0, n1, n2}), + }; + + std::unordered_set result = g.query_nodes(query); std::unordered_set correct = {n0, n1, n2}; CHECK(result == correct); } SUBCASE("specific nodes") { - std::unordered_set result = g.query_nodes(NodeQuery{{n0, n2}}); + NodeQuery query = NodeQuery{ + query_set::match_values_in(std::set{n0, n2}), + }; + + std::unordered_set result = g.query_nodes(query); std::unordered_set correct = {n0, n2}; CHECK(result == correct); } SUBCASE("matchall") { - std::unordered_set result = - g.query_nodes(NodeQuery{matchall()}); + NodeQuery query = NodeQuery{ + query_set::matchall(), + }; + + std::unordered_set result = g.query_nodes(query); std::unordered_set correct = {n0, n1, n2}; CHECK(result == correct); } SUBCASE("nodes not in graph") { - Node n3 = Node(3); - Node n4 = Node(4); - std::unordered_set result = g.query_nodes(NodeQuery{{n3, n4}}); + Node n3 = Node{3}; + Node n4 = Node{4}; + + NodeQuery query = NodeQuery{ + query_set::match_values_in(std::set{n3, n4}), + }; + + std::unordered_set result = g.query_nodes(query); std::unordered_set correct = {}; CHECK(result == correct); } @@ -120,42 +160,64 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("query_edges") { SUBCASE("all edges") { - std::unordered_set result = - g.query_edges(MultiDiEdgeQuery({n0, n1, n2}, {n0, n1, n2})); + MultiDiEdgeQuery query = MultiDiEdgeQuery{ + query_set::match_values_in(std::set{n0, n1, n2}), + query_set::match_values_in(std::set{n0, n1, n2}), + }; + + std::unordered_set result = g.query_edges(query); std::unordered_set correct = {e0, e1, e2, e3, e4, e5, e6}; CHECK(result == correct); } SUBCASE("edges from n1") { - std::unordered_set result = - g.query_edges(MultiDiEdgeQuery({n1}, {n0, n1, n2})); + MultiDiEdgeQuery query = MultiDiEdgeQuery{ + query_set::match_single_value(n1), + query_set::match_values_in(std::set{n0, n1, n2}), + }; + + std::unordered_set result = g.query_edges(query); std::unordered_set correct = {e1, e2, e3, e4}; CHECK(result == correct); } SUBCASE("edges to n2") { - std::unordered_set result = - g.query_edges(MultiDiEdgeQuery({n0, n1, n2}, {n2})); + MultiDiEdgeQuery query = MultiDiEdgeQuery{ + query_set::match_values_in(std::set{n0, n1, n2}), + query_set::match_single_value(n2), + }; + + std::unordered_set result = g.query_edges(query); std::unordered_set correct = {e0, e3, e4, e6}; CHECK(result == correct); } SUBCASE("matchall") { - std::unordered_set result = - g.query_edges(MultiDiEdgeQuery(matchall(), matchall())); + MultiDiEdgeQuery query = MultiDiEdgeQuery{ + query_set::matchall(), + query_set::matchall(), + }; + + std::unordered_set result = g.query_edges(query); std::unordered_set correct = {e0, e1, e2, e3, e4, e5, e6}; CHECK(result == correct); } SUBCASE("nodes that don't exist") { - Node n3 = Node(3); - Node n4 = Node(4); - std::unordered_set result = - g.query_edges(MultiDiEdgeQuery({n1, n3}, {n4})); + Node n3 = Node{3}; + Node n4 = Node{4}; + + MultiDiEdgeQuery query = MultiDiEdgeQuery{ + query_set::match_values_in(std::set{n1, n3}), + query_set::match_single_value(n4), + }; + + std::unordered_set result = g.query_edges(query); std::unordered_set correct = {}; CHECK(result == correct); } } + SUBCASE("get_multidiedge_src") { Node result = g.get_multidiedge_src(e0); Node correct = n0; diff --git a/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/permute_node_ids.cc b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/permute_node_ids.cc index 1e7ad87d88..7466c23943 100644 --- a/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/permute_node_ids.cc +++ b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/permute_node_ids.cc @@ -91,15 +91,21 @@ TEST_SUITE(FF_TEST_SUITE) { // queries to check the through-node-permutation querying logic SUBCASE("query_nodes(NodeQuery)") { SUBCASE("check access to old nodes") { - std::unordered_set result_nodes = - result.query_nodes(NodeQuery{n0}); + NodeQuery query = NodeQuery{ + query_set::match_single_value(n0), + }; + + std::unordered_set result_nodes = result.query_nodes(query); std::unordered_set correct = {}; CHECK(result_nodes == correct); } SUBCASE("check access to new nodes") { - std::unordered_set result_nodes = - result.query_nodes(NodeQuery{new_node0}); + NodeQuery query = NodeQuery{ + query_set::match_single_value(new_node0), + }; + + std::unordered_set result_nodes = result.query_nodes(query); std::unordered_set correct = {new_node0}; CHECK(result_nodes == correct); } diff --git a/lib/utils/test/src/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graphs_are_isomorphic_under.cc b/lib/utils/test/src/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graphs_are_isomorphic_under.cc new file mode 100644 index 0000000000..e7b4e176f1 --- /dev/null +++ b/lib/utils/test/src/utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graphs_are_isomorphic_under.cc @@ -0,0 +1,57 @@ +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graphs_are_isomorphic_under.h" +#include "utils/bidict/algorithms/bidict_from_keys_and_values.h" +#include "utils/containers/get_all_permutations.h" +#include "utils/containers/vector_of.h" +#include "utils/graph/instances/unordered_set_open_kwarg_dataflow_graph.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("open_kwarg_dataflow_graphs_are_isomorphic_under") { + auto mk_graph = [] { + OpenKwargDataflowGraph g = + OpenKwargDataflowGraph::template create< + UnorderedSetOpenKwargDataflowGraph>(); + + KwargNodeAddedResult n1_added = g.add_node( + /*inputs=*/{}, + /*outputs=*/{}); + + KwargNodeAddedResult n2_added = g.add_node( + /*inputs=*/{}, + /*outputs=*/{}); + + KwargNodeAddedResult n3_added = g.add_node( + /*inputs=*/{}, + /*outputs=*/{}); + + KwargNodeAddedResult n4_added = g.add_node( + /*inputs=*/{}, + /*outputs=*/{}); + + return g; + }; + + OpenKwargDataflowGraphView lhs = mk_graph(); + OpenKwargDataflowGraphView rhs = mk_graph(); + + std::unordered_set lhs_nodes = get_nodes(lhs); + std::unordered_set rhs_nodes = get_nodes(rhs); + + std::vector ordered_lhs_nodes = vector_of(lhs_nodes); + + for (std::vector ordered_rhs_nodes : + get_all_permutations(rhs_nodes)) { + OpenKwargDataflowGraphIsomorphism iso = + OpenKwargDataflowGraphIsomorphism{ + /*node_mapping=*/bidict_from_keys_and_values(ordered_lhs_nodes, + ordered_rhs_nodes), + /*input_mapping=*/{}, + }; + + CHECK(open_kwarg_dataflow_graphs_are_isomorphic_under(lhs, rhs, iso)); + }; + } +} diff --git a/lib/utils/test/src/utils/graph/open_kwarg_dataflow_graph/algorithms/view_as_closed_kwarg_dataflow_graph_by_materializing_inputs.cc b/lib/utils/test/src/utils/graph/open_kwarg_dataflow_graph/algorithms/view_as_closed_kwarg_dataflow_graph_by_materializing_inputs.cc new file mode 100644 index 0000000000..e96468ac7a --- /dev/null +++ b/lib/utils/test/src/utils/graph/open_kwarg_dataflow_graph/algorithms/view_as_closed_kwarg_dataflow_graph_by_materializing_inputs.cc @@ -0,0 +1,150 @@ +#include "utils/graph/open_kwarg_dataflow_graph/algorithms/view_as_closed_kwarg_dataflow_graph_by_materializing_inputs.h" +#include "utils/containers/require_only_key.h" +#include "utils/graph/instances/unordered_set_kwarg_dataflow_graph.h" +#include "utils/graph/instances/unordered_set_open_kwarg_dataflow_graph.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_inputs.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_outputs.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/kwarg_dataflow_graphs_are_isomorphic.h" +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph.h" +#include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("view_as_closed_kwarg_dataflow_graph_by_materializing_inputs") { + OpenKwargDataflowGraphView open_g = [] { + OpenKwargDataflowGraph g = + OpenKwargDataflowGraph::template create< + UnorderedSetOpenKwargDataflowGraph>(); + + KwargDataflowGraphInput input1 = g.add_input("input1"); + KwargDataflowGraphInput input2 = g.add_input("input2"); + + KwargNodeAddedResult n1_added = g.add_node( + /*inputs=*/ + std::unordered_map>{ + { + 1, + OpenKwargDataflowValue{input1}, + }, + { + 3, + OpenKwargDataflowValue{input2}, + }, + { + 8, + OpenKwargDataflowValue{input1}, + }, + }, + /*outputs=*/std::unordered_set{ + 5, + }); + + KwargDataflowOutput n1_output = + require_only_key(n1_added.outputs, 5); + + KwargNodeAddedResult n2_added = g.add_node( + /*inputs=*/ + std::unordered_map>{ + { + 4, + OpenKwargDataflowValue{input2}, + }, + { + 1, + OpenKwargDataflowValue{n1_output}, + }, + }, + /*outputs=*/std::unordered_set{ + 5, + }); + + KwargDataflowOutput n2_output = + require_only_key(n2_added.outputs, 5); + + return g; + }(); + + std::pair>, + bidict, Node>> + result = + view_as_closed_kwarg_dataflow_graph_by_materializing_inputs(open_g); + + KwargDataflowGraphView> result_g = result.first; + + KwargDataflowGraphView> correct = [] { + KwargDataflowGraph> g = + KwargDataflowGraph>::template create< + UnorderedSetKwargDataflowGraph>>(); + + KwargNodeAddedResult> input1_added = g.add_node( + /*inputs=*/{}, + /*outputs=*/std::unordered_set>{ + std::nullopt, + }); + + KwargDataflowOutput> input1 = require_only_key( + input1_added.outputs, std::optional{std::nullopt}); + + KwargNodeAddedResult> input2_added = g.add_node( + /*inputs=*/{}, + /*outputs=*/std::unordered_set>{ + std::nullopt, + }); + + KwargDataflowOutput> input2 = require_only_key( + input2_added.outputs, std::optional{std::nullopt}); + + KwargNodeAddedResult> n1_added = g.add_node( + /*inputs=*/ + std::unordered_map, + KwargDataflowOutput>>{ + { + 1, + input1, + }, + { + 3, + input2, + }, + { + 8, + input1, + }, + }, + /*outputs=*/std::unordered_set>{ + 5, + }); + + KwargDataflowOutput> n1_output = + require_only_key(n1_added.outputs, std::optional{5}); + + KwargNodeAddedResult> n2_added = g.add_node( + /*inputs=*/ + std::unordered_map, + KwargDataflowOutput>>{ + { + 4, + input2, + }, + { + 1, + n1_output, + }, + }, + /*outputs=*/std::unordered_set>{ + 5, + }); + + return g; + }(); + + ASSERT(get_nodes(result_g).size() == 4); + ASSERT(get_all_kwarg_dataflow_edges(result_g).size() == + get_all_open_kwarg_dataflow_edges(open_g).size()); + ASSERT(get_all_kwarg_dataflow_inputs(result_g).size() == 5); + ASSERT(get_all_kwarg_dataflow_outputs(result_g).size() == 4); + CHECK(kwarg_dataflow_graphs_are_isomorphic(result_g, correct)); + } +} diff --git a/lib/utils/test/src/utils/graph/query_set.cc b/lib/utils/test/src/utils/graph/query_set.cc new file mode 100644 index 0000000000..9045ebd76b --- /dev/null +++ b/lib/utils/test/src/utils/graph/query_set.cc @@ -0,0 +1,45 @@ +#include "utils/graph/query_set.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("query_set") { + SUBCASE("handles optional values correctly") { + std::optional nopt = std::nullopt; + std::optional three = 3; + std::optional five = 5; + + query_set> q1 = + query_set>::matchall(); + + query_set> q2 = + query_set>::match_values_in(std::set{ + nopt, + three, + }); + + query_set> q3 = + query_set>::match_single_value(nopt); + + query_set> q4 = + query_set>::match_none(); + + CHECK(includes(q1, nopt)); + CHECK(includes(q1, three)); + CHECK(includes(q1, five)); + + CHECK(includes(q2, nopt)); + CHECK(includes(q2, three)); + CHECK_FALSE(includes(q2, five)); + + CHECK(includes(q3, nopt)); + CHECK_FALSE(includes(q3, three)); + CHECK_FALSE(includes(q3, five)); + + CHECK_FALSE(includes(q4, nopt)); + CHECK_FALSE(includes(q4, three)); + CHECK_FALSE(includes(q4, five)); + } + } +} diff --git a/lib/utils/test/src/utils/graph/undirected/undirected_graph.cc b/lib/utils/test/src/utils/graph/undirected/undirected_graph.cc index c1537f1d9b..898c8aa154 100644 --- a/lib/utils/test/src/utils/graph/undirected/undirected_graph.cc +++ b/lib/utils/test/src/utils/graph/undirected/undirected_graph.cc @@ -25,26 +25,63 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("query_nodes") { - CHECK(g.query_nodes(node_query_all()) == - std::unordered_set{ - n.at(0), n.at(1), n.at(2), n.at(3), n.at(4)}); - - CHECK(g.query_nodes(NodeQuery{query_set{{n.at(0), n.at(2)}}}) == - std::unordered_set{n.at(0), n.at(2)}); + SUBCASE("query_all") { + std::unordered_set result = g.query_nodes(node_query_all()); + std::unordered_set correct = std::unordered_set{ + n.at(0), + n.at(1), + n.at(2), + n.at(3), + n.at(4), + }; + + CHECK(result == correct); + } + + SUBCASE("query for specific nodes") { + NodeQuery query = NodeQuery{ + query_set::match_values_in(std::set{n.at(0), n.at(2)}), + }; + + std::unordered_set result = g.query_nodes(query); + std::unordered_set correct = std::unordered_set{ + n.at(0), + n.at(2), + }; + + CHECK(result == correct); + } } SUBCASE("query_edges") { + SUBCASE("query_all") { + std::unordered_set result = + g.query_edges(undirected_edge_query_all()); + + std::unordered_set correct = { + e.at(0), + e.at(1), + e.at(2), + e.at(3), + e.at(4), + }; + + CHECK(result == correct); + } + + SUBCASE("query for specific edge") { + UndirectedEdgeQuery query = UndirectedEdgeQuery{ + query_set::match_values_in(std::set{n.at(0), n.at(1)}), + }; + + std::unordered_set result = g.query_edges(query); + std::unordered_set correct = + std::unordered_set{ + e.at(0), + }; - std::unordered_set queried_edges = - g.query_edges(undirected_edge_query_all()); - std::unordered_set expected = { - e.at(0), e.at(1), e.at(2), e.at(3), e.at(4)}; - CHECK(queried_edges == expected); - - queried_edges = g.query_edges( - UndirectedEdgeQuery{query_set{{n.at(0), n.at(1)}}}); - expected = std::unordered_set{e.at(0)}; - CHECK(queried_edges == expected); + CHECK(result == correct); + } } SUBCASE("remove_node_unsafe") { diff --git a/lib/utils/test/src/utils/nonempty_unordered_set/nonempty_unordered_set.cc b/lib/utils/test/src/utils/nonempty_unordered_set/nonempty_unordered_set.cc new file mode 100644 index 0000000000..4a29466b50 --- /dev/null +++ b/lib/utils/test/src/utils/nonempty_unordered_set/nonempty_unordered_set.cc @@ -0,0 +1,44 @@ +#include "utils/nonempty_unordered_set/nonempty_unordered_set.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("nonempty_unordered_set") { + SUBCASE("construct from initializer_list") { + SUBCASE("does not throw if nonempty") { + nonempty_unordered_set s{1, 2, 3}; + + CHECK(s.num_elements() == 3_p); + } + + SUBCASE("throws if empty") { + auto init_with_empty = []() -> void { + nonempty_unordered_set s{std::initializer_list{}}; + }; + + CHECK_THROWS(init_with_empty()); + } + } + + SUBCASE("construct from unordered_set") { + SUBCASE("does not throw if nonempty") { + nonempty_unordered_set s{ + std::unordered_set{1, 2, 3}, + }; + + CHECK(s.num_elements() == 3_p); + } + + SUBCASE("throws if empty") { + auto init_with_empty = []() -> void { + nonempty_unordered_set s{ + std::unordered_set{}, + }; + }; + + CHECK_THROWS(init_with_empty()); + } + } + } +} diff --git a/lib/utils/test/src/utils/one_to_many/one_to_many.cc b/lib/utils/test/src/utils/one_to_many/one_to_many.cc index 7d6e1c77d1..d2ea7d6a0b 100644 --- a/lib/utils/test/src/utils/one_to_many/one_to_many.cc +++ b/lib/utils/test/src/utils/one_to_many/one_to_many.cc @@ -31,9 +31,9 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("at_l") { - std::unordered_set result = m.at_l(1); + nonempty_unordered_set result = m.at_l(1); - std::unordered_set correct = {"one", "One", "ONE"}; + nonempty_unordered_set correct = {"one", "One", "ONE"}; CHECK(result == correct); } diff --git a/lib/utils/test/src/utils/one_to_many/one_to_many_transform_values.cc b/lib/utils/test/src/utils/one_to_many/one_to_many_transform_values.cc new file mode 100644 index 0000000000..003ac7657c --- /dev/null +++ b/lib/utils/test/src/utils/one_to_many/one_to_many_transform_values.cc @@ -0,0 +1,37 @@ +#include "utils/one_to_many/one_to_many_transform_values.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("one_to_many_transform_values") { + OneToMany input = OneToMany{ + { + "a", + {1, 2, 3}, + }, + { + "b", + {4, 6}, + }, + }; + + auto func = [](int x) -> std::string { return fmt::to_string(x); }; + + OneToMany result = + one_to_many_transform_values(input, func); + + OneToMany correct = { + { + "a", + {"1", "2", "3"}, + }, + { + "b", + {"4", "6"}, + }, + }; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/record_formatter.cc b/lib/utils/test/src/utils/record_formatter.cc index f0d396a123..71d11260a5 100644 --- a/lib/utils/test/src/utils/record_formatter.cc +++ b/lib/utils/test/src/utils/record_formatter.cc @@ -1,6 +1,8 @@ #include "utils/record_formatter.h" #include +using namespace ::FlexFlow; + std::string formatRecord(RecordFormatter const &formatter) { std::ostringstream oss; oss << formatter; @@ -9,7 +11,8 @@ std::string formatRecord(RecordFormatter const &formatter) { TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("RecordFormatter") { - RecordFormatter formatter; + RecordFormatter formatter = mk_empty_record(Orientation::HORIZONTAL); + SUBCASE("Appending string") { formatter << "Hello"; formatter << "World"; @@ -23,11 +26,10 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("Appending another RecordFormatter") { - RecordFormatter subFormatter; + RecordFormatter subFormatter = mk_empty_record(Orientation::VERTICAL); subFormatter << "Sub"; subFormatter << "Formatter"; - RecordFormatter formatter; formatter << "Hello"; formatter << subFormatter;