From 9369ae9085604b7bbb6b8a3954df49ea771844cd Mon Sep 17 00:00:00 2001 From: niko4299 Date: Wed, 6 Jul 2022 14:50:00 +0200 Subject: [PATCH 1/6] My version with map, will test tomorrow --- .gitignore | 2 ++ src/auth/models.cpp | 20 +++++++++++++++++ src/auth/models.hpp | 39 ++++++++++++++++++++++++++++++++++ src/memgraph.cpp | 37 ++++++++++++++++++++++++++++++-- src/query/frontend/ast/ast.lcp | 11 +++++----- src/query/interpreter.cpp | 7 +++--- src/query/interpreter.hpp | 3 ++- 7 files changed, 108 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index e1a4187b0..6cd9e8e46 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ *.swn *.swo *.swp + *~ .DS_Store .gdb_history @@ -26,6 +27,7 @@ src/query/frontend/opencypher/generated/ tags ve/ ve3/ +.cache/ perf.data* TAGS *.apollo_measurements diff --git a/src/auth/models.cpp b/src/auth/models.cpp index 9720c3596..bfb508679 100644 --- a/src/auth/models.cpp +++ b/src/auth/models.cpp @@ -185,15 +185,25 @@ bool operator==(const Permissions &first, const Permissions &second) { bool operator!=(const Permissions &first, const Permissions &second) { return !(first == second); } +LabelPermissions::LabelPermissions(const std::unordered_map &permissions) + : permissions_(permissions) {} + +void LabelPermissions::Grant(const std::string &permission) { permissions_[permission] = 1; } + Role::Role(const std::string &rolename) : rolename_(utils::ToLowerCase(rolename)) {} Role::Role(const std::string &rolename, const Permissions &permissions) : rolename_(utils::ToLowerCase(rolename)), permissions_(permissions) {} +Role::Role(const std::string &rolename, const Permissions &permissions, const LabelPermissions &labelPermissions) + : rolename_(utils::ToLowerCase(rolename)), permissions_(permissions), labelPermissions_(labelPermissions) {} + const std::string &Role::rolename() const { return rolename_; } const Permissions &Role::permissions() const { return permissions_; } Permissions &Role::permissions() { return permissions_; } +LabelPermissions &Role::labelPermissions() { return labelPermissions_; } + nlohmann::json Role::Serialize() const { nlohmann::json data = nlohmann::json::object(); data["rolename"] = rolename_; @@ -221,6 +231,13 @@ User::User(const std::string &username) : username_(utils::ToLowerCase(username) User::User(const std::string &username, const std::string &password_hash, const Permissions &permissions) : username_(utils::ToLowerCase(username)), password_hash_(password_hash), permissions_(permissions) {} +User::User(const std::string &username, const std::string &password_hash, const Permissions &permissions, + const LabelPermissions &labelPermissions) + : username_(utils::ToLowerCase(username)), + password_hash_(password_hash), + permissions_(permissions), + labelPermissions_(labelPermissions) {} + bool User::CheckPassword(const std::string &password) { if (password_hash_.empty()) return true; return VerifyPassword(password, password_hash_); @@ -273,6 +290,8 @@ const std::string &User::username() const { return username_; } const Permissions &User::permissions() const { return permissions_; } Permissions &User::permissions() { return permissions_; } +LabelPermissions &User::labelPermissions() { return labelPermissions_; } + const Role *User::role() const { if (role_.has_value()) { return &role_.value(); @@ -304,4 +323,5 @@ bool operator==(const User &first, const User &second) { return first.username_ == second.username_ && first.password_hash_ == second.password_hash_ && first.permissions_ == second.permissions_ && first.role_ == second.role_; } + } // namespace memgraph::auth diff --git a/src/auth/models.hpp b/src/auth/models.hpp index c6ed3e0ae..fdd56425f 100644 --- a/src/auth/models.hpp +++ b/src/auth/models.hpp @@ -89,16 +89,45 @@ bool operator==(const Permissions &first, const Permissions &second); bool operator!=(const Permissions &first, const Permissions &second); +class LabelPermissions final { + public: + explicit LabelPermissions(const std::unordered_map &permissions_ = {}); + + PermissionLevel Has(const std::string &label) const; + + void Grant(const std::string &label); + + void Revoke(const std::string &label); + + void Deny(const std::string &label); + + nlohmann::json Serialize() const; + + /// @throw AuthException if unable to deserialize. + static LabelPermissions Deserialize(const nlohmann::json &data); + + std::unordered_map permissions() const; + + private: + std::unordered_map permissions_; +}; + +bool operator==(const LabelPermissions &first, const LabelPermissions &second); + class Role final { public: Role(const std::string &rolename); Role(const std::string &rolename, const Permissions &permissions); + Role(const std::string &rolename, const Permissions &permissions, const LabelPermissions &labelPermissions); + const std::string &rolename() const; const Permissions &permissions() const; Permissions &permissions(); + LabelPermissions &labelPermissions(); + nlohmann::json Serialize() const; /// @throw AuthException if unable to deserialize. @@ -109,6 +138,7 @@ class Role final { private: std::string rolename_; Permissions permissions_; + LabelPermissions labelPermissions_; }; bool operator==(const Role &first, const Role &second); @@ -120,6 +150,9 @@ class User final { User(const std::string &username, const std::string &password_hash, const Permissions &permissions); + User(const std::string &username, const std::string &password_hash, const Permissions &permissions, + const LabelPermissions &labelPermissions); + /// @throw AuthException if unable to verify the password. bool CheckPassword(const std::string &password); @@ -139,6 +172,8 @@ class User final { const Role *role() const; + LabelPermissions &labelPermissions(); + nlohmann::json Serialize() const; /// @throw AuthException if unable to deserialize. @@ -151,7 +186,11 @@ class User final { std::string password_hash_; Permissions permissions_; std::optional role_; + LabelPermissions labelPermissions_; }; bool operator==(const User &first, const User &second); + } // namespace memgraph::auth + +// namespace memgraph::auth diff --git a/src/memgraph.cpp b/src/memgraph.cpp index d1a972c7c..8035af77e 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -501,7 +501,7 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler { if (first_user) { spdlog::info("{} is first created user. Granting all privileges.", username); - GrantPrivilege(username, memgraph::query::kPrivilegesAll); + GrantPrivilege(username, memgraph::query::kPrivilegesAll, {}); } return user_added; @@ -747,13 +747,18 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler { } void GrantPrivilege(const std::string &user_or_role, - const std::vector &privileges) override { + const std::vector &privileges, + const std::vector &labels) override { EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) { // TODO (mferencevic): should we first check that the // privilege is granted/denied/revoked before // unconditionally granting/denying/revoking it? permissions->Grant(permission); }); + if (labels.size() > 0) { + EditLabels(user_or_role, labels, + [](auto *labelPermissions, const auto &label) { labelPermissions->Grant(label); }); + } } void DenyPrivilege(const std::string &user_or_role, @@ -810,6 +815,34 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler { throw memgraph::query::QueryRuntimeException(e.what()); } } + + template + void EditLabels(const std::string &user_or_role, const std::vector &labels, const TEditFun &edit_fun) { + if (!std::regex_match(user_or_role, name_regex_)) { + throw memgraph::query::QueryRuntimeException("Invalid user or role name."); + } + try { + auto locked_auth = auth_->Lock(); + auto user = locked_auth->GetUser(user_or_role); + auto role = locked_auth->GetRole(user_or_role); + if (!user && !role) { + throw memgraph::query::QueryRuntimeException("User or role '{}' doesn't exist.", user_or_role); + } + if (user) { + for (const auto &label : labels) { + edit_fun(&user->labelPermissions(), label); + } + locked_auth->SaveUser(*user); + } else { + for (const auto &label : labels) { + edit_fun(&role->labelPermissions(), label); + } + locked_auth->SaveRole(*role); + } + } catch (const memgraph::auth::AuthException &e) { + throw memgraph::query::QueryRuntimeException(e.what()); + } + } }; class AuthChecker final : public memgraph::query::AuthChecker { diff --git a/src/query/frontend/ast/ast.lcp b/src/query/frontend/ast/ast.lcp index a618adf69..55b4b3e37 100644 --- a/src/query/frontend/ast/ast.lcp +++ b/src/query/frontend/ast/ast.lcp @@ -2239,10 +2239,11 @@ cpp<# (user "std::string" :scope :public) (role "std::string" :scope :public) (user-or-role "std::string" :scope :public) - (labels "std::vector" :scope :public) + (password "Expression *" :initval "nullptr" :scope :public :slk-save #'slk-save-ast-pointer :slk-load (slk-load-ast-pointer "Expression")) + (labels "std::vector" :scope :public) (privileges "std::vector" :scope :public)) (:public (lcp:define-enum action @@ -2265,14 +2266,14 @@ cpp<# #>cpp AuthQuery(Action action, std::string user, std::string role, std::string user_or_role, Expression *password, - std::vector privileges, std::vector labels) + std::vector labels ,std::vector privileges) : action_(action), user_(user), role_(role), user_or_role_(user_or_role), password_(password), - privileges_(privileges), - labels_(labels) {} + labels_(labels), + privileges_(privileges){} cpp<#) (:private #>cpp @@ -2297,7 +2298,7 @@ const std::vector kPrivilegesAll = { AuthQuery::Privilege::FREE_MEMORY, AuthQuery::Privilege::TRIGGER, AuthQuery::Privilege::CONFIG, AuthQuery::Privilege::STREAM, AuthQuery::Privilege::MODULE_READ, AuthQuery::Privilege::MODULE_WRITE, - AuthQuery::Privilege::WEBSOCKET + AuthQuery::Privilege::WEBSOCKET, AuthQuery::Privilege::LABELS}; cpp<# diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index c1d206a72..ed5b46e42 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -282,6 +282,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Pa std::string rolename = auth_query->role_; std::string user_or_role = auth_query->user_or_role_; std::vector privileges = auth_query->privileges_; + std::vector labels = auth_query->labels_; auto password = EvaluateOptionalExpression(auth_query->password_, &evaluator); Callback callback; @@ -311,7 +312,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Pa // If the license is not valid we create users with admin access if (!valid_enterprise_license) { spdlog::warn("Granting all the privileges to {}.", username); - auth->GrantPrivilege(username, kPrivilegesAll); + auth->GrantPrivilege(username, kPrivilegesAll, {}); } return std::vector>(); @@ -386,8 +387,8 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Pa }; return callback; case AuthQuery::Action::GRANT_PRIVILEGE: - callback.fn = [auth, user_or_role, privileges] { - auth->GrantPrivilege(user_or_role, privileges); + callback.fn = [auth, user_or_role, privileges, labels] { + auth->GrantPrivilege(user_or_role, privileges, labels); return std::vector>(); }; return callback; diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index 155cb28a6..ff240a327 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -99,7 +99,8 @@ class AuthQueryHandler { virtual std::vector> GetPrivileges(const std::string &user_or_role) = 0; /// @throw QueryRuntimeException if an error ocurred. - virtual void GrantPrivilege(const std::string &user_or_role, const std::vector &privileges) = 0; + virtual void GrantPrivilege(const std::string &user_or_role, const std::vector &privileges, + const std::vector &labels) = 0; /// @throw QueryRuntimeException if an error ocurred. virtual void DenyPrivilege(const std::string &user_or_role, const std::vector &privileges) = 0; From 4366085d89d8c6d64d18e7c407d5e87440ccdf65 Mon Sep 17 00:00:00 2001 From: niko4299 Date: Thu, 7 Jul 2022 11:26:43 +0200 Subject: [PATCH 2/6] GRANT, DENY, REVOKE all saving to rocksdb and working --- config/generate.py | 38 +- query_modules/example.py | 24 +- query_modules/graph_analyzer.py | 202 +++-- query_modules/mgp_networkx.py | 97 +-- query_modules/nxalg.py | 815 +++++++++--------- query_modules/wcc.py | 19 +- release/get_version.py | 34 +- src/auth/models.cpp | 31 +- src/auth/models.hpp | 2 +- src/auth/reference_modules/ldap.py | 64 +- src/memgraph.cpp | 47 +- .../frontend/ast/cypher_main_visitor.cpp | 12 +- src/query/interpreter.cpp | 8 +- src/query/interpreter.hpp | 7 +- .../drivers/python/v4_1/docs_how_to_query.py | 30 +- tests/drivers/python/v4_1/max_query_length.py | 14 +- tests/drivers/python/v4_1/transactions.py | 7 +- tests/e2e/magic_functions/function_example.py | 1 + .../show_while_creating_invalid_state.py | 10 +- tests/e2e/streams/common.py | 100 ++- tests/e2e/streams/conftest.py | 18 +- tests/e2e/streams/kafka_streams_tests.py | 11 +- tests/e2e/streams/pulsar_streams_tests.py | 38 +- tests/e2e/streams/streams_owner_tests.py | 81 +- .../transformations/common_transform.py | 13 +- .../transformations/pulsar_transform.py | 4 +- tests/e2e/triggers/procedures/write.py | 23 +- tests/e2e/write_procedures/common.py | 3 +- tests/e2e/write_procedures/procedures/read.py | 3 +- .../e2e/write_procedures/procedures/write.py | 21 +- tests/e2e/write_procedures/simple_write.py | 129 +-- tests/gql_behave/environment.py | 16 +- tests/gql_behave/run.py | 26 +- tests/gql_behave/steps/binary_tree.py | 8 +- tests/gql_behave/steps/database.py | 3 +- tests/gql_behave/steps/errors.py | 94 +- tests/gql_behave/steps/graph.py | 24 +- tests/gql_behave/steps/parser.py | 85 +- tests/gql_behave/steps/query.py | 91 +- tests/gql_behave/steps/test_parameters.py | 16 +- tests/integration/audit/runner.py | 45 +- tests/integration/auth/runner.py | 237 ++--- tests/integration/durability/runner.py | 80 +- tests/integration/ldap/runner.py | 149 ++-- tests/integration/mg_import_csv/runner.py | 93 +- tests/integration/telemetry/runner.py | 46 +- tests/integration/telemetry/server.py | 4 +- tests/macro_benchmark/clients.py | 68 +- tests/macro_benchmark/common.py | 16 +- tests/macro_benchmark/databases.py | 40 +- .../groups/1000_create/vertex_big.run.py | 5 +- .../groups/aggregation/setup.py | 3 +- .../groups/aggregation_parallel/setup.py | 3 +- .../groups/bfs_parallel/bfs.run.py | 11 +- .../groups/bfs_parallel/common.py | 1 - .../groups/bfs_parallel/setup.py | 1 - .../groups/card_fraud/setup.py | 27 +- tests/macro_benchmark/groups/delete/common.py | 15 +- .../groups/expression/common.py | 2 + .../groups/expression/expression.run.py | 32 +- tests/macro_benchmark/groups/match/setup.py | 27 +- .../groups/match/vertex_on_index.run.py | 6 +- .../groups/match/vertex_on_label.run.py | 3 +- .../match/vertex_on_label_property.run.py | 15 +- .../groups/match/vertex_on_property.run.py | 6 +- tests/macro_benchmark/groups/return/setup.py | 3 +- tests/macro_benchmark/jail_faker.py | 107 +-- tests/macro_benchmark/long_running_suite.py | 91 +- tests/macro_benchmark/query_suite.py | 186 ++-- tests/mgbench/benchmark.py | 245 +++--- tests/mgbench/compare_results.py | 79 +- tests/mgbench/datasets.py | 191 ++-- tests/mgbench/helpers.py | 13 +- tests/mgbench/runners.py | 24 +- tests/stress/bipartite.py | 177 ++-- tests/stress/common.py | 85 +- tests/stress/create_match.py | 107 +-- tools/bench-graph-client/main.py | 16 +- tools/gdb-plugins/operator_tree.py | 68 +- tools/gdb-plugins/pretty_printers.py | 66 +- tools/github/clang-tidy/clang-tidy-diff.py | 412 ++++----- tools/github/clang-tidy/run-clang-tidy.py | 534 ++++++------ 82 files changed, 2946 insertions(+), 2662 deletions(-) diff --git a/config/generate.py b/config/generate.py index 863c0017e..b91549cb1 100755 --- a/config/generate.py +++ b/config/generate.py @@ -18,14 +18,16 @@ WIDTH = 80 def wrap_text(s, initial_indent="# "): return "\n#\n".join( - map(lambda x: textwrap.fill(x, WIDTH, initial_indent=initial_indent, - subsequent_indent="# "), s.split("\n"))) + map( + lambda x: textwrap.fill(x, WIDTH, initial_indent=initial_indent, subsequent_indent="# "), + s.split("\n"), + ) + ) def extract_flags(binary_path): ret = {} - data = subprocess.run([binary_path, "--help-xml"], - stdout=subprocess.PIPE).stdout.decode("utf-8") + data = subprocess.run([binary_path, "--help-xml"], stdout=subprocess.PIPE).stdout.decode("utf-8") root = ET.fromstring(data) for child in root: if child.tag == "usage" and child.text.lower().count("warning"): @@ -46,8 +48,7 @@ def apply_config_to_flags(config, flags): for modification in config["modifications"]: name = modification["name"] if name not in flags: - print("WARNING: Flag '" + name + "' missing from binary!", - file=sys.stderr) + print("WARNING: Flag '" + name + "' missing from binary!", file=sys.stderr) continue flags[name]["default"] = modification["value"] flags[name]["override"] = modification["override"] @@ -75,8 +76,9 @@ def extract_sections(flags): else: sections.append((current_section, current_flags)) sections.append(("other", other)) - assert set(sum(map(lambda x: x[1], sections), [])) == set(flags.keys()), \ - "The section extraction algorithm lost some flags!" + assert set(sum(map(lambda x: x[1], sections), [])) == set( + flags.keys() + ), "The section extraction algorithm lost some flags!" return sections @@ -89,8 +91,7 @@ def generate_config_file(sections, flags): helpstr = flag["meaning"] + " [" + flag["type"] + "]" ret += wrap_text(helpstr) + "\n" prefix = "# " if not flag["override"] else "" - ret += prefix + "--" + flag["name"].replace("_", "-") + \ - "=" + flag["default"] + "\n\n" + ret += prefix + "--" + flag["name"].replace("_", "-") + "=" + flag["default"] + "\n\n" ret += "\n" ret += wrap_text(config["footer"]) return ret.strip() + "\n" @@ -98,13 +99,16 @@ def generate_config_file(sections, flags): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("memgraph_binary", - help="path to Memgraph binary") - parser.add_argument("output_file", - help="path where to store the generated Memgraph " - "configuration file") - parser.add_argument("--config-file", default=CONFIG_FILE, - help="path to generator configuration file") + parser.add_argument("memgraph_binary", help="path to Memgraph binary") + parser.add_argument( + "output_file", + help="path where to store the generated Memgraph " "configuration file", + ) + parser.add_argument( + "--config-file", + default=CONFIG_FILE, + help="path to generator configuration file", + ) args = parser.parse_args() flags = extract_flags(args.memgraph_binary) diff --git a/query_modules/example.py b/query_modules/example.py index 7a71d6423..08967226c 100644 --- a/query_modules/example.py +++ b/query_modules/example.py @@ -7,13 +7,11 @@ import copy @mgp.read_proc -def procedure(context: mgp.ProcCtx, - required_arg: mgp.Nullable[mgp.Any], - optional_arg: mgp.Nullable[mgp.Any] = None - ) -> mgp.Record(args=list, - vertex_count=int, - avg_degree=mgp.Number, - props=mgp.Nullable[mgp.Map]): +def procedure( + context: mgp.ProcCtx, + required_arg: mgp.Nullable[mgp.Any], + optional_arg: mgp.Nullable[mgp.Any] = None, +) -> mgp.Record(args=list, vertex_count=int, avg_degree=mgp.Number, props=mgp.Nullable[mgp.Map]): """ This example procedure returns 4 fields. @@ -37,7 +35,7 @@ def procedure(context: mgp.ProcCtx, if isinstance(required_arg, (mgp.Edge, mgp.Vertex)): props = dict(required_arg.properties.items()) elif isinstance(required_arg, mgp.Path): - start_vertex, = required_arg.vertices + (start_vertex,) = required_arg.vertices props = dict(start_vertex.properties.items()) # Count the vertices and edges in the database; this may take a while. vertex_count = 0 @@ -51,15 +49,13 @@ def procedure(context: mgp.ProcCtx, # Copy the received arguments to make it equivalent to the C example. args_copy = [copy.deepcopy(required_arg), copy.deepcopy(optional_arg)] # Multiple rows can be produced by returning an iterable of mgp.Record. - return mgp.Record(args=args_copy, vertex_count=vertex_count, - avg_degree=avg_degree, props=props) + return mgp.Record(args=args_copy, vertex_count=vertex_count, avg_degree=avg_degree, props=props) @mgp.write_proc -def write_procedure(context: mgp.ProcCtx, - property_name: str, - property_value: mgp.Nullable[mgp.Any] - ) -> mgp.Record(created_vertex=mgp.Vertex): +def write_procedure( + context: mgp.ProcCtx, property_name: str, property_value: mgp.Nullable[mgp.Any] +) -> mgp.Record(created_vertex=mgp.Vertex): """ This example procedure creates a new vertex with the specified property and connects it to all existing vertex which has the same property with diff --git a/query_modules/graph_analyzer.py b/query_modules/graph_analyzer.py index 8e89c2222..abfc46278 100644 --- a/query_modules/graph_analyzer.py +++ b/query_modules/graph_analyzer.py @@ -4,15 +4,17 @@ from collections import OrderedDict from itertools import chain, repeat from inspect import cleandoc from typing import List, Tuple + try: import networkx as nx except ImportError as import_error: - sys.stderr.write(( - '\n' - 'NOTE: Please install networkx to be able to use graph_analyzer ' - 'module. Using Python:\n' - + sys.version + - '\n')) + sys.stderr.write( + ( + "\n" + "NOTE: Please install networkx to be able to use graph_analyzer " + "module. Using Python:\n" + sys.version + "\n" + ) + ) raise import_error # Imported last because it also depends on networkx. from mgp_networkx import MemgraphMultiDiGraph # noqa E402 @@ -23,16 +25,14 @@ _MAX_LIST_SIZE = 10 @mgp.read_proc def help() -> mgp.Record(name=str, value=str): - '''Shows manual page for graph_analyzer.''' + """Shows manual page for graph_analyzer.""" records = [] def make_records(name, doc): - return (mgp.Record(name=n, value=v) for n, v in - zip(chain([name], repeat('')), cleandoc(doc).splitlines())) + return (mgp.Record(name=n, value=v) for n, v in zip(chain([name], repeat("")), cleandoc(doc).splitlines())) for func in (help, analyze, analyze_subgraph): - records.extend(make_records("Procedure '{}'".format(func.__name__), - func.__doc__)) + records.extend(make_records("Procedure '{}'".format(func.__name__), func.__doc__)) for m, v in _get_analysis_mapping().items(): records.extend(make_records("Analysis '{}'".format(m), v.__doc__)) @@ -41,10 +41,8 @@ def help() -> mgp.Record(name=str, value=str): @mgp.read_proc -def analyze(context: mgp.ProcCtx, - analyses: mgp.Nullable[List[str]] = None - ) -> mgp.Record(name=str, value=str): - ''' +def analyze(context: mgp.ProcCtx, analyses: mgp.Nullable[List[str]] = None) -> mgp.Record(name=str, value=str): + """ Shows graph information. In case of multiple results, only the first 10 will be shown. @@ -57,19 +55,20 @@ def analyze(context: mgp.ProcCtx, Example call (with parameter): CALL graph_analyzer.analyze(['nodes', 'edges']) YIELD *; - ''' + """ g = MemgraphMultiDiGraph(ctx=context) recs = _analyze_graph(context, g, analyses) return [mgp.Record(name=name, value=value) for name, value in recs] @mgp.read_proc -def analyze_subgraph(context: mgp.ProcCtx, - vertices: mgp.List[mgp.Vertex], - edges: mgp.List[mgp.Edge], - analyses: mgp.Nullable[List[str]] = None - ) -> mgp.Record(name=str, value=str): - ''' +def analyze_subgraph( + context: mgp.ProcCtx, + vertices: mgp.List[mgp.Vertex], + edges: mgp.List[mgp.Edge], + analyses: mgp.Nullable[List[str]] = None, +) -> mgp.Record(name=str, value=str): + """ Shows subgraph information. In case of multiple results, only the first 10 will be shown. @@ -91,36 +90,40 @@ def analyze_subgraph(context: mgp.ProcCtx, CALL graph_analyzer.analyze_subgraph(nodes, edges, ['nodes', 'edges']) YIELD * RETURN name, value; - ''' + """ vertices, edges = map(set, [vertices, edges]) g = nx.subgraph_view( MemgraphMultiDiGraph(ctx=context), lambda n: n in vertices, - lambda n1, n2, e: e in edges) + lambda n1, n2, e: e in edges, + ) recs = _analyze_graph(context, g, analyses) return [mgp.Record(name=name, value=value) for name, value in recs] def _get_analysis_mapping(): - return OrderedDict([ - ('nodes', _number_of_nodes), - ('edges', _number_of_edges), - ('bridges', _bridges), - ('articulation_points', _articulation_points), - ('avg_degree', _avg_degree), - ('sorted_nodes_degree', _sorted_nodes_degree), - ('self_loops', _self_loops), - ('is_bipartite', _is_bipartite), - ('is_planar', _is_planar), - ('is_biconnected: ', _is_biconnected), - ('is_weakly_connected', _is_weakly_connected), - ('number_of_weakly_components', _weakly_components), - ('is_strongly_connected', _is_strongly_connected), - ('strongly_components', _strongly_components), - ('is_dag', _is_dag), - ('is_eulerian', _is_eulerian), - ('is_forest', _is_forest), - ('is_tree', _is_tree)]) + return OrderedDict( + [ + ("nodes", _number_of_nodes), + ("edges", _number_of_edges), + ("bridges", _bridges), + ("articulation_points", _articulation_points), + ("avg_degree", _avg_degree), + ("sorted_nodes_degree", _sorted_nodes_degree), + ("self_loops", _self_loops), + ("is_bipartite", _is_bipartite), + ("is_planar", _is_planar), + ("is_biconnected: ", _is_biconnected), + ("is_weakly_connected", _is_weakly_connected), + ("number_of_weakly_components", _weakly_components), + ("is_strongly_connected", _is_strongly_connected), + ("strongly_components", _strongly_components), + ("is_dag", _is_dag), + ("is_eulerian", _is_eulerian), + ("is_forest", _is_forest), + ("is_tree", _is_tree), + ] + ) def _get_analysis_func(name: str): @@ -132,20 +135,15 @@ def _get_analysis_funcs(): return _get_analysis_mapping().values() -def _analyze_graph(context: mgp.ProcCtx, - g: nx.MultiDiGraph, - analyses: List[str] - ) -> List[Tuple[str, str]]: +def _analyze_graph(context: mgp.ProcCtx, g: nx.MultiDiGraph, analyses: List[str]) -> List[Tuple[str, str]]: - functions = (_get_analysis_funcs() if analyses is None - else [_get_analysis_func(name) for name in analyses]) + functions = _get_analysis_funcs() if analyses is None else [_get_analysis_func(name) for name in analyses] records = [] for index, f in enumerate(functions): context.check_must_abort() if f is None: - raise KeyError('Graph analysis is not supported: ' + - analyses[index]) + raise KeyError("Graph analysis is not supported: " + analyses[index]) name, value = f(g) if isinstance(value, (list, set, tuple)): value = list(value)[:_MAX_LIST_SIZE] @@ -155,126 +153,120 @@ def _analyze_graph(context: mgp.ProcCtx, def _number_of_nodes(g: nx.MultiDiGraph) -> Tuple[str, int]: - '''Returns number of nodes.''' - return 'Number of nodes', nx.number_of_nodes(g) + """Returns number of nodes.""" + return "Number of nodes", nx.number_of_nodes(g) def _number_of_edges(g: nx.MultiDiGraph) -> Tuple[str, int]: - '''Returns number of edges.''' - return 'Number of edges', nx.number_of_edges(g) + """Returns number of edges.""" + return "Number of edges", nx.number_of_edges(g) def _avg_degree(g: nx.MultiDiGraph) -> Tuple[str, float]: - '''Returns average degree.''' + """Returns average degree.""" _, number_of_nodes = _number_of_nodes(g) _, number_of_edges = _number_of_edges(g) - avg_degree = (0 if number_of_nodes == 0 - else number_of_edges / number_of_nodes) - return 'Average degree', avg_degree + avg_degree = 0 if number_of_nodes == 0 else number_of_edges / number_of_nodes + return "Average degree", avg_degree def _sorted_nodes_degree(g: nx.MultiDiGraph) -> Tuple[str, List[int]]: - '''Returns list of sorted nodes degree. [(node_id, degree), ...]''' + """Returns list of sorted nodes degree. [(node_id, degree), ...]""" nodes_degree = [(n, g.degree(n)) for n in g.nodes()] nodes_degree.sort(key=lambda x: x[1], reverse=True) - return 'Sorted nodes degree', nodes_degree + return "Sorted nodes degree", nodes_degree def _self_loops(g: nx.MultiDiGraph) -> Tuple[str, int]: - '''Returns number of self loops.''' - return 'Self loops', sum((1 if e[0] == e[1] else 0 for e in g.edges())) + """Returns number of self loops.""" + return "Self loops", sum((1 if e[0] == e[1] else 0 for e in g.edges())) def _is_bipartite(g: nx.MultiDiGraph) -> Tuple[str, bool]: - '''Checks if graph is bipartite.''' + """Checks if graph is bipartite.""" _, number_of_nodes = _number_of_nodes(g) - ret = (False if number_of_nodes == 0 - else nx.algorithms.bipartite.basic.is_bipartite(g)) - return 'Is bipartite', ret + ret = False if number_of_nodes == 0 else nx.algorithms.bipartite.basic.is_bipartite(g) + return "Is bipartite", ret def _is_planar(g: nx.MultiDiGraph) -> Tuple[str, bool]: - '''Checks if graph is planar.''' + """Checks if graph is planar.""" _, number_of_nodes = _number_of_nodes(g) - ret = (False if number_of_nodes == 0 - else nx.algorithms.planarity.check_planarity(g)[0]) - return 'Is planar', ret + ret = False if number_of_nodes == 0 else nx.algorithms.planarity.check_planarity(g)[0] + return "Is planar", ret def _is_biconnected(g: nx.MultiDiGraph) -> Tuple[str, bool]: - '''Check if graph is biconnected.''' + """Check if graph is biconnected.""" _, number_of_nodes = _number_of_nodes(g) - ret = (False if number_of_nodes == 0 - else nx.is_biconnected(nx.MultiDiGraph.to_undirected(g))) - return 'Is biconnected', ret + ret = False if number_of_nodes == 0 else nx.is_biconnected(nx.MultiDiGraph.to_undirected(g)) + return "Is biconnected", ret def _is_weakly_connected(g: nx.MultiDiGraph) -> Tuple[str, bool]: - '''Check if graph is weakly connected.''' + """Check if graph is weakly connected.""" _, number_of_nodes = _number_of_nodes(g) ret = False if number_of_nodes == 0 else nx.is_weakly_connected(g) - return 'Is weakly connected', ret + return "Is weakly connected", ret def _is_strongly_connected(g: nx.MultiDiGraph) -> Tuple[str, bool]: - '''Checks if graph is strongly connected.''' + """Checks if graph is strongly connected.""" _, number_of_nodes = _number_of_nodes(g) ret = False if number_of_nodes == 0 else nx.is_strongly_connected(g) - return 'Is strongly connected', ret + return "Is strongly connected", ret def _is_dag(g: nx.MultiDiGraph) -> Tuple[str, bool]: - '''Check if graph is directed acyclic graph (DAG)''' + """Check if graph is directed acyclic graph (DAG)""" _, number_of_nodes = _number_of_nodes(g) - ret = (False if number_of_nodes == 0 - else nx.algorithms.dag.is_directed_acyclic_graph(g)) - return 'Is DAG', ret + ret = False if number_of_nodes == 0 else nx.algorithms.dag.is_directed_acyclic_graph(g) + return "Is DAG", ret def _is_eulerian(g: nx.MultiDiGraph) -> Tuple[str, bool]: - '''Checks if graph is Eulerian.''' + """Checks if graph is Eulerian.""" _, number_of_nodes = _number_of_nodes(g) - ret = (False if number_of_nodes == 0 - else nx.algorithms.euler.is_eulerian(g)) - return 'Is eulerian', ret + ret = False if number_of_nodes == 0 else nx.algorithms.euler.is_eulerian(g) + return "Is eulerian", ret def _is_forest(g: nx.MultiDiGraph) -> Tuple[str, bool]: - '''Checks if graph is forest, all components must be trees.''' + """Checks if graph is forest, all components must be trees.""" _, number_of_nodes = _number_of_nodes(g) - ret = (False if number_of_nodes == 0 - else nx.algorithms.tree.recognition.is_forest(g)) - return 'Is forest', ret + ret = False if number_of_nodes == 0 else nx.algorithms.tree.recognition.is_forest(g) + return "Is forest", ret def _is_tree(g: nx.MultiDiGraph) -> Tuple[str, bool]: - '''Checks if graph is tree.''' + """Checks if graph is tree.""" _, number_of_nodes = _number_of_nodes(g) - ret = (False if number_of_nodes == 0 - else nx.algorithms.tree.recognition.is_tree(g)) - return 'Is tree', ret + ret = False if number_of_nodes == 0 else nx.algorithms.tree.recognition.is_tree(g) + return "Is tree", ret def _bridges(g: nx.MultiDiGraph) -> Tuple[str, int]: - '''Returns number of bridges, multiple edges between same nodes are - mapped to one edge.''' - return 'Number of bridges', sum(1 for _ in nx.bridges(nx.Graph(g))) + """Returns number of bridges, multiple edges between same nodes are + mapped to one edge.""" + return "Number of bridges", sum(1 for _ in nx.bridges(nx.Graph(g))) def _articulation_points(g: nx.MultiDiGraph): - '''Returns number of articulation points.''' + """Returns number of articulation points.""" undirected = nx.MultiDiGraph.to_undirected(g) - return ('Number of articulation points', - sum(1 for _ in nx.articulation_points(undirected))) + return ( + "Number of articulation points", + sum(1 for _ in nx.articulation_points(undirected)), + ) def _weakly_components(g: nx.MultiDiGraph): - '''Returns number of weakly components.''' + """Returns number of weakly components.""" comps = nx.algorithms.components.number_weakly_connected_components(g) - return 'Number of weakly connected components', comps + return "Number of weakly connected components", comps def _strongly_components(g: nx.MultiDiGraph): - '''Returns number of strongly connected components.''' + """Returns number of strongly connected components.""" comps = nx.algorithms.components.number_strongly_connected_components(g) - return 'Number of strongly connected components', comps + return "Number of strongly connected components", comps diff --git a/query_modules/mgp_networkx.py b/query_modules/mgp_networkx.py index 0539c6493..6e2919baf 100644 --- a/query_modules/mgp_networkx.py +++ b/query_modules/mgp_networkx.py @@ -1,20 +1,22 @@ import sys import mgp import collections + try: import networkx as nx except ImportError as import_error: - sys.stderr.write(( - '\n' - 'NOTE: Please install networkx to be able to use Memgraph NetworkX ' - 'wrappers. Using Python:\n' - + sys.version + - '\n')) + sys.stderr.write( + ( + "\n" + "NOTE: Please install networkx to be able to use Memgraph NetworkX " + "wrappers. Using Python:\n" + sys.version + "\n" + ) + ) raise import_error class MemgraphAdjlistOuterDict(collections.abc.Mapping): - __slots__ = ('_ctx', '_succ', '_multi') + __slots__ = ("_ctx", "_succ", "_multi") def __init__(self, ctx, succ=True, multi=True): self._ctx = ctx @@ -24,8 +26,7 @@ class MemgraphAdjlistOuterDict(collections.abc.Mapping): def __getitem__(self, key): if key not in self: raise KeyError - return MemgraphAdjlistInnerDict(key, succ=self._succ, - multi=self._multi) + return MemgraphAdjlistInnerDict(key, succ=self._succ, multi=self._multi) def __iter__(self): return iter(self._ctx.graph.vertices) @@ -40,7 +41,7 @@ class MemgraphAdjlistOuterDict(collections.abc.Mapping): class MemgraphAdjlistInnerDict(collections.abc.Mapping): - __slots__ = ('_node', '_succ', '_multi', '_neighbors') + __slots__ = ("_node", "_succ", "_multi", "_neighbors") def __init__(self, node, succ=True, multi=True): self._node = node @@ -71,31 +72,26 @@ class MemgraphAdjlistInnerDict(collections.abc.Mapping): def _get_neighbors(self): if not self._neighbors: if self._succ: - self._neighbors = set( - e.to_vertex for e in self._node.out_edges) + self._neighbors = set(e.to_vertex for e in self._node.out_edges) else: - self._neighbors = set( - e.from_vertex for e in self._node.in_edges) + self._neighbors = set(e.from_vertex for e in self._node.in_edges) return self._neighbors def _get_edge(self, neighbor): if self._succ: - edge = list(filter(lambda e: e.to_vertex == neighbor, - self._node.out_edges)) + edge = list(filter(lambda e: e.to_vertex == neighbor, self._node.out_edges)) else: - edge = list(filter(lambda e: e.from_vertex == neighbor, - self._node.in_edges)) + edge = list(filter(lambda e: e.from_vertex == neighbor, self._node.in_edges)) assert len(edge) >= 1 if len(edge) > 1: - raise RuntimeError('Graph contains multiedges but ' - 'is of non-multigraph type: {}'.format(edge)) + raise RuntimeError("Graph contains multiedges but " "is of non-multigraph type: {}".format(edge)) return edge[0] class MemgraphEdgeKeyDict(collections.abc.Mapping): - __slots__ = ('_node', '_neighbor', '_succ', '_edges') + __slots__ = ("_node", "_neighbor", "_succ", "_edges") def __init__(self, node, neighbor, succ=True): self._node = node @@ -122,18 +118,14 @@ class MemgraphEdgeKeyDict(collections.abc.Mapping): def _get_edges(self): if not self._edges: if self._succ: - self._edges = list(filter( - lambda e: e.to_vertex == self._neighbor, - self._node.out_edges)) + self._edges = list(filter(lambda e: e.to_vertex == self._neighbor, self._node.out_edges)) else: - self._edges = list(filter( - lambda e: e.from_vertex == self._neighbor, - self._node.in_edges)) + self._edges = list(filter(lambda e: e.from_vertex == self._neighbor, self._node.in_edges)) return self._edges class UnhashableProperties(collections.abc.Mapping): - __slots__ = ('_properties') + __slots__ = "_properties" def __init__(self, properties): self._properties = properties @@ -155,7 +147,7 @@ class UnhashableProperties(collections.abc.Mapping): class MemgraphNodeDict(collections.abc.Mapping): - __slots__ = ('_ctx',) + __slots__ = ("_ctx",) def __init__(self, ctx): self._ctx = ctx @@ -187,8 +179,7 @@ class MemgraphNodeDict(collections.abc.Mapping): class MemgraphDiGraphBase: - def __init__(self, incoming_graph_data=None, ctx=None, multi=True, - **kwargs): + def __init__(self, incoming_graph_data=None, ctx=None, multi=True, **kwargs): # NOTE: We assume that our graph will never be given any initial data # because we already pull our data from the Memgraph database. This # assert is triggered by certain NetworkX procedures because they @@ -201,23 +192,30 @@ class MemgraphDiGraphBase: # modify the graph's internal attributes and don't try to populate it # with initial data or modify it. - self.node_dict_factory = lambda: MemgraphNodeDict(ctx) \ - if ctx else self._error + self.node_dict_factory = lambda: MemgraphNodeDict(ctx) if ctx else self._error self.node_attr_dict_factory = self._error - self.adjlist_outer_dict_factory = \ - lambda: MemgraphAdjlistOuterDict(ctx, multi=multi) \ - if ctx else self._error + self.adjlist_outer_dict_factory = lambda: MemgraphAdjlistOuterDict(ctx, multi=multi) if ctx else self._error self.adjlist_inner_dict_factory = self._error self.edge_key_dict_factory = self._error self.edge_attr_dict_factory = self._error # NOTE: We forbid any mutating operations because our graph is # immutable and pulls its data from the Memgraph database. - for f in ['add_node', 'add_nodes_from', 'remove_node', - 'remove_nodes_from', 'add_edge', 'add_edges_from', - 'add_weighted_edges_from', 'new_edge_key', 'remove_edge', - 'remove_edges_from', 'update', 'clear']: + for f in [ + "add_node", + "add_nodes_from", + "remove_node", + "remove_nodes_from", + "add_edge", + "add_edges_from", + "add_weighted_edges_from", + "new_edge_key", + "remove_edge", + "remove_edges_from", + "update", + "clear", + ]: setattr(self, f, lambda *args, **kwargs: self._error()) super().__init__(None, **kwargs) @@ -231,33 +229,29 @@ class MemgraphDiGraphBase: self._pred = MemgraphAdjlistOuterDict(ctx, succ=False, multi=multi) def _error(self): - raise RuntimeError('Modification operations are not supported') + raise RuntimeError("Modification operations are not supported") class MemgraphMultiDiGraph(MemgraphDiGraphBase, nx.MultiDiGraph): def __init__(self, incoming_graph_data=None, ctx=None, **kwargs): - super().__init__(incoming_graph_data=incoming_graph_data, - ctx=ctx, multi=True, **kwargs) + super().__init__(incoming_graph_data=incoming_graph_data, ctx=ctx, multi=True, **kwargs) def MemgraphMultiGraph(incoming_graph_data=None, ctx=None, **kwargs): - return MemgraphMultiDiGraph(incoming_graph_data=incoming_graph_data, - ctx=ctx, **kwargs).to_undirected(as_view=True) + return MemgraphMultiDiGraph(incoming_graph_data=incoming_graph_data, ctx=ctx, **kwargs).to_undirected(as_view=True) class MemgraphDiGraph(MemgraphDiGraphBase, nx.DiGraph): def __init__(self, incoming_graph_data=None, ctx=None, **kwargs): - super().__init__(incoming_graph_data=incoming_graph_data, - ctx=ctx, multi=False, **kwargs) + super().__init__(incoming_graph_data=incoming_graph_data, ctx=ctx, multi=False, **kwargs) def MemgraphGraph(incoming_graph_data=None, ctx=None, **kwargs): - return MemgraphDiGraph(incoming_graph_data=incoming_graph_data, - ctx=ctx, **kwargs).to_undirected(as_view=True) + return MemgraphDiGraph(incoming_graph_data=incoming_graph_data, ctx=ctx, **kwargs).to_undirected(as_view=True) class PropertiesDictionary(collections.abc.Mapping): - __slots__ = ('_ctx', '_prop', '_len') + __slots__ = ("_ctx", "_prop", "_len") def __init__(self, ctx, prop): self._ctx = ctx @@ -270,8 +264,7 @@ class PropertiesDictionary(collections.abc.Mapping): try: return vertex.properties[self._prop] except KeyError: - raise KeyError(("{} doesn\t have the required " + - "property '{}'").format(vertex, self._prop)) + raise KeyError(("{} doesn\t have the required " + "property '{}'").format(vertex, self._prop)) def __iter__(self): for v in self._ctx.graph.vertices: diff --git a/query_modules/nxalg.py b/query_modules/nxalg.py index 2dea251e7..c3282bf07 100644 --- a/query_modules/nxalg.py +++ b/query_modules/nxalg.py @@ -1,45 +1,52 @@ import sys import mgp + try: import networkx as nx import numpy # noqa E401 import scipy # noqa E401 except ImportError as import_error: - sys.stderr.write(( - '\n' - 'NOTE: Please install networkx, numpy, scipy to be able to ' - 'use proxied NetworkX algorithms. E.g., CALL nxalg.pagerank(...).\n' - 'Using Python:\n' - + sys.version + - '\n')) + sys.stderr.write( + ( + "\n" + "NOTE: Please install networkx, numpy, scipy to be able to " + "use proxied NetworkX algorithms. E.g., CALL nxalg.pagerank(...).\n" + "Using Python:\n" + sys.version + "\n" + ) + ) raise import_error # Imported last because it also depends on networkx. -from mgp_networkx import (MemgraphMultiDiGraph, MemgraphDiGraph, # noqa: E402 - MemgraphMultiGraph, MemgraphGraph, - PropertiesDictionary) +from mgp_networkx import ( + MemgraphMultiDiGraph, + MemgraphDiGraph, # noqa: E402 + MemgraphMultiGraph, + MemgraphGraph, + PropertiesDictionary, +) # networkx.algorithms.approximation.connectivity.node_connectivity @mgp.read_proc -def node_connectivity(ctx: mgp.ProcCtx, - source: mgp.Nullable[mgp.Vertex] = None, - target: mgp.Nullable[mgp.Vertex] = None - ) -> mgp.Record(connectivity=int): - return mgp.Record(connectivity=nx.node_connectivity( - MemgraphMultiDiGraph(ctx=ctx), source, target)) +def node_connectivity( + ctx: mgp.ProcCtx, + source: mgp.Nullable[mgp.Vertex] = None, + target: mgp.Nullable[mgp.Vertex] = None, +) -> mgp.Record(connectivity=int): + return mgp.Record(connectivity=nx.node_connectivity(MemgraphMultiDiGraph(ctx=ctx), source, target)) # networkx.algorithms.assortativity.degree_assortativity_coefficient @mgp.read_proc def degree_assortativity_coefficient( - ctx: mgp.ProcCtx, - x: str = 'out', - y: str = 'in', - weight: mgp.Nullable[str] = None, - nodes: mgp.Nullable[mgp.List[mgp.Vertex]] = None + ctx: mgp.ProcCtx, + x: str = "out", + y: str = "in", + weight: mgp.Nullable[str] = None, + nodes: mgp.Nullable[mgp.List[mgp.Vertex]] = None, ) -> mgp.Record(assortativity=float): - return mgp.Record(assortativity=nx.degree_assortativity_coefficient( - MemgraphMultiDiGraph(ctx=ctx), x, y, weight, nodes)) + return mgp.Record( + assortativity=nx.degree_assortativity_coefficient(MemgraphMultiDiGraph(ctx=ctx), x, y, weight, nodes) + ) # networkx.algorithms.asteroidal.is_at_free @@ -51,58 +58,58 @@ def is_at_free(ctx: mgp.ProcCtx) -> mgp.Record(is_at_free=bool): # networkx.algorithms.bipartite.basic.is_bipartite @mgp.read_proc def is_bipartite(ctx: mgp.ProcCtx) -> mgp.Record(is_bipartite=bool): - return mgp.Record(is_bipartite=nx.is_bipartite( - MemgraphMultiDiGraph(ctx=ctx))) + return mgp.Record(is_bipartite=nx.is_bipartite(MemgraphMultiDiGraph(ctx=ctx))) # networkx.algorithms.boundary.node_boundary @mgp.read_proc -def node_boundary(ctx: mgp.ProcCtx, - nbunch1: mgp.List[mgp.Vertex], - nbunch2: mgp.Nullable[mgp.List[mgp.Vertex]] = None - ) -> mgp.Record(boundary=mgp.List[mgp.Vertex]): - return mgp.Record(boundary=list(nx.node_boundary( - MemgraphMultiDiGraph(ctx=ctx), nbunch1, nbunch2))) +def node_boundary( + ctx: mgp.ProcCtx, + nbunch1: mgp.List[mgp.Vertex], + nbunch2: mgp.Nullable[mgp.List[mgp.Vertex]] = None, +) -> mgp.Record(boundary=mgp.List[mgp.Vertex]): + return mgp.Record(boundary=list(nx.node_boundary(MemgraphMultiDiGraph(ctx=ctx), nbunch1, nbunch2))) # networkx.algorithms.bridges.bridges @mgp.read_proc -def bridges(ctx: mgp.ProcCtx, - root: mgp.Nullable[mgp.Vertex] = None - ) -> mgp.Record(bridges=mgp.List[mgp.Edge]): +def bridges(ctx: mgp.ProcCtx, root: mgp.Nullable[mgp.Vertex] = None) -> mgp.Record(bridges=mgp.List[mgp.Edge]): g = MemgraphMultiGraph(ctx=ctx) - return mgp.Record( - bridges=[next(iter(g[u][v])) - for u, v in nx.bridges(MemgraphGraph(ctx=ctx), - root=root)]) + return mgp.Record(bridges=[next(iter(g[u][v])) for u, v in nx.bridges(MemgraphGraph(ctx=ctx), root=root)]) # networkx.algorithms.centrality.betweenness_centrality @mgp.read_proc -def betweenness_centrality(ctx: mgp.ProcCtx, - k: mgp.Nullable[int] = None, - normalized: bool = True, - weight: mgp.Nullable[str] = None, - endpoints: bool = False, - seed: mgp.Nullable[int] = None - ) -> mgp.Record(node=mgp.Vertex, - betweenness=mgp.Number): - return [mgp.Record(node=n, betweenness=b) - for n, b in nx.betweenness_centrality( - MemgraphDiGraph(ctx=ctx), k=k, normalized=normalized, - weight=weight, endpoints=endpoints, seed=seed).items()] +def betweenness_centrality( + ctx: mgp.ProcCtx, + k: mgp.Nullable[int] = None, + normalized: bool = True, + weight: mgp.Nullable[str] = None, + endpoints: bool = False, + seed: mgp.Nullable[int] = None, +) -> mgp.Record(node=mgp.Vertex, betweenness=mgp.Number): + return [ + mgp.Record(node=n, betweenness=b) + for n, b in nx.betweenness_centrality( + MemgraphDiGraph(ctx=ctx), + k=k, + normalized=normalized, + weight=weight, + endpoints=endpoints, + seed=seed, + ).items() + ] # networkx.algorithms.chains.chain_decomposition @mgp.read_proc -def chain_decomposition(ctx: mgp.ProcCtx, - root: mgp.Nullable[mgp.Vertex] = None - ) -> mgp.Record(chains=mgp.List[mgp.List[mgp.Edge]]): +def chain_decomposition( + ctx: mgp.ProcCtx, root: mgp.Nullable[mgp.Vertex] = None +) -> mgp.Record(chains=mgp.List[mgp.List[mgp.Edge]]): g = MemgraphMultiGraph(ctx=ctx) return mgp.Record( - chains=[[next(iter(g[u][v])) for u, v in d] - for d in nx.chain_decomposition(MemgraphGraph(ctx=ctx), - root=root)]) + chains=[[next(iter(g[u][v])) for u, v in d] for d in nx.chain_decomposition(MemgraphGraph(ctx=ctx), root=root)] + ) # networkx.algorithms.chordal.is_chordal @@ -113,72 +120,74 @@ def is_chordal(ctx: mgp.ProcCtx) -> mgp.Record(is_chordal=bool): # networkx.algorithms.clique.find_cliques @mgp.read_proc -def find_cliques(ctx: mgp.ProcCtx - ) -> mgp.Record(cliques=mgp.List[mgp.List[mgp.Vertex]]): - return mgp.Record(cliques=list(nx.find_cliques( - MemgraphMultiGraph(ctx=ctx)))) +def find_cliques( + ctx: mgp.ProcCtx, +) -> mgp.Record(cliques=mgp.List[mgp.List[mgp.Vertex]]): + return mgp.Record(cliques=list(nx.find_cliques(MemgraphMultiGraph(ctx=ctx)))) # networkx.algorithms.cluster.clustering @mgp.read_proc -def clustering(ctx: mgp.ProcCtx, - nodes: mgp.Nullable[mgp.List[mgp.Vertex]] = None, - weight: mgp.Nullable[str] = None - ) -> mgp.Record(node=mgp.Vertex, clustering=mgp.Number): - return [mgp.Record(node=n, clustering=c) - for n, c in nx.clustering( - MemgraphDiGraph(ctx=ctx), nodes=nodes, - weight=weight).items()] +def clustering( + ctx: mgp.ProcCtx, + nodes: mgp.Nullable[mgp.List[mgp.Vertex]] = None, + weight: mgp.Nullable[str] = None, +) -> mgp.Record(node=mgp.Vertex, clustering=mgp.Number): + return [ + mgp.Record(node=n, clustering=c) + for n, c in nx.clustering(MemgraphDiGraph(ctx=ctx), nodes=nodes, weight=weight).items() + ] # networkx.algorithms.coloring.greedy_color @mgp.read_proc -def greedy_color(ctx: mgp.ProcCtx, - strategy: str = 'largest_first', - interchange: bool = False - ) -> mgp.Record(node=mgp.Vertex, color=int): - return [mgp.Record(node=n, color=c) for n, c in nx.greedy_color( - MemgraphMultiDiGraph(ctx=ctx), strategy, interchange).items()] +def greedy_color( + ctx: mgp.ProcCtx, strategy: str = "largest_first", interchange: bool = False +) -> mgp.Record(node=mgp.Vertex, color=int): + return [ + mgp.Record(node=n, color=c) + for n, c in nx.greedy_color(MemgraphMultiDiGraph(ctx=ctx), strategy, interchange).items() + ] # networkx.algorithms.communicability_alg.communicability @mgp.read_proc -def communicability(ctx: mgp.ProcCtx - ) -> mgp.Record(node1=mgp.Vertex, node2=mgp.Vertex, - communicability=mgp.Number): - return [mgp.Record(node1=n1, node2=n2, communicability=v) - for n1, d in nx.communicability(MemgraphGraph(ctx=ctx)).items() - for n2, v in d.items()] +def communicability( + ctx: mgp.ProcCtx, +) -> mgp.Record(node1=mgp.Vertex, node2=mgp.Vertex, communicability=mgp.Number): + return [ + mgp.Record(node1=n1, node2=n2, communicability=v) + for n1, d in nx.communicability(MemgraphGraph(ctx=ctx)).items() + for n2, v in d.items() + ] # networkx.algorithms.community.kclique.k_clique_communities @mgp.read_proc def k_clique_communities( - ctx: mgp.ProcCtx, - k: int, - cliques: mgp.Nullable[mgp.List[mgp.List[mgp.Vertex]]] = None + ctx: mgp.ProcCtx, + k: int, + cliques: mgp.Nullable[mgp.List[mgp.List[mgp.Vertex]]] = None, ) -> mgp.Record(communities=mgp.List[mgp.List[mgp.Vertex]]): - return mgp.Record(communities=[ - list(s) for s in nx.community.k_clique_communities( - MemgraphMultiGraph(ctx=ctx), k, cliques)]) + return mgp.Record( + communities=[list(s) for s in nx.community.k_clique_communities(MemgraphMultiGraph(ctx=ctx), k, cliques)] + ) # networkx.algorithms.approximation.kcomponents.k_components @mgp.read_proc -def k_components(ctx: mgp.ProcCtx, - density: mgp.Number = 0.95 - ) -> mgp.Record(k=int, - components=mgp.List[mgp.List[mgp.Vertex]]): +def k_components( + ctx: mgp.ProcCtx, density: mgp.Number = 0.95 +) -> mgp.Record(k=int, components=mgp.List[mgp.List[mgp.Vertex]]): kcomps = nx.k_components(MemgraphMultiGraph(ctx=ctx), density) - return [mgp.Record(k=k, components=[list(s) for s in comps]) - for k, comps in kcomps.items()] + return [mgp.Record(k=k, components=[list(s) for s in comps]) for k, comps in kcomps.items()] # networkx.algorithms.components.biconnected_components @mgp.read_proc def biconnected_components( - ctx: mgp.ProcCtx + ctx: mgp.ProcCtx, ) -> mgp.Record(components=mgp.List[mgp.List[mgp.Vertex]]): comps = nx.biconnected_components(MemgraphMultiGraph(ctx=ctx)) return mgp.Record(components=[list(s) for s in comps]) @@ -187,7 +196,7 @@ def biconnected_components( # networkx.algorithms.components.strongly_connected_components @mgp.read_proc def strongly_connected_components( - ctx: mgp.ProcCtx + ctx: mgp.ProcCtx, ) -> mgp.Record(components=mgp.List[mgp.List[mgp.Vertex]]): comps = nx.strongly_connected_components(MemgraphMultiDiGraph(ctx=ctx)) return mgp.Record(components=[list(s) for s in comps]) @@ -199,40 +208,32 @@ def strongly_connected_components( # a *copy* of the graph because the algorithm copies the graph using # __class__() and tries to modify it. @mgp.read_proc -def k_edge_components( - ctx: mgp.ProcCtx, - k: int -) -> mgp.Record(components=mgp.List[mgp.List[mgp.Vertex]]): - return mgp.Record(components=[list(s) for s in nx.k_edge_components( - nx.DiGraph(MemgraphDiGraph(ctx=ctx)), k)]) +def k_edge_components(ctx: mgp.ProcCtx, k: int) -> mgp.Record(components=mgp.List[mgp.List[mgp.Vertex]]): + return mgp.Record(components=[list(s) for s in nx.k_edge_components(nx.DiGraph(MemgraphDiGraph(ctx=ctx)), k)]) # networkx.algorithms.core.core_number @mgp.read_proc -def core_number(ctx: mgp.ProcCtx - ) -> mgp.Record(node=mgp.Vertex, core=mgp.Number): - return [mgp.Record(node=n, core=c) - for n, c in nx.core_number(MemgraphDiGraph(ctx=ctx)).items()] +def core_number(ctx: mgp.ProcCtx) -> mgp.Record(node=mgp.Vertex, core=mgp.Number): + return [mgp.Record(node=n, core=c) for n, c in nx.core_number(MemgraphDiGraph(ctx=ctx)).items()] # networkx.algorithms.covering.is_edge_cover @mgp.read_proc -def is_edge_cover(ctx: mgp.ProcCtx, cover: mgp.List[mgp.Edge] - ) -> mgp.Record(is_edge_cover=bool): +def is_edge_cover(ctx: mgp.ProcCtx, cover: mgp.List[mgp.Edge]) -> mgp.Record(is_edge_cover=bool): cover = set([(e.from_vertex, e.to_vertex) for e in cover]) - return mgp.Record(is_edge_cover=nx.is_edge_cover( - MemgraphMultiGraph(ctx=ctx), cover)) + return mgp.Record(is_edge_cover=nx.is_edge_cover(MemgraphMultiGraph(ctx=ctx), cover)) # networkx.algorithms.cycles.find_cycle @mgp.read_proc -def find_cycle(ctx: mgp.ProcCtx, - source: mgp.Nullable[mgp.List[mgp.Vertex]] = None, - orientation: mgp.Nullable[str] = None - ) -> mgp.Record(cycle=mgp.Nullable[mgp.List[mgp.Edge]]): +def find_cycle( + ctx: mgp.ProcCtx, + source: mgp.Nullable[mgp.List[mgp.Vertex]] = None, + orientation: mgp.Nullable[str] = None, +) -> mgp.Record(cycle=mgp.Nullable[mgp.List[mgp.Edge]]): try: - return mgp.Record(cycle=[e for _, _, e in nx.find_cycle( - MemgraphMultiDiGraph(ctx=ctx), source, orientation)]) + return mgp.Record(cycle=[e for _, _, e in nx.find_cycle(MemgraphMultiDiGraph(ctx=ctx), source, orientation)]) except nx.NetworkXNoCycle: return mgp.Record(cycle=None) @@ -243,44 +244,36 @@ def find_cycle(ctx: mgp.ProcCtx, # because the algorithm copies the graph using type() and tries to pass initial # data. @mgp.read_proc -def simple_cycles(ctx: mgp.ProcCtx - ) -> mgp.Record(cycles=mgp.List[mgp.List[mgp.Vertex]]): - return mgp.Record(cycles=list(nx.simple_cycles( - nx.MultiDiGraph(MemgraphMultiDiGraph(ctx=ctx)).copy()))) +def simple_cycles( + ctx: mgp.ProcCtx, +) -> mgp.Record(cycles=mgp.List[mgp.List[mgp.Vertex]]): + return mgp.Record(cycles=list(nx.simple_cycles(nx.MultiDiGraph(MemgraphMultiDiGraph(ctx=ctx)).copy()))) # networkx.algorithms.cuts.node_expansion @mgp.read_proc -def node_expansion(ctx: mgp.ProcCtx, s: mgp.List[mgp.Vertex] - ) -> mgp.Record(node_expansion=mgp.Number): - return mgp.Record(node_expansion=nx.node_expansion( - MemgraphMultiDiGraph(ctx=ctx), set(s))) +def node_expansion(ctx: mgp.ProcCtx, s: mgp.List[mgp.Vertex]) -> mgp.Record(node_expansion=mgp.Number): + return mgp.Record(node_expansion=nx.node_expansion(MemgraphMultiDiGraph(ctx=ctx), set(s))) # networkx.algorithms.dag.topological_sort @mgp.read_proc -def topological_sort(ctx: mgp.ProcCtx - ) -> mgp.Record(nodes=mgp.Nullable[mgp.List[mgp.Vertex]]): - return mgp.Record(nodes=list(nx.topological_sort( - MemgraphMultiDiGraph(ctx=ctx)))) +def topological_sort( + ctx: mgp.ProcCtx, +) -> mgp.Record(nodes=mgp.Nullable[mgp.List[mgp.Vertex]]): + return mgp.Record(nodes=list(nx.topological_sort(MemgraphMultiDiGraph(ctx=ctx)))) # networkx.algorithms.dag.ancestors @mgp.read_proc -def ancestors(ctx: mgp.ProcCtx, - source: mgp.Vertex - ) -> mgp.Record(ancestors=mgp.List[mgp.Vertex]): - return mgp.Record(ancestors=list(nx.ancestors( - MemgraphMultiDiGraph(ctx=ctx), source))) +def ancestors(ctx: mgp.ProcCtx, source: mgp.Vertex) -> mgp.Record(ancestors=mgp.List[mgp.Vertex]): + return mgp.Record(ancestors=list(nx.ancestors(MemgraphMultiDiGraph(ctx=ctx), source))) # networkx.algorithms.dag.descendants @mgp.read_proc -def descendants(ctx: mgp.ProcCtx, - source: mgp.Vertex - ) -> mgp.Record(descendants=mgp.List[mgp.Vertex]): - return mgp.Record(descendants=list(nx.descendants( - MemgraphMultiDiGraph(ctx=ctx), source))) +def descendants(ctx: mgp.ProcCtx, source: mgp.Vertex) -> mgp.Record(descendants=mgp.List[mgp.Vertex]): + return mgp.Record(descendants=list(nx.descendants(MemgraphMultiDiGraph(ctx=ctx), source))) # networkx.algorithms.distance_measures.center @@ -301,143 +294,138 @@ def diameter(ctx: mgp.ProcCtx) -> mgp.Record(diameter=int): # networkx.algorithms.distance_regular.is_distance_regular @mgp.read_proc -def is_distance_regular(ctx: mgp.ProcCtx - ) -> mgp.Record(is_distance_regular=bool): - return mgp.Record(is_distance_regular=nx.is_distance_regular( - MemgraphMultiGraph(ctx=ctx))) +def is_distance_regular(ctx: mgp.ProcCtx) -> mgp.Record(is_distance_regular=bool): + return mgp.Record(is_distance_regular=nx.is_distance_regular(MemgraphMultiGraph(ctx=ctx))) # networkx.algorithms.strongly_regular.is_strongly_regular @mgp.read_proc -def is_strongly_regular(ctx: mgp.ProcCtx - ) -> mgp.Record(is_strongly_regular=bool): - return mgp.Record(is_strongly_regular=nx.is_strongly_regular( - MemgraphMultiGraph(ctx=ctx))) +def is_strongly_regular(ctx: mgp.ProcCtx) -> mgp.Record(is_strongly_regular=bool): + return mgp.Record(is_strongly_regular=nx.is_strongly_regular(MemgraphMultiGraph(ctx=ctx))) # networkx.algorithms.dominance.dominance_frontiers @mgp.read_proc -def dominance_frontiers(ctx: mgp.ProcCtx, start: mgp.Vertex, - ) -> mgp.Record(node=mgp.Vertex, - frontier=mgp.List[mgp.Vertex]): - return [mgp.Record(node=n, frontier=list(f)) - for n, f in nx.dominance_frontiers( - MemgraphMultiDiGraph(ctx=ctx), start).items()] +def dominance_frontiers( + ctx: mgp.ProcCtx, + start: mgp.Vertex, +) -> mgp.Record(node=mgp.Vertex, frontier=mgp.List[mgp.Vertex]): + return [ + mgp.Record(node=n, frontier=list(f)) + for n, f in nx.dominance_frontiers(MemgraphMultiDiGraph(ctx=ctx), start).items() + ] # networkx.algorithms.dominance.immediate_dominatorss @mgp.read_proc -def immediate_dominators(ctx: mgp.ProcCtx, start: mgp.Vertex, - ) -> mgp.Record(node=mgp.Vertex, - dominator=mgp.Vertex): - return [mgp.Record(node=n, dominator=d) - for n, d in nx.immediate_dominators( - MemgraphMultiDiGraph(ctx=ctx), start).items()] +def immediate_dominators( + ctx: mgp.ProcCtx, + start: mgp.Vertex, +) -> mgp.Record(node=mgp.Vertex, dominator=mgp.Vertex): + return [ + mgp.Record(node=n, dominator=d) + for n, d in nx.immediate_dominators(MemgraphMultiDiGraph(ctx=ctx), start).items() + ] # networkx.algorithms.dominating.dominating_set @mgp.read_proc -def dominating_set(ctx: mgp.ProcCtx, start: mgp.Vertex, - ) -> mgp.Record(dominating_set=mgp.List[mgp.Vertex]): - return mgp.Record(dominating_set=list(nx.dominating_set( - MemgraphMultiDiGraph(ctx=ctx), start))) +def dominating_set( + ctx: mgp.ProcCtx, + start: mgp.Vertex, +) -> mgp.Record(dominating_set=mgp.List[mgp.Vertex]): + return mgp.Record(dominating_set=list(nx.dominating_set(MemgraphMultiDiGraph(ctx=ctx), start))) # networkx.algorithms.efficiency_measures.local_efficiency @mgp.read_proc def local_efficiency(ctx: mgp.ProcCtx) -> mgp.Record(local_efficiency=float): - return mgp.Record(local_efficiency=nx.local_efficiency( - MemgraphMultiGraph(ctx=ctx))) + return mgp.Record(local_efficiency=nx.local_efficiency(MemgraphMultiGraph(ctx=ctx))) # networkx.algorithms.efficiency_measures.global_efficiency @mgp.read_proc def global_efficiency(ctx: mgp.ProcCtx) -> mgp.Record(global_efficiency=float): - return mgp.Record(global_efficiency=nx.global_efficiency( - MemgraphMultiGraph(ctx=ctx))) + return mgp.Record(global_efficiency=nx.global_efficiency(MemgraphMultiGraph(ctx=ctx))) # networkx.algorithms.euler.is_eulerian @mgp.read_proc def is_eulerian(ctx: mgp.ProcCtx) -> mgp.Record(is_eulerian=bool): - return mgp.Record(is_eulerian=nx.is_eulerian( - MemgraphMultiDiGraph(ctx=ctx))) + return mgp.Record(is_eulerian=nx.is_eulerian(MemgraphMultiDiGraph(ctx=ctx))) # networkx.algorithms.euler.is_semieulerian @mgp.read_proc def is_semieulerian(ctx: mgp.ProcCtx) -> mgp.Record(is_semieulerian=bool): - return mgp.Record(is_semieulerian=nx.is_semieulerian( - MemgraphMultiDiGraph(ctx=ctx))) + return mgp.Record(is_semieulerian=nx.is_semieulerian(MemgraphMultiDiGraph(ctx=ctx))) # networkx.algorithms.euler.has_eulerian_path @mgp.read_proc def has_eulerian_path(ctx: mgp.ProcCtx) -> mgp.Record(has_eulerian_path=bool): - return mgp.Record(has_eulerian_path=nx.has_eulerian_path( - MemgraphMultiDiGraph(ctx=ctx))) + return mgp.Record(has_eulerian_path=nx.has_eulerian_path(MemgraphMultiDiGraph(ctx=ctx))) # networkx.algorithms.hierarchy.flow_hierarchy @mgp.read_proc -def flow_hierarchy(ctx: mgp.ProcCtx, - weight: mgp.Nullable[str] = None - ) -> mgp.Record(flow_hierarchy=float): - return mgp.Record(flow_hierarchy=nx.flow_hierarchy( - MemgraphMultiDiGraph(ctx=ctx), weight=weight)) +def flow_hierarchy(ctx: mgp.ProcCtx, weight: mgp.Nullable[str] = None) -> mgp.Record(flow_hierarchy=float): + return mgp.Record(flow_hierarchy=nx.flow_hierarchy(MemgraphMultiDiGraph(ctx=ctx), weight=weight)) # networkx.algorithms.isolate.isolates @mgp.read_proc def isolates(ctx: mgp.ProcCtx) -> mgp.Record(isolates=mgp.List[mgp.Vertex]): - return mgp.Record(isolates=list(nx.isolates( - MemgraphMultiDiGraph(ctx=ctx)))) + return mgp.Record(isolates=list(nx.isolates(MemgraphMultiDiGraph(ctx=ctx)))) # networkx.algorithms.isolate.is_isolate @mgp.read_proc -def is_isolate(ctx: mgp.ProcCtx, n: mgp.Vertex - ) -> mgp.Record(is_isolate=bool): - return mgp.Record(is_isolate=nx.is_isolate( - MemgraphMultiDiGraph(ctx=ctx), n)) +def is_isolate(ctx: mgp.ProcCtx, n: mgp.Vertex) -> mgp.Record(is_isolate=bool): + return mgp.Record(is_isolate=nx.is_isolate(MemgraphMultiDiGraph(ctx=ctx), n)) # networkx.algorithms.isomorphism.is_isomorphic @mgp.read_proc -def is_isomorphic(ctx: mgp.ProcCtx, - nodes1: mgp.List[mgp.Vertex], - edges1: mgp.List[mgp.Edge], - nodes2: mgp.List[mgp.Vertex], - edges2: mgp.List[mgp.Edge] - ) -> mgp.Record(is_isomorphic=bool): +def is_isomorphic( + ctx: mgp.ProcCtx, + nodes1: mgp.List[mgp.Vertex], + edges1: mgp.List[mgp.Edge], + nodes2: mgp.List[mgp.Vertex], + edges2: mgp.List[mgp.Edge], +) -> mgp.Record(is_isomorphic=bool): nodes1, edges1, nodes2, edges2 = map(set, [nodes1, edges1, nodes2, edges2]) g = MemgraphMultiDiGraph(ctx=ctx) - g1 = nx.subgraph_view( - g, lambda n: n in nodes1, lambda n1, n2, e: e in edges1) - g2 = nx.subgraph_view( - g, lambda n: n in nodes2, lambda n1, n2, e: e in edges2) + g1 = nx.subgraph_view(g, lambda n: n in nodes1, lambda n1, n2, e: e in edges1) + g2 = nx.subgraph_view(g, lambda n: n in nodes2, lambda n1, n2, e: e in edges2) return mgp.Record(is_isomorphic=nx.is_isomorphic(g1, g2)) # networkx.algorithms.link_analysis.pagerank_alg.pagerank @mgp.read_proc -def pagerank(ctx: mgp.ProcCtx, - alpha: mgp.Number = 0.85, - personalization: mgp.Nullable[str] = None, - max_iter: int = 100, - tol: mgp.Number = 1e-06, - nstart: mgp.Nullable[str] = None, - weight: mgp.Nullable[str] = 'weight', - dangling: mgp.Nullable[str] = None, - ) -> mgp.Record(node=mgp.Vertex, rank=float): +def pagerank( + ctx: mgp.ProcCtx, + alpha: mgp.Number = 0.85, + personalization: mgp.Nullable[str] = None, + max_iter: int = 100, + tol: mgp.Number = 1e-06, + nstart: mgp.Nullable[str] = None, + weight: mgp.Nullable[str] = "weight", + dangling: mgp.Nullable[str] = None, +) -> mgp.Record(node=mgp.Vertex, rank=float): def to_properties_dictionary(prop): return None if prop is None else PropertiesDictionary(ctx, prop) - pg = nx.pagerank(MemgraphDiGraph(ctx=ctx), alpha=alpha, - personalization=to_properties_dictionary(personalization), - max_iter=max_iter, tol=tol, - nstart=to_properties_dictionary(nstart), weight=weight, - dangling=to_properties_dictionary(dangling)) + pg = nx.pagerank( + MemgraphDiGraph(ctx=ctx), + alpha=alpha, + personalization=to_properties_dictionary(personalization), + max_iter=max_iter, + tol=tol, + nstart=to_properties_dictionary(nstart), + weight=weight, + dangling=to_properties_dictionary(dangling), + ) return [mgp.Record(node=k, rank=v) for k, v in pg.items()] @@ -445,29 +433,24 @@ def pagerank(ctx: mgp.ProcCtx, # networkx.algorithms.link_prediction.jaccard_coefficient @mgp.read_proc def jaccard_coefficient( - ctx: mgp.ProcCtx, - ebunch: mgp.Nullable[mgp.List[mgp.List[mgp.Vertex]]] = None -) -> mgp.Record(u=mgp.Vertex, v=mgp.Vertex, - coef=float): - return [mgp.Record(u=u, v=v, coef=c) for u, v, c - in nx.jaccard_coefficient(MemgraphGraph(ctx=ctx), ebunch)] + ctx: mgp.ProcCtx, ebunch: mgp.Nullable[mgp.List[mgp.List[mgp.Vertex]]] = None +) -> mgp.Record(u=mgp.Vertex, v=mgp.Vertex, coef=float): + return [mgp.Record(u=u, v=v, coef=c) for u, v, c in nx.jaccard_coefficient(MemgraphGraph(ctx=ctx), ebunch)] # networkx.algorithms.lowest_common_ancestors.lowest_common_ancestor @mgp.read_proc -def lowest_common_ancestor(ctx: mgp.ProcCtx, node1: mgp.Vertex, - node2: mgp.Vertex - ) -> mgp.Record(ancestor=mgp.Nullable[mgp.Vertex]): - return mgp.Record(ancestor=nx.lowest_common_ancestor( - MemgraphDiGraph(ctx=ctx), node1, node2)) +def lowest_common_ancestor( + ctx: mgp.ProcCtx, node1: mgp.Vertex, node2: mgp.Vertex +) -> mgp.Record(ancestor=mgp.Nullable[mgp.Vertex]): + return mgp.Record(ancestor=nx.lowest_common_ancestor(MemgraphDiGraph(ctx=ctx), node1, node2)) # networkx.algorithms.matching.maximal_matching @mgp.read_proc def maximal_matching(ctx: mgp.ProcCtx) -> mgp.Record(edges=mgp.List[mgp.Edge]): g = MemgraphMultiDiGraph(ctx=ctx) - return mgp.Record(edges=list( - next(iter(g[u][v])) for u, v in nx.maximal_matching(g))) + return mgp.Record(edges=list(next(iter(g[u][v])) for u, v in nx.maximal_matching(g))) # networkx.algorithms.planarity.check_planarity @@ -475,27 +458,23 @@ def maximal_matching(ctx: mgp.ProcCtx) -> mgp.Record(edges=mgp.List[mgp.Edge]): # NOTE: Returns a graph. @mgp.read_proc def check_planarity(ctx: mgp.ProcCtx) -> mgp.Record(is_planar=bool): - return mgp.Record(is_planar=nx.check_planarity( - MemgraphMultiDiGraph(ctx=ctx))[0]) + return mgp.Record(is_planar=nx.check_planarity(MemgraphMultiDiGraph(ctx=ctx))[0]) # networkx.algorithms.non_randomness.non_randomness @mgp.read_proc -def non_randomness(ctx: mgp.ProcCtx, - k: mgp.Nullable[int] = None - ) -> mgp.Record(non_randomness=float, - relative_non_randomness=float): - nn, rnn = nx.non_randomness( - MemgraphGraph(ctx=ctx), k=k) +def non_randomness( + ctx: mgp.ProcCtx, k: mgp.Nullable[int] = None +) -> mgp.Record(non_randomness=float, relative_non_randomness=float): + nn, rnn = nx.non_randomness(MemgraphGraph(ctx=ctx), k=k) return mgp.Record(non_randomness=nn, relative_non_randomness=rnn) # networkx.algorithms.reciprocity.reciprocity @mgp.read_proc -def reciprocity(ctx: mgp.ProcCtx, - nodes: mgp.Nullable[mgp.List[mgp.Vertex]] = None - ) -> mgp.Record(node=mgp.Nullable[mgp.Vertex], - reciprocity=mgp.Nullable[float]): +def reciprocity( + ctx: mgp.ProcCtx, nodes: mgp.Nullable[mgp.List[mgp.Vertex]] = None +) -> mgp.Record(node=mgp.Nullable[mgp.Vertex], reciprocity=mgp.Nullable[float]): rp = nx.reciprocity(MemgraphMultiDiGraph(ctx=ctx), nodes=nodes) if nodes is None: return mgp.Record(node=None, reciprocity=rp) @@ -505,15 +484,20 @@ def reciprocity(ctx: mgp.ProcCtx, # networkx.algorithms.shortest_paths.generic.shortest_path @mgp.read_proc -def shortest_path(ctx: mgp.ProcCtx, - source: mgp.Nullable[mgp.Vertex] = None, - target: mgp.Nullable[mgp.Vertex] = None, - weight: mgp.Nullable[str] = None, - method: str = 'dijkstra' - ) -> mgp.Record(source=mgp.Vertex, target=mgp.Vertex, - path=mgp.List[mgp.Vertex]): - sp = nx.shortest_path(MemgraphMultiDiGraph(ctx=ctx), source=source, - target=target, weight=weight, method=method) +def shortest_path( + ctx: mgp.ProcCtx, + source: mgp.Nullable[mgp.Vertex] = None, + target: mgp.Nullable[mgp.Vertex] = None, + weight: mgp.Nullable[str] = None, + method: str = "dijkstra", +) -> mgp.Record(source=mgp.Vertex, target=mgp.Vertex, path=mgp.List[mgp.Vertex]): + sp = nx.shortest_path( + MemgraphMultiDiGraph(ctx=ctx), + source=source, + target=target, + weight=weight, + method=method, + ) if source and target: sp = {source: {target: sp}} @@ -522,22 +506,25 @@ def shortest_path(ctx: mgp.ProcCtx, elif not source and target: sp = {source: {target: p} for source, p in sp.items()} - return [mgp.Record(source=s, target=t, path=p) - for s, d in sp.items() - for t, p in d.items()] + return [mgp.Record(source=s, target=t, path=p) for s, d in sp.items() for t, p in d.items()] # networkx.algorithms.shortest_paths.generic.shortest_path_length @mgp.read_proc -def shortest_path_length(ctx: mgp.ProcCtx, - source: mgp.Nullable[mgp.Vertex] = None, - target: mgp.Nullable[mgp.Vertex] = None, - weight: mgp.Nullable[str] = None, - method: str = 'dijkstra' - ) -> mgp.Record(source=mgp.Vertex, target=mgp.Vertex, - length=mgp.Number): - sp = nx.shortest_path_length(MemgraphMultiDiGraph(ctx=ctx), source=source, - target=target, weight=weight, method=method) +def shortest_path_length( + ctx: mgp.ProcCtx, + source: mgp.Nullable[mgp.Vertex] = None, + target: mgp.Nullable[mgp.Vertex] = None, + weight: mgp.Nullable[str] = None, + method: str = "dijkstra", +) -> mgp.Record(source=mgp.Vertex, target=mgp.Vertex, length=mgp.Number): + sp = nx.shortest_path_length( + MemgraphMultiDiGraph(ctx=ctx), + source=source, + target=target, + weight=weight, + method=method, + ) if source and target: sp = {source: {target: sp}} @@ -548,214 +535,224 @@ def shortest_path_length(ctx: mgp.ProcCtx, else: sp = dict(sp) - return [mgp.Record(source=s, target=t, length=l) - for s, d in sp.items() - for t, l in d.items()] + return [mgp.Record(source=s, target=t, length=l) for s, d in sp.items() for t, l in d.items()] # networkx.algorithms.shortest_paths.generic.all_shortest_paths @mgp.read_proc -def all_shortest_paths(ctx: mgp.ProcCtx, - source: mgp.Vertex, - target: mgp.Vertex, - weight: mgp.Nullable[str] = None, - method: str = 'dijkstra' - ) -> mgp.Record(paths=mgp.List[mgp.List[mgp.Vertex]]): - return mgp.Record(paths=list(nx.all_shortest_paths( - MemgraphMultiDiGraph(ctx=ctx), source=source, target=target, - weight=weight, method=method))) +def all_shortest_paths( + ctx: mgp.ProcCtx, + source: mgp.Vertex, + target: mgp.Vertex, + weight: mgp.Nullable[str] = None, + method: str = "dijkstra", +) -> mgp.Record(paths=mgp.List[mgp.List[mgp.Vertex]]): + return mgp.Record( + paths=list( + nx.all_shortest_paths( + MemgraphMultiDiGraph(ctx=ctx), + source=source, + target=target, + weight=weight, + method=method, + ) + ) + ) # networkx.algorithms.shortest_paths.generic.has_path @mgp.read_proc -def has_path(ctx: mgp.ProcCtx, - source: mgp.Vertex, - target: mgp.Vertex) -> mgp.Record(has_path=bool): - return mgp.Record(has_path=nx.has_path(MemgraphMultiDiGraph(ctx=ctx), - source, target)) +def has_path(ctx: mgp.ProcCtx, source: mgp.Vertex, target: mgp.Vertex) -> mgp.Record(has_path=bool): + return mgp.Record(has_path=nx.has_path(MemgraphMultiDiGraph(ctx=ctx), source, target)) # networkx.algorithms.shortest_paths.weighted.multi_source_dijkstra_path @mgp.read_proc -def multi_source_dijkstra_path(ctx: mgp.ProcCtx, - sources: mgp.List[mgp.Vertex], - cutoff: mgp.Nullable[int] = None, - weight: str = 'weight' - ) -> mgp.Record(target=mgp.Vertex, - path=mgp.List[mgp.Vertex]): - return [mgp.Record(target=t, path=p) - for t, p in nx.multi_source_dijkstra_path( - MemgraphMultiDiGraph(ctx=ctx), sources, cutoff=cutoff, - weight=weight).items()] +def multi_source_dijkstra_path( + ctx: mgp.ProcCtx, + sources: mgp.List[mgp.Vertex], + cutoff: mgp.Nullable[int] = None, + weight: str = "weight", +) -> mgp.Record(target=mgp.Vertex, path=mgp.List[mgp.Vertex]): + return [ + mgp.Record(target=t, path=p) + for t, p in nx.multi_source_dijkstra_path( + MemgraphMultiDiGraph(ctx=ctx), sources, cutoff=cutoff, weight=weight + ).items() + ] # networkx.algorithms.shortest_paths.weighted.multi_source_dijkstra_path_length @mgp.read_proc -def multi_source_dijkstra_path_length(ctx: mgp.ProcCtx, - sources: mgp.List[mgp.Vertex], - cutoff: mgp.Nullable[int] = None, - weight: str = 'weight' - ) -> mgp.Record(target=mgp.Vertex, - length=mgp.Number): - return [mgp.Record(target=t, length=l) - for t, l in nx.multi_source_dijkstra_path_length( - MemgraphMultiDiGraph(ctx=ctx), sources, cutoff=cutoff, - weight=weight).items()] +def multi_source_dijkstra_path_length( + ctx: mgp.ProcCtx, + sources: mgp.List[mgp.Vertex], + cutoff: mgp.Nullable[int] = None, + weight: str = "weight", +) -> mgp.Record(target=mgp.Vertex, length=mgp.Number): + return [ + mgp.Record(target=t, length=l) + for t, l in nx.multi_source_dijkstra_path_length( + MemgraphMultiDiGraph(ctx=ctx), sources, cutoff=cutoff, weight=weight + ).items() + ] # networkx.algorithms.simple_paths.is_simple_path @mgp.read_proc -def is_simple_path(ctx: mgp.ProcCtx, - nodes: mgp.List[mgp.Vertex] - ) -> mgp.Record(is_simple_path=bool): - return mgp.Record(is_simple_path=nx.is_simple_path( - MemgraphMultiDiGraph(ctx=ctx), nodes)) +def is_simple_path(ctx: mgp.ProcCtx, nodes: mgp.List[mgp.Vertex]) -> mgp.Record(is_simple_path=bool): + return mgp.Record(is_simple_path=nx.is_simple_path(MemgraphMultiDiGraph(ctx=ctx), nodes)) # networkx.algorithms.simple_paths.all_simple_paths @mgp.read_proc -def all_simple_paths(ctx: mgp.ProcCtx, - source: mgp.Vertex, - target: mgp.Vertex, - cutoff: mgp.Nullable[int] = None - ) -> mgp.Record(paths=mgp.List[mgp.List[mgp.Vertex]]): - return mgp.Record(paths=list(nx.all_simple_paths( - MemgraphMultiDiGraph(ctx=ctx), source, target, cutoff=cutoff))) +def all_simple_paths( + ctx: mgp.ProcCtx, + source: mgp.Vertex, + target: mgp.Vertex, + cutoff: mgp.Nullable[int] = None, +) -> mgp.Record(paths=mgp.List[mgp.List[mgp.Vertex]]): + return mgp.Record(paths=list(nx.all_simple_paths(MemgraphMultiDiGraph(ctx=ctx), source, target, cutoff=cutoff))) # networkx.algorithms.tournament.is_tournament @mgp.read_proc def is_tournament(ctx: mgp.ProcCtx) -> mgp.Record(is_tournament=bool): - return mgp.Record(is_tournament=nx.tournament.is_tournament( - MemgraphDiGraph(ctx=ctx))) + return mgp.Record(is_tournament=nx.tournament.is_tournament(MemgraphDiGraph(ctx=ctx))) # networkx.algorithms.traversal.breadth_first_search.bfs_edges @mgp.read_proc -def bfs_edges(ctx: mgp.ProcCtx, - source: mgp.Vertex, - reverse: bool = False, - depth_limit: mgp.Nullable[int] = None - ) -> mgp.Record(edges=mgp.List[mgp.Edge]): - return mgp.Record(edges=list(nx.bfs_edges( - MemgraphMultiDiGraph(ctx=ctx), source, reverse=reverse, - depth_limit=depth_limit))) +def bfs_edges( + ctx: mgp.ProcCtx, + source: mgp.Vertex, + reverse: bool = False, + depth_limit: mgp.Nullable[int] = None, +) -> mgp.Record(edges=mgp.List[mgp.Edge]): + return mgp.Record( + edges=list( + nx.bfs_edges( + MemgraphMultiDiGraph(ctx=ctx), + source, + reverse=reverse, + depth_limit=depth_limit, + ) + ) + ) # networkx.algorithms.traversal.breadth_first_search.bfs_tree @mgp.read_proc -def bfs_tree(ctx: mgp.ProcCtx, - source: mgp.Vertex, - reverse: bool = False, - depth_limit: mgp.Nullable[int] = None - ) -> mgp.Record(tree=mgp.List[mgp.Vertex]): - return mgp.Record(tree=list(nx.bfs_tree( - MemgraphMultiDiGraph(ctx=ctx), source, reverse=reverse, - depth_limit=depth_limit))) +def bfs_tree( + ctx: mgp.ProcCtx, + source: mgp.Vertex, + reverse: bool = False, + depth_limit: mgp.Nullable[int] = None, +) -> mgp.Record(tree=mgp.List[mgp.Vertex]): + return mgp.Record( + tree=list( + nx.bfs_tree( + MemgraphMultiDiGraph(ctx=ctx), + source, + reverse=reverse, + depth_limit=depth_limit, + ) + ) + ) # networkx.algorithms.traversal.breadth_first_search.bfs_predecessors @mgp.read_proc -def bfs_predecessors(ctx: mgp.ProcCtx, - source: mgp.Vertex, - depth_limit: mgp.Nullable[int] = None - ) -> mgp.Record(node=mgp.Vertex, - predecessor=mgp.Vertex): - return [mgp.Record(node=n, predecessor=p) - for n, p in nx.bfs_predecessors( - MemgraphMultiDiGraph(ctx=ctx), source, - depth_limit=depth_limit)] +def bfs_predecessors( + ctx: mgp.ProcCtx, source: mgp.Vertex, depth_limit: mgp.Nullable[int] = None +) -> mgp.Record(node=mgp.Vertex, predecessor=mgp.Vertex): + return [ + mgp.Record(node=n, predecessor=p) + for n, p in nx.bfs_predecessors(MemgraphMultiDiGraph(ctx=ctx), source, depth_limit=depth_limit) + ] # networkx.algorithms.traversal.breadth_first_search.bfs_successors @mgp.read_proc -def bfs_successors(ctx: mgp.ProcCtx, - source: mgp.Vertex, - depth_limit: mgp.Nullable[int] = None - ) -> mgp.Record(node=mgp.Vertex, - successors=mgp.List[mgp.Vertex]): - return [mgp.Record(node=n, successors=s) - for n, s in nx.bfs_successors( - MemgraphMultiDiGraph(ctx=ctx), source, - depth_limit=depth_limit)] +def bfs_successors( + ctx: mgp.ProcCtx, source: mgp.Vertex, depth_limit: mgp.Nullable[int] = None +) -> mgp.Record(node=mgp.Vertex, successors=mgp.List[mgp.Vertex]): + return [ + mgp.Record(node=n, successors=s) + for n, s in nx.bfs_successors(MemgraphMultiDiGraph(ctx=ctx), source, depth_limit=depth_limit) + ] # networkx.algorithms.traversal.depth_first_search.dfs_tree @mgp.read_proc -def dfs_tree(ctx: mgp.ProcCtx, - source: mgp.Vertex, - depth_limit: mgp.Nullable[int] = None - ) -> mgp.Record(tree=mgp.List[mgp.Vertex]): - return mgp.Record(tree=list(nx.dfs_tree( - MemgraphMultiDiGraph(ctx=ctx), source, depth_limit=depth_limit))) +def dfs_tree( + ctx: mgp.ProcCtx, source: mgp.Vertex, depth_limit: mgp.Nullable[int] = None +) -> mgp.Record(tree=mgp.List[mgp.Vertex]): + return mgp.Record(tree=list(nx.dfs_tree(MemgraphMultiDiGraph(ctx=ctx), source, depth_limit=depth_limit))) # networkx.algorithms.traversal.depth_first_search.dfs_predecessors @mgp.read_proc -def dfs_predecessors(ctx: mgp.ProcCtx, - source: mgp.Vertex, - depth_limit: mgp.Nullable[int] = None - ) -> mgp.Record(node=mgp.Vertex, - predecessor=mgp.Vertex): - return [mgp.Record(node=n, predecessor=p) - for n, p in nx.dfs_predecessors( - MemgraphMultiDiGraph(ctx=ctx), source, - depth_limit=depth_limit).items()] +def dfs_predecessors( + ctx: mgp.ProcCtx, source: mgp.Vertex, depth_limit: mgp.Nullable[int] = None +) -> mgp.Record(node=mgp.Vertex, predecessor=mgp.Vertex): + return [ + mgp.Record(node=n, predecessor=p) + for n, p in nx.dfs_predecessors(MemgraphMultiDiGraph(ctx=ctx), source, depth_limit=depth_limit).items() + ] # networkx.algorithms.traversal.depth_first_search.dfs_successors @mgp.read_proc -def dfs_successors(ctx: mgp.ProcCtx, - source: mgp.Vertex, - depth_limit: mgp.Nullable[int] = None - ) -> mgp.Record(node=mgp.Vertex, - successors=mgp.List[mgp.Vertex]): - return [mgp.Record(node=n, successors=s) - for n, s in nx.dfs_successors( - MemgraphMultiDiGraph(ctx=ctx), source, - depth_limit=depth_limit).items()] +def dfs_successors( + ctx: mgp.ProcCtx, source: mgp.Vertex, depth_limit: mgp.Nullable[int] = None +) -> mgp.Record(node=mgp.Vertex, successors=mgp.List[mgp.Vertex]): + return [ + mgp.Record(node=n, successors=s) + for n, s in nx.dfs_successors(MemgraphMultiDiGraph(ctx=ctx), source, depth_limit=depth_limit).items() + ] # networkx.algorithms.traversal.depth_first_search.dfs_preorder_nodes @mgp.read_proc -def dfs_preorder_nodes(ctx: mgp.ProcCtx, - source: mgp.Vertex, - depth_limit: mgp.Nullable[int] = None - ) -> mgp.Record(nodes=mgp.List[mgp.Vertex]): - return mgp.Record(nodes=list(nx.dfs_preorder_nodes( - MemgraphMultiDiGraph(ctx=ctx), source, depth_limit=depth_limit))) +def dfs_preorder_nodes( + ctx: mgp.ProcCtx, source: mgp.Vertex, depth_limit: mgp.Nullable[int] = None +) -> mgp.Record(nodes=mgp.List[mgp.Vertex]): + return mgp.Record(nodes=list(nx.dfs_preorder_nodes(MemgraphMultiDiGraph(ctx=ctx), source, depth_limit=depth_limit))) # networkx.algorithms.traversal.depth_first_search.dfs_postorder_nodes @mgp.read_proc -def dfs_postorder_nodes(ctx: mgp.ProcCtx, - source: mgp.Vertex, - depth_limit: mgp.Nullable[int] = None - ) -> mgp.Record(nodes=mgp.List[mgp.Vertex]): - return mgp.Record(nodes=list(nx.dfs_postorder_nodes( - MemgraphMultiDiGraph(ctx=ctx), source, depth_limit=depth_limit))) +def dfs_postorder_nodes( + ctx: mgp.ProcCtx, source: mgp.Vertex, depth_limit: mgp.Nullable[int] = None +) -> mgp.Record(nodes=mgp.List[mgp.Vertex]): + return mgp.Record( + nodes=list(nx.dfs_postorder_nodes(MemgraphMultiDiGraph(ctx=ctx), source, depth_limit=depth_limit)) + ) # networkx.algorithms.traversal.edgebfs.edge_bfs @mgp.read_proc -def edge_bfs(ctx: mgp.ProcCtx, - source: mgp.Nullable[mgp.Vertex] = None, - orientation: mgp.Nullable[str] = None - ) -> mgp.Record(edges=mgp.List[mgp.Edge]): - return mgp.Record(edges=list(e for _, _, e in nx.edge_bfs( - MemgraphMultiDiGraph(ctx=ctx), source=source, - orientation=orientation))) +def edge_bfs( + ctx: mgp.ProcCtx, + source: mgp.Nullable[mgp.Vertex] = None, + orientation: mgp.Nullable[str] = None, +) -> mgp.Record(edges=mgp.List[mgp.Edge]): + return mgp.Record( + edges=list(e for _, _, e in nx.edge_bfs(MemgraphMultiDiGraph(ctx=ctx), source=source, orientation=orientation)) + ) # networkx.algorithms.traversal.edgedfs.edge_dfs @mgp.read_proc -def edge_dfs(ctx: mgp.ProcCtx, - source: mgp.Nullable[mgp.Vertex] = None, - orientation: mgp.Nullable[str] = None - ) -> mgp.Record(edges=mgp.List[mgp.Edge]): - return mgp.Record(edges=list(e for _, _, e in nx.edge_dfs( - MemgraphMultiDiGraph(ctx=ctx), source=source, - orientation=orientation))) +def edge_dfs( + ctx: mgp.ProcCtx, + source: mgp.Nullable[mgp.Vertex] = None, + orientation: mgp.Nullable[str] = None, +) -> mgp.Record(edges=mgp.List[mgp.Edge]): + return mgp.Record( + edges=list(e for _, _, e in nx.edge_dfs(MemgraphMultiDiGraph(ctx=ctx), source=source, orientation=orientation)) + ) # networkx.algorithms.tree.recognition.is_tree @@ -773,8 +770,7 @@ def is_forest(ctx: mgp.ProcCtx) -> mgp.Record(is_forest=bool): # networkx.algorithms.tree.recognition.is_arborescence @mgp.read_proc def is_arborescence(ctx: mgp.ProcCtx) -> mgp.Record(is_arborescence=bool): - return mgp.Record(is_arborescence=nx.is_arborescence( - MemgraphDiGraph(ctx=ctx))) + return mgp.Record(is_arborescence=nx.is_arborescence(MemgraphDiGraph(ctx=ctx))) # networkx.algorithms.tree.recognition.is_branching @@ -785,41 +781,34 @@ def is_branching(ctx: mgp.ProcCtx) -> mgp.Record(is_branching=bool): # networkx.algorithms.tree.mst.minimum_spanning_tree @mgp.read_proc -def minimum_spanning_tree(ctx: mgp.ProcCtx, - weight: str = 'weight', - algorithm: str = 'kruskal', - ignore_nan: bool = False - ) -> mgp.Record(nodes=mgp.List[mgp.Vertex], - edges=mgp.List[mgp.Edge]): - gres = nx.minimum_spanning_tree(MemgraphMultiGraph(ctx=ctx), - weight, algorithm, ignore_nan) - return mgp.Record(nodes=list(gres.nodes()), - edges=[e for _, _, e in gres.edges(keys=True)]) +def minimum_spanning_tree( + ctx: mgp.ProcCtx, + weight: str = "weight", + algorithm: str = "kruskal", + ignore_nan: bool = False, +) -> mgp.Record(nodes=mgp.List[mgp.Vertex], edges=mgp.List[mgp.Edge]): + gres = nx.minimum_spanning_tree(MemgraphMultiGraph(ctx=ctx), weight, algorithm, ignore_nan) + return mgp.Record(nodes=list(gres.nodes()), edges=[e for _, _, e in gres.edges(keys=True)]) # networkx.algorithms.triads.triadic_census @mgp.read_proc def triadic_census(ctx: mgp.ProcCtx) -> mgp.Record(triad=str, count=int): - return [mgp.Record(triad=t, count=c) - for t, c in nx.triadic_census(MemgraphDiGraph(ctx=ctx)).items()] + return [mgp.Record(triad=t, count=c) for t, c in nx.triadic_census(MemgraphDiGraph(ctx=ctx)).items()] # networkx.algorithms.voronoi.voronoi_cells @mgp.read_proc -def voronoi_cells(ctx: mgp.ProcCtx, - center_nodes: mgp.List[mgp.Vertex], - weight: str = 'weight' - ) -> mgp.Record(center=mgp.Vertex, - cell=mgp.List[mgp.Vertex]): - return [mgp.Record(center=c1, cell=list(c2)) - for c1, c2 in nx.voronoi_cells( - MemgraphMultiDiGraph(ctx=ctx), center_nodes, - weight=weight).items()] +def voronoi_cells( + ctx: mgp.ProcCtx, center_nodes: mgp.List[mgp.Vertex], weight: str = "weight" +) -> mgp.Record(center=mgp.Vertex, cell=mgp.List[mgp.Vertex]): + return [ + mgp.Record(center=c1, cell=list(c2)) + for c1, c2 in nx.voronoi_cells(MemgraphMultiDiGraph(ctx=ctx), center_nodes, weight=weight).items() + ] # networkx.algorithms.wiener.wiener_index @mgp.read_proc -def wiener_index(ctx: mgp.ProcCtx, weight: mgp.Nullable[str] = None - ) -> mgp.Record(wiener_index=mgp.Number): - return mgp.Record(wiener_index=nx.wiener_index( - MemgraphMultiDiGraph(ctx=ctx), weight=weight)) +def wiener_index(ctx: mgp.ProcCtx, weight: mgp.Nullable[str] = None) -> mgp.Record(wiener_index=mgp.Number): + return mgp.Record(wiener_index=nx.wiener_index(MemgraphMultiDiGraph(ctx=ctx), weight=weight)) diff --git a/query_modules/wcc.py b/query_modules/wcc.py index 054461af9..a3b83137a 100644 --- a/query_modules/wcc.py +++ b/query_modules/wcc.py @@ -1,23 +1,20 @@ import sys import mgp + try: import networkx as nx except ImportError as import_error: sys.stderr.write( - '\n' - 'NOTE: Please install networkx to be able to use wcc module.\n' - 'Using Python:\n' - + sys.version + - '\n') + "\n" "NOTE: Please install networkx to be able to use wcc module.\n" "Using Python:\n" + sys.version + "\n" + ) raise import_error @mgp.read_proc -def get_components(vertices: mgp.List[mgp.Vertex], - edges: mgp.List[mgp.Edge] - ) -> mgp.Record(n_components=int, - components=mgp.List[mgp.List[mgp.Vertex]]): - ''' +def get_components( + vertices: mgp.List[mgp.Vertex], edges: mgp.List[mgp.Edge] +) -> mgp.Record(n_components=int, components=mgp.List[mgp.List[mgp.Vertex]]): + """ This procedure finds weakly connected components of a given subgraph of a directed graph. @@ -41,7 +38,7 @@ def get_components(vertices: mgp.List[mgp.Vertex], WITH collect(n) AS nodes, collect(e) AS edges CALL wcc.get_components(nodes, edges) YIELD * RETURN n_components, components; - ''' + """ g = nx.DiGraph() g.add_nodes_from(vertices) g.add_edges_from([(edge.from_vertex, edge.to_vertex) for edge in edges]) diff --git a/release/get_version.py b/release/get_version.py index cfce88475..c67bd12db 100755 --- a/release/get_version.py +++ b/release/get_version.py @@ -104,7 +104,9 @@ def retry(retry_limit, timeout=100): except Exception: time.sleep(timeout) return func(*args, **kwargs) + return wrapper + return inner_func @@ -163,8 +165,15 @@ def format_version(variant, version, offering, distance=None, shorthash=None, su # Parse arguments. parser = argparse.ArgumentParser(description="Get the current version of Memgraph.") -parser.add_argument("--open-source", action="store_true", help="set the current offering to 'open-source'") -parser.add_argument("version", help="manual version override, if supplied the version isn't " "determined using git") +parser.add_argument( + "--open-source", + action="store_true", + help="set the current offering to 'open-source'", +) +parser.add_argument( + "version", + help="manual version override, if supplied the version isn't " "determined using git", +) parser.add_argument("suffix", help="custom suffix for the current version being built") parser.add_argument( "--variant", @@ -173,7 +182,9 @@ parser.add_argument( help="which variant of the version string should be generated", ) parser.add_argument( - "--memgraph-root-dir", help="The root directory of the checked out " "Memgraph repository.", default="." + "--memgraph-root-dir", + help="The root directory of the checked out " "Memgraph repository.", + default=".", ) args = parser.parse_args() @@ -256,14 +267,27 @@ for version in versions: if current_version is None: raise Exception("You are attempting to determine the version for a very " "old version of Memgraph!") version, branch, master_branch_merge = current_version -distance = int(get_output("git", "rev-list", "--count", "--first-parent", master_branch_merge + ".." + current_hash)) +distance = int( + get_output( + "git", + "rev-list", + "--count", + "--first-parent", + master_branch_merge + ".." + current_hash, + ) +) version_str = ".".join(map(str, version)) + ".0" if distance == 0: print(format_version(args.variant, version_str, offering, suffix=args.suffix), end="") else: print( format_version( - args.variant, version_str, offering, distance=distance, shorthash=current_hash_short, suffix=args.suffix + args.variant, + version_str, + offering, + distance=distance, + shorthash=current_hash_short, + suffix=args.suffix, ), end="", ) diff --git a/src/auth/models.cpp b/src/auth/models.cpp index bfb508679..9780f84e6 100644 --- a/src/auth/models.cpp +++ b/src/auth/models.cpp @@ -188,7 +188,25 @@ bool operator!=(const Permissions &first, const Permissions &second) { return !( LabelPermissions::LabelPermissions(const std::unordered_map &permissions) : permissions_(permissions) {} -void LabelPermissions::Grant(const std::string &permission) { permissions_[permission] = 1; } +void LabelPermissions::Grant(const std::string &label) { permissions_[label] = 1; } + +void LabelPermissions::Deny(const std::string &label) { permissions_[label] = 0; } + +void LabelPermissions::Revoke(const std::string &label) { permissions_.erase(label); } + +nlohmann::json LabelPermissions::Serialize() const { + nlohmann::json data = nlohmann::json::object(); + data["labelPermissions"] = permissions_; + return data; +} + +LabelPermissions LabelPermissions::Deserialize(const nlohmann::json &data) { + if (!data.is_object()) { + throw AuthException("Couldn't load permissions data!"); + } + + return {data["labelPermissions"]}; +} Role::Role(const std::string &rolename) : rolename_(utils::ToLowerCase(rolename)) {} @@ -208,6 +226,8 @@ nlohmann::json Role::Serialize() const { nlohmann::json data = nlohmann::json::object(); data["rolename"] = rolename_; data["permissions"] = permissions_.Serialize(); + data["labelPermissions"] = labelPermissions_.Serialize(); + return data; } @@ -219,7 +239,9 @@ Role Role::Deserialize(const nlohmann::json &data) { throw AuthException("Couldn't load role data!"); } auto permissions = Permissions::Deserialize(data["permissions"]); - return {data["rolename"], permissions}; + auto labelPermissions = LabelPermissions::Deserialize(data["labelPermissions"]); + + return {data["rolename"], permissions, labelPermissions}; } bool operator==(const Role &first, const Role &second) { @@ -304,6 +326,7 @@ nlohmann::json User::Serialize() const { data["username"] = username_; data["password_hash"] = password_hash_; data["permissions"] = permissions_.Serialize(); + data["labelPermissions"] = labelPermissions_.Serialize(); // The role shouldn't be serialized here, it is stored as a foreign key. return data; } @@ -316,7 +339,9 @@ User User::Deserialize(const nlohmann::json &data) { throw AuthException("Couldn't load user data!"); } auto permissions = Permissions::Deserialize(data["permissions"]); - return {data["username"], data["password_hash"], permissions}; + auto labelPermissions = LabelPermissions::Deserialize(data["labelPermissions"]); + + return {data["username"], data["password_hash"], permissions, labelPermissions}; } bool operator==(const User &first, const User &second) { diff --git a/src/auth/models.hpp b/src/auth/models.hpp index fdd56425f..1003086c9 100644 --- a/src/auth/models.hpp +++ b/src/auth/models.hpp @@ -91,7 +91,7 @@ bool operator!=(const Permissions &first, const Permissions &second); class LabelPermissions final { public: - explicit LabelPermissions(const std::unordered_map &permissions_ = {}); + LabelPermissions(const std::unordered_map &permissions_ = {}); PermissionLevel Has(const std::string &label) const; diff --git a/src/auth/reference_modules/ldap.py b/src/auth/reference_modules/ldap.py index 761db8fd6..e181ce84f 100755 --- a/src/auth/reference_modules/ldap.py +++ b/src/auth/reference_modules/ldap.py @@ -18,19 +18,24 @@ roles_config = config["roles"] # Initialize LDAP server. tls = None if server_config["encryption"] != "disabled": - cert_file = server_config["cert_file"] if server_config["cert_file"] \ - else None + cert_file = server_config["cert_file"] if server_config["cert_file"] else None key_file = server_config["key_file"] if server_config["key_file"] else None ca_file = server_config["ca_file"] if server_config["ca_file"] else None - validate = ssl.CERT_REQUIRED if server_config["validate_cert"] \ - else ssl.CERT_NONE - tls = ldap3.Tls(local_private_key_file=key_file, - local_certificate_file=cert_file, - ca_certs_file=ca_file, - validate=validate) + validate = ssl.CERT_REQUIRED if server_config["validate_cert"] else ssl.CERT_NONE + tls = ldap3.Tls( + local_private_key_file=key_file, + local_certificate_file=cert_file, + ca_certs_file=ca_file, + validate=validate, + ) use_ssl = server_config["encryption"] == "ssl" -server = ldap3.Server(server_config["host"], port=server_config["port"], - tls=tls, use_ssl=use_ssl, get_info=ldap3.ALL) +server = ldap3.Server( + server_config["host"], + port=server_config["port"], + tls=tls, + use_ssl=use_ssl, + get_info=ldap3.ALL, +) # Main authentication/authorization function. @@ -40,14 +45,12 @@ def authenticate(username, password): return {"authenticated": False, "role": ""} # Create the DN of the user - dn = users_config["prefix"] + ldap3.utils.dn.escape_rdn(username) + \ - users_config["suffix"] + dn = users_config["prefix"] + ldap3.utils.dn.escape_rdn(username) + users_config["suffix"] # Bind to the server conn = ldap3.Connection(server, dn, password) if server_config["encryption"] == "starttls" and not conn.start_tls(): - print("ERROR: Couldn't issue STARTTLS to the LDAP server!", - file=sys.stderr) + print("ERROR: Couldn't issue STARTTLS to the LDAP server!", file=sys.stderr) return {"authenticated": False, "role": ""} if not conn.bind(): return {"authenticated": False, "role": ""} @@ -56,25 +59,32 @@ def authenticate(username, password): if roles_config["root_dn"] != "": # search for role search_filter = "(&(objectclass={objclass})({attr}={value}))".format( - objclass=roles_config["root_objectclass"], - attr=roles_config["user_attribute"], - value=ldap3.utils.conv.escape_filter_chars(dn)) - succ = conn.search(roles_config["root_dn"], search_filter, - search_scope=ldap3.LEVEL, - attributes=[roles_config["role_attribute"]]) + objclass=roles_config["root_objectclass"], + attr=roles_config["user_attribute"], + value=ldap3.utils.conv.escape_filter_chars(dn), + ) + succ = conn.search( + roles_config["root_dn"], + search_filter, + search_scope=ldap3.LEVEL, + attributes=[roles_config["role_attribute"]], + ) if not succ or len(conn.entries) == 0: return {"authenticated": True, "role": ""} if len(conn.entries) > 1: - roles = list(map(lambda x: x[roles_config["role_attribute"]].value, - conn.entries)) + roles = list(map(lambda x: x[roles_config["role_attribute"]].value, conn.entries)) # Because we don't know exactly which role the user should have # we authorize the user with an empty role. - print("WARNING: Found more than one role for " - "user '" + username + "':", ", ".join(roles) + "!", - file=sys.stderr) + print( + "WARNING: Found more than one role for " "user '" + username + "':", + ", ".join(roles) + "!", + file=sys.stderr, + ) return {"authenticated": True, "role": ""} - return {"authenticated": True, - "role": conn.entries[0][roles_config["role_attribute"]].value} + return { + "authenticated": True, + "role": conn.entries[0][roles_config["role_attribute"]].value, + } else: return {"authenticated": True, "role": ""} diff --git a/src/memgraph.cpp b/src/memgraph.cpp index 8035af77e..955379024 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -749,21 +749,18 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler { void GrantPrivilege(const std::string &user_or_role, const std::vector &privileges, const std::vector &labels) override { - EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) { + EditPermissions(user_or_role, privileges, labels, [](auto *permissions, const auto &permission) { // TODO (mferencevic): should we first check that the // privilege is granted/denied/revoked before // unconditionally granting/denying/revoking it? permissions->Grant(permission); }); - if (labels.size() > 0) { - EditLabels(user_or_role, labels, - [](auto *labelPermissions, const auto &label) { labelPermissions->Grant(label); }); - } } void DenyPrivilege(const std::string &user_or_role, - const std::vector &privileges) override { - EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) { + const std::vector &privileges, + const std::vector &labels) override { + EditPermissions(user_or_role, privileges, labels, [](auto *permissions, const auto &permission) { // TODO (mferencevic): should we first check that the // privilege is granted/denied/revoked before // unconditionally granting/denying/revoking it? @@ -772,8 +769,9 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler { } void RevokePrivilege(const std::string &user_or_role, - const std::vector &privileges) override { - EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) { + const std::vector &privileges, + const std::vector &labels) override { + EditPermissions(user_or_role, privileges, labels, [](auto *permissions, const auto &permission) { // TODO (mferencevic): should we first check that the // privilege is granted/denied/revoked before // unconditionally granting/denying/revoking it? @@ -784,7 +782,8 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler { private: template void EditPermissions(const std::string &user_or_role, - const std::vector &privileges, const TEditFun &edit_fun) { + const std::vector &privileges, + const std::vector &labels, const TEditFun &edit_fun) { if (!std::regex_match(user_or_role, name_regex_)) { throw memgraph::query::QueryRuntimeException("Invalid user or role name."); } @@ -804,36 +803,14 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler { for (const auto &permission : permissions) { edit_fun(&user->permissions(), permission); } - locked_auth->SaveUser(*user); - } else { - for (const auto &permission : permissions) { - edit_fun(&role->permissions(), permission); - } - locked_auth->SaveRole(*role); - } - } catch (const memgraph::auth::AuthException &e) { - throw memgraph::query::QueryRuntimeException(e.what()); - } - } - - template - void EditLabels(const std::string &user_or_role, const std::vector &labels, const TEditFun &edit_fun) { - if (!std::regex_match(user_or_role, name_regex_)) { - throw memgraph::query::QueryRuntimeException("Invalid user or role name."); - } - try { - auto locked_auth = auth_->Lock(); - auto user = locked_auth->GetUser(user_or_role); - auto role = locked_auth->GetRole(user_or_role); - if (!user && !role) { - throw memgraph::query::QueryRuntimeException("User or role '{}' doesn't exist.", user_or_role); - } - if (user) { for (const auto &label : labels) { edit_fun(&user->labelPermissions(), label); } locked_auth->SaveUser(*user); } else { + for (const auto &permission : permissions) { + edit_fun(&role->permissions(), permission); + } for (const auto &label : labels) { edit_fun(&role->labelPermissions(), label); } diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index 162552b0f..6f57b7683 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -1307,7 +1307,11 @@ antlrcpp::Any CypherMainVisitor::visitDenyPrivilege(MemgraphCypher::DenyPrivileg auth->user_or_role_ = ctx->userOrRole->accept(this).as(); if (ctx->privilegeList()) { for (auto *privilege : ctx->privilegeList()->privilege()) { - auth->privileges_.push_back(privilege->accept(this)); + if (privilege->LABELS()) { + auth->labels_ = privilege->labelList()->accept(this).as>(); + } else { + auth->privileges_.push_back(privilege->accept(this)); + } } } else { /* deny all privileges */ @@ -1325,7 +1329,11 @@ antlrcpp::Any CypherMainVisitor::visitRevokePrivilege(MemgraphCypher::RevokePriv auth->user_or_role_ = ctx->userOrRole->accept(this).as(); if (ctx->privilegeList()) { for (auto *privilege : ctx->privilegeList()->privilege()) { - auth->privileges_.push_back(privilege->accept(this)); + if (privilege->LABELS()) { + auth->labels_ = privilege->labelList()->accept(this).as>(); + } else { + auth->privileges_.push_back(privilege->accept(this)); + } } } else { /* revoke all privileges */ diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index ed5b46e42..f07d6e414 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -393,14 +393,14 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Pa }; return callback; case AuthQuery::Action::DENY_PRIVILEGE: - callback.fn = [auth, user_or_role, privileges] { - auth->DenyPrivilege(user_or_role, privileges); + callback.fn = [auth, user_or_role, privileges, labels] { + auth->DenyPrivilege(user_or_role, privileges, labels); return std::vector>(); }; return callback; case AuthQuery::Action::REVOKE_PRIVILEGE: { - callback.fn = [auth, user_or_role, privileges] { - auth->RevokePrivilege(user_or_role, privileges); + callback.fn = [auth, user_or_role, privileges, labels] { + auth->RevokePrivilege(user_or_role, privileges, labels); return std::vector>(); }; return callback; diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index ff240a327..e893d8e5d 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -103,11 +103,12 @@ class AuthQueryHandler { const std::vector &labels) = 0; /// @throw QueryRuntimeException if an error ocurred. - virtual void DenyPrivilege(const std::string &user_or_role, const std::vector &privileges) = 0; + virtual void DenyPrivilege(const std::string &user_or_role, const std::vector &privileges, + const std::vector &labels) = 0; /// @throw QueryRuntimeException if an error ocurred. - virtual void RevokePrivilege(const std::string &user_or_role, - const std::vector &privileges) = 0; + virtual void RevokePrivilege(const std::string &user_or_role, const std::vector &privileges, + const std::vector &labels) = 0; }; enum class QueryHandlerResult { COMMIT, ABORT, NOTHING }; diff --git a/tests/drivers/python/v4_1/docs_how_to_query.py b/tests/drivers/python/v4_1/docs_how_to_query.py index 6bfec3b0a..272f0ea93 100644 --- a/tests/drivers/python/v4_1/docs_how_to_query.py +++ b/tests/drivers/python/v4_1/docs_how_to_query.py @@ -15,33 +15,31 @@ import sys from neo4j import GraphDatabase, basic_auth -driver = GraphDatabase.driver('bolt://localhost:7687', - auth=basic_auth('', ''), - encrypted=False) +driver = GraphDatabase.driver("bolt://localhost:7687", auth=basic_auth("", ""), encrypted=False) session = driver.session() -session.run('MATCH (n) DETACH DELETE n').consume() -print('Database cleared.') +session.run("MATCH (n) DETACH DELETE n").consume() +print("Database cleared.") session.run('CREATE (alice:Person {name: "Alice", age: 22})').consume() -print('Record created.') +print("Record created.") -node = session.run('MATCH (n) RETURN n').single()['n'] -print('Record matched.') +node = session.run("MATCH (n) RETURN n").single()["n"] +print("Record matched.") label = list(node.labels)[0] -name = node['name'] -age = node['age'] +name = node["name"] +age = node["age"] -if label != 'Person' or name != 'Alice' or age != 22: - print('Data does not match') +if label != "Person" or name != "Alice" or age != 22: + print("Data does not match") sys.exit(1) -print('Label: %s' % label) -print('name: %s' % name) -print('age: %s' % age) +print("Label: %s" % label) +print("name: %s" % name) +print("age: %s" % age) session.close() driver.close() -print('All ok!') +print("All ok!") diff --git a/tests/drivers/python/v4_1/max_query_length.py b/tests/drivers/python/v4_1/max_query_length.py index 8954744b9..2dcf9ba54 100644 --- a/tests/drivers/python/v4_1/max_query_length.py +++ b/tests/drivers/python/v4_1/max_query_length.py @@ -14,9 +14,7 @@ from neo4j import GraphDatabase, basic_auth -driver = GraphDatabase.driver("bolt://localhost:7687", - auth=basic_auth("", ""), - encrypted=False) +driver = GraphDatabase.driver("bolt://localhost:7687", auth=basic_auth("", ""), encrypted=False) query_template = 'CREATE (n {name:"%s"})' template_size = len(query_template) - 2 # because of %s @@ -26,10 +24,11 @@ max_len = 1000000 # binary search because we have to find the maximum size (in number of chars) # of a query that can be executed via driver while True: - assert min_len > 0 and max_len > 0, \ - "The lengths have to be positive values! If this happens something" \ - " is terrible wrong with min & max lengths OR the database" \ + assert min_len > 0 and max_len > 0, ( + "The lengths have to be positive values! If this happens something" + " is terrible wrong with min & max lengths OR the database" " isn't available." + ) property_size = (max_len + min_len) // 2 try: driver.session().run(query_template % ("a" * property_size)).consume() @@ -42,8 +41,7 @@ while True: assert property_size == max_len, "max_len probably has to be increased!" -print("\nThe max length of a query from Python driver is: %s\n" % - (template_size + property_size)) +print("\nThe max length of a query from Python driver is: %s\n" % (template_size + property_size)) # sessions are not closed bacause all sessions that are # executed with wrong query size might be broken diff --git a/tests/drivers/python/v4_1/transactions.py b/tests/drivers/python/v4_1/transactions.py index 2e3dd1762..9c133c5ca 100644 --- a/tests/drivers/python/v4_1/transactions.py +++ b/tests/drivers/python/v4_1/transactions.py @@ -15,6 +15,7 @@ from neo4j import GraphDatabase, basic_auth from neo4j.exceptions import ClientError, TransientError + def tx_error(tx, name, name2): a = tx.run("CREATE (a:Person {name: $name}) RETURN a", name=name).value() print(a[0]) @@ -22,17 +23,19 @@ def tx_error(tx, name, name2): a = tx.run("CREATE (a:Person {name: $name}) RETURN a", name=name2).value() print(a[0]) + def tx_good(tx, name, name2): a = tx.run("CREATE (a:Person {name: $name}) RETURN a", name=name).value() print(a[0]) a = tx.run("CREATE (a:Person {name: $name}) RETURN a", name=name2).value() print(a[0]) + def tx_too_long(tx): tx.run("MATCH (a), (b), (c), (d), (e), (f) RETURN COUNT(*) AS cnt") -with GraphDatabase.driver("bolt://localhost:7687", auth=basic_auth("", ""), - encrypted=False) as driver: + +with GraphDatabase.driver("bolt://localhost:7687", auth=basic_auth("", ""), encrypted=False) as driver: def add_person(f, name, name2): with driver.session() as session: diff --git a/tests/e2e/magic_functions/function_example.py b/tests/e2e/magic_functions/function_example.py index 0ca208b7d..526d47f0f 100644 --- a/tests/e2e/magic_functions/function_example.py +++ b/tests/e2e/magic_functions/function_example.py @@ -106,6 +106,7 @@ def test_try_to_write(connection, function_type): f"MATCH (n) RETURN {function_type}_write.try_to_write(n, 'property', 1);", ) + @pytest.mark.parametrize("function_type", ["py", "c"]) def test_case_sensitivity(connection, function_type): cursor = connection.cursor() diff --git a/tests/e2e/replication/show_while_creating_invalid_state.py b/tests/e2e/replication/show_while_creating_invalid_state.py index 151119be8..a84ade382 100644 --- a/tests/e2e/replication/show_while_creating_invalid_state.py +++ b/tests/e2e/replication/show_while_creating_invalid_state.py @@ -141,10 +141,16 @@ def test_add_replica_invalid_timeout(connection): cursor = connection(7687, "main").cursor() with pytest.raises(mgclient.DatabaseError): - execute_and_fetch_all(cursor, "REGISTER REPLICA replica_1 SYNC WITH TIMEOUT 0 TO '127.0.0.1:10001';") + execute_and_fetch_all( + cursor, + "REGISTER REPLICA replica_1 SYNC WITH TIMEOUT 0 TO '127.0.0.1:10001';", + ) with pytest.raises(mgclient.DatabaseError): - execute_and_fetch_all(cursor, "REGISTER REPLICA replica_1 SYNC WITH TIMEOUT -5 TO '127.0.0.1:10001';") + execute_and_fetch_all( + cursor, + "REGISTER REPLICA replica_1 SYNC WITH TIMEOUT -5 TO '127.0.0.1:10001';", + ) actual_data = execute_and_fetch_all(cursor, "SHOW REPLICAS;") assert 0 == len(actual_data) diff --git a/tests/e2e/streams/common.py b/tests/e2e/streams/common.py index 1b3315071..a1b9aac10 100644 --- a/tests/e2e/streams/common.py +++ b/tests/e2e/streams/common.py @@ -115,7 +115,10 @@ def start_stream(cursor, stream_name): def start_stream_with_limit(cursor, stream_name, batch_limit, timeout=None): if timeout is not None: - execute_and_fetch_all(cursor, f"START STREAM {stream_name} BATCH_LIMIT {batch_limit} TIMEOUT {timeout} ") + execute_and_fetch_all( + cursor, + f"START STREAM {stream_name} BATCH_LIMIT {batch_limit} TIMEOUT {timeout} ", + ) else: execute_and_fetch_all(cursor, f"START STREAM {stream_name} BATCH_LIMIT {batch_limit}") @@ -156,7 +159,12 @@ def pulsar_default_namespace_topic(topic): def test_start_and_stop_during_check( - operation, connection, stream_creator, message_sender, already_stopped_error, batchSize + operation, + connection, + stream_creator, + message_sender, + already_stopped_error, + batchSize, ): # This test is quite complex. The goal is to call START/STOP queries # while a CHECK query is waiting for its result. Because the Global @@ -317,24 +325,42 @@ def test_check_stream_same_number_of_queries_than_messages(connection, stream_cr expected_queries_and_raw_messages_1 = ( [ # queries - {PARAMETERS_LITERAL: {"value": "Parameter: 01"}, QUERY_LITERAL: "Message: 01"}, - {PARAMETERS_LITERAL: {"value": "Parameter: 02"}, QUERY_LITERAL: "Message: 02"}, + { + PARAMETERS_LITERAL: {"value": "Parameter: 01"}, + QUERY_LITERAL: "Message: 01", + }, + { + PARAMETERS_LITERAL: {"value": "Parameter: 02"}, + QUERY_LITERAL: "Message: 02", + }, ], ["01", "02"], # raw message ) expected_queries_and_raw_messages_2 = ( [ # queries - {PARAMETERS_LITERAL: {"value": "Parameter: 03"}, QUERY_LITERAL: "Message: 03"}, - {PARAMETERS_LITERAL: {"value": "Parameter: 04"}, QUERY_LITERAL: "Message: 04"}, + { + PARAMETERS_LITERAL: {"value": "Parameter: 03"}, + QUERY_LITERAL: "Message: 03", + }, + { + PARAMETERS_LITERAL: {"value": "Parameter: 04"}, + QUERY_LITERAL: "Message: 04", + }, ], ["03", "04"], # raw message ) expected_queries_and_raw_messages_3 = ( [ # queries - {PARAMETERS_LITERAL: {"value": "Parameter: 05"}, QUERY_LITERAL: "Message: 05"}, - {PARAMETERS_LITERAL: {"value": "Parameter: 06"}, QUERY_LITERAL: "Message: 06"}, + { + PARAMETERS_LITERAL: {"value": "Parameter: 05"}, + QUERY_LITERAL: "Message: 05", + }, + { + PARAMETERS_LITERAL: {"value": "Parameter: 06"}, + QUERY_LITERAL: "Message: 06", + }, ], ["05", "06"], # raw message ) @@ -389,20 +415,32 @@ def test_check_stream_different_number_of_queries_than_messages(connection, stre expected_queries_and_raw_messages_2 = ( [ # queries - {PARAMETERS_LITERAL: {"value": "Parameter: 03"}, QUERY_LITERAL: "Message: 03"}, - {PARAMETERS_LITERAL: {"value": "Parameter: 04"}, QUERY_LITERAL: "Message: 04"}, + { + PARAMETERS_LITERAL: {"value": "Parameter: 03"}, + QUERY_LITERAL: "Message: 03", + }, + { + PARAMETERS_LITERAL: {"value": "Parameter: 04"}, + QUERY_LITERAL: "Message: 04", + }, ], ["03", "04"], # raw message ) expected_queries_and_raw_messages_3 = ( [ # queries - {PARAMETERS_LITERAL: {"value": "Parameter: b_05"}, QUERY_LITERAL: "Message: b_05"}, + { + PARAMETERS_LITERAL: {"value": "Parameter: b_05"}, + QUERY_LITERAL: "Message: b_05", + }, { PARAMETERS_LITERAL: {"value": "Parameter: extra_b_05"}, QUERY_LITERAL: "Message: extra_b_05", }, - {PARAMETERS_LITERAL: {"value": "Parameter: 06"}, QUERY_LITERAL: "Message: 06"}, + { + PARAMETERS_LITERAL: {"value": "Parameter: 06"}, + QUERY_LITERAL: "Message: 06", + }, ], ["b_05", "06"], # raw message ) @@ -467,7 +505,10 @@ def test_start_stream_with_batch_limit_reaching_timeout(connection, stream_creat start_time = time.time() with pytest.raises(mgclient.DatabaseError): - execute_and_fetch_all(cursor, f"START STREAM {STREAM_NAME} BATCH_LIMIT {BATCH_LIMIT} TIMEOUT {TIMEOUT}") + execute_and_fetch_all( + cursor, + f"START STREAM {STREAM_NAME} BATCH_LIMIT {BATCH_LIMIT} TIMEOUT {TIMEOUT}", + ) end_time = time.time() assert ( @@ -483,7 +524,10 @@ def test_start_stream_with_batch_limit_while_check_running( def start_check_stream(stream_name, batch_limit, timeout): connection = connect() cursor = connection.cursor() - execute_and_fetch_all(cursor, f"CHECK STREAM {stream_name} BATCH_LIMIT {batch_limit} TIMEOUT {timeout}") + execute_and_fetch_all( + cursor, + f"CHECK STREAM {stream_name} BATCH_LIMIT {batch_limit} TIMEOUT {timeout}", + ) def start_new_stream_with_limit(stream_name, batch_limit, timeout): connection = connect() @@ -518,7 +562,9 @@ def test_start_stream_with_batch_limit_while_check_running( # 2/ thread_stream_running = Process( - target=start_new_stream_with_limit, daemon=True, args=(STREAM_NAME, BATCH_LIMIT + 1, TIMEOUT) + target=start_new_stream_with_limit, + daemon=True, + args=(STREAM_NAME, BATCH_LIMIT + 1, TIMEOUT), ) # Sending BATCH_LIMIT + 1 messages as BATCH_LIMIT messages have already been sent during the CHECK STREAM (and not consumed) thread_stream_running.start() time.sleep(2) @@ -541,7 +587,10 @@ def test_check_while_stream_with_batch_limit_running(connection, stream_creator, def start_check_stream(stream_name, batch_limit, timeout): connection = connect() cursor = connection.cursor() - execute_and_fetch_all(cursor, f"CHECK STREAM {stream_name} BATCH_LIMIT {batch_limit} TIMEOUT {timeout}") + execute_and_fetch_all( + cursor, + f"CHECK STREAM {stream_name} BATCH_LIMIT {batch_limit} TIMEOUT {timeout}", + ) STREAM_NAME = "test_batch_limit_and_check" BATCH_LIMIT = 1 @@ -553,7 +602,9 @@ def test_check_while_stream_with_batch_limit_running(connection, stream_creator, # 1/ thread_stream_running = Process( - target=start_new_stream_with_limit, daemon=True, args=(STREAM_NAME, BATCH_LIMIT, TIMEOUT) + target=start_new_stream_with_limit, + daemon=True, + args=(STREAM_NAME, BATCH_LIMIT, TIMEOUT), ) start_time = time.time() thread_stream_running.start() @@ -561,7 +612,10 @@ def test_check_while_stream_with_batch_limit_running(connection, stream_creator, assert get_is_running(cursor, STREAM_NAME) with pytest.raises(mgclient.DatabaseError): - execute_and_fetch_all(cursor, f"CHECK STREAM {STREAM_NAME} BATCH_LIMIT {BATCH_LIMIT} TIMEOUT {TIMEOUT}") + execute_and_fetch_all( + cursor, + f"CHECK STREAM {STREAM_NAME} BATCH_LIMIT {BATCH_LIMIT} TIMEOUT {TIMEOUT}", + ) end_time = time.time() assert (end_time - start_time) < 0.8 * TIMEOUT, "The CHECK STREAM has probably thrown due to timeout!" @@ -632,7 +686,10 @@ def test_check_stream_with_batch_limit_with_invalid_batch_limit(connection, stre start_time = time.time() with pytest.raises(mgclient.DatabaseError): - execute_and_fetch_all(cursor, f"CHECK STREAM {STREAM_NAME} BATCH_LIMIT {batch_limit} TIMEOUT {TIMEOUT}") + execute_and_fetch_all( + cursor, + f"CHECK STREAM {STREAM_NAME} BATCH_LIMIT {batch_limit} TIMEOUT {TIMEOUT}", + ) end_time = time.time() assert (end_time - start_time) < 0.8 * TIMEOUT_IN_SECONDS, "The CHECK STREAM has probably thrown due to timeout!" @@ -642,7 +699,10 @@ def test_check_stream_with_batch_limit_with_invalid_batch_limit(connection, stre start_time = time.time() with pytest.raises(mgclient.DatabaseError): - execute_and_fetch_all(cursor, f"CHECK STREAM {STREAM_NAME} BATCH_LIMIT {batch_limit} TIMEOUT {TIMEOUT}") + execute_and_fetch_all( + cursor, + f"CHECK STREAM {STREAM_NAME} BATCH_LIMIT {batch_limit} TIMEOUT {TIMEOUT}", + ) end_time = time.time() assert (end_time - start_time) < 0.8 * TIMEOUT_IN_SECONDS, "The CHECK STREAM has probably thrown due to timeout!" diff --git a/tests/e2e/streams/conftest.py b/tests/e2e/streams/conftest.py index a26138b53..551c145af 100644 --- a/tests/e2e/streams/conftest.py +++ b/tests/e2e/streams/conftest.py @@ -37,29 +37,22 @@ def connection(): def get_topics(num): - return [f'topic_{i}' for i in range(num)] + return [f"topic_{i}" for i in range(num)] @pytest.fixture(scope="function") def kafka_topics(): - admin_client = KafkaAdminClient( - bootstrap_servers="localhost:9092", - client_id="test") + admin_client = KafkaAdminClient(bootstrap_servers="localhost:9092", client_id="test") # The issue arises if we remove default kafka topics, e.g. # "__consumer_offsets" - previous_topics = [ - topic for topic in admin_client.list_topics() if topic != "__consumer_offsets"] + previous_topics = [topic for topic in admin_client.list_topics() if topic != "__consumer_offsets"] if previous_topics: admin_client.delete_topics(topics=previous_topics, timeout_ms=5000) topics = get_topics(3) topics_to_create = [] for topic in topics: - topics_to_create.append( - NewTopic( - name=topic, - num_partitions=1, - replication_factor=1)) + topics_to_create.append(NewTopic(name=topic, num_partitions=1, replication_factor=1)) admin_client.create_topics(new_topics=topics_to_create, timeout_ms=5000) yield topics @@ -80,6 +73,5 @@ def pulsar_client(): def pulsar_topics(): topics = get_topics(3) for topic in topics: - requests.delete( - f'http://127.0.0.1:6652/admin/v2/persistent/public/default/{topic}?force=true') + requests.delete(f"http://127.0.0.1:6652/admin/v2/persistent/public/default/{topic}?force=true") yield topics diff --git a/tests/e2e/streams/kafka_streams_tests.py b/tests/e2e/streams/kafka_streams_tests.py index 65ad1b6b0..b74a41514 100755 --- a/tests/e2e/streams/kafka_streams_tests.py +++ b/tests/e2e/streams/kafka_streams_tests.py @@ -20,7 +20,10 @@ import common TRANSFORMATIONS_TO_CHECK_C = ["c_transformations.empty_transformation"] -TRANSFORMATIONS_TO_CHECK_PY = ["kafka_transform.simple", "kafka_transform.with_parameters"] +TRANSFORMATIONS_TO_CHECK_PY = [ + "kafka_transform.simple", + "kafka_transform.with_parameters", +] @pytest.mark.parametrize("transformation", TRANSFORMATIONS_TO_CHECK_PY) @@ -463,7 +466,11 @@ def test_start_stream_with_batch_limit_while_check_running(kafka_producer, kafka kafka_producer.send(kafka_topics[0], message).get(timeout=6000) def setup_function(start_check_stream, cursor, stream_name, batch_limit, timeout): - thread_stream_check = Process(target=start_check_stream, daemon=True, args=(stream_name, batch_limit, timeout)) + thread_stream_check = Process( + target=start_check_stream, + daemon=True, + args=(stream_name, batch_limit, timeout), + ) thread_stream_check.start() time.sleep(2) assert common.get_is_running(cursor, stream_name) diff --git a/tests/e2e/streams/pulsar_streams_tests.py b/tests/e2e/streams/pulsar_streams_tests.py index e0aecbebc..5131a04f4 100755 --- a/tests/e2e/streams/pulsar_streams_tests.py +++ b/tests/e2e/streams/pulsar_streams_tests.py @@ -18,13 +18,20 @@ import time from multiprocessing import Process, Value import common -TRANSFORMATIONS_TO_CHECK = ["pulsar_transform.simple", "pulsar_transform.with_parameters"] +TRANSFORMATIONS_TO_CHECK = [ + "pulsar_transform.simple", + "pulsar_transform.with_parameters", +] def check_vertex_exists_with_topic_and_payload(cursor, topic, payload_byte): decoded_payload = payload_byte.decode("utf-8") common.check_vertex_exists_with_properties( - cursor, {"topic": f'"{common.pulsar_default_namespace_topic(topic)}"', "payload": f'"{decoded_payload}"'} + cursor, + { + "topic": f'"{common.pulsar_default_namespace_topic(topic)}"', + "payload": f'"{decoded_payload}"', + }, ) @@ -100,7 +107,8 @@ def test_start_from_latest_messages(pulsar_client, pulsar_topics, connection): assert len(vertices_with_msg) == 0 producer = pulsar_client.create_producer( - common.pulsar_default_namespace_topic(pulsar_topics[0]), send_timeout_millis=60000 + common.pulsar_default_namespace_topic(pulsar_topics[0]), + send_timeout_millis=60000, ) producer.send(common.SIMPLE_MSG) @@ -157,7 +165,8 @@ def test_check_stream(pulsar_client, pulsar_topics, connection, transformation): time.sleep(1) producer = pulsar_client.create_producer( - common.pulsar_default_namespace_topic(pulsar_topics[0]), send_timeout_millis=60000 + common.pulsar_default_namespace_topic(pulsar_topics[0]), + send_timeout_millis=60000, ) producer.send(common.SIMPLE_MSG) check_vertex_exists_with_topic_and_payload(cursor, pulsar_topics[0], common.SIMPLE_MSG) @@ -263,7 +272,8 @@ def test_start_and_stop_during_check(pulsar_client, pulsar_topics, connection, o return f"CREATE PULSAR STREAM {stream_name} TOPICS {pulsar_topics[0]} TRANSFORM pulsar_transform.simple BATCH_SIZE {BATCH_SIZE}" producer = pulsar_client.create_producer( - common.pulsar_default_namespace_topic(pulsar_topics[0]), send_timeout_millis=60000 + common.pulsar_default_namespace_topic(pulsar_topics[0]), + send_timeout_millis=60000, ) def message_sender(msg): @@ -311,7 +321,8 @@ def test_restart_after_error(pulsar_client, pulsar_topics, connection): time.sleep(1) producer = pulsar_client.create_producer( - common.pulsar_default_namespace_topic(pulsar_topics[0]), send_timeout_millis=60000 + common.pulsar_default_namespace_topic(pulsar_topics[0]), + send_timeout_millis=60000, ) producer.send(common.SIMPLE_MSG) assert common.timed_wait(lambda: not common.get_is_running(cursor, "test_stream")) @@ -351,7 +362,8 @@ def test_start_stream_with_batch_limit(pulsar_client, pulsar_topics, connection) return f"CREATE PULSAR STREAM {stream_name} TOPICS {pulsar_topics[0]} TRANSFORM pulsar_transform.simple BATCH_SIZE 1" producer = pulsar_client.create_producer( - common.pulsar_default_namespace_topic(pulsar_topics[0]), send_timeout_millis=60000 + common.pulsar_default_namespace_topic(pulsar_topics[0]), + send_timeout_millis=60000, ) def messages_sender(nof_messages): @@ -386,7 +398,8 @@ def test_start_stream_with_batch_limit_while_check_running(pulsar_client, pulsar return f"CREATE PULSAR STREAM {stream_name} TOPICS {pulsar_topics[0]} TRANSFORM pulsar_transform.simple BATCH_SIZE 1" producer = pulsar_client.create_producer( - common.pulsar_default_namespace_topic(pulsar_topics[0]), send_timeout_millis=60000 + common.pulsar_default_namespace_topic(pulsar_topics[0]), + send_timeout_millis=60000, ) def message_sender(message): @@ -402,7 +415,8 @@ def test_check_while_stream_with_batch_limit_running(pulsar_client, pulsar_topic return f"CREATE PULSAR STREAM {stream_name} TOPICS {pulsar_topics[0]} TRANSFORM pulsar_transform.simple BATCH_SIZE 1" producer = pulsar_client.create_producer( - common.pulsar_default_namespace_topic(pulsar_topics[0]), send_timeout_millis=60000 + common.pulsar_default_namespace_topic(pulsar_topics[0]), + send_timeout_millis=60000, ) def message_sender(message): @@ -420,7 +434,8 @@ def test_check_stream_same_number_of_queries_than_messages(pulsar_client, pulsar return f"CREATE PULSAR STREAM {stream_name} TOPICS {pulsar_topics[0]} TRANSFORM {TRANSFORMATION} BATCH_INTERVAL 3000 BATCH_SIZE {batch_size} " producer = pulsar_client.create_producer( - common.pulsar_default_namespace_topic(pulsar_topics[0]), send_timeout_millis=60000 + common.pulsar_default_namespace_topic(pulsar_topics[0]), + send_timeout_millis=60000, ) def message_sender(msg): @@ -438,7 +453,8 @@ def test_check_stream_different_number_of_queries_than_messages(pulsar_client, p return f"CREATE PULSAR STREAM {stream_name} TOPICS {pulsar_topics[0]} TRANSFORM {TRANSFORMATION} BATCH_INTERVAL 3000 BATCH_SIZE {batch_size} " producer = pulsar_client.create_producer( - common.pulsar_default_namespace_topic(pulsar_topics[0]), send_timeout_millis=60000 + common.pulsar_default_namespace_topic(pulsar_topics[0]), + send_timeout_millis=60000, ) def message_sender(msg): diff --git a/tests/e2e/streams/streams_owner_tests.py b/tests/e2e/streams/streams_owner_tests.py index 6fd658fa6..8a21c2cec 100644 --- a/tests/e2e/streams/streams_owner_tests.py +++ b/tests/e2e/streams/streams_owner_tests.py @@ -15,6 +15,7 @@ import time import mgclient import common + def get_cursor_with_user(username): connection = common.connect(username=username, password="") return connection.cursor() @@ -22,23 +23,21 @@ def get_cursor_with_user(username): def create_admin_user(cursor, admin_user): common.execute_and_fetch_all(cursor, f"CREATE USER {admin_user}") - common.execute_and_fetch_all( - cursor, f"GRANT ALL PRIVILEGES TO {admin_user}") + common.execute_and_fetch_all(cursor, f"GRANT ALL PRIVILEGES TO {admin_user}") def create_stream_user(cursor, stream_user): common.execute_and_fetch_all(cursor, f"CREATE USER {stream_user}") - common.execute_and_fetch_all( - cursor, f"GRANT STREAM TO {stream_user}") + common.execute_and_fetch_all(cursor, f"GRANT STREAM TO {stream_user}") def test_ownerless_stream(kafka_producer, kafka_topics, connection): assert len(kafka_topics) > 0 userless_cursor = connection.cursor() - common.execute_and_fetch_all(userless_cursor, - "CREATE KAFKA STREAM ownerless " - f"TOPICS {kafka_topics[0]} " - f"TRANSFORM kafka_transform.simple") + common.execute_and_fetch_all( + userless_cursor, + "CREATE KAFKA STREAM ownerless " f"TOPICS {kafka_topics[0]} " f"TRANSFORM kafka_transform.simple", + ) common.start_stream(userless_cursor, "ownerless") time.sleep(1) @@ -46,11 +45,9 @@ def test_ownerless_stream(kafka_producer, kafka_topics, connection): create_admin_user(userless_cursor, admin_user) kafka_producer.send(kafka_topics[0], b"first message").get(timeout=60) - assert common.timed_wait( - lambda: not common.get_is_running(userless_cursor, "ownerless")) + assert common.timed_wait(lambda: not common.get_is_running(userless_cursor, "ownerless")) - assert len(common.execute_and_fetch_all( - userless_cursor, "MATCH (n) RETURN n")) == 0 + assert len(common.execute_and_fetch_all(userless_cursor, "MATCH (n) RETURN n")) == 0 common.execute_and_fetch_all(userless_cursor, f"DROP USER {admin_user}") common.start_stream(userless_cursor, "ownerless") @@ -58,11 +55,9 @@ def test_ownerless_stream(kafka_producer, kafka_topics, connection): second_message = b"second message" kafka_producer.send(kafka_topics[0], second_message).get(timeout=60) - common.kafka_check_vertex_exists_with_topic_and_payload( - userless_cursor, kafka_topics[0], second_message) + common.kafka_check_vertex_exists_with_topic_and_payload(userless_cursor, kafka_topics[0], second_message) - assert len(common.execute_and_fetch_all( - userless_cursor, "MATCH (n) RETURN n")) == 1 + assert len(common.execute_and_fetch_all(userless_cursor, "MATCH (n) RETURN n")) == 1 def test_owner_is_shown(kafka_topics, connection): @@ -73,12 +68,16 @@ def test_owner_is_shown(kafka_topics, connection): create_stream_user(userless_cursor, stream_user) stream_cursor = get_cursor_with_user(stream_user) - common.execute_and_fetch_all(stream_cursor, "CREATE KAFKA STREAM test " - f"TOPICS {kafka_topics[0]} " - f"TRANSFORM kafka_transform.simple") + common.execute_and_fetch_all( + stream_cursor, + "CREATE KAFKA STREAM test " f"TOPICS {kafka_topics[0]} " f"TRANSFORM kafka_transform.simple", + ) - common.check_stream_info(userless_cursor, "test", ("test", "kafka", 100, 1000, - "kafka_transform.simple", stream_user, False)) + common.check_stream_info( + userless_cursor, + "test", + ("test", "kafka", 100, 1000, "kafka_transform.simple", stream_user, False), + ) def test_insufficient_privileges(kafka_producer, kafka_topics, connection): @@ -93,10 +92,10 @@ def test_insufficient_privileges(kafka_producer, kafka_topics, connection): create_stream_user(userless_cursor, stream_user) stream_cursor = get_cursor_with_user(stream_user) - common.execute_and_fetch_all(stream_cursor, - "CREATE KAFKA STREAM insufficient_test " - f"TOPICS {kafka_topics[0]} " - f"TRANSFORM kafka_transform.simple") + common.execute_and_fetch_all( + stream_cursor, + "CREATE KAFKA STREAM insufficient_test " f"TOPICS {kafka_topics[0]} " f"TRANSFORM kafka_transform.simple", + ) # the stream is started by admin, but should check against the owner # privileges @@ -104,24 +103,19 @@ def test_insufficient_privileges(kafka_producer, kafka_topics, connection): time.sleep(1) kafka_producer.send(kafka_topics[0], b"first message").get(timeout=60) - assert common.timed_wait( - lambda: not common.get_is_running(userless_cursor, "insufficient_test")) + assert common.timed_wait(lambda: not common.get_is_running(userless_cursor, "insufficient_test")) - assert len(common.execute_and_fetch_all( - userless_cursor, "MATCH (n) RETURN n")) == 0 + assert len(common.execute_and_fetch_all(userless_cursor, "MATCH (n) RETURN n")) == 0 - common.execute_and_fetch_all( - admin_cursor, f"GRANT CREATE TO {stream_user}") + common.execute_and_fetch_all(admin_cursor, f"GRANT CREATE TO {stream_user}") common.start_stream(userless_cursor, "insufficient_test") time.sleep(1) second_message = b"second message" kafka_producer.send(kafka_topics[0], second_message).get(timeout=60) - common.kafka_check_vertex_exists_with_topic_and_payload( - userless_cursor, kafka_topics[0], second_message) + common.kafka_check_vertex_exists_with_topic_and_payload(userless_cursor, kafka_topics[0], second_message) - assert len(common.execute_and_fetch_all( - userless_cursor, "MATCH (n) RETURN n")) == 1 + assert len(common.execute_and_fetch_all(userless_cursor, "MATCH (n) RETURN n")) == 1 def test_happy_case(kafka_producer, kafka_topics, connection): @@ -135,13 +129,12 @@ def test_happy_case(kafka_producer, kafka_topics, connection): stream_user = "stream_user" create_stream_user(userless_cursor, stream_user) stream_cursor = get_cursor_with_user(stream_user) - common.execute_and_fetch_all( - admin_cursor, f"GRANT CREATE TO {stream_user}") + common.execute_and_fetch_all(admin_cursor, f"GRANT CREATE TO {stream_user}") - common.execute_and_fetch_all(stream_cursor, - "CREATE KAFKA STREAM insufficient_test " - f"TOPICS {kafka_topics[0]} " - f"TRANSFORM kafka_transform.simple") + common.execute_and_fetch_all( + stream_cursor, + "CREATE KAFKA STREAM insufficient_test " f"TOPICS {kafka_topics[0]} " f"TRANSFORM kafka_transform.simple", + ) common.start_stream(stream_cursor, "insufficient_test") time.sleep(1) @@ -149,11 +142,9 @@ def test_happy_case(kafka_producer, kafka_topics, connection): first_message = b"first message" kafka_producer.send(kafka_topics[0], first_message).get(timeout=60) - common.kafka_check_vertex_exists_with_topic_and_payload( - userless_cursor, kafka_topics[0], first_message) + common.kafka_check_vertex_exists_with_topic_and_payload(userless_cursor, kafka_topics[0], first_message) - assert len(common.execute_and_fetch_all( - userless_cursor, "MATCH (n) RETURN n")) == 1 + assert len(common.execute_and_fetch_all(userless_cursor, "MATCH (n) RETURN n")) == 1 if __name__ == "__main__": diff --git a/tests/e2e/streams/transformations/common_transform.py b/tests/e2e/streams/transformations/common_transform.py index 25836c207..1621ec161 100644 --- a/tests/e2e/streams/transformations/common_transform.py +++ b/tests/e2e/streams/transformations/common_transform.py @@ -23,7 +23,10 @@ def check_stream_no_filtering( message = messages.message_at(i) payload_as_str = message.payload().decode("utf-8") result_queries.append( - mgp.Record(query=f"Message: {payload_as_str}", parameters={"value": f"Parameter: {payload_as_str}"}) + mgp.Record( + query=f"Message: {payload_as_str}", + parameters={"value": f"Parameter: {payload_as_str}"}, + ) ) return result_queries @@ -44,13 +47,17 @@ def check_stream_with_filtering( continue result_queries.append( - mgp.Record(query=f"Message: {payload_as_str}", parameters={"value": f"Parameter: {payload_as_str}"}) + mgp.Record( + query=f"Message: {payload_as_str}", + parameters={"value": f"Parameter: {payload_as_str}"}, + ) ) if "b" in payload_as_str: result_queries.append( mgp.Record( - query=f"Message: extra_{payload_as_str}", parameters={"value": f"Parameter: extra_{payload_as_str}"} + query=f"Message: extra_{payload_as_str}", + parameters={"value": f"Parameter: extra_{payload_as_str}"}, ) ) diff --git a/tests/e2e/streams/transformations/pulsar_transform.py b/tests/e2e/streams/transformations/pulsar_transform.py index d8560c2a7..08379149c 100644 --- a/tests/e2e/streams/transformations/pulsar_transform.py +++ b/tests/e2e/streams/transformations/pulsar_transform.py @@ -59,7 +59,9 @@ def with_parameters(context: mgp.TransCtx, messages: mgp.Messages) -> mgp.Record @mgp.transformation -def query(messages: mgp.Messages) -> mgp.Record(query=str, parameters=mgp.Nullable[mgp.Map]): +def query( + messages: mgp.Messages, +) -> mgp.Record(query=str, parameters=mgp.Nullable[mgp.Map]): result_queries = [] for i in range(0, messages.total_messages()): diff --git a/tests/e2e/triggers/procedures/write.py b/tests/e2e/triggers/procedures/write.py index 1b4b7ad99..8cb77dfd8 100644 --- a/tests/e2e/triggers/procedures/write.py +++ b/tests/e2e/triggers/procedures/write.py @@ -11,6 +11,7 @@ import mgp + @mgp.write_proc def create_vertex(ctx: mgp.ProcCtx, id: mgp.Any) -> mgp.Record(v=mgp.Any): v = None @@ -36,15 +37,14 @@ def detach_delete_vertex(ctx: mgp.ProcCtx, v: mgp.Any) -> mgp.Record(): @mgp.write_proc -def create_edge(ctx: mgp.ProcCtx, from_vertex: mgp.Vertex, - to_vertex: mgp.Vertex, - edge_type: str) -> mgp.Record(e=mgp.Any): +def create_edge( + ctx: mgp.ProcCtx, from_vertex: mgp.Vertex, to_vertex: mgp.Vertex, edge_type: str +) -> mgp.Record(e=mgp.Any): e = None try: - e = ctx.graph.create_edge( - from_vertex, to_vertex, mgp.EdgeType(edge_type)) - e.properties.set("id", 1); - e.properties.set("tbd", 0); + e = ctx.graph.create_edge(from_vertex, to_vertex, mgp.EdgeType(edge_type)) + e.properties.set("id", 1) + e.properties.set("tbd", 0) except RuntimeError as ex: return mgp.Record(e=str(ex)) return mgp.Record(e=e) @@ -61,19 +61,20 @@ def set_property(ctx: mgp.ProcCtx, object: mgp.Any) -> mgp.Record(): object.properties.set("id", 2) return mgp.Record() + @mgp.write_proc def remove_property(ctx: mgp.ProcCtx, object: mgp.Any) -> mgp.Record(): object.properties.set("tbd", None) return mgp.Record() + @mgp.write_proc -def add_label(ctx: mgp.ProcCtx, object: mgp.Any, - name: str) -> mgp.Record(o=mgp.Any): +def add_label(ctx: mgp.ProcCtx, object: mgp.Any, name: str) -> mgp.Record(o=mgp.Any): object.add_label(name) return mgp.Record(o=object) + @mgp.write_proc -def remove_label(ctx: mgp.ProcCtx, object: mgp.Any, - name: str) -> mgp.Record(o=mgp.Any): +def remove_label(ctx: mgp.ProcCtx, object: mgp.Any, name: str) -> mgp.Record(o=mgp.Any): object.remove_label(name) return mgp.Record(o=object) diff --git a/tests/e2e/write_procedures/common.py b/tests/e2e/write_procedures/common.py index 1ad71d350..4d2333c19 100644 --- a/tests/e2e/write_procedures/common.py +++ b/tests/e2e/write_procedures/common.py @@ -13,8 +13,7 @@ import mgclient import typing -def execute_and_fetch_all(cursor: mgclient.Cursor, query: str, - params: dict = {}) -> typing.List[tuple]: +def execute_and_fetch_all(cursor: mgclient.Cursor, query: str, params: dict = {}) -> typing.List[tuple]: cursor.execute(query, params) return cursor.fetchall() diff --git a/tests/e2e/write_procedures/procedures/read.py b/tests/e2e/write_procedures/procedures/read.py index 3f791a37a..0532fda1a 100644 --- a/tests/e2e/write_procedures/procedures/read.py +++ b/tests/e2e/write_procedures/procedures/read.py @@ -13,8 +13,7 @@ import mgp @mgp.read_proc -def underlying_graph_is_mutable(ctx: mgp.ProcCtx, - object: mgp.Any) -> mgp.Record(mutable=bool): +def underlying_graph_is_mutable(ctx: mgp.ProcCtx, object: mgp.Any) -> mgp.Record(mutable=bool): return mgp.Record(mutable=object.underlying_graph_is_mutable()) diff --git a/tests/e2e/write_procedures/procedures/write.py b/tests/e2e/write_procedures/procedures/write.py index e180be40e..3e5f6bc78 100644 --- a/tests/e2e/write_procedures/procedures/write.py +++ b/tests/e2e/write_procedures/procedures/write.py @@ -35,13 +35,12 @@ def detach_delete_vertex(ctx: mgp.ProcCtx, v: mgp.Any) -> mgp.Record(): @mgp.write_proc -def create_edge(ctx: mgp.ProcCtx, from_vertex: mgp.Vertex, - to_vertex: mgp.Vertex, - edge_type: str) -> mgp.Record(e=mgp.Any): +def create_edge( + ctx: mgp.ProcCtx, from_vertex: mgp.Vertex, to_vertex: mgp.Vertex, edge_type: str +) -> mgp.Record(e=mgp.Any): e = None try: - e = ctx.graph.create_edge( - from_vertex, to_vertex, mgp.EdgeType(edge_type)) + e = ctx.graph.create_edge(from_vertex, to_vertex, mgp.EdgeType(edge_type)) except RuntimeError as ex: return mgp.Record(e=str(ex)) return mgp.Record(e=e) @@ -54,29 +53,25 @@ def delete_edge(ctx: mgp.ProcCtx, edge: mgp.Edge) -> mgp.Record(): @mgp.write_proc -def set_property(ctx: mgp.ProcCtx, object: mgp.Any, - name: str, value: mgp.Nullable[mgp.Any]) -> mgp.Record(): +def set_property(ctx: mgp.ProcCtx, object: mgp.Any, name: str, value: mgp.Nullable[mgp.Any]) -> mgp.Record(): object.properties.set(name, value) return mgp.Record() @mgp.write_proc -def add_label(ctx: mgp.ProcCtx, object: mgp.Any, - name: str) -> mgp.Record(o=mgp.Any): +def add_label(ctx: mgp.ProcCtx, object: mgp.Any, name: str) -> mgp.Record(o=mgp.Any): object.add_label(name) return mgp.Record(o=object) @mgp.write_proc -def remove_label(ctx: mgp.ProcCtx, object: mgp.Any, - name: str) -> mgp.Record(o=mgp.Any): +def remove_label(ctx: mgp.ProcCtx, object: mgp.Any, name: str) -> mgp.Record(o=mgp.Any): object.remove_label(name) return mgp.Record(o=object) @mgp.write_proc -def underlying_graph_is_mutable(ctx: mgp.ProcCtx, - object: mgp.Any) -> mgp.Record(mutable=bool): +def underlying_graph_is_mutable(ctx: mgp.ProcCtx, object: mgp.Any) -> mgp.Record(mutable=bool): return mgp.Record(mutable=object.underlying_graph_is_mutable()) diff --git a/tests/e2e/write_procedures/simple_write.py b/tests/e2e/write_procedures/simple_write.py index 30a43f626..151266298 100644 --- a/tests/e2e/write_procedures/simple_write.py +++ b/tests/e2e/write_procedures/simple_write.py @@ -13,8 +13,7 @@ import typing import mgclient import sys import pytest -from common import (execute_and_fetch_all, - has_one_result_row, has_n_result_row) +from common import execute_and_fetch_all, has_one_result_row, has_n_result_row def test_is_write(connection): @@ -22,15 +21,19 @@ def test_is_write(connection): result_order = "name, signature, is_write" cursor = connection.cursor() for proc in execute_and_fetch_all( - cursor, "CALL mg.procedures() YIELD * WITH name, signature, " - "is_write WHERE name STARTS WITH 'write' " - f"RETURN {result_order}"): + cursor, + "CALL mg.procedures() YIELD * WITH name, signature, " + "is_write WHERE name STARTS WITH 'write' " + f"RETURN {result_order}", + ): assert proc[is_write] is True for proc in execute_and_fetch_all( - cursor, "CALL mg.procedures() YIELD * WITH name, signature, " - "is_write WHERE NOT name STARTS WITH 'write' " - f"RETURN {result_order}"): + cursor, + "CALL mg.procedures() YIELD * WITH name, signature, " + "is_write WHERE NOT name STARTS WITH 'write' " + f"RETURN {result_order}", + ): assert proc[is_write] is False assert cursor.description[0].name == "name" @@ -41,8 +44,7 @@ def test_is_write(connection): def test_single_vertex(connection): cursor = connection.cursor() assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0) - result = execute_and_fetch_all( - cursor, "CALL write.create_vertex() YIELD v RETURN v") + result = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v") vertex = result[0][0] assert isinstance(vertex, mgclient.Node) assert has_one_result_row(cursor, "MATCH (n) RETURN n") @@ -50,14 +52,13 @@ def test_single_vertex(connection): assert vertex.properties == {} def add_label(label: str): - execute_and_fetch_all( - cursor, f"MATCH (n) CALL write.add_label(n, '{label}') " - "YIELD * RETURN *") + execute_and_fetch_all(cursor, f"MATCH (n) CALL write.add_label(n, '{label}') " "YIELD * RETURN *") def remove_label(label: str): execute_and_fetch_all( - cursor, f"MATCH (n) CALL write.remove_label(n, '{label}') " - "YIELD * RETURN *") + cursor, + f"MATCH (n) CALL write.remove_label(n, '{label}') " "YIELD * RETURN *", + ) def get_vertex() -> mgclient.Node: return execute_and_fetch_all(cursor, "MATCH (n) RETURN n")[0][0] @@ -65,8 +66,10 @@ def test_single_vertex(connection): def set_property(property_name: str, property: typing.Any): nonlocal cursor execute_and_fetch_all( - cursor, f"MATCH (n) CALL write.set_property(n, '{property_name}', " - "$property) YIELD * RETURN *", {"property": property}) + cursor, + f"MATCH (n) CALL write.set_property(n, '{property_name}', " "$property) YIELD * RETURN *", + {"property": property}, + ) label_1 = "LABEL1" label_2 = "LABEL2" @@ -89,24 +92,23 @@ def test_single_vertex(connection): set_property(property_name, None) assert get_vertex().properties == {} - execute_and_fetch_all( - cursor, "MATCH (n) CALL write.delete_vertex(n) YIELD * RETURN 1") + execute_and_fetch_all(cursor, "MATCH (n) CALL write.delete_vertex(n) YIELD * RETURN 1") assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0) def test_single_edge(connection): cursor = connection.cursor() assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0) - v1_id = execute_and_fetch_all( - cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id - v2_id = execute_and_fetch_all( - cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id + v1_id = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id + v2_id = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id edge_type = "EDGE" edge = execute_and_fetch_all( - cursor, f"MATCH (n) WHERE id(n) = {v1_id} " - f"MATCH (m) WHERE id(m) = {v2_id} " - f"CALL write.create_edge(n, m, '{edge_type}') " - "YIELD e RETURN e")[0][0] + cursor, + f"MATCH (n) WHERE id(n) = {v1_id} " + f"MATCH (m) WHERE id(m) = {v2_id} " + f"CALL write.create_edge(n, m, '{edge_type}') " + "YIELD e RETURN e", + )[0][0] assert edge.type == edge_type assert edge.properties == {} @@ -120,9 +122,10 @@ def test_single_edge(connection): def set_property(property_name: str, property: typing.Any): nonlocal cursor execute_and_fetch_all( - cursor, "MATCH ()-[e]->() " - f"CALL write.set_property(e, '{property_name}', " - "$property) YIELD * RETURN *", {"property": property}) + cursor, + "MATCH ()-[e]->() " f"CALL write.set_property(e, '{property_name}', " "$property) YIELD * RETURN *", + {"property": property}, + ) set_property(property_name, property_value_1) assert get_edge().properties == {property_name: property_value_1} @@ -130,60 +133,68 @@ def test_single_edge(connection): assert get_edge().properties == {property_name: property_value_2} set_property(property_name, None) assert get_edge().properties == {} - execute_and_fetch_all( - cursor, "MATCH ()-[e]->() CALL write.delete_edge(e) YIELD * RETURN 1") + execute_and_fetch_all(cursor, "MATCH ()-[e]->() CALL write.delete_edge(e) YIELD * RETURN 1") assert has_n_result_row(cursor, "MATCH ()-[e]->() RETURN e", 0) def test_detach_delete_vertex(connection): cursor = connection.cursor() assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0) - v1_id = execute_and_fetch_all( - cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id - v2_id = execute_and_fetch_all( - cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id + v1_id = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id + v2_id = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id execute_and_fetch_all( - cursor, f"MATCH (n) WHERE id(n) = {v1_id} " + cursor, + f"MATCH (n) WHERE id(n) = {v1_id} " f"MATCH (m) WHERE id(m) = {v2_id} " f"CALL write.create_edge(n, m, 'EDGE') " - "YIELD e RETURN e") + "YIELD e RETURN e", + ) assert has_one_result_row(cursor, "MATCH (n)-[e]->(m) RETURN n, e, m") execute_and_fetch_all( - cursor, f"MATCH (n) WHERE id(n) = {v1_id} " - "CALL write.detach_delete_vertex(n) YIELD * RETURN 1") + cursor, + f"MATCH (n) WHERE id(n) = {v1_id} " "CALL write.detach_delete_vertex(n) YIELD * RETURN 1", + ) assert has_n_result_row(cursor, "MATCH (n)-[e]->(m) RETURN n, e, m", 0) assert has_n_result_row(cursor, "MATCH ()-[e]->() RETURN e", 0) - assert has_one_result_row( - cursor, f"MATCH (n) WHERE id(n) = {v2_id} RETURN n") + assert has_one_result_row(cursor, f"MATCH (n) WHERE id(n) = {v2_id} RETURN n") def test_graph_mutability(connection): cursor = connection.cursor() assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0) - v1_id = execute_and_fetch_all( - cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id - v2_id = execute_and_fetch_all( - cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id + v1_id = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id + v2_id = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id execute_and_fetch_all( - cursor, f"MATCH (n) WHERE id(n) = {v1_id} " + cursor, + f"MATCH (n) WHERE id(n) = {v1_id} " f"MATCH (m) WHERE id(m) = {v2_id} " f"CALL write.create_edge(n, m, 'EDGE') " - "YIELD e RETURN e") + "YIELD e RETURN e", + ) def test_mutability(is_write: bool): module = "write" if is_write else "read" - assert execute_and_fetch_all( - cursor, f"CALL {module}.graph_is_mutable() " - "YIELD mutable RETURN mutable")[0][0] is is_write - assert execute_and_fetch_all( - cursor, "MATCH (n) " - f"CALL {module}.underlying_graph_is_mutable(n) " - "YIELD mutable RETURN mutable")[0][0] is is_write - assert execute_and_fetch_all( - cursor, "MATCH (n)-[e]->(m) " - f"CALL {module}.underlying_graph_is_mutable(e) " - "YIELD mutable RETURN mutable")[0][0] is is_write + assert ( + execute_and_fetch_all(cursor, f"CALL {module}.graph_is_mutable() " "YIELD mutable RETURN mutable",)[ + 0 + ][0] + is is_write + ) + assert ( + execute_and_fetch_all( + cursor, + "MATCH (n) " f"CALL {module}.underlying_graph_is_mutable(n) " "YIELD mutable RETURN mutable", + )[0][0] + is is_write + ) + assert ( + execute_and_fetch_all( + cursor, + "MATCH (n)-[e]->(m) " f"CALL {module}.underlying_graph_is_mutable(e) " "YIELD mutable RETURN mutable", + )[0][0] + is is_write + ) test_mutability(True) test_mutability(False) diff --git a/tests/gql_behave/environment.py b/tests/gql_behave/environment.py index 0c0122850..50507d974 100644 --- a/tests/gql_behave/environment.py +++ b/tests/gql_behave/environment.py @@ -20,6 +20,7 @@ from neo4j import GraphDatabase, basic_auth # Helper class and functions + class TestResults: def __init__(self): self.total = 0 @@ -39,18 +40,16 @@ class TestResults: # Behave specific functions + def before_all(context): # logging logging.basicConfig(level="DEBUG") context.log = logging.getLogger(__name__) # driver - uri = "bolt://{}:{}".format(context.config.db_host, - context.config.db_port) - auth_token = basic_auth( - context.config.db_user, context.config.db_pass) - context.driver = GraphDatabase.driver(uri, auth=auth_token, - encrypted=False) + uri = "bolt://{}:{}".format(context.config.db_host, context.config.db_port) + auth_token = basic_auth(context.config.db_user, context.config.db_pass) + context.driver = GraphDatabase.driver(uri, auth=auth_token, encrypted=False) # test results context.test_results = TestResults() @@ -63,8 +62,7 @@ def before_scenario(context, scenario): def after_scenario(context, scenario): context.test_results.add_test(scenario.status) - if context.config.single_scenario or \ - (context.config.single_fail and scenario.status == "failed"): + if context.config.single_scenario or (context.config.single_fail and scenario.status == "failed"): print("Press enter to continue") sys.stdin.readline() @@ -87,5 +85,5 @@ def after_all(context): "test_suite": context.config.test_suite, } - with open(context.config.stats_file, 'w') as f: + with open(context.config.stats_file, "w") as f: json.dump(js, f) diff --git a/tests/gql_behave/run.py b/tests/gql_behave/run.py index 50035384a..815c7b638 100755 --- a/tests/gql_behave/run.py +++ b/tests/gql_behave/run.py @@ -55,22 +55,14 @@ def main(): add_config("--test-directory") # Arguments that should be passed on to Behave - add_argument("--db-host", default="127.0.0.1", - help="server host (default is 127.0.0.1)") - add_argument("--db-port", default="7687", - help="server port (default is 7687)") - add_argument("--db-user", default="memgraph", - help="server user (default is memgraph)") - add_argument("--db-pass", default="memgraph", - help="server pass (default is memgraph)") - add_argument("--stop", action="store_true", - help="stop testing after first fail") - add_argument("--single-fail", action="store_true", - help="pause after failed scenario") - add_argument("--single-scenario", action="store_true", - help="pause after every scenario") - add_argument("--single-feature", action="store_true", - help="pause after every feature") + add_argument("--db-host", default="127.0.0.1", help="server host (default is 127.0.0.1)") + add_argument("--db-port", default="7687", help="server port (default is 7687)") + add_argument("--db-user", default="memgraph", help="server user (default is memgraph)") + add_argument("--db-pass", default="memgraph", help="server pass (default is memgraph)") + add_argument("--stop", action="store_true", help="stop testing after first fail") + add_argument("--single-fail", action="store_true", help="pause after failed scenario") + add_argument("--single-scenario", action="store_true", help="pause after every scenario") + add_argument("--single-feature", action="store_true", help="pause after every feature") add_argument("--stats-file", default="", help="statistics output file") # Parse arguments @@ -96,5 +88,5 @@ def main(): return behave_main(behave_args) -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/tests/gql_behave/steps/binary_tree.py b/tests/gql_behave/steps/binary_tree.py index 5f35b1e6e..6f5041934 100644 --- a/tests/gql_behave/steps/binary_tree.py +++ b/tests/gql_behave/steps/binary_tree.py @@ -15,11 +15,11 @@ from behave import given import graph -@given(u'the binary-tree-1 graph') +@given("the binary-tree-1 graph") def step_impl(context): - graph.create_graph('binary-tree-1', context) + graph.create_graph("binary-tree-1", context) -@given(u'the binary-tree-2 graph') +@given("the binary-tree-2 graph") def step_impl(context): - graph.create_graph('binary-tree-2', context) + graph.create_graph("binary-tree-2", context) diff --git a/tests/gql_behave/steps/database.py b/tests/gql_behave/steps/database.py index 339e706ee..c63beb211 100644 --- a/tests/gql_behave/steps/database.py +++ b/tests/gql_behave/steps/database.py @@ -11,6 +11,7 @@ # -*- coding: utf-8 -*- + def query(q, context, params={}): """ Function used to execute query on database. Query results are @@ -44,7 +45,7 @@ def query(q, context, params={}): except Exception as e: # exception context.exception = e - context.log.info('%s', str(e)) + context.log.info("%s", str(e)) finally: session.close() diff --git a/tests/gql_behave/steps/errors.py b/tests/gql_behave/steps/errors.py index 94412c54c..a2484dbdb 100644 --- a/tests/gql_behave/steps/errors.py +++ b/tests/gql_behave/steps/errors.py @@ -24,234 +24,234 @@ def handle_error(context): @param context: behave.runner.Context, context of behave. """ - assert(context.exception is not None) + assert context.exception is not None -@then('an error should be raised') +@then("an error should be raised") def error(context): handle_error(context) -@then('a SyntaxError should be raised at compile time: NestedAggregation') +@then("a SyntaxError should be raised at compile time: NestedAggregation") def syntax_error(context): handle_error(context) -@then('TypeError should be raised at compile time: IncomparableValues') +@then("TypeError should be raised at compile time: IncomparableValues") def type_error(context): handle_error(context) -@then(u'a TypeError should be raised at compile time: IncomparableValues') +@then("a TypeError should be raised at compile time: IncomparableValues") def step(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: RequiresDirectedRelationship') +@then("a SyntaxError should be raised at compile time: RequiresDirectedRelationship") def step(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: InvalidRelationshipPattern') +@then("a SyntaxError should be raised at compile time: InvalidRelationshipPattern") def syntax_error(context): handle_error(context) -@then(u'a TypeError should be raised at runtime: MapElementAccessByNonString') +@then("a TypeError should be raised at runtime: MapElementAccessByNonString") def type_error(context): handle_error(context) -@then(u'a ConstraintVerificationFailed should be raised at runtime: DeleteConnectedNode') +@then("a ConstraintVerificationFailed should be raised at runtime: DeleteConnectedNode") def step(context): handle_error(context) -@then(u'a TypeError should be raised at runtime: ListElementAccessByNonInteger') +@then("a TypeError should be raised at runtime: ListElementAccessByNonInteger") def step(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: InvalidArgumentType') +@then("a SyntaxError should be raised at compile time: InvalidArgumentType") def step(context): handle_error(context) -@then(u'a TypeError should be raised at runtime: InvalidElementAccess') +@then("a TypeError should be raised at runtime: InvalidElementAccess") def step(context): handle_error(context) -@then(u'a ArgumentError should be raised at runtime: NumberOutOfRange') +@then("a ArgumentError should be raised at runtime: NumberOutOfRange") def step(context): handle_error(context) -@then(u'a TypeError should be raised at runtime: InvalidArgumentValue') +@then("a TypeError should be raised at runtime: InvalidArgumentValue") def step(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: VariableAlreadyBound') +@then("a SyntaxError should be raised at compile time: VariableAlreadyBound") def step(context): handle_error(context) -@then(u'a TypeError should be raised at runtime: IncomparableValues') +@then("a TypeError should be raised at runtime: IncomparableValues") def step(context): handle_error(context) -@then(u'a TypeError should be raised at runtime: PropertyAccessOnNonMap') +@then("a TypeError should be raised at runtime: PropertyAccessOnNonMap") def step(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: InvalidUnicodeLiteral') +@then("a SyntaxError should be raised at compile time: InvalidUnicodeLiteral") def step(context): handle_error(context) -@then(u'a SemanticError should be raised at compile time: MergeReadOwnWrites') +@then("a SemanticError should be raised at compile time: MergeReadOwnWrites") def step(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: InvalidAggregation') +@then("a SyntaxError should be raised at compile time: InvalidAggregation") def step(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: NoExpressionAlias') +@then("a SyntaxError should be raised at compile time: NoExpressionAlias") def step(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: UndefinedVariable') +@then("a SyntaxError should be raised at compile time: UndefinedVariable") def step(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: VariableTypeConflict') +@then("a SyntaxError should be raised at compile time: VariableTypeConflict") def step(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: DifferentColumnsInUnion') +@then("a SyntaxError should be raised at compile time: DifferentColumnsInUnion") def step(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: InvalidClauseComposition') +@then("a SyntaxError should be raised at compile time: InvalidClauseComposition") def step(context): handle_error(context) -@then(u'a TypeError should be raised at compile time: InvalidPropertyType') +@then("a TypeError should be raised at compile time: InvalidPropertyType") def step(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: ColumnNameConflict') +@then("a SyntaxError should be raised at compile time: ColumnNameConflict") def step(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: NoVariablesInScope') +@then("a SyntaxError should be raised at compile time: NoVariablesInScope") def step(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: InvalidDelete') +@then("a SyntaxError should be raised at compile time: InvalidDelete") def step(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: NegativeIntegerArgument') +@then("a SyntaxError should be raised at compile time: NegativeIntegerArgument") def step(context): handle_error(context) -@then(u'a EntityNotFound should be raised at runtime: DeletedEntityAccess') +@then("a EntityNotFound should be raised at runtime: DeletedEntityAccess") def step(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: RelationshipUniquenessViolation') +@then("a SyntaxError should be raised at compile time: RelationshipUniquenessViolation") def step(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: CreatingVarLength') +@then("a SyntaxError should be raised at compile time: CreatingVarLength") def step_impl(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: InvalidParameterUse') +@then("a SyntaxError should be raised at compile time: InvalidParameterUse") def step_impl(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: FloatingPointOverflow') +@then("a SyntaxError should be raised at compile time: FloatingPointOverflow") def step_impl(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time InvalidArgumentExpression') +@then("a SyntaxError should be raised at compile time InvalidArgumentExpression") def step_impl(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time InvalidUnicodeCharacter') +@then("a SyntaxError should be raised at compile time InvalidUnicodeCharacter") def step_impl(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: NonConstantExpression') +@then("a SyntaxError should be raised at compile time: NonConstantExpression") def step_impl(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: NoSingleRelationshipType') +@then("a SyntaxError should be raised at compile time: NoSingleRelationshipType") def step_impl(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: UnknownFunction') +@then("a SyntaxError should be raised at compile time: UnknownFunction") def step_impl(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: InvalidNumberLiteral') +@then("a SyntaxError should be raised at compile time: InvalidNumberLiteral") def step_impl(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: InvalidArgumentExpression') +@then("a SyntaxError should be raised at compile time: InvalidArgumentExpression") def step(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: InvalidUnicodeCharacter') +@then("a SyntaxError should be raised at compile time: InvalidUnicodeCharacter") def step(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: InvalidArgumentPassingMode') +@then("a SyntaxError should be raised at compile time: InvalidArgumentPassingMode") def step_impl(context): handle_error(context) -@then(u'a SyntaxError should be raised at compile time: InvalidNumberOfArguments') +@then("a SyntaxError should be raised at compile time: InvalidNumberOfArguments") def step_impl(context): handle_error(context) -@then(u'a ParameterMissing should be raised at compile time: MissingParameter') +@then("a ParameterMissing should be raised at compile time: MissingParameter") def step_impl(context): handle_error(context) -@then(u'a ProcedureError should be raised at compile time: ProcedureNotFound') +@then("a ProcedureError should be raised at compile time: ProcedureNotFound") def step_impl(context): handle_error(context) diff --git a/tests/gql_behave/steps/graph.py b/tests/gql_behave/steps/graph.py index c0c5c07ce..6ce0d5d4e 100644 --- a/tests/gql_behave/steps/graph.py +++ b/tests/gql_behave/steps/graph.py @@ -23,12 +23,12 @@ def clear_graph(context): database.query("MATCH (n) DETACH DELETE n", context) -@given('an empty graph') +@given("an empty graph") def empty_graph_step(context): clear_graph(context) -@given('any graph') +@given("any graph") def any_graph_step(context): clear_graph(context) @@ -46,20 +46,18 @@ def create_graph(name, context): and sets graph properties to beginning values. """ clear_graph(context) - path = os.path.join(context.config.test_directory, "graphs", - name + ".cypher") + path = os.path.join(context.config.test_directory, "graphs", name + ".cypher") - q_marks = ["'", '"', '`'] + q_marks = ["'", '"', "`"] - with open(path, 'r') as f: - content = f.read().replace('\n', ' ') - single_query = '' + with open(path, "r") as f: + content = f.read().replace("\n", " ") + single_query = "" quote = None i = 0 while i < len(content): ch = content[i] - if ch == '\\' and i != len(content) - 1 and \ - content[i + 1] in q_marks: + if ch == "\\" and i != len(content) - 1 and content[i + 1] in q_marks: single_query += ch + content[i + 1] i += 2 else: @@ -68,9 +66,9 @@ def create_graph(name, context): quote = None elif ch in q_marks and quote is None: quote = ch - if ch == ';' and quote is None: + if ch == ";" and quote is None: database.query(single_query, context) - single_query = '' + single_query = "" i += 1 - if single_query.strip() != '': + if single_query.strip() != "": database.query(single_query, context) diff --git a/tests/gql_behave/steps/parser.py b/tests/gql_behave/steps/parser.py index 202d47741..9e64e10b7 100644 --- a/tests/gql_behave/steps/parser.py +++ b/tests/gql_behave/steps/parser.py @@ -29,13 +29,13 @@ def parse(el, ignore_order): @return: Parsed string of element. """ - if el.startswith('(') and el.endswith(')'): + if el.startswith("(") and el.endswith(")"): return parse_node(el, ignore_order) - if el.startswith('<') and el.endswith('>'): + if el.startswith("<") and el.endswith(">"): return parse_path(el, ignore_order) - if el.startswith('{') and el.endswith('}'): + if el.startswith("{") and el.endswith("}"): return parse_map(el, ignore_order) - if el.startswith('[') and el.endswith(']'): + if el.startswith("[") and el.endswith("]"): if is_list(el): return parse_list(el, ignore_order) else: @@ -51,7 +51,7 @@ def is_list(el): @return: true if el is list. """ - if el[1] == ':': + if el[1] == ":": return False return True @@ -64,20 +64,20 @@ def parse_path(path, ignore_order): @return: parsed path """ - parsed_path = '<' + parsed_path = "<" dif_open_closed_brackets = 0 for i in range(1, len(path) - 1): - if path[i] == '(' or path[i] == '{' or path[i] == '[': + if path[i] == "(" or path[i] == "{" or path[i] == "[": dif_open_closed_brackets += 1 if dif_open_closed_brackets == 1: start = i - if path[i] == ')' or path[i] == '}' or path[i] == ']': + if path[i] == ")" or path[i] == "}" or path[i] == "]": dif_open_closed_brackets -= 1 if dif_open_closed_brackets == 0: - parsed_path += parse(path[start:(i + 1)], ignore_order) + parsed_path += parse(path[start : (i + 1)], ignore_order) elif dif_open_closed_brackets == 0: parsed_path += path[i] - parsed_path += '>' + parsed_path += ">" return parsed_path @@ -89,28 +89,27 @@ def parse_node(node_str, ignore_order): @return: parsed node """ - label = '' + label = "" labels = [] props_start = None for i in range(1, len(node_str)): - if node_str[i] == ':' or node_str[i] == ')' or node_str[i] == '{': - if label.startswith(':'): + if node_str[i] == ":" or node_str[i] == ")" or node_str[i] == "{": + if label.startswith(":"): labels.append(label) - label = '' + label = "" label += node_str[i] - if node_str[i] == '{': + if node_str[i] == "{": props_start = i break labels.sort() - parsed_node = '(' + parsed_node = "(" for label in labels: parsed_node += label if props_start is not None: - parsed_node += parse_map( - node_str[props_start:len(node_str) - 1], ignore_order) - parsed_node += ')' + parsed_node += parse_map(node_str[props_start : len(node_str) - 1], ignore_order) + parsed_node += ")" return parsed_node @@ -123,23 +122,23 @@ def parse_map(props, ignore_order): parsed map """ dif_open_closed_brackets = 0 - prop = '' + prop = "" list_props = [] for i in range(1, len(props) - 1): - if props[i] == ',' and dif_open_closed_brackets == 0: + if props[i] == "," and dif_open_closed_brackets == 0: list_props.append(prop_to_str(prop, ignore_order)) - prop = '' + prop = "" else: prop += props[i] - if props[i] == '(' or props[i] == '{' or props[i] == '[': + if props[i] == "(" or props[i] == "{" or props[i] == "[": dif_open_closed_brackets += 1 - elif props[i] == ')' or props[i] == '}' or props[i] == ']': + elif props[i] == ")" or props[i] == "}" or props[i] == "]": dif_open_closed_brackets -= 1 - if prop != '': + if prop != "": list_props.append(prop_to_str(prop, ignore_order)) list_props.sort() - return '{' + ','.join(list_props) + '}' + return "{" + ",".join(list_props) + "}" def prop_to_str(prop, ignore_order): @@ -152,8 +151,8 @@ def prop_to_str(prop, ignore_order): @return: parsed prop """ - key = prop.split(':', 1)[0] - val = prop.split(':', 1)[1] + key = prop.split(":", 1)[0] + val = prop.split(":", 1)[1] return key + ":" + parse(val, ignore_order) @@ -166,25 +165,25 @@ def parse_list(l, ignore_order): parsed list """ dif_open_closed_brackets = 0 - el = '' + el = "" list_el = [] for i in range(1, len(l) - 1): - if l[i] == ',' and dif_open_closed_brackets == 0: + if l[i] == "," and dif_open_closed_brackets == 0: list_el.append(parse(el, ignore_order)) - el = '' + el = "" else: el += l[i] - if l[i] == '(' or l[i] == '{' or l[i] == '[': + if l[i] == "(" or l[i] == "{" or l[i] == "[": dif_open_closed_brackets += 1 - elif l[i] == ')' or l[i] == '}' or l[i] == ']': + elif l[i] == ")" or l[i] == "}" or l[i] == "]": dif_open_closed_brackets -= 1 - if el != '': + if el != "": list_el.append(parse(el, ignore_order)) if ignore_order: list_el.sort() - return '[' + ','.join(list_el) + ']' + return "[" + ",".join(list_el) + "]" def parse_rel(rel, ignore_order): @@ -195,25 +194,25 @@ def parse_rel(rel, ignore_order): @return: parsed relationship """ - label = '' + label = "" labels = [] props_start = None for i in range(1, len(rel)): - if rel[i] == ':' or rel[i] == ']' or rel[i] == '{': - if label.startswith(':'): + if rel[i] == ":" or rel[i] == "]" or rel[i] == "{": + if label.startswith(":"): labels.append(label) - label = '' + label = "" label += rel[i] - if rel[i] == '{': + if rel[i] == "{": props_start = i break labels.sort() - parsed_rel = '[' + parsed_rel = "[" for label in labels: parsed_rel += label if props_start is not None: - parsed_rel += parse_map(rel[props_start:len(rel) - 1], ignore_order) - parsed_rel += ']' + parsed_rel += parse_map(rel[props_start : len(rel) - 1], ignore_order) + parsed_rel += "]" return parsed_rel diff --git a/tests/gql_behave/steps/query.py b/tests/gql_behave/steps/query.py index 5a050d4df..00549537b 100644 --- a/tests/gql_behave/steps/query.py +++ b/tests/gql_behave/steps/query.py @@ -17,32 +17,29 @@ from behave import given, then, step, when from neo4j.graph import Node, Path, Relationship -@given('parameters are') +@given("parameters are") def parameters_step(context): context.test_parameters.set_parameters_from_table(context.table) -@then('parameters are') +@then("parameters are") def parameters_step(context): context.test_parameters.set_parameters_from_table(context.table) -@step('having executed') +@step("having executed") def having_executed_step(context): - context.results = database.query( - context.text, context, context.test_parameters.get_parameters()) + context.results = database.query(context.text, context, context.test_parameters.get_parameters()) -@when('executing query') +@when("executing query") def executing_query_step(context): - context.results = database.query( - context.text, context, context.test_parameters.get_parameters()) + context.results = database.query(context.text, context, context.test_parameters.get_parameters()) -@when('executing control query') +@when("executing control query") def executing_query_step(context): - context.results = database.query( - context.text, context, context.test_parameters.get_parameters()) + context.results = database.query(context.text, context, context.test_parameters.get_parameters()) def parse_props(props_key_value): @@ -93,11 +90,11 @@ def to_string(element): # parsing Node sol = "(" if element.labels: - sol += ':' + ': '.join(element.labels) + sol += ":" + ": ".join(element.labels) if element.keys(): if element.labels: - sol += ' ' + sol += " " sol += parse_props(element.items()) sol += ")" @@ -109,7 +106,7 @@ def to_string(element): if element.type: sol += element.type if element.keys(): - sol += ' ' + sol += " " sol += parse_props(element.items()) sol += "]" return sol @@ -144,12 +141,12 @@ def to_string(element): elif isinstance(element, list): # parsing list - sol = '[' + sol = "[" el_str = [] for el in element: el_str.append(to_string(el)) - sol += ', '.join(el_str) - sol += ']' + sol += ", ".join(el_str) + sol += "]" return sol @@ -162,23 +159,22 @@ def to_string(element): elif isinstance(element, dict): # parsing map if len(element) == 0: - return '{}' - sol = '{' + return "{}" + sol = "{" for key, val in element.items(): - sol += key + ':' + to_string(val) + ',' - sol = sol[:-1] + '}' + sol += key + ":" + to_string(val) + "," + sol = sol[:-1] + "}" return sol elif isinstance(element, float): # parsing float, scientific - if 'e' in str(element): - if str(element)[-3] == '-': + if "e" in str(element): + if str(element)[-3] == "-": zeroes = int(str(element)[-2:]) - 1 - num_str = '' - if str(element)[0] == '-': - num_str += '-' - num_str += '.' + zeroes * '0' + \ - str(element)[:-4].replace("-", "").replace(".", "") + num_str = "" + if str(element)[0] == "-": + num_str += "-" + num_str += "." + zeroes * "0" + str(element)[:-4].replace("-", "").replace(".", "") return num_str return str(element) @@ -201,9 +197,14 @@ def get_result_rows(context, ignore_order): keys = result.keys() values = result.values() for i in range(0, len(keys)): - result_rows.append(keys[i] + ":" + parser.parse( - to_string(values[i]).replace("\n", "\\n").replace(" ", ""), - ignore_order)) + result_rows.append( + keys[i] + + ":" + + parser.parse( + to_string(values[i]).replace("\n", "\\n").replace(" ", ""), + ignore_order, + ) + ) return result_rows @@ -221,9 +222,7 @@ def get_expected_rows(context, ignore_order): expected_rows = [] for row in context.table: for col in context.table.headings: - expected_rows.append( - col + ":" + parser.parse(row[col].replace(" ", ""), - ignore_order)) + expected_rows.append(col + ":" + parser.parse(row[col].replace(" ", ""), ignore_order)) return expected_rows @@ -242,13 +241,13 @@ def validate(context, ignore_order): context.log.info("Expected: %s", str(expected_rows)) context.log.info("Results: %s", str(result_rows)) - assert(len(expected_rows) == len(result_rows)) + assert len(expected_rows) == len(result_rows) for i in range(0, len(expected_rows)): if expected_rows[i] in result_rows: result_rows.remove(expected_rows[i]) else: - assert(False) + assert False def validate_in_order(context, ignore_order): @@ -267,26 +266,26 @@ def validate_in_order(context, ignore_order): context.log.info("Expected: %s", str(expected_rows)) context.log.info("Results: %s", str(result_rows)) - assert(len(expected_rows) == len(result_rows)) + assert len(expected_rows) == len(result_rows) for i in range(0, len(expected_rows)): if expected_rows[i] != result_rows[i]: - assert(False) + assert False -@then('the result should be') +@then("the result should be") def expected_result_step(context): validate(context, False) check_exception(context) -@then('the result should be, in order') +@then("the result should be, in order") def expected_result_step(context): validate_in_order(context, False) check_exception(context) -@then('the result should be (ignoring element order for lists)') +@then("the result should be (ignoring element order for lists)") def expected_result_step(context): validate(context, True) check_exception(context) @@ -295,20 +294,20 @@ def expected_result_step(context): def check_exception(context): if context.exception is not None: context.log.info("Exception when executing query!") - assert(False) + assert False -@then('the result should be empty') +@then("the result should be empty") def empty_result_step(context): - assert(len(context.results) == 0) + assert len(context.results) == 0 check_exception(context) -@then('the side effects should be') +@then("the side effects should be") def side_effects_step(context): return -@then('no side effects') +@then("no side effects") def side_effects_step(context): return diff --git a/tests/gql_behave/steps/test_parameters.py b/tests/gql_behave/steps/test_parameters.py index ee92d9ff3..30e81c1fc 100644 --- a/tests/gql_behave/steps/test_parameters.py +++ b/tests/gql_behave/steps/test_parameters.py @@ -40,15 +40,15 @@ class TestParameters: par = dict() for row in table: par[row[0]] = self.parse_parameters(row[1]) - if isinstance(par[row[0]], str) and par[row[0]].startswith("'") \ - and par[row[0]].endswith("'"): - par[row[0]] = par[row[0]][1:len(par[row[0]]) - 1] + if isinstance(par[row[0]], str) and par[row[0]].startswith("'") and par[row[0]].endswith("'"): + par[row[0]] = par[row[0]][1 : len(par[row[0]]) - 1] par[table.headings[0]] = self.parse_parameters(table.headings[1]) - if isinstance(par[table.headings[0]], str) and \ - par[table.headings[0]].startswith("'") and \ - par[table.headings[0]].endswith("'"): - par[table.headings[0]] = \ - par[table.headings[0]][1:len(par[table.headings[0]]) - 1] + if ( + isinstance(par[table.headings[0]], str) + and par[table.headings[0]].startswith("'") + and par[table.headings[0]].endswith("'") + ): + par[table.headings[0]] = par[table.headings[0]][1 : len(par[table.headings[0]]) - 1] self.parameters = par diff --git a/tests/integration/audit/runner.py b/tests/integration/audit/runner.py index 92e175a1b..46c6b4e72 100755 --- a/tests/integration/audit/runner.py +++ b/tests/integration/audit/runner.py @@ -37,20 +37,14 @@ QUERIES = [ ("CREATE (n {name: $name})", {"name": 5, "leftover": 42}), ("MATCH (n), (m) CREATE (n)-[:e {when: $when}]->(m)", {"when": 42}), ("MATCH (n) RETURN n", {}), - ( - "MATCH (n), (m {type: $type}) RETURN count(n), count(m)", - {"type": "dadada"} - ), + ("MATCH (n), (m {type: $type}) RETURN count(n), count(m)", {"type": "dadada"}), ( "MERGE (n) ON CREATE SET n.created = timestamp() " "ON MATCH SET n.lastSeen = timestamp() " "RETURN n.name, n.created, n.lastSeen", - {} - ), - ( - "MATCH (n {value: $value}) SET n.value = 0 RETURN n", - {"value": "nandare!"} + {}, ), + ("MATCH (n {value: $value}) SET n.value = 0 RETURN n", {"value": "nandare!"}), ("MATCH (n), (m) SET n.value = m.value", {}), ("MATCH (n {test: $test}) REMOVE n.value", {"test": 48}), ("MATCH (n), (m) REMOVE n.value, m.value", {}), @@ -74,7 +68,8 @@ def execute_test(memgraph_binary, tester_binary): storage_directory.name, "--audit-enabled", "--log-file=memgraph.log", - "--log-level=TRACE"] + "--log-level=TRACE", + ] # Start the memgraph binary memgraph = subprocess.Popen(list(map(str, memgraph_args))) @@ -92,8 +87,13 @@ def execute_test(memgraph_binary, tester_binary): def execute_queries(queries): for query, params in queries: print(query, params) - args = [tester_binary, "--query", query, - "--params-json", json.dumps(params)] + args = [ + tester_binary, + "--query", + query, + "--params-json", + json.dumps(params), + ] subprocess.run(args).check_returncode() # Execute all queries @@ -109,10 +109,17 @@ def execute_test(memgraph_binary, tester_binary): # Verify the written log print("\033[1;36m~~ Starting log verification ~~\033[0m") with open(os.path.join(storage_directory.name, "audit", "audit.log")) as f: - reader = csv.reader(f, delimiter=',', doublequote=False, - escapechar='\\', lineterminator='\n', - quotechar='"', quoting=csv.QUOTE_MINIMAL, - skipinitialspace=False, strict=True) + reader = csv.reader( + f, + delimiter=",", + doublequote=False, + escapechar="\\", + lineterminator="\n", + quotechar='"', + quoting=csv.QUOTE_MINIMAL, + skipinitialspace=False, + strict=True, + ) queries = [] for line in reader: timestamp, address, username, query, params = line @@ -120,15 +127,13 @@ def execute_test(memgraph_binary, tester_binary): queries.append((query, params)) print(query, params) - assert queries == QUERIES, "Logged queries don't match " \ - "executed queries!" + assert queries == QUERIES, "Logged queries don't match " "executed queries!" print("\033[1;36m~~ Finished log verification ~~\033[0m\n") if __name__ == "__main__": memgraph_binary = os.path.join(PROJECT_DIR, "build", "memgraph") - tester_binary = os.path.join(PROJECT_DIR, "build", "tests", - "integration", "audit", "tester") + tester_binary = os.path.join(PROJECT_DIR, "build", "tests", "integration", "audit", "tester") parser = argparse.ArgumentParser() parser.add_argument("--memgraph", default=memgraph_binary) diff --git a/tests/integration/auth/runner.py b/tests/integration/auth/runner.py index 7953bc1ee..bb2bd7650 100755 --- a/tests/integration/auth/runner.py +++ b/tests/integration/auth/runner.py @@ -29,15 +29,8 @@ PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", "..")) QUERIES = [ # CREATE - ( - "CREATE (n)", - ("CREATE",) - ), - ( - "MATCH (n), (m) CREATE (n)-[:e]->(m)", - ("CREATE", "MATCH") - ), - + ("CREATE (n)", ("CREATE",)), + ("MATCH (n), (m) CREATE (n)-[:e]->(m)", ("CREATE", "MATCH")), # DELETE ( "MATCH (n) DELETE n", @@ -47,116 +40,43 @@ QUERIES = [ "MATCH (n) DETACH DELETE n", ("DELETE", "MATCH"), ), - # MATCH - ( - "MATCH (n) RETURN n", - ("MATCH",) - ), - ( - "MATCH (n), (m) RETURN count(n), count(m)", - ("MATCH",) - ), - + ("MATCH (n) RETURN n", ("MATCH",)), + ("MATCH (n), (m) RETURN count(n), count(m)", ("MATCH",)), # MERGE ( "MERGE (n) ON CREATE SET n.created = timestamp() " "ON MATCH SET n.lastSeen = timestamp() " "RETURN n.name, n.created, n.lastSeen", - ("MERGE",) + ("MERGE",), ), - # SET - ( - "MATCH (n) SET n.value = 0 RETURN n", - ("SET", "MATCH") - ), - ( - "MATCH (n), (m) SET n.value = m.value", - ("SET", "MATCH") - ), - + ("MATCH (n) SET n.value = 0 RETURN n", ("SET", "MATCH")), + ("MATCH (n), (m) SET n.value = m.value", ("SET", "MATCH")), # REMOVE - ( - "MATCH (n) REMOVE n.value", - ("REMOVE", "MATCH") - ), - ( - "MATCH (n), (m) REMOVE n.value, m.value", - ("REMOVE", "MATCH") - ), - + ("MATCH (n) REMOVE n.value", ("REMOVE", "MATCH")), + ("MATCH (n), (m) REMOVE n.value, m.value", ("REMOVE", "MATCH")), # INDEX - ( - "CREATE INDEX ON :User (id)", - ("INDEX",) - ), - + ("CREATE INDEX ON :User (id)", ("INDEX",)), # AUTH - ( - "CREATE ROLE test_role", - ("AUTH",) - ), - ( - "DROP ROLE test_role", - ("AUTH",) - ), - ( - "SHOW ROLES", - ("AUTH",) - ), - ( - "CREATE USER test_user", - ("AUTH",) - ), - ( - "SET PASSWORD FOR test_user TO '1234'", - ("AUTH",) - ), - ( - "DROP USER test_user", - ("AUTH",) - ), - ( - "SHOW USERS", - ("AUTH",) - ), - ( - "SET ROLE FOR test_user TO test_role", - ("AUTH",) - ), - ( - "CLEAR ROLE FOR test_user", - ("AUTH",) - ), - ( - "GRANT ALL PRIVILEGES TO test_user", - ("AUTH",) - ), - ( - "DENY ALL PRIVILEGES TO test_user", - ("AUTH",) - ), - ( - "REVOKE ALL PRIVILEGES FROM test_user", - ("AUTH",) - ), - ( - "SHOW PRIVILEGES FOR test_user", - ("AUTH",) - ), - ( - "SHOW ROLE FOR test_user", - ("AUTH",) - ), - ( - "SHOW USERS FOR test_role", - ("AUTH",) - ), + ("CREATE ROLE test_role", ("AUTH",)), + ("DROP ROLE test_role", ("AUTH",)), + ("SHOW ROLES", ("AUTH",)), + ("CREATE USER test_user", ("AUTH",)), + ("SET PASSWORD FOR test_user TO '1234'", ("AUTH",)), + ("DROP USER test_user", ("AUTH",)), + ("SHOW USERS", ("AUTH",)), + ("SET ROLE FOR test_user TO test_role", ("AUTH",)), + ("CLEAR ROLE FOR test_user", ("AUTH",)), + ("GRANT ALL PRIVILEGES TO test_user", ("AUTH",)), + ("DENY ALL PRIVILEGES TO test_user", ("AUTH",)), + ("REVOKE ALL PRIVILEGES FROM test_user", ("AUTH",)), + ("SHOW PRIVILEGES FOR test_user", ("AUTH",)), + ("SHOW ROLE FOR test_user", ("AUTH",)), + ("SHOW USERS FOR test_role", ("AUTH",)), ] -UNAUTHORIZED_ERROR = "You are not authorized to execute this query! Please " \ - "contact your database administrator." +UNAUTHORIZED_ERROR = "You are not authorized to execute this query! Please " "contact your database administrator." def wait_for_server(port, delay=0.1): @@ -166,8 +86,15 @@ def wait_for_server(port, delay=0.1): time.sleep(delay) -def execute_tester(binary, queries, should_fail=False, failure_message="", - username="", password="", check_failure=True): +def execute_tester( + binary, + queries, + should_fail=False, + failure_message="", + username="", + password="", + check_failure=True, +): args = [binary, "--username", username, "--password", password] if should_fail: args.append("--should-fail") @@ -200,18 +127,28 @@ def check_permissions(query_perms, user_perms): def execute_test(memgraph_binary, tester_binary, checker_binary): storage_directory = tempfile.TemporaryDirectory() - memgraph_args = [memgraph_binary, - "--data-directory", storage_directory.name] + memgraph_args = [memgraph_binary, "--data-directory", storage_directory.name] def execute_admin_queries(queries): - return execute_tester(tester_binary, queries, should_fail=False, - check_failure=True, username="admin", - password="admin") + return execute_tester( + tester_binary, + queries, + should_fail=False, + check_failure=True, + username="admin", + password="admin", + ) - def execute_user_queries(queries, should_fail=False, failure_message="", - check_failure=True): - return execute_tester(tester_binary, queries, should_fail, - failure_message, "user", "user", check_failure) + def execute_user_queries(queries, should_fail=False, failure_message="", check_failure=True): + return execute_tester( + tester_binary, + queries, + should_fail, + failure_message, + "user", + "user", + check_failure, + ) # Start the memgraph binary memgraph = subprocess.Popen(list(map(str, memgraph_args))) @@ -227,11 +164,13 @@ def execute_test(memgraph_binary, tester_binary, checker_binary): assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" # Prepare all users - execute_admin_queries([ - "CREATE USER ADmin IDENTIFIED BY 'admin'", - "GRANT ALL PRIVILEGES TO admIN", - "CREATE USER usEr IDENTIFIED BY 'user'", - ]) + execute_admin_queries( + [ + "CREATE USER ADmin IDENTIFIED BY 'admin'", + "GRANT ALL PRIVILEGES TO admIN", + "CREATE USER usEr IDENTIFIED BY 'user'", + ] + ) # Find all existing permissions permissions = set() @@ -243,12 +182,14 @@ def execute_test(memgraph_binary, tester_binary, checker_binary): print("\033[1;36m~~ Starting query test ~~\033[0m") for mask in range(0, 2 ** len(permissions)): user_perms = get_permissions(permissions, mask) - print("\033[1;34m~~ Checking queries with privileges: ", - ", ".join(user_perms), " ~~\033[0m") + print( + "\033[1;34m~~ Checking queries with privileges: ", + ", ".join(user_perms), + " ~~\033[0m", + ) admin_queries = ["REVOKE ALL PRIVILEGES FROM uSer"] if len(user_perms) > 0: - admin_queries.append( - "GRANT {} TO User".format(", ".join(user_perms))) + admin_queries.append("GRANT {} TO User".format(", ".join(user_perms))) execute_admin_queries(admin_queries) authorized, unauthorized = [], [] for query, query_perms in QUERIES: @@ -256,35 +197,43 @@ def execute_test(memgraph_binary, tester_binary, checker_binary): authorized.append(query) else: unauthorized.append(query) - execute_user_queries(authorized, check_failure=False, - failure_message=UNAUTHORIZED_ERROR) - execute_user_queries(unauthorized, should_fail=True, - failure_message=UNAUTHORIZED_ERROR) + execute_user_queries(authorized, check_failure=False, failure_message=UNAUTHORIZED_ERROR) + execute_user_queries(unauthorized, should_fail=True, failure_message=UNAUTHORIZED_ERROR) print("\033[1;36m~~ Finished query test ~~\033[0m\n") # Run the user/role permissions test print("\033[1;36m~~ Starting permissions test ~~\033[0m") - execute_admin_queries([ - "CREATE ROLE roLe", - "REVOKE ALL PRIVILEGES FROM uSeR", - ]) + execute_admin_queries( + [ + "CREATE ROLE roLe", + "REVOKE ALL PRIVILEGES FROM uSeR", + ] + ) execute_checker(checker_binary, []) for user_perm in ["GRANT", "DENY", "REVOKE"]: for role_perm in ["GRANT", "DENY", "REVOKE"]: for mapped in [True, False]: - print("\033[1;34m~~ Checking permissions with user ", - user_perm, ", role ", role_perm, - "user mapped to role:", mapped, " ~~\033[0m") + print( + "\033[1;34m~~ Checking permissions with user ", + user_perm, + ", role ", + role_perm, + "user mapped to role:", + mapped, + " ~~\033[0m", + ) if mapped: execute_admin_queries(["SET ROLE FOR USER TO roLE"]) else: execute_admin_queries(["CLEAR ROLE FOR user"]) user_prep = "FROM" if user_perm == "REVOKE" else "TO" role_prep = "FROM" if role_perm == "REVOKE" else "TO" - execute_admin_queries([ - "{} MATCH {} user".format(user_perm, user_prep), - "{} MATCH {} rOLe".format(role_perm, role_prep) - ]) + execute_admin_queries( + [ + "{} MATCH {} user".format(user_perm, user_prep), + "{} MATCH {} rOLe".format(role_perm, role_prep), + ] + ) expected = [] perms = [user_perm, role_perm] if mapped else [user_perm] if "DENY" in perms: @@ -313,10 +262,8 @@ def execute_test(memgraph_binary, tester_binary, checker_binary): if __name__ == "__main__": memgraph_binary = os.path.join(PROJECT_DIR, "build", "memgraph") - tester_binary = os.path.join(PROJECT_DIR, "build", "tests", - "integration", "auth", "tester") - checker_binary = os.path.join(PROJECT_DIR, "build", "tests", - "integration", "auth", "checker") + tester_binary = os.path.join(PROJECT_DIR, "build", "tests", "integration", "auth", "tester") + checker_binary = os.path.join(PROJECT_DIR, "build", "tests", "integration", "auth", "checker") parser = argparse.ArgumentParser() parser.add_argument("--memgraph", default=memgraph_binary) diff --git a/tests/integration/durability/runner.py b/tests/integration/durability/runner.py index dd8c41456..e877d3ab0 100755 --- a/tests/integration/durability/runner.py +++ b/tests/integration/durability/runner.py @@ -40,7 +40,7 @@ def wait_for_server(port, delay=0.1): def sorted_content(file_path): - with open(file_path, 'r') as fin: + with open(file_path, "r") as fin: return sorted(list(map(lambda x: x.strip(), fin.readlines()))) @@ -52,32 +52,30 @@ def list_to_string(data): return ret -def execute_test( - memgraph_binary, - dump_binary, - test_directory, - test_type, - write_expected): - assert test_type in ["SNAPSHOT", "WAL"], \ - "Test type should be either 'SNAPSHOT' or 'WAL'." - print("\033[1;36m~~ Executing test {} ({}) ~~\033[0m" - .format(os.path.relpath(test_directory, TESTS_DIR), test_type)) +def execute_test(memgraph_binary, dump_binary, test_directory, test_type, write_expected): + assert test_type in [ + "SNAPSHOT", + "WAL", + ], "Test type should be either 'SNAPSHOT' or 'WAL'." + print("\033[1;36m~~ Executing test {} ({}) ~~\033[0m".format(os.path.relpath(test_directory, TESTS_DIR), test_type)) working_data_directory = tempfile.TemporaryDirectory() if test_type == "SNAPSHOT": snapshots_dir = os.path.join(working_data_directory.name, "snapshots") os.makedirs(snapshots_dir) - shutil.copy(os.path.join(test_directory, SNAPSHOT_FILE_NAME), - snapshots_dir) + shutil.copy(os.path.join(test_directory, SNAPSHOT_FILE_NAME), snapshots_dir) else: wal_dir = os.path.join(working_data_directory.name, "wal") os.makedirs(wal_dir) shutil.copy(os.path.join(test_directory, WAL_FILE_NAME), wal_dir) - memgraph_args = [memgraph_binary, - "--storage-recover-on-startup", - "--storage-properties-on-edges", - "--data-directory", working_data_directory.name] + memgraph_args = [ + memgraph_binary, + "--storage-recover-on-startup", + "--storage-properties-on-edges", + "--data-directory", + working_data_directory.name, + ] # Start the memgraph binary memgraph = subprocess.Popen(memgraph_args) @@ -104,22 +102,21 @@ def execute_test( dump_file_name = DUMP_SNAPSHOT_FILE_NAME if test_type == "SNAPSHOT" else DUMP_WAL_FILE_NAME if write_expected: - with open(dump_output_file.name, 'r') as dump: + with open(dump_output_file.name, "r") as dump: queries_got = dump.readlines() # Write dump files expected_dump_file = os.path.join(test_directory, dump_file_name) - with open(expected_dump_file, 'w') as expected: + with open(expected_dump_file, "w") as expected: expected.writelines(queries_got) else: # Compare dump files expected_dump_file = os.path.join(test_directory, dump_file_name) - assert os.path.exists(expected_dump_file), \ - "Could not find expected dump path {}".format(expected_dump_file) + assert os.path.exists(expected_dump_file), "Could not find expected dump path {}".format(expected_dump_file) queries_got = sorted_content(dump_output_file.name) queries_expected = sorted_content(expected_dump_file) - assert queries_got == queries_expected, "Expected\n{}\nto be equal to\n" \ - "{}".format(list_to_string(queries_got), - list_to_string(queries_expected)) + assert queries_got == queries_expected, "Expected\n{}\nto be equal to\n" "{}".format( + list_to_string(queries_got), list_to_string(queries_expected) + ) print("\033[1;32m~~ Test successful ~~\033[0m\n") @@ -141,15 +138,17 @@ def find_test_directories(directory): continue snapshot_file = os.path.join(test_dir_path, SNAPSHOT_FILE_NAME) wal_file = os.path.join(test_dir_path, WAL_FILE_NAME) - dump_snapshot_file = os.path.join( - test_dir_path, DUMP_SNAPSHOT_FILE_NAME) + dump_snapshot_file = os.path.join(test_dir_path, DUMP_SNAPSHOT_FILE_NAME) dump_wal_file = os.path.join(test_dir_path, DUMP_WAL_FILE_NAME) - if (os.path.isfile(snapshot_file) and os.path.isfile(dump_snapshot_file) - and os.path.isfile(wal_file) and os.path.isfile(dump_wal_file)): + if ( + os.path.isfile(snapshot_file) + and os.path.isfile(dump_snapshot_file) + and os.path.isfile(wal_file) + and os.path.isfile(dump_wal_file) + ): test_dirs.append(test_dir_path) else: - raise Exception("Missing data in test directory '{}'" - .format(test_dir_path)) + raise Exception("Missing data in test directory '{}'".format(test_dir_path)) return test_dirs @@ -161,26 +160,17 @@ if __name__ == "__main__": parser.add_argument("--memgraph", default=memgraph_binary) parser.add_argument("--dump", default=dump_binary) parser.add_argument( - '--write-expected', - action='store_true', - help='Overwrite the expected cypher with results from current run') + "--write-expected", + action="store_true", + help="Overwrite the expected cypher with results from current run", + ) args = parser.parse_args() test_directories = find_test_directories(TESTS_DIR) assert len(test_directories) > 0, "No tests have been found!" for test_directory in test_directories: - execute_test( - args.memgraph, - args.dump, - test_directory, - "SNAPSHOT", - args.write_expected) - execute_test( - args.memgraph, - args.dump, - test_directory, - "WAL", - args.write_expected) + execute_test(args.memgraph, args.dump, test_directory, "SNAPSHOT", args.write_expected) + execute_test(args.memgraph, args.dump, test_directory, "WAL", args.write_expected) sys.exit(0) diff --git a/tests/integration/ldap/runner.py b/tests/integration/ldap/runner.py index 6b0446690..1177cb9f0 100755 --- a/tests/integration/ldap/runner.py +++ b/tests/integration/ldap/runner.py @@ -52,8 +52,14 @@ def wait_for_server(port, delay=0.1): time.sleep(delay) -def execute_tester(binary, queries, username="", password="", - auth_should_fail=False, query_should_fail=False): +def execute_tester( + binary, + queries, + username="", + password="", + auth_should_fail=False, + query_should_fail=False, +): if password == "": password = username args = [binary, "--username", username, "--password", password] @@ -76,18 +82,14 @@ class Memgraph: def start(self, **kwargs): self.stop() self._storage_directory = tempfile.TemporaryDirectory() - self._auth_module = os.path.join(self._storage_directory.name, - "ldap.py") - self._auth_config = os.path.join(self._storage_directory.name, - "ldap.yaml") - script_file = os.path.join(PROJECT_DIR, "src", "auth", - "reference_modules", "ldap.py") + self._auth_module = os.path.join(self._storage_directory.name, "ldap.py") + self._auth_config = os.path.join(self._storage_directory.name, "ldap.yaml") + script_file = os.path.join(PROJECT_DIR, "src", "auth", "reference_modules", "ldap.py") virtualenv_bin = os.path.join(SCRIPT_DIR, "ve3", "bin", "python3") with open(script_file) as fin: data = fin.read() data = data.replace("/usr/bin/python3", virtualenv_bin) - data = data.replace("/etc/memgraph/auth/ldap.yaml", - self._auth_config) + data = data.replace("/etc/memgraph/auth/ldap.yaml", self._auth_config) with open(self._auth_module, "w") as fout: fout.write(data) os.chmod(self._auth_module, stat.S_IRWXU | stat.S_IRWXG) @@ -106,10 +108,13 @@ class Memgraph: } with open(self._auth_config, "w") as f: f.write(CONFIG_TEMPLATE.format(**config)) - args = [self._binary, - "--data-directory", self._storage_directory.name, - "--auth-module-executable", - kwargs.pop("module_executable", self._auth_module)] + args = [ + self._binary, + "--data-directory", + self._storage_directory.name, + "--auth-module-executable", + kwargs.pop("module_executable", self._auth_module), + ] for key, value in kwargs.items(): ldap_key = "--auth-module-" + key.replace("_", "-") if isinstance(value, bool): @@ -119,8 +124,7 @@ class Memgraph: args.append(value) self._process = subprocess.Popen(args) time.sleep(0.1) - assert self._process.poll() is None, "Memgraph process died " \ - "prematurely!" + assert self._process.poll() is None, "Memgraph process died " "prematurely!" wait_for_server(7687) def stop(self, check=True): @@ -137,8 +141,7 @@ class Memgraph: def initialize_test(memgraph, tester_binary, **kwargs): memgraph.start(module_executable="") - execute_tester(tester_binary, - ["CREATE USER root", "GRANT ALL PRIVILEGES TO root"]) + execute_tester(tester_binary, ["CREATE USER root", "GRANT ALL PRIVILEGES TO root"]) check_login = kwargs.pop("check_login", True) memgraph.restart(**kwargs) if check_login: @@ -170,18 +173,15 @@ def test_role_mapping(memgraph, tester_binary): initialize_test(memgraph, tester_binary) execute_tester(tester_binary, [], "alice") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", - query_should_fail=True) + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root") execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") execute_tester(tester_binary, [], "bob") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "bob", - query_should_fail=True) + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "bob", query_should_fail=True) execute_tester(tester_binary, [], "carol") - execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol", - query_should_fail=True) + execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol", query_should_fail=True) execute_tester(tester_binary, ["GRANT CREATE TO admin"], "root") execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol") execute_tester(tester_binary, ["CREATE (n) RETURN n"], "dave") @@ -192,15 +192,13 @@ def test_role_mapping(memgraph, tester_binary): def test_role_removal(memgraph, tester_binary): initialize_test(memgraph, tester_binary) execute_tester(tester_binary, [], "alice") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", - query_should_fail=True) + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root") execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") memgraph.restart(manage_roles=False) execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") execute_tester(tester_binary, ["CLEAR ROLE FOR alice"], "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", - query_should_fail=True) + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) memgraph.stop() @@ -229,28 +227,22 @@ def test_user_is_role(memgraph, tester_binary): def test_user_permissions_persistancy(memgraph, tester_binary): initialize_test(memgraph, tester_binary) - execute_tester(tester_binary, - ["CREATE USER alice", "GRANT MATCH TO alice"], "root") + execute_tester(tester_binary, ["CREATE USER alice", "GRANT MATCH TO alice"], "root") execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") memgraph.stop() def test_role_permissions_persistancy(memgraph, tester_binary): initialize_test(memgraph, tester_binary) - execute_tester(tester_binary, - ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], - "root") + execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root") execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") memgraph.stop() def test_only_authentication(memgraph, tester_binary): initialize_test(memgraph, tester_binary, manage_roles=False) - execute_tester(tester_binary, - ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], - "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", - query_should_fail=True) + execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root") + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) memgraph.stop() @@ -267,22 +259,16 @@ def test_wrong_suffix(memgraph, tester_binary): def test_suffix_with_spaces(memgraph, tester_binary): - initialize_test(memgraph, tester_binary, - suffix=", ou= people, dc = memgraph, dc = com") - execute_tester(tester_binary, - ["CREATE USER alice", "GRANT MATCH TO alice"], "root") + initialize_test(memgraph, tester_binary, suffix=", ou= people, dc = memgraph, dc = com") + execute_tester(tester_binary, ["CREATE USER alice", "GRANT MATCH TO alice"], "root") execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") memgraph.stop() def test_role_mapping_wrong_root_dn(memgraph, tester_binary): - initialize_test(memgraph, tester_binary, - root_dn="ou=invalid,dc=memgraph,dc=com") - execute_tester(tester_binary, - ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], - "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", - query_should_fail=True) + initialize_test(memgraph, tester_binary, root_dn="ou=invalid,dc=memgraph,dc=com") + execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root") + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) memgraph.restart() execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") memgraph.stop() @@ -290,11 +276,8 @@ def test_role_mapping_wrong_root_dn(memgraph, tester_binary): def test_role_mapping_wrong_root_objectclass(memgraph, tester_binary): initialize_test(memgraph, tester_binary, root_objectclass="person") - execute_tester(tester_binary, - ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], - "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", - query_should_fail=True) + execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root") + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) memgraph.restart() execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") memgraph.stop() @@ -302,11 +285,8 @@ def test_role_mapping_wrong_root_objectclass(memgraph, tester_binary): def test_role_mapping_wrong_user_attribute(memgraph, tester_binary): initialize_test(memgraph, tester_binary, user_attribute="cn") - execute_tester(tester_binary, - ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], - "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", - query_should_fail=True) + execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root") + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) memgraph.restart() execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") memgraph.stop() @@ -314,8 +294,7 @@ def test_role_mapping_wrong_user_attribute(memgraph, tester_binary): def test_wrong_password(memgraph, tester_binary): initialize_test(memgraph, tester_binary) - execute_tester(tester_binary, [], "root", password="sudo", - auth_should_fail=True) + execute_tester(tester_binary, [], "root", password="sudo", auth_should_fail=True) execute_tester(tester_binary, ["SHOW USERS"], "root", password="root") memgraph.stop() @@ -326,12 +305,10 @@ def test_password_persistancy(memgraph, tester_binary): execute_tester(tester_binary, ["SHOW USERS"], "root", password="sudo") execute_tester(tester_binary, ["SHOW USERS"], "root", password="root") memgraph.restart() - execute_tester(tester_binary, [], "root", password="sudo", - auth_should_fail=True) + execute_tester(tester_binary, [], "root", password="sudo", auth_should_fail=True) execute_tester(tester_binary, ["SHOW USERS"], "root", password="root") memgraph.restart(module_executable="") - execute_tester(tester_binary, [], "root", password="sudo", - auth_should_fail=True) + execute_tester(tester_binary, [], "root", password="sudo", auth_should_fail=True) execute_tester(tester_binary, ["SHOW USERS"], "root", password="root") memgraph.stop() @@ -339,33 +316,25 @@ def test_password_persistancy(memgraph, tester_binary): def test_user_multiple_roles(memgraph, tester_binary): initialize_test(memgraph, tester_binary, check_login=False) memgraph.restart() - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", - query_should_fail=True) - execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", - query_should_fail=True) + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True) + execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True) memgraph.restart(manage_roles=False) - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", - query_should_fail=True) - execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", - query_should_fail=True) + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True) + execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True) memgraph.restart(manage_roles=False, root_dn="") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", - query_should_fail=True) - execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", - query_should_fail=True) + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True) + execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True) memgraph.stop() def test_starttls_failure(memgraph, tester_binary): - initialize_test(memgraph, tester_binary, encryption="starttls", - check_login=False) + initialize_test(memgraph, tester_binary, encryption="starttls", check_login=False) execute_tester(tester_binary, [], "root", auth_should_fail=True) memgraph.stop() def test_ssl_failure(memgraph, tester_binary): - initialize_test(memgraph, tester_binary, encryption="ssl", - check_login=False) + initialize_test(memgraph, tester_binary, encryption="ssl", check_login=False) execute_tester(tester_binary, [], "root", auth_should_fail=True) memgraph.stop() @@ -375,22 +344,25 @@ def test_ssl_failure(memgraph, tester_binary): if __name__ == "__main__": memgraph_binary = os.path.join(PROJECT_DIR, "build", "memgraph") - tester_binary = os.path.join(PROJECT_DIR, "build", "tests", - "integration", "ldap", "tester") + tester_binary = os.path.join(PROJECT_DIR, "build", "tests", "integration", "ldap", "tester") parser = argparse.ArgumentParser() parser.add_argument("--memgraph", default=memgraph_binary) parser.add_argument("--tester", default=tester_binary) - parser.add_argument("--openldap-dir", - default=os.path.join(SCRIPT_DIR, "openldap-2.4.47")) + parser.add_argument("--openldap-dir", default=os.path.join(SCRIPT_DIR, "openldap-2.4.47")) args = parser.parse_args() # Setup Memgraph handler memgraph = Memgraph(args.memgraph) # Start the slapd binary - slapd_args = [os.path.join(args.openldap_dir, "exe", "libexec", "slapd"), - "-h", "ldap://127.0.0.1:1389/", "-d", "0"] + slapd_args = [ + os.path.join(args.openldap_dir, "exe", "libexec", "slapd"), + "-h", + "ldap://127.0.0.1:1389/", + "-d", + "0", + ] slapd = subprocess.Popen(slapd_args) time.sleep(0.1) assert slapd.poll() is None, "slapd process died prematurely!" @@ -409,8 +381,7 @@ if __name__ == "__main__": if slapd_stat != 0: print("slapd process didn't exit cleanly!") - assert mg_stat == 0 and slapd_stat == 0, "Some of the processes " \ - "(memgraph, slapd) crashed!" + assert mg_stat == 0 and slapd_stat == 0, "Some of the processes " "(memgraph, slapd) crashed!" # Execute tests names = sorted(globals().keys()) diff --git a/tests/integration/mg_import_csv/runner.py b/tests/integration/mg_import_csv/runner.py index afaa58e28..60c292709 100755 --- a/tests/integration/mg_import_csv/runner.py +++ b/tests/integration/mg_import_csv/runner.py @@ -46,17 +46,18 @@ def list_to_string(data): def verify_lifetime(memgraph_binary, mg_import_csv_binary): - print("\033[1;36m~~ Verifying that mg_import_csv can't be started while " - "memgraph is running ~~\033[0m") + print("\033[1;36m~~ Verifying that mg_import_csv can't be started while " "memgraph is running ~~\033[0m") storage_directory = tempfile.TemporaryDirectory() # Generate common args - common_args = ["--data-directory", storage_directory.name, - "--storage-properties-on-edges=false"] + common_args = [ + "--data-directory", + storage_directory.name, + "--storage-properties-on-edges=false", + ] # Start the memgraph binary - memgraph_args = [memgraph_binary, "--storage-recover-on-startup"] + \ - common_args + memgraph_args = [memgraph_binary, "--storage-recover-on-startup"] + common_args memgraph = subprocess.Popen(list(map(str, memgraph_args))) time.sleep(0.1) assert memgraph.poll() is None, "Memgraph process died prematurely!" @@ -70,14 +71,12 @@ def verify_lifetime(memgraph_binary, mg_import_csv_binary): assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" # Execute mg_import_csv. - mg_import_csv_args = [mg_import_csv_binary, "--nodes", "/dev/null"] + \ - common_args + mg_import_csv_args = [mg_import_csv_binary, "--nodes", "/dev/null"] + common_args ret = subprocess.run(mg_import_csv_args) # Check the return code if ret.returncode == 0: - raise Exception( - "The importer was able to run while memgraph was running!") + raise Exception("The importer was able to run while memgraph was running!") # Shutdown the memgraph binary memgraph.terminate() @@ -86,27 +85,34 @@ def verify_lifetime(memgraph_binary, mg_import_csv_binary): print("\033[1;32m~~ Test successful ~~\033[0m\n") -def execute_test(name, test_path, test_config, memgraph_binary, - mg_import_csv_binary, tester_binary, write_expected): +def execute_test( + name, + test_path, + test_config, + memgraph_binary, + mg_import_csv_binary, + tester_binary, + write_expected, +): print("\033[1;36m~~ Executing test", name, "~~\033[0m") storage_directory = tempfile.TemporaryDirectory() # Verify test configuration - if ("import_should_fail" not in test_config and - "expected" not in test_config) or \ - ("import_should_fail" in test_config and - "expected" in test_config): - raise Exception("The test should specify either 'import_should_fail' " - "or 'expected'!") + if ("import_should_fail" not in test_config and "expected" not in test_config) or ( + "import_should_fail" in test_config and "expected" in test_config + ): + raise Exception("The test should specify either 'import_should_fail' " "or 'expected'!") expected_path = test_config.pop("expected", "") import_should_fail = test_config.pop("import_should_fail", False) # Generate common args properties_on_edges = bool(test_config.pop("properties_on_edges", False)) - common_args = ["--data-directory", storage_directory.name, - "--storage-properties-on-edges=" + - str(properties_on_edges).lower()] + common_args = [ + "--data-directory", + storage_directory.name, + "--storage-properties-on-edges=" + str(properties_on_edges).lower(), + ] # Generate mg_import_csv args using flags specified in the test mg_import_csv_args = [mg_import_csv_binary] + common_args @@ -125,19 +131,16 @@ def execute_test(name, test_path, test_config, memgraph_binary, if import_should_fail: if ret.returncode == 0: - raise Exception("The import should have failed, but it " - "succeeded instead!") + raise Exception("The import should have failed, but it " "succeeded instead!") else: print("\033[1;32m~~ Test successful ~~\033[0m\n") return else: if ret.returncode != 0: - raise Exception("The import should have succeeded, but it " - "failed instead!") + raise Exception("The import should have succeeded, but it " "failed instead!") # Start the memgraph binary - memgraph_args = [memgraph_binary, "--storage-recover-on-startup"] + \ - common_args + memgraph_args = [memgraph_binary, "--storage-recover-on-startup"] + common_args memgraph = subprocess.Popen(list(map(str, memgraph_args))) time.sleep(0.1) assert memgraph.poll() is None, "Memgraph process died prematurely!" @@ -151,17 +154,17 @@ def execute_test(name, test_path, test_config, memgraph_binary, assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" # Get the contents of the database - queries_got = extract_rows(subprocess.run( - [tester_binary], stdout=subprocess.PIPE, - check=True).stdout.decode("utf-8")) + queries_got = extract_rows( + subprocess.run([tester_binary], stdout=subprocess.PIPE, check=True).stdout.decode("utf-8") + ) # Shutdown the memgraph binary memgraph.terminate() assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" if write_expected: - with open(os.path.join(test_path, expected_path), 'w') as expected: - expected.write('\n'.join(queries_got)) + with open(os.path.join(test_path, expected_path), "w") as expected: + expected.write("\n".join(queries_got)) else: if expected_path: @@ -173,18 +176,16 @@ def execute_test(name, test_path, test_config, memgraph_binary, # Verify the queries queries_expected.sort() queries_got.sort() - assert queries_got == queries_expected, "Expected\n{}\nto be equal to\n" \ - "{}".format(list_to_string(queries_got), - list_to_string(queries_expected)) + assert queries_got == queries_expected, "Expected\n{}\nto be equal to\n" "{}".format( + list_to_string(queries_got), list_to_string(queries_expected) + ) print("\033[1;32m~~ Test successful ~~\033[0m\n") if __name__ == "__main__": memgraph_binary = os.path.join(BUILD_DIR, "memgraph") - mg_import_csv_binary = os.path.join( - BUILD_DIR, "src", "mg_import_csv") - tester_binary = os.path.join( - BUILD_DIR, "tests", "integration", "mg_import_csv", "tester") + mg_import_csv_binary = os.path.join(BUILD_DIR, "src", "mg_import_csv") + tester_binary = os.path.join(BUILD_DIR, "tests", "integration", "mg_import_csv", "tester") parser = argparse.ArgumentParser() parser.add_argument("--memgraph", default=memgraph_binary) @@ -193,7 +194,8 @@ if __name__ == "__main__": parser.add_argument( "--write-expected", action="store_true", - help="Overwrite the expected values with the results of the current run") + help="Overwrite the expected values with the results of the current run", + ) args = parser.parse_args() # First test whether the CSV importer can be started while the main @@ -211,7 +213,14 @@ if __name__ == "__main__": testcases = yaml.safe_load(f) for test_config in testcases: test_name = name + "/" + test_config.pop("name") - execute_test(test_name, test_path, test_config, args.memgraph, - args.mg_import_csv, args.tester, args.write_expected) + execute_test( + test_name, + test_path, + test_config, + args.memgraph, + args.mg_import_csv, + args.tester, + args.write_expected, + ) sys.exit(0) diff --git a/tests/integration/telemetry/runner.py b/tests/integration/telemetry/runner.py index 6fa81d91f..84ce90471 100755 --- a/tests/integration/telemetry/runner.py +++ b/tests/integration/telemetry/runner.py @@ -36,8 +36,7 @@ def execute_test(**kwargs): timeout = duration * 2 if "hang" not in kwargs else duration * 2 + 60 success = False - server_args = [server_binary, "--interval", interval, - "--duration", duration] + server_args = [server_binary, "--interval", interval, "--duration", duration] for flag, value in kwargs.items(): flag = "--" + flag.replace("_", "-") # We handle boolean flags here. The type of value must be `bool`, and @@ -48,9 +47,15 @@ def execute_test(**kwargs): else: server_args.extend([flag, value]) - client_args = [client_binary, "--interval", interval, - "--duration", duration, - "--storage-directory", storage_directory] + client_args = [ + client_binary, + "--interval", + interval, + "--duration", + duration, + "--storage-directory", + storage_directory, + ] if endpoint: client_args.extend(["--endpoint", endpoint]) @@ -61,8 +66,7 @@ def execute_test(**kwargs): assert server.poll() is None, "Server process died prematurely!" try: - subprocess.run(list(map(str, client_args)), timeout=timeout, - check=True) + subprocess.run(list(map(str, client_args)), timeout=timeout, check=True) finally: if server is None: success = True @@ -88,16 +92,14 @@ TESTS = [ {"endpoint": "http://127.0.0.1:9000/nonexistant/", "no_check": True}, {"start_server": False}, {"startups": 4, "no_check_duration": True}, # the last 3 tests failed - # to send any data + this test - {"add_garbage": True} + # to send any data + this test + {"add_garbage": True}, ] if __name__ == "__main__": server_binary = os.path.join(SCRIPT_DIR, "server.py") - client_binary = os.path.join(PROJECT_DIR, "build", "tests", - "integration", "telemetry", "client") - kvstore_console_binary = os.path.join(PROJECT_DIR, "build", "tests", - "manual", "kvstore_console") + client_binary = os.path.join(PROJECT_DIR, "build", "tests", "integration", "telemetry", "client") + kvstore_console_binary = os.path.join(PROJECT_DIR, "build", "tests", "manual", "kvstore_console") parser = argparse.ArgumentParser() parser.add_argument("--client", default=client_binary) @@ -108,19 +110,23 @@ if __name__ == "__main__": storage = tempfile.TemporaryDirectory() for test in TESTS: - print("\033[1;36m~~ Executing test with arguments:", - json.dumps(test, sort_keys=True), "~~\033[0m") + print( + "\033[1;36m~~ Executing test with arguments:", + json.dumps(test, sort_keys=True), + "~~\033[0m", + ) if test.pop("add_garbage", False): - proc = subprocess.Popen([args.kvstore_console, "--path", - storage.name], stdin=subprocess.PIPE, - stdout=subprocess.DEVNULL) + proc = subprocess.Popen( + [args.kvstore_console, "--path", storage.name], + stdin=subprocess.PIPE, + stdout=subprocess.DEVNULL, + ) proc.communicate("put garbage garbage".encode("utf-8")) assert proc.wait() == 0 try: - success = execute_test(client=args.client, server=args.server, - storage=storage.name, **test) + success = execute_test(client=args.client, server=args.server, storage=storage.name, **test) except Exception as e: print("\033[1;33m", e, "\033[0m", sep="") success = False diff --git a/tests/integration/telemetry/server.py b/tests/integration/telemetry/server.py index 0fca6986d..24dad7c2b 100755 --- a/tests/integration/telemetry/server.py +++ b/tests/integration/telemetry/server.py @@ -46,7 +46,7 @@ def build_handler(storage, args): assert self.headers["accept"] == "application/json" assert self.headers["content-type"] == "application/json" - content_len = int(self.headers.get('content-length', 0)) + content_len = int(self.headers.get("content-length", 0)) data = json.loads(self.rfile.read(content_len).decode("utf-8")) if self.path not in [args.path, args.redirect_path]: @@ -195,4 +195,4 @@ if __name__ == "__main__": verify_storage(startup, args) # machine id has to be same for every run on the same machine - assert len(set(map(lambda x: x['machine_id'], itertools.chain(*startups)))) == 1 + assert len(set(map(lambda x: x["machine_id"], itertools.chain(*startups)))) == 1 diff --git a/tests/macro_benchmark/clients.py b/tests/macro_benchmark/clients.py index 859b6f1c3..e67441f3f 100644 --- a/tests/macro_benchmark/clients.py +++ b/tests/macro_benchmark/clients.py @@ -34,7 +34,8 @@ class QueryClient: self.default_num_workers = default_num_workers def __call__(self, queries, database, num_workers=None): - if num_workers is None: num_workers = self.default_num_workers + if num_workers is None: + num_workers = self.default_num_workers self.log.debug("execute('%s')", str(queries)) client_path = "tests/macro_benchmark/query_client" @@ -53,29 +54,36 @@ class QueryClient: output_fd, output = tempfile.mkstemp() os.close(output_fd) - client_args = ["--port", database.args.port, - "--num-workers", str(num_workers), - "--output", output] + client_args = [ + "--port", + database.args.port, + "--num-workers", + str(num_workers), + "--output", + output, + ] cpu_time_start = database.database_bin.get_usage()["cpu"] # TODO make the timeout configurable per query or something - return_code = self.client.run_and_wait( - client, client_args, timeout=600, stdin=queries_path) + return_code = self.client.run_and_wait(client, client_args, timeout=600, stdin=queries_path) usage = database.database_bin.get_usage() cpu_time_end = usage["cpu"] os.remove(queries_path) if return_code != 0: with open(self.client.get_stderr()) as f: stderr = f.read() - self.log.error("Error while executing queries '%s'. " - "Failed with return_code %d and stderr:\n%s", - str(queries), return_code, stderr) + self.log.error( + "Error while executing queries '%s'. " "Failed with return_code %d and stderr:\n%s", + str(queries), + return_code, + stderr, + ) raise Exception("BoltClient execution failed") - data = {"groups" : []} + data = {"groups": []} with open(output) as f: for line in f: - data["groups"].append(json.loads(line)) + data["groups"].append(json.loads(line)) data[CPU_TIME] = cpu_time_end - cpu_time_start data[MAX_MEMORY] = usage["max_memory"] @@ -94,7 +102,8 @@ class LongRunningClient: # TODO: This is quite similar to __call__ method of QueryClient. Remove # duplication. def __call__(self, config, database, duration, client, num_workers=None): - if num_workers is None: num_workers = self.default_num_workers + if num_workers is None: + num_workers = self.default_num_workers self.log.debug("execute('%s')", config) client_path = "tests/macro_benchmark/{}".format(client) @@ -113,32 +122,41 @@ class LongRunningClient: output_fd, output = tempfile.mkstemp() os.close(output_fd) - client_args = ["--port", database.args.port, - "--num-workers", str(num_workers), - "--output", output, - "--duration", str(duration), - "--db", database.name, - "--scenario", self.workload] + client_args = [ + "--port", + database.args.port, + "--num-workers", + str(num_workers), + "--output", + output, + "--duration", + str(duration), + "--db", + database.name, + "--scenario", + self.workload, + ] - return_code = self.client.run_and_wait( - client, client_args, timeout=600, stdin=config_path) + return_code = self.client.run_and_wait(client, client_args, timeout=600, stdin=config_path) os.remove(config_path) if return_code != 0: with open(self.client.get_stderr()) as f: stderr = f.read() - self.log.error("Error while executing config '%s'. " - "Failed with return_code %d and stderr:\n%s", - str(config), return_code, stderr) + self.log.error( + "Error while executing config '%s'. " "Failed with return_code %d and stderr:\n%s", + str(config), + return_code, + stderr, + ) raise Exception("BoltClient execution failed") - # TODO: We shouldn't wait for process to finish to start reading output. # We should implement periodic reading of data and stream data when it # becomes available. data = [] with open(output) as f: for line in f: - data.append(json.loads(line)) + data.append(json.loads(line)) os.remove(output) return data diff --git a/tests/macro_benchmark/common.py b/tests/macro_benchmark/common.py index 9e8b08afd..e2fdad35f 100644 --- a/tests/macro_benchmark/common.py +++ b/tests/macro_benchmark/common.py @@ -14,9 +14,11 @@ from argparse import ArgumentParser try: import jail + APOLLO = True except: import jail_faker as jail + APOLLO = False @@ -45,13 +47,15 @@ def get_absolute_path(path, base=""): def set_cpus(flag_name, process, args): argp = ArgumentParser() # named, optional arguments - argp.add_argument("--" + flag_name, nargs="+", type=int, help="cpus that " - "will be used by process. Obligatory on Apollo, ignored " - "otherwise.") + argp.add_argument( + "--" + flag_name, + nargs="+", + type=int, + help="cpus that " "will be used by process. Obligatory on Apollo, ignored " "otherwise.", + ) args, _ = argp.parse_known_args(args) attr_flag_name = flag_name.replace("-", "_") cpus = getattr(args, attr_flag_name) - assert not APOLLO or cpus, \ - "flag --{} is obligatory on Apollo".format(flag_name) + assert not APOLLO or cpus, "flag --{} is obligatory on Apollo".format(flag_name) if cpus: - process.set_cpus(cpus, hyper = False) + process.set_cpus(cpus, hyper=False) diff --git a/tests/macro_benchmark/databases.py b/tests/macro_benchmark/databases.py index 439464c85..6bfff5fd2 100644 --- a/tests/macro_benchmark/databases.py +++ b/tests/macro_benchmark/databases.py @@ -36,13 +36,12 @@ class Memgraph: """ Knows how to start and stop memgraph. """ + def __init__(self, args, num_workers): self.log = logging.getLogger("MemgraphRunner") argp = ArgumentParser("MemgraphArgumentParser") - argp.add_argument("--runner-bin", - default=get_absolute_path("memgraph", "build")) - argp.add_argument("--port", default="7687", - help="Database and client port") + argp.add_argument("--runner-bin", default=get_absolute_path("memgraph", "build")) + argp.add_argument("--port", default="7687", help="Database and client port") argp.add_argument("--data-directory", default=None) argp.add_argument("--storage-snapshot-on-exit", action="store_true") argp.add_argument("--storage-recover-on-startup", action="store_true") @@ -55,8 +54,12 @@ class Memgraph: def start(self): self.log.info("start") - database_args = ["--bolt-port", self.args.port, - "--query-execution-timeout-sec", "0"] + database_args = [ + "--bolt-port", + self.args.port, + "--query-execution-timeout-sec", + "0", + ] if self.num_workers: database_args += ["--bolt-num-workers", str(self.num_workers)] if self.args.data_directory: @@ -82,15 +85,13 @@ class Neo: """ Knows how to start and stop neo4j. """ + def __init__(self, args, config): self.log = logging.getLogger("NeoRunner") argp = ArgumentParser("NeoArgumentParser") - argp.add_argument("--runner-bin", default=get_absolute_path( - "neo4j/bin/neo4j", "libs")) - argp.add_argument("--port", default="7687", - help="Database and client port") - argp.add_argument("--http-port", default="7474", - help="Database and client port") + argp.add_argument("--runner-bin", default=get_absolute_path("neo4j/bin/neo4j", "libs")) + argp.add_argument("--port", default="7687", help="Database and client port") + argp.add_argument("--http-port", default="7474", help="Database and client port") self.log.info("Initializing Runner with arguments %r", args) self.args, _ = argp.parse_known_args(args) self.config = config @@ -105,24 +106,23 @@ class Neo: self.neo4j_home_path = tempfile.mkdtemp(dir="/dev/shm") try: - os.symlink(os.path.join(get_absolute_path("neo4j", "libs"), "lib"), - os.path.join(self.neo4j_home_path, "lib")) + os.symlink( + os.path.join(get_absolute_path("neo4j", "libs"), "lib"), + os.path.join(self.neo4j_home_path, "lib"), + ) neo4j_conf_dir = os.path.join(self.neo4j_home_path, "conf") neo4j_conf_file = os.path.join(neo4j_conf_dir, "neo4j.conf") os.mkdir(neo4j_conf_dir) shutil.copyfile(self.config, neo4j_conf_file) with open(neo4j_conf_file, "a") as f: - f.write("\ndbms.connector.bolt.listen_address=:" + - self.args.port + "\n") - f.write("\ndbms.connector.http.listen_address=:" + - self.args.http_port + "\n") + f.write("\ndbms.connector.bolt.listen_address=:" + self.args.port + "\n") + f.write("\ndbms.connector.http.listen_address=:" + self.args.http_port + "\n") # environment cwd = os.path.dirname(self.args.runner_bin) env = {"NEO4J_HOME": self.neo4j_home_path} - self.database_bin.run(self.args.runner_bin, args=["console"], - env=env, timeout=600, cwd=cwd) + self.database_bin.run(self.args.runner_bin, args=["console"], env=env, timeout=600, cwd=cwd) except: shutil.rmtree(self.neo4j_home_path) raise Exception("Couldn't run Neo4j!") diff --git a/tests/macro_benchmark/groups/1000_create/vertex_big.run.py b/tests/macro_benchmark/groups/1000_create/vertex_big.run.py index 7a7d19e84..26810e9eb 100644 --- a/tests/macro_benchmark/groups/1000_create/vertex_big.run.py +++ b/tests/macro_benchmark/groups/1000_create/vertex_big.run.py @@ -9,4 +9,7 @@ # by the Apache License, Version 2.0, included in the file # licenses/APL.txt. -print("""CREATE (:L1:L2:L3:L4:L5:L6:L7 {p1: true, p2: 42, p3: "Here is some text that is not extremely short", p4:"Short text", p5: 234.434, p6: 11.11, p7: false});""" * 1000) +print( + """CREATE (:L1:L2:L3:L4:L5:L6:L7 {p1: true, p2: 42, p3: "Here is some text that is not extremely short", p4:"Short text", p5: 234.434, p6: 11.11, p7: false});""" + * 1000 +) diff --git a/tests/macro_benchmark/groups/aggregation/setup.py b/tests/macro_benchmark/groups/aggregation/setup.py index 6eaa7f6e7..b875ce19e 100644 --- a/tests/macro_benchmark/groups/aggregation/setup.py +++ b/tests/macro_benchmark/groups/aggregation/setup.py @@ -15,6 +15,5 @@ VERTEX_COUNT = 100000 for i in range(VERTEX_COUNT): print("CREATE (n%d {x: %d})" % (i, i)) # batch CREATEs because we can't execute all at once - if (i != 0 and i % BATCH_SIZE == 0) or \ - (i + 1 == VERTEX_COUNT): + if (i != 0 and i % BATCH_SIZE == 0) or (i + 1 == VERTEX_COUNT): print(";") diff --git a/tests/macro_benchmark/groups/aggregation_parallel/setup.py b/tests/macro_benchmark/groups/aggregation_parallel/setup.py index fd4d89471..2bc43beba 100644 --- a/tests/macro_benchmark/groups/aggregation_parallel/setup.py +++ b/tests/macro_benchmark/groups/aggregation_parallel/setup.py @@ -15,6 +15,5 @@ VERTEX_COUNT = 1000000 for i in range(VERTEX_COUNT): print("CREATE (n%d {x: %d})" % (i, i)) # batch CREATEs because we can't execute all at once - if (i != 0 and i % BATCH_SIZE == 0) or \ - (i + 1 == VERTEX_COUNT): + if (i != 0 and i % BATCH_SIZE == 0) or (i + 1 == VERTEX_COUNT): print(";") diff --git a/tests/macro_benchmark/groups/bfs_parallel/bfs.run.py b/tests/macro_benchmark/groups/bfs_parallel/bfs.run.py index a0c516a05..46a5f4b45 100644 --- a/tests/macro_benchmark/groups/bfs_parallel/bfs.run.py +++ b/tests/macro_benchmark/groups/bfs_parallel/bfs.run.py @@ -19,8 +19,9 @@ random.seed(1) for i in range(common.BFS_ITERS): a = int(random.random() * common.VERTEX_COUNT) b = int(random.random() * common.VERTEX_COUNT) - print("MATCH (from: Node {id: %d}) WITH from " - "MATCH (to: Node {id: %d}) WITH to " - "MATCH path = (from)-[*bfs..%d (e, n | true)]->(to) WITH path " - "LIMIT 10 RETURN 0;" - % (a, b, common.PATH_LENGTH)) + print( + "MATCH (from: Node {id: %d}) WITH from " + "MATCH (to: Node {id: %d}) WITH to " + "MATCH path = (from)-[*bfs..%d (e, n | true)]->(to) WITH path " + "LIMIT 10 RETURN 0;" % (a, b, common.PATH_LENGTH) + ) diff --git a/tests/macro_benchmark/groups/bfs_parallel/common.py b/tests/macro_benchmark/groups/bfs_parallel/common.py index 6b019710e..a40504b9a 100644 --- a/tests/macro_benchmark/groups/bfs_parallel/common.py +++ b/tests/macro_benchmark/groups/bfs_parallel/common.py @@ -13,4 +13,3 @@ VERTEX_COUNT = 1000 SPARSE_FACTOR = 10 BFS_ITERS = 50 PATH_LENGTH = 5000 - diff --git a/tests/macro_benchmark/groups/bfs_parallel/setup.py b/tests/macro_benchmark/groups/bfs_parallel/setup.py index 143f44780..52f78fd9e 100644 --- a/tests/macro_benchmark/groups/bfs_parallel/setup.py +++ b/tests/macro_benchmark/groups/bfs_parallel/setup.py @@ -32,4 +32,3 @@ for i in range(common.VERTEX_COUNT * common.VERTEX_COUNT // common.SPARSE_FACTOR a = int(random.random() * common.VERTEX_COUNT) b = int(random.random() * common.VERTEX_COUNT) print("MATCH (a: Node {id: %d}), (b: Node {id: %d}) CREATE (a)-[:Friend]->(b);" % (a, b)) - diff --git a/tests/macro_benchmark/groups/card_fraud/setup.py b/tests/macro_benchmark/groups/card_fraud/setup.py index 428388154..f4cbc1c24 100644 --- a/tests/macro_benchmark/groups/card_fraud/setup.py +++ b/tests/macro_benchmark/groups/card_fraud/setup.py @@ -11,13 +11,10 @@ import random + def init_data(card_count, pos_count): - print("UNWIND range(0, {} - 1) AS id " - "CREATE (:Card {{id: id, compromised: false}});".format( - card_count)) - print("UNWIND range(0, {} - 1) AS id " - "CREATE (:Pos {{id: id, compromised: false}});".format( - pos_count)) + print("UNWIND range(0, {} - 1) AS id " "CREATE (:Card {{id: id, compromised: false}});".format(card_count)) + print("UNWIND range(0, {} - 1) AS id " "CREATE (:Pos {{id: id, compromised: false}});".format(pos_count)) def compromise_pos_device(pos_id): @@ -34,20 +31,24 @@ def pump_transactions(card_count, pos_count, tx_count, report_pct): # Card of the transaction gets compromised too. If the card # is compromised, there is a 0.1 chance the transaction is # fraudulent and detected (regardless of POS). - q = ("MATCH (c:Card {{id: {}}}), (p:Pos {{id: {}}}) " - "CREATE (t:Transaction " - "{{id: {}, fraud_reported: c.compromised AND (rand() < %f)}}) " - "CREATE (c)<-[:Using]-(t)-[:At]->(p) " - "SET c.compromised = p.compromised;" % report_pct) + q = ( + "MATCH (c:Card {{id: {}}}), (p:Pos {{id: {}}}) " + "CREATE (t:Transaction " + "{{id: {}, fraud_reported: c.compromised AND (rand() < %f)}}) " + "CREATE (c)<-[:Using]-(t)-[:At]->(p) " + "SET c.compromised = p.compromised;" % report_pct + ) + + def rint(max): + return random.randint(0, max - 1) # NOQA - def rint(max): return random.randint(0, max - 1) # NOQA for i in range(tx_count): print(q.format(rint(card_count), rint(pos_count), i)) POS_COUNT = 1000 CARD_COUNT = 10000 -FRAUD_POS_COUNT = 20 +FRAUD_POS_COUNT = 20 TX_COUNT = 50000 REPORT_PCT = 0.1 diff --git a/tests/macro_benchmark/groups/delete/common.py b/tests/macro_benchmark/groups/delete/common.py index b52923a24..07ba80c3a 100644 --- a/tests/macro_benchmark/groups/delete/common.py +++ b/tests/macro_benchmark/groups/delete/common.py @@ -21,22 +21,21 @@ seed(0) def create_vertices(vertex_count): for vertex in range(vertex_count): print("CREATE (:Label {id: %d})" % vertex) - if (vertex != 0 and vertex % BATCH_SIZE == 0) or \ - (vertex + 1 == vertex_count): + if (vertex != 0 and vertex % BATCH_SIZE == 0) or (vertex + 1 == vertex_count): print(";") def create_edges(edge_count, vertex_count): - """ vertex_count is the number of already existing vertices in graph """ + """vertex_count is the number of already existing vertices in graph""" matches = [] merges = [] for edge in range(edge_count): - matches.append("MATCH (a%d :Label {id: %d}), (b%d :Label {id: %d})" % - (edge, randint(0, vertex_count - 1), - edge, randint(0, vertex_count - 1))) + matches.append( + "MATCH (a%d :Label {id: %d}), (b%d :Label {id: %d})" + % (edge, randint(0, vertex_count - 1), edge, randint(0, vertex_count - 1)) + ) merges.append("CREATE (a%d)-[:Type]->(b%d)" % (edge, edge)) - if (edge != 0 and edge % BATCH_SIZE == 0) or \ - ((edge + 1) == edge_count): + if (edge != 0 and edge % BATCH_SIZE == 0) or ((edge + 1) == edge_count): print(" ".join(matches + merges)) print(";") matches = [] diff --git a/tests/macro_benchmark/groups/expression/common.py b/tests/macro_benchmark/groups/expression/common.py index cb8dff26f..ca14e892d 100644 --- a/tests/macro_benchmark/groups/expression/common.py +++ b/tests/macro_benchmark/groups/expression/common.py @@ -9,8 +9,10 @@ # by the Apache License, Version 2.0, included in the file # licenses/APL.txt. + def generate(expressions, repetitions): idx = 0 + def get_alias(): nonlocal idx idx += 1 diff --git a/tests/macro_benchmark/groups/expression/expression.run.py b/tests/macro_benchmark/groups/expression/expression.run.py index abb127623..0970897f1 100644 --- a/tests/macro_benchmark/groups/expression/expression.run.py +++ b/tests/macro_benchmark/groups/expression/expression.run.py @@ -11,11 +11,31 @@ import common -expressions = ['1 + 3', '2 - 1', '2 * 5', '5 / 2', '5 % 5', '-5' + '1.4 + 3.3', - '6.2 - 5.4', '6.5 * 1.2', '6.6 / 1.2', '8.7 % 3.2', '-6.6', - '"Flo" + "Lasta"', 'true AND false', 'true OR false', - 'true XOR false', 'NOT true', '1 < 2', '2 = 3', '6.66 < 10.2', - '3.14 = 3.2', '"Ana" < "Ivana"', '"Ana" = "Mmmmm"', - 'Null < Null', 'Null = Null'] +expressions = [ + "1 + 3", + "2 - 1", + "2 * 5", + "5 / 2", + "5 % 5", + "-5" + "1.4 + 3.3", + "6.2 - 5.4", + "6.5 * 1.2", + "6.6 / 1.2", + "8.7 % 3.2", + "-6.6", + '"Flo" + "Lasta"', + "true AND false", + "true OR false", + "true XOR false", + "NOT true", + "1 < 2", + "2 = 3", + "6.66 < 10.2", + "3.14 = 3.2", + '"Ana" < "Ivana"', + '"Ana" = "Mmmmm"', + "Null < Null", + "Null = Null", +] print(common.generate(expressions, 30)) diff --git a/tests/macro_benchmark/groups/match/setup.py b/tests/macro_benchmark/groups/match/setup.py index 1a40a031c..52fa123a2 100644 --- a/tests/macro_benchmark/groups/match/setup.py +++ b/tests/macro_benchmark/groups/match/setup.py @@ -18,9 +18,11 @@ from random import randint, seed seed(0) + def rint(upper_bound_exclusive): return randint(0, upper_bound_exclusive - 1) + VERTEX_COUNT = 1500 EDGE_COUNT = VERTEX_COUNT * 15 @@ -28,7 +30,7 @@ EDGE_COUNT = VERTEX_COUNT * 15 LABEL_COUNT = 10 MAX_LABELS = 5 # maximum number of labels in a vertex -MAX_PROPS = 4 # maximum number of properties in a vertex/edge +MAX_PROPS = 4 # maximum number of properties in a vertex/edge MAX_PROP_VALUE = 1000 # some consts used in mutiple files @@ -38,7 +40,6 @@ PROP_PREFIX = "Prop" ID = "id" - def labels(): labels = ":" + LABEL_INDEX for _ in range(rint(MAX_LABELS)): @@ -47,12 +48,11 @@ def labels(): def properties(id): - """ Generates a properties string with [0, MAX_PROPS) properties. + """Generates a properties string with [0, MAX_PROPS) properties. Note that if PropX is generated, then all the PropY where Y < X are generated. Thus most labels have Prop0, and least have PropMAX_PROPS. """ - props = {"%s%d" % (PROP_PREFIX, i): rint(MAX_PROP_VALUE) - for i in range(rint(MAX_PROPS))} + props = {"%s%d" % (PROP_PREFIX, i): rint(MAX_PROP_VALUE) for i in range(rint(MAX_PROPS))} props[ID] = id return "{" + ", ".join("%s: %s" % kv for kv in props.items()) + "}" @@ -74,21 +74,20 @@ def main(): # create vertices for vertex_index in range(VERTEX_COUNT): print("CREATE %s" % vertex(vertex_index)) - if (vertex_index != 0 and vertex_index % BATCH_SIZE == 0) or \ - vertex_index + 1 == VERTEX_COUNT: + if (vertex_index != 0 and vertex_index % BATCH_SIZE == 0) or vertex_index + 1 == VERTEX_COUNT: print(";") print("MATCH (n) RETURN assert(count(n) = %d);" % VERTEX_COUNT) # create edges stohastically - attempts = VERTEX_COUNT ** 2 - p = EDGE_COUNT / VERTEX_COUNT ** 2 - print("MATCH (a) WITH a MATCH (b) WITH a, b WHERE rand() < %f " - " CREATE (a)-[:EdgeType]->(b);" % p) + attempts = VERTEX_COUNT**2 + p = EDGE_COUNT / VERTEX_COUNT**2 + print("MATCH (a) WITH a MATCH (b) WITH a, b WHERE rand() < %f " " CREATE (a)-[:EdgeType]->(b);" % p) sigma = (attempts * p * (1 - p)) ** 0.5 delta = 5 * sigma - print("MATCH (n)-[r]->() WITH count(r) AS c " - "RETURN assert(c >= %d AND c <= %d);" % ( - EDGE_COUNT - delta, EDGE_COUNT + delta)) + print( + "MATCH (n)-[r]->() WITH count(r) AS c " + "RETURN assert(c >= %d AND c <= %d);" % (EDGE_COUNT - delta, EDGE_COUNT + delta) + ) if __name__ == "__main__": diff --git a/tests/macro_benchmark/groups/match/vertex_on_index.run.py b/tests/macro_benchmark/groups/match/vertex_on_index.run.py index 435b01bd2..23a97c4d2 100644 --- a/tests/macro_benchmark/groups/match/vertex_on_index.run.py +++ b/tests/macro_benchmark/groups/match/vertex_on_index.run.py @@ -11,6 +11,6 @@ from setup import LABEL_INDEX, ID, VERTEX_COUNT, rint -print("UNWIND range(0, 10000) AS i " - "MATCH (n:%s {%s: %d}) RETURN n SKIP 1000000" % ( - LABEL_INDEX, ID, rint(VERTEX_COUNT))) +print( + "UNWIND range(0, 10000) AS i " "MATCH (n:%s {%s: %d}) RETURN n SKIP 1000000" % (LABEL_INDEX, ID, rint(VERTEX_COUNT)) +) diff --git a/tests/macro_benchmark/groups/match/vertex_on_label.run.py b/tests/macro_benchmark/groups/match/vertex_on_label.run.py index 56c04ab3b..95ee1c994 100644 --- a/tests/macro_benchmark/groups/match/vertex_on_label.run.py +++ b/tests/macro_benchmark/groups/match/vertex_on_label.run.py @@ -12,5 +12,4 @@ from setup import LABEL_COUNT, LABEL_PREFIX for i in range(LABEL_COUNT): - print("UNWIND range(0, 30) AS i MATCH (n:%s%d) " - "RETURN n SKIP 1000000;" % (LABEL_PREFIX, i)) + print("UNWIND range(0, 30) AS i MATCH (n:%s%d) " "RETURN n SKIP 1000000;" % (LABEL_PREFIX, i)) diff --git a/tests/macro_benchmark/groups/match/vertex_on_label_property.run.py b/tests/macro_benchmark/groups/match/vertex_on_label_property.run.py index 6c63543b3..8e744f8ba 100644 --- a/tests/macro_benchmark/groups/match/vertex_on_label_property.run.py +++ b/tests/macro_benchmark/groups/match/vertex_on_label_property.run.py @@ -9,8 +9,17 @@ # by the Apache License, Version 2.0, included in the file # licenses/APL.txt. -from setup import LABEL_PREFIX, PROP_PREFIX, MAX_PROPS, MAX_PROP_VALUE, LABEL_COUNT, rint +from setup import ( + LABEL_PREFIX, + PROP_PREFIX, + MAX_PROPS, + MAX_PROP_VALUE, + LABEL_COUNT, + rint, +) for i in range(LABEL_COUNT): - print("UNWIND range(0, 50) AS i MATCH (n:%s%d {%s%d: %d}) RETURN n SKIP 10000;" % ( - LABEL_PREFIX, i, PROP_PREFIX, rint(MAX_PROPS), rint(MAX_PROP_VALUE))) + print( + "UNWIND range(0, 50) AS i MATCH (n:%s%d {%s%d: %d}) RETURN n SKIP 10000;" + % (LABEL_PREFIX, i, PROP_PREFIX, rint(MAX_PROPS), rint(MAX_PROP_VALUE)) + ) diff --git a/tests/macro_benchmark/groups/match/vertex_on_property.run.py b/tests/macro_benchmark/groups/match/vertex_on_property.run.py index 4188b0e0f..65781811f 100644 --- a/tests/macro_benchmark/groups/match/vertex_on_property.run.py +++ b/tests/macro_benchmark/groups/match/vertex_on_property.run.py @@ -11,5 +11,7 @@ from setup import PROP_PREFIX, MAX_PROPS, rint, MAX_PROP_VALUE -print("UNWIND range(0, 50) AS i MATCH (n {%s%d: %d}) RETURN n SKIP 10000" % ( - PROP_PREFIX, rint(MAX_PROPS), rint(MAX_PROP_VALUE))) +print( + "UNWIND range(0, 50) AS i MATCH (n {%s%d: %d}) RETURN n SKIP 10000" + % (PROP_PREFIX, rint(MAX_PROPS), rint(MAX_PROP_VALUE)) +) diff --git a/tests/macro_benchmark/groups/return/setup.py b/tests/macro_benchmark/groups/return/setup.py index 1578ef802..f50d28e37 100644 --- a/tests/macro_benchmark/groups/return/setup.py +++ b/tests/macro_benchmark/groups/return/setup.py @@ -21,5 +21,6 @@ def main(): if i != 0 and i % BATCH_SIZE == 0: print(";") -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/tests/macro_benchmark/jail_faker.py b/tests/macro_benchmark/jail_faker.py index 9eb8ed3f1..050fb7dbe 100644 --- a/tests/macro_benchmark/jail_faker.py +++ b/tests/macro_benchmark/jail_faker.py @@ -28,6 +28,7 @@ from signal import * class ProcessException(Exception): pass + class StorageException(Exception): pass @@ -41,10 +42,11 @@ class Process: self._usage = {} self._files = [] - def run(self, binary, args=None, env=None, timeout=120, - stdin="/dev/null", cwd="."): - if args is None: args = [] - if env is None: env = {} + def run(self, binary, args=None, env=None, timeout=120, stdin="/dev/null", cwd="."): + if args is None: + args = [] + if env is None: + env = {} # don't start a new process if one is already running if self._proc != None and self._proc.returncode == None: raise ProcessException @@ -65,15 +67,14 @@ class Process: self._timeout = timeout # start process - self._proc = subprocess.Popen(exe, env=env, cwd=cwd, - stdin=open(stdin, "r")) + self._proc = subprocess.Popen(exe, env=env, cwd=cwd, stdin=open(stdin, "r")) def run_and_wait(self, *args, **kwargs): check = kwargs.pop("check", True) self.run(*args, **kwargs) return self.wait(check) - def wait(self, check = True): + def wait(self, check=True): if self._proc == None: raise ProcessException self._proc.wait() @@ -100,18 +101,17 @@ class Process: # this is implemented only in the real API def set_cpus(self, cpus, hyper=True): s = "out" if not hyper else "" - sys.stderr.write("WARNING: Trying to set cpus for {} to " - "{} with{} hyperthreading!\n".format(str(self), cpus, s)) + sys.stderr.write( + "WARNING: Trying to set cpus for {} to " "{} with{} hyperthreading!\n".format(str(self), cpus, s) + ) # this is implemented only in the real API def set_nproc(self, nproc): - sys.stderr.write("WARNING: Trying to set nproc for {} to " - "{}!\n".format(str(self), nproc)) + sys.stderr.write("WARNING: Trying to set nproc for {} to " "{}!\n".format(str(self), nproc)) # this is implemented only in the real API def set_memory(self, memory): - sys.stderr.write("WARNING: Trying to set memory for {} to " - "{}\n".format(str(self), memory)) + sys.stderr.write("WARNING: Trying to set memory for {} to " "{}\n".format(str(self), memory)) # WARNING: this won't be implemented in the real API def get_pid(self): @@ -121,7 +121,8 @@ class Process: def _set_usage(self, val, name, only_value=False): self._usage[name] = val - if only_value: return + if only_value: + return maxname = "max_" + name maxval = val if maxname in self._usage: @@ -133,7 +134,8 @@ class Process: self._watchdog() def _update_usage(self): - if self._proc == None: return + if self._proc == None: + return try: f = open("/proc/{}/stat".format(self._proc.pid), "r") data_stat = f.read().split() @@ -144,21 +146,20 @@ class Process: except: return # for a description of these fields see: man proc; man times - utime, stime, cutime, cstime = map( - lambda x: int(x) / self._ticks_per_sec, data_stat[13:17]) + utime, stime, cutime, cstime = map(lambda x: int(x) / self._ticks_per_sec, data_stat[13:17]) self._set_usage(utime + stime + cutime + cstime, "cpu", only_value=True) self._set_usage(utime + cutime, "cpu_user", only_value=True) self._set_usage(stime + cstime, "cpu_sys", only_value=True) self._set_usage(int(data_stat[19]), "threads") - mem_vm, mem_res, mem_shr = map( - lambda x: int(x) * self._page_size // 1024, data_statm[:3]) + mem_vm, mem_res, mem_shr = map(lambda x: int(x) * self._page_size // 1024, data_statm[:3]) self._set_usage(mem_res, "memory") def _watchdog(self): - if self._proc == None or self._proc.returncode != None: return - if time.time() - self._start_time < self._timeout: return - sys.stderr.write("Timeout of {}s reached, sending " - "SIGKILL to {}!\n".format(self._timeout, self)) + if self._proc == None or self._proc.returncode != None: + return + if time.time() - self._start_time < self._timeout: + return + sys.stderr.write("Timeout of {}s reached, sending " "SIGKILL to {}!\n".format(self._timeout, self)) self.send_signal(SIGKILL) self.get_status() @@ -172,22 +173,27 @@ PROCESSES_NUM = 8 _processes = [Process(i) for i in range(1, PROCESSES_NUM + 1)] _last_process = 0 + def _usage_updater(): while True: for proc in _processes: proc._do_background_tasks() time.sleep(0.1) + _thread = threading.Thread(target=_usage_updater, daemon=True) _thread.start() + @atexit.register def cleanup(): for proc in _processes: - if proc._proc == None: continue + if proc._proc == None: + continue proc.send_signal(SIGKILL) proc.get_status() + # end of private methods ------------------------------------------------------ @@ -199,6 +205,7 @@ def get_process(): return proc return None + def get_host_info(): with open("/proc/meminfo") as f: memdata = f.read() @@ -215,21 +222,24 @@ def get_host_info(): threads, cpus = 0, set() for row in cpudata.split("\n\n"): - if not row: continue + if not row: + continue data = row.split("\n") core_id, physical_id = -1, -1 for line in data: name, val = map(lambda x: x.strip(), line.split(":")) - if name == "physical id": physical_id = int(val) - elif name == "core id": core_id = int(val) + if name == "physical id": + physical_id = int(val) + elif name == "core id": + core_id = int(val) threads += 1 cpus.add((core_id, physical_id)) cpus = len(cpus) hyper = True if cpus != threads else False - return {"cpus": cpus, "memory": memory, "hyperthreading": hyper, - "threads": threads} + return {"cpus": cpus, "memory": memory, "hyperthreading": hyper, "threads": threads} + # placeholder function that stores a label in the real API def store_label(label): @@ -253,6 +263,8 @@ If chain is None, this function performs the following commands: If chain is either "INPUT" or "OUTPUT" then only that chain is cleared using the appropriate subset of the above mentioned commands. """ + + def network_flush_rules(chain=None): print("Network flush rules: chain={}".format(chain)) @@ -300,12 +312,13 @@ in the following diagram: Other combinations of `chain`, `src`/`dst` and `sport`/`dport` can be used, but are advised to be used only when you exactly know what you are doing :) """ -def network_block_tcp(chain=None, - src=None, dst=None, - sport=None, dport=None, - action=None): - print("Network block TCP: chain={}, src={}, dst={}, sport={}, dport={}, " - "action={}".format(chain, src, dst, sport, dport, action)) + + +def network_block_tcp(chain=None, src=None, dst=None, sport=None, dport=None, action=None): + print( + "Network block TCP: chain={}, src={}, dst={}, sport={}, dport={}, " + "action={}".format(chain, src, dst, sport, dport, action) + ) """ @@ -319,28 +332,24 @@ same* parameters that were used to define the rule in the first place. All other documentation for this function is the same as for `network_block_tcp`, so take a look there. """ -def network_unblock_tcp(chain=None, - src=None, dst=None, - sport=None, dport=None, - action=None): - print("Network unblock TCP: chain={}, src={}, dst={}, sport={}, dport={}, " - "action={}".format(chain, src, dst, sport, dport, action)) + + +def network_unblock_tcp(chain=None, src=None, dst=None, sport=None, dport=None, action=None): + print( + "Network unblock TCP: chain={}, src={}, dst={}, sport={}, dport={}, " + "action={}".format(chain, src, dst, sport, dport, action) + ) # this function is deprecated def store_data(data): pass + # placeholder function that returns real data in the real API def get_network_usage(): usage = { - "lo": { - "bytes": {"rx": 0, "tx": 0}, - "packets": {"rx": 0, "tx": 0} - }, - "eth0": { - "bytes": {"rx": 0, "tx": 0}, - "packets": {"rx": 0, "tx": 0} - } + "lo": {"bytes": {"rx": 0, "tx": 0}, "packets": {"rx": 0, "tx": 0}}, + "eth0": {"bytes": {"rx": 0, "tx": 0}, "packets": {"rx": 0, "tx": 0}}, } return usage diff --git a/tests/macro_benchmark/long_running_suite.py b/tests/macro_benchmark/long_running_suite.py index 9de374aec..51a9c3d6d 100644 --- a/tests/macro_benchmark/long_running_suite.py +++ b/tests/macro_benchmark/long_running_suite.py @@ -38,27 +38,30 @@ class LongRunningSuite: duration = config["duration"] if self.args.duration: duration = self.args.duration - log.info("Executing run for {} seconds".format( - duration)) + log.info("Executing run for {} seconds".format(duration)) results = runner.run(next(scenario.get("run")()), duration, config["client"]) runner.stop() measurements = [] summary_format = "{:>15} {:>22} {:>22}\n" - self.summary = summary_format.format( - "elapsed_time", "num_executed_queries", "num_executed_steps") + self.summary = summary_format.format("elapsed_time", "num_executed_queries", "num_executed_steps") for result in results: self.summary += summary_format.format( - result["elapsed_time"], result["num_executed_queries"], - result["num_executed_steps"]) - measurements.append({ - "target": "throughput", - "time": result["elapsed_time"], - "value": result["num_executed_queries"], - "steps": result["num_executed_steps"], - "unit": "number of executed queries", - "type": "throughput"}) + result["elapsed_time"], + result["num_executed_queries"], + result["num_executed_steps"], + ) + measurements.append( + { + "target": "throughput", + "time": result["elapsed_time"], + "value": result["num_executed_queries"], + "steps": result["num_executed_steps"], + "unit": "number of executed queries", + "type": "throughput", + } + ) self.summary += "\n\nThroughput: " + str(measurements[-1]["value"]) self.summary += "\nExecuted steps: " + str(measurements[-1]["steps"]) return measurements @@ -75,8 +78,7 @@ class _LongRunningRunner: self.log = logging.getLogger("_LongRunningRunner") self.database = database self.query_client = QueryClient(args, num_client_workers) - self.long_running_client = LongRunningClient(args, num_client_workers, - workload) + self.long_running_client = LongRunningClient(args, num_client_workers, workload) def start(self): self.database.start() @@ -85,8 +87,7 @@ class _LongRunningRunner: return self.query_client(queries, self.database, num_client_workers) def run(self, config, duration, client, num_client_workers=None): - return self.long_running_client( - config, self.database, duration, client, num_client_workers) + return self.long_running_client(config, self.database, duration, client, num_client_workers) def stop(self): self.log.info("stop") @@ -99,44 +100,46 @@ class MemgraphRunner(_LongRunningRunner): """ Configures memgraph database for LongRunningSuite execution. """ + def __init__(self, args): argp = ArgumentParser("MemgraphRunnerArgumentParser") - argp.add_argument("--num-database-workers", type=int, default=8, - help="Number of workers") - argp.add_argument("--num-client-workers", type=int, default=24, - help="Number of clients") - argp.add_argument("--workload", type=str, default="", - help="Type of client workload. Sets \ - scenario flag for 'TestClient'") + argp.add_argument("--num-database-workers", type=int, default=8, help="Number of workers") + argp.add_argument("--num-client-workers", type=int, default=24, help="Number of clients") + argp.add_argument( + "--workload", + type=str, + default="", + help="Type of client workload. Sets \ + scenario flag for 'TestClient'", + ) self.args, remaining_args = argp.parse_known_args(args) - assert not APOLLO or self.args.num_database_workers, \ - "--num-database-workers is obligatory flag on apollo" - assert not APOLLO or self.args.num_client_workers, \ - "--num-client-workers is obligatory flag on apollo" + assert not APOLLO or self.args.num_database_workers, "--num-database-workers is obligatory flag on apollo" + assert not APOLLO or self.args.num_client_workers, "--num-client-workers is obligatory flag on apollo" database = Memgraph(remaining_args, self.args.num_database_workers) - super(MemgraphRunner, self).__init__( - remaining_args, database, self.args.num_client_workers, - self.args.workload) + super(MemgraphRunner, self).__init__(remaining_args, database, self.args.num_client_workers, self.args.workload) class NeoRunner(_LongRunningRunner): """ Configures neo4j database for QuerySuite execution. """ + def __init__(self, args): argp = ArgumentParser("NeoRunnerArgumentParser") - argp.add_argument("--runner-config", - default=get_absolute_path("config/neo4j.conf"), - help="Path to neo config file") - argp.add_argument("--num-client-workers", type=int, default=24, - help="Number of clients") - argp.add_argument("--workload", type=str, default="", - help="Type of client workload. Sets \ - scenario flag for 'TestClient'") + argp.add_argument( + "--runner-config", + default=get_absolute_path("config/neo4j.conf"), + help="Path to neo config file", + ) + argp.add_argument("--num-client-workers", type=int, default=24, help="Number of clients") + argp.add_argument( + "--workload", + type=str, + default="", + help="Type of client workload. Sets \ + scenario flag for 'TestClient'", + ) self.args, remaining_args = argp.parse_known_args(args) - assert not APOLLO or self.args.num_client_workers, \ - "--client-num-clients is obligatory flag on apollo" + assert not APOLLO or self.args.num_client_workers, "--client-num-clients is obligatory flag on apollo" database = Neo(remaining_args, self.args.runner_config) - super(NeoRunner, self).__init__( - remaining_args, database, self.args.num_client_workers, - self.args.workload) + super(NeoRunner, self).__init__(remaining_args, database, self.args.num_client_workers, self.args.workload) diff --git a/tests/macro_benchmark/query_suite.py b/tests/macro_benchmark/query_suite.py index 7d2f7d5ea..6a5af358e 100644 --- a/tests/macro_benchmark/query_suite.py +++ b/tests/macro_benchmark/query_suite.py @@ -33,21 +33,48 @@ class _QuerySuite: a single Cypher query that is benchmarked, and teardown steps (Cypher queries) executed after the benchmark. """ + # what the QuerySuite can work with - KNOWN_KEYS = {"config", "setup", "itersetup", "run", "iterteardown", - "teardown", "common"} - FORMAT = ["{:>24}", "{:>28}", "{:>16}", "{:>18}", "{:>22}", - "{:>16}", "{:>16}", "{:>16}"] + KNOWN_KEYS = { + "config", + "setup", + "itersetup", + "run", + "iterteardown", + "teardown", + "common", + } + FORMAT = [ + "{:>24}", + "{:>28}", + "{:>16}", + "{:>18}", + "{:>22}", + "{:>16}", + "{:>16}", + "{:>16}", + ] FULL_FORMAT = "".join(FORMAT) + "\n" - headers = ["group_name", "scenario_name", "parsing_time", - "planning_time", "plan_execution_time", - WALL_TIME, CPU_TIME, MAX_MEMORY] + headers = [ + "group_name", + "scenario_name", + "parsing_time", + "planning_time", + "plan_execution_time", + WALL_TIME, + CPU_TIME, + MAX_MEMORY, + ] summary = summary_raw = FULL_FORMAT.format(*headers) def __init__(self, args): argp = ArgumentParser("MemgraphRunnerArgumentParser") - argp.add_argument("--perf", default=False, action="store_true", - help="Run perf on running tests and store data") + argp.add_argument( + "--perf", + default=False, + action="store_true", + help="Run perf on running tests and store data", + ) self.args, remaining_args = argp.parse_known_args(args) def run(self, scenario, group_name, scenario_name, runner): @@ -62,8 +89,7 @@ class _QuerySuite: r_val = runner.execute(queries(), num_client_workers) else: r_val = None - log.info("\t%s done in %.2f seconds" % (config_name, - time.time() - start_time)) + log.info("\t%s done in %.2f seconds" % (config_name, time.time() - start_time)) return r_val measurements = defaultdict(list) @@ -75,8 +101,12 @@ class _QuerySuite: execute("setup") # warmup phase - for _ in range(min(scenario_config.get("iterations", 1), - scenario_config.get("warmup", 2))): + for _ in range( + min( + scenario_config.get("iterations", 1), + scenario_config.get("warmup", 2), + ) + ): execute("itersetup") execute("run") execute("iterteardown") @@ -91,15 +121,28 @@ class _QuerySuite: execute("itersetup") if self.args.perf: - file_directory = './perf_results/run_%d/%s/%s/' \ - % (rerun_cnt, group_name, scenario_name) + file_directory = "./perf_results/run_%d/%s/%s/" % ( + rerun_cnt, + group_name, + scenario_name, + ) os.makedirs(file_directory, exist_ok=True) - file_name = '%d.perf.data' % iteration + file_name = "%d.perf.data" % iteration path = file_directory + file_name database_pid = str(runner.database.database_bin._proc.pid) self.perf_proc = subprocess.Popen( - ["perf", "record", "-F", "999", "-g", "-o", path, "-p", - database_pid]) + [ + "perf", + "record", + "-F", + "999", + "-g", + "-o", + path, + "-p", + database_pid, + ] + ) run_result = execute("run") @@ -110,16 +153,15 @@ class _QuerySuite: measurements["cpu_time"].append(run_result["cpu_time"]) measurements["max_memory"].append(run_result["max_memory"]) - assert len(run_result["groups"]) == 1, \ - "Multiple groups in run step not yet supported" + assert len(run_result["groups"]) == 1, "Multiple groups in run step not yet supported" group = run_result["groups"][0] measurements["wall_time"].append(group["wall_time"]) - for key in ["parsing_time", "plan_execution_time", - "planning_time"]: + for key in ["parsing_time", "plan_execution_time", "planning_time"]: for i in range(len(group.get("metadatas", []))): - if not key in group["metadatas"][i]: continue + if not key in group["metadatas"][i]: + continue measurements[key].append(group["metadatas"][i][key]) execute("iterteardown") @@ -127,27 +169,35 @@ class _QuerySuite: execute("teardown") runner.stop() - self.append_scenario_summary(group_name, scenario_name, - measurements, num_iterations) + self.append_scenario_summary(group_name, scenario_name, measurements, num_iterations) # calculate mean, median and stdev of measurements for key in measurements: samples = measurements[key] - measurements[key] = {"mean": mean(samples), - "median": median(samples), - "stdev": stdev(samples), - "count": len(samples)} + measurements[key] = { + "mean": mean(samples), + "median": median(samples), + "stdev": stdev(samples), + "count": len(samples), + } measurements["group_name"] = group_name measurements["scenario_name"] = scenario_name return measurements - def append_scenario_summary(self, group_name, scenario_name, - measurement_lists, num_iterations): + def append_scenario_summary(self, group_name, scenario_name, measurement_lists, num_iterations): self.summary += self.FORMAT[0].format(group_name) self.summary += self.FORMAT[1].format(scenario_name) - for i, key in enumerate(("parsing_time", "planning_time", - "plan_execution_time", WALL_TIME, CPU_TIME, MAX_MEMORY)): + for i, key in enumerate( + ( + "parsing_time", + "planning_time", + "plan_execution_time", + WALL_TIME, + CPU_TIME, + MAX_MEMORY, + ) + ): if key not in measurement_lists: time = "-" else: @@ -162,11 +212,11 @@ class _QuerySuite: self.summary += "\n" def runners(self): - """ Which runners can execute a QuerySuite scenario """ + """Which runners can execute a QuerySuite scenario""" assert False, "This is a base class, use one of derived suites" def groups(self): - """ Which groups can be executed by a QuerySuite scenario """ + """Which groups can be executed by a QuerySuite scenario""" assert False, "This is a base class, use one of derived suites" @@ -175,11 +225,20 @@ class QuerySuite(_QuerySuite): _QuerySuite.__init__(self, args) def runners(self): - return {"MemgraphRunner" : MemgraphRunner, "NeoRunner" : NeoRunner} + return {"MemgraphRunner": MemgraphRunner, "NeoRunner": NeoRunner} def groups(self): - return ["1000_create", "unwind_create", "match", "dense_expand", - "expression", "aggregation", "return", "update", "delete"] + return [ + "1000_create", + "unwind_create", + "match", + "dense_expand", + "expression", + "aggregation", + "return", + "update", + "delete", + ] class QueryParallelSuite(_QuerySuite): @@ -187,8 +246,10 @@ class QueryParallelSuite(_QuerySuite): _QuerySuite.__init__(self, args) def runners(self): - return {"MemgraphRunner" : MemgraphParallelRunner, "NeoRunner" : - NeoParallelRunner} + return { + "MemgraphRunner": MemgraphParallelRunner, + "NeoRunner": NeoParallelRunner, + } def groups(self): return ["aggregation_parallel", "create_parallel", "bfs_parallel"] @@ -201,6 +262,7 @@ class _QueryRunner: Execution returns benchmarking data (execution times, memory usage etc). """ + def __init__(self, args, database, num_client_workers): self.log = logging.getLogger("_HarnessClientRunner") self.database = database @@ -221,6 +283,7 @@ class MemgraphRunner(_QueryRunner): """ Configures memgraph database for QuerySuite execution. """ + def __init__(self, args): database = Memgraph(args, 1) super(MemgraphRunner, self).__init__(args, database, 1) @@ -230,11 +293,14 @@ class NeoRunner(_QueryRunner): """ Configures neo4j database for QuerySuite execution. """ + def __init__(self, args): argp = ArgumentParser("NeoRunnerArgumentParser") - argp.add_argument("--runner-config", - default=get_absolute_path("config/neo4j.conf"), - help="Path to neo config file") + argp.add_argument( + "--runner-config", + default=get_absolute_path("config/neo4j.conf"), + help="Path to neo config file", + ) self.args, remaining_args = argp.parse_known_args(args) database = Neo(remaining_args, self.args.runner_config) super(NeoRunner, self).__init__(remaining_args, database) @@ -244,36 +310,32 @@ class NeoParallelRunner(_QueryRunner): """ Configures neo4j database for QuerySuite execution. """ + def __init__(self, args): argp = ArgumentParser("NeoRunnerArgumentParser") - argp.add_argument("--runner-config", - default=get_absolute_path("config/neo4j.conf"), - help="Path to neo config file") - argp.add_argument("--num-client-workers", type=int, default=24, - help="Number of clients") + argp.add_argument( + "--runner-config", + default=get_absolute_path("config/neo4j.conf"), + help="Path to neo config file", + ) + argp.add_argument("--num-client-workers", type=int, default=24, help="Number of clients") self.args, remaining_args = argp.parse_known_args(args) - assert not APOLLO or self.args.num_client_workers, \ - "--client-num-clients is obligatory flag on apollo" + assert not APOLLO or self.args.num_client_workers, "--client-num-clients is obligatory flag on apollo" database = Neo(remaining_args, self.args.runner_config) - super(NeoRunner, self).__init__( - remaining_args, database, self.args.num_client_workers) + super(NeoRunner, self).__init__(remaining_args, database, self.args.num_client_workers) class MemgraphParallelRunner(_QueryRunner): """ Configures memgraph database for QuerySuite execution. """ + def __init__(self, args): argp = ArgumentParser("MemgraphRunnerArgumentParser") - argp.add_argument("--num-database-workers", type=int, default=8, - help="Number of workers") - argp.add_argument("--num-client-workers", type=int, default=24, - help="Number of clients") + argp.add_argument("--num-database-workers", type=int, default=8, help="Number of workers") + argp.add_argument("--num-client-workers", type=int, default=24, help="Number of clients") self.args, remaining_args = argp.parse_known_args(args) - assert not APOLLO or self.args.num_database_workers, \ - "--num-database-workers is obligatory flag on apollo" - assert not APOLLO or self.args.num_client_workers, \ - "--num-client-workers is obligatory flag on apollo" + assert not APOLLO or self.args.num_database_workers, "--num-database-workers is obligatory flag on apollo" + assert not APOLLO or self.args.num_client_workers, "--num-client-workers is obligatory flag on apollo" database = Memgraph(remaining_args, self.args.num_database_workers) - super(MemgraphParallelRunner, self).__init__( - remaining_args, database, self.args.num_client_workers) + super(MemgraphParallelRunner, self).__init__(remaining_args, database, self.args.num_client_workers) diff --git a/tests/mgbench/benchmark.py b/tests/mgbench/benchmark.py index 5ce715571..e2e9d3f13 100755 --- a/tests/mgbench/benchmark.py +++ b/tests/mgbench/benchmark.py @@ -37,8 +37,7 @@ def get_queries(gen, count): return ret -def match_patterns(dataset, variant, group, test, is_default_variant, - patterns): +def match_patterns(dataset, variant, group, test, is_default_variant, patterns): for pattern in patterns: verdict = [fnmatch.fnmatchcase(dataset, pattern[0])] if pattern[1] != "": @@ -58,7 +57,7 @@ def filter_benchmarks(generators, patterns): pattern = patterns[i].split("/") if len(pattern) > 4 or len(pattern) == 0: raise Exception("Invalid benchmark description '" + pattern + "'!") - pattern.extend(["", "*", "*"][len(pattern) - 1:]) + pattern.extend(["", "*", "*"][len(pattern) - 1 :]) patterns[i] = pattern filtered = [] for dataset in sorted(generators.keys()): @@ -68,8 +67,7 @@ def filter_benchmarks(generators, patterns): current = collections.defaultdict(list) for group in tests: for test_name, test_func in tests[group]: - if match_patterns(dataset, variant, group, test_name, - is_default_variant, patterns): + if match_patterns(dataset, variant, group, test_name, is_default_variant, patterns): current[group].append((test_name, test_func)) if len(current) > 0: filtered.append((generator(variant), dict(current))) @@ -79,43 +77,71 @@ def filter_benchmarks(generators, patterns): # Parse options. parser = argparse.ArgumentParser( description="Memgraph benchmark executor.", - formatter_class=argparse.ArgumentDefaultsHelpFormatter) -parser.add_argument("benchmarks", nargs="*", default="", - help="descriptions of benchmarks that should be run; " - "multiple descriptions can be specified to run multiple " - "benchmarks; the description is specified as " - "dataset/variant/group/test; Unix shell-style wildcards " - "can be used in the descriptions; variant, group and test " - "are optional and they can be left out; the default " - "variant is '' which selects the default dataset variant; " - "the default group is '*' which selects all groups; the " - "default test is '*' which selects all tests") -parser.add_argument("--memgraph-binary", - default=helpers.get_binary_path("memgraph"), - help="Memgraph binary used for benchmarking") -parser.add_argument("--client-binary", - default=helpers.get_binary_path("tests/mgbench/client"), - help="client binary used for benchmarking") -parser.add_argument("--num-workers-for-import", type=int, - default=multiprocessing.cpu_count() // 2, - help="number of workers used to import the dataset") -parser.add_argument("--num-workers-for-benchmark", type=int, - default=1, - help="number of workers used to execute the benchmark") -parser.add_argument("--single-threaded-runtime-sec", type=int, - default=10, - help="single threaded duration of each test") -parser.add_argument("--no-load-query-counts", action="store_true", - help="disable loading of cached query counts") -parser.add_argument("--no-save-query-counts", action="store_true", - help="disable storing of cached query counts") -parser.add_argument("--export-results", default="", - help="file path into which results should be exported") -parser.add_argument("--temporary-directory", default="/tmp", - help="directory path where temporary data should " - "be stored") -parser.add_argument("--no-properties-on-edges", action="store_true", - help="disable properties on edges") + formatter_class=argparse.ArgumentDefaultsHelpFormatter, +) +parser.add_argument( + "benchmarks", + nargs="*", + default="", + help="descriptions of benchmarks that should be run; " + "multiple descriptions can be specified to run multiple " + "benchmarks; the description is specified as " + "dataset/variant/group/test; Unix shell-style wildcards " + "can be used in the descriptions; variant, group and test " + "are optional and they can be left out; the default " + "variant is '' which selects the default dataset variant; " + "the default group is '*' which selects all groups; the " + "default test is '*' which selects all tests", +) +parser.add_argument( + "--memgraph-binary", + default=helpers.get_binary_path("memgraph"), + help="Memgraph binary used for benchmarking", +) +parser.add_argument( + "--client-binary", + default=helpers.get_binary_path("tests/mgbench/client"), + help="client binary used for benchmarking", +) +parser.add_argument( + "--num-workers-for-import", + type=int, + default=multiprocessing.cpu_count() // 2, + help="number of workers used to import the dataset", +) +parser.add_argument( + "--num-workers-for-benchmark", + type=int, + default=1, + help="number of workers used to execute the benchmark", +) +parser.add_argument( + "--single-threaded-runtime-sec", + type=int, + default=10, + help="single threaded duration of each test", +) +parser.add_argument( + "--no-load-query-counts", + action="store_true", + help="disable loading of cached query counts", +) +parser.add_argument( + "--no-save-query-counts", + action="store_true", + help="disable storing of cached query counts", +) +parser.add_argument( + "--export-results", + default="", + help="file path into which results should be exported", +) +parser.add_argument( + "--temporary-directory", + default="/tmp", + help="directory path where temporary data should " "be stored", +) +parser.add_argument("--no-properties-on-edges", action="store_true", help="disable properties on edges") args = parser.parse_args() # Detect available datasets. @@ -124,8 +150,7 @@ for key in dir(datasets): if key.startswith("_"): continue dataset = getattr(datasets, key) - if not inspect.isclass(dataset) or dataset == datasets.Dataset or \ - not issubclass(dataset, datasets.Dataset): + if not inspect.isclass(dataset) or dataset == datasets.Dataset or not issubclass(dataset, datasets.Dataset): continue tests = collections.defaultdict(list) for funcname in dir(dataset): @@ -135,8 +160,9 @@ for key in dir(datasets): tests[group].append((test, funcname)) generators[dataset.NAME] = (dataset, dict(tests)) if dataset.PROPERTIES_ON_EDGES and args.no_properties_on_edges: - raise Exception("The \"{}\" dataset requires properties on edges, " - "but you have disabled them!".format(dataset.NAME)) + raise Exception( + 'The "{}" dataset requires properties on edges, ' "but you have disabled them!".format(dataset.NAME) + ) # List datasets if there is no specified dataset. if len(args.benchmarks) == 0: @@ -144,8 +170,11 @@ if len(args.benchmarks) == 0: for name in sorted(generators.keys()): print("Dataset:", name) dataset, tests = generators[name] - print(" Variants:", ", ".join(dataset.VARIANTS), - "(default: " + dataset.DEFAULT_VARIANT + ")") + print( + " Variants:", + ", ".join(dataset.VARIANTS), + "(default: " + dataset.DEFAULT_VARIANT + ")", + ) for group in sorted(tests.keys()): print(" Group:", group) for test_name, test_func in tests[group]: @@ -165,31 +194,38 @@ benchmarks = filter_benchmarks(generators, args.benchmarks) # Run all specified benchmarks. for dataset, tests in benchmarks: - log.init("Preparing", dataset.NAME + "/" + dataset.get_variant(), - "dataset") - dataset.prepare(cache.cache_directory("datasets", dataset.NAME, - dataset.get_variant())) + log.init("Preparing", dataset.NAME + "/" + dataset.get_variant(), "dataset") + dataset.prepare(cache.cache_directory("datasets", dataset.NAME, dataset.get_variant())) # Prepare runners and import the dataset. - memgraph = runners.Memgraph(args.memgraph_binary, args.temporary_directory, - not args.no_properties_on_edges) + memgraph = runners.Memgraph(args.memgraph_binary, args.temporary_directory, not args.no_properties_on_edges) client = runners.Client(args.client_binary, args.temporary_directory) memgraph.start_preparation() - ret = client.execute(file_path=dataset.get_file(), - num_workers=args.num_workers_for_import) + ret = client.execute(file_path=dataset.get_file(), num_workers=args.num_workers_for_import) usage = memgraph.stop() # Display import statistics. print() for row in ret: - print("Executed", row["count"], "queries in", row["duration"], - "seconds using", row["num_workers"], - "workers with a total throughput of", row["throughput"], - "queries/second.") + print( + "Executed", + row["count"], + "queries in", + row["duration"], + "seconds using", + row["num_workers"], + "workers with a total throughput of", + row["throughput"], + "queries/second.", + ) print() - print("The database used", usage["cpu"], - "seconds of CPU time and peaked at", - usage["memory"] / 1024 / 1024, "MiB of RAM.") + print( + "The database used", + usage["cpu"], + "seconds of CPU time and peaked at", + usage["memory"] / 1024 / 1024, + "MiB of RAM.", + ) # Save import results. import_key = [dataset.NAME, dataset.get_variant(), "__import__"] @@ -208,24 +244,26 @@ for dataset, tests in benchmarks: config_key = [dataset.NAME, dataset.get_variant(), group, test] cached_count = config.get_value(*config_key) if cached_count is None: - print("Determining the number of queries necessary for", - args.single_threaded_runtime_sec, - "seconds of single-threaded runtime...") + print( + "Determining the number of queries necessary for", + args.single_threaded_runtime_sec, + "seconds of single-threaded runtime...", + ) # First run to prime the query caches. memgraph.start_benchmark() client.execute(queries=get_queries(func, 1), num_workers=1) # Get a sense of the runtime. count = 1 while True: - ret = client.execute(queries=get_queries(func, count), - num_workers=1) + ret = client.execute(queries=get_queries(func, count), num_workers=1) duration = ret[0]["duration"] - should_execute = int(args.single_threaded_runtime_sec / - (duration / count)) - print("executed_queries={}, total_duration={}, " - "query_duration={}, estimated_count={}".format( - count, duration, duration / count, - should_execute)) + should_execute = int(args.single_threaded_runtime_sec / (duration / count)) + print( + "executed_queries={}, total_duration={}, " + "query_duration={}, estimated_count={}".format( + count, duration, duration / count, should_execute + ) + ) # We don't have to execute the next iteration when # `should_execute` becomes the same order of magnitude as # `count * 10`. @@ -235,45 +273,52 @@ for dataset, tests in benchmarks: else: count = count * 10 memgraph.stop() - config.set_value(*config_key, value={ - "count": count, - "duration": args.single_threaded_runtime_sec}) + config.set_value(*config_key, value={"count": count, "duration": args.single_threaded_runtime_sec}) else: - print("Using cached query count of", cached_count["count"], - "queries for", cached_count["duration"], - "seconds of single-threaded runtime.") - count = int(cached_count["count"] * - args.single_threaded_runtime_sec / - cached_count["duration"]) + print( + "Using cached query count of", + cached_count["count"], + "queries for", + cached_count["duration"], + "seconds of single-threaded runtime.", + ) + count = int(cached_count["count"] * args.single_threaded_runtime_sec / cached_count["duration"]) # Benchmark run. print("Sample query:", get_queries(func, 1)[0][0]) - print("Executing benchmark with", count, "queries that should " - "yield a single-threaded runtime of", - args.single_threaded_runtime_sec, "seconds.") - print("Queries are executed using", args.num_workers_for_benchmark, - "concurrent clients.") + print( + "Executing benchmark with", + count, + "queries that should " "yield a single-threaded runtime of", + args.single_threaded_runtime_sec, + "seconds.", + ) + print( + "Queries are executed using", + args.num_workers_for_benchmark, + "concurrent clients.", + ) memgraph.start_benchmark() - ret = client.execute(queries=get_queries(func, count), - num_workers=args.num_workers_for_benchmark)[0] + ret = client.execute( + queries=get_queries(func, count), + num_workers=args.num_workers_for_benchmark, + )[0] usage = memgraph.stop() ret["database"] = usage # Output summary. print() - print("Executed", ret["count"], "queries in", - ret["duration"], "seconds.") + print("Executed", ret["count"], "queries in", ret["duration"], "seconds.") print("Queries have been retried", ret["retries"], "times.") - print("Database used {:.3f} seconds of CPU time.".format( - usage["cpu"])) - print("Database peaked at {:.3f} MiB of memory.".format( - usage["memory"] / 1024.0 / 1024.0)) - print("{:<31} {:>20} {:>20} {:>20}".format("Metadata:", "min", - "avg", "max")) + print("Database used {:.3f} seconds of CPU time.".format(usage["cpu"])) + print("Database peaked at {:.3f} MiB of memory.".format(usage["memory"] / 1024.0 / 1024.0)) + print("{:<31} {:>20} {:>20} {:>20}".format("Metadata:", "min", "avg", "max")) metadata = ret["metadata"] for key in sorted(metadata.keys()): - print("{name:>30}: {minimum:>20.06f} {average:>20.06f} " - "{maximum:>20.06f}".format(name=key, **metadata[key])) + print( + "{name:>30}: {minimum:>20.06f} {average:>20.06f} " + "{maximum:>20.06f}".format(name=key, **metadata[key]) + ) log.success("Throughput: {:02f} QPS".format(ret["throughput"])) # Save results. diff --git a/tests/mgbench/compare_results.py b/tests/mgbench/compare_results.py index 2179bb408..27d4db70b 100755 --- a/tests/mgbench/compare_results.py +++ b/tests/mgbench/compare_results.py @@ -85,39 +85,41 @@ def compare_results(results_from, results_to, fields): if group == "__import__": continue for scenario, summary_to in scenarios.items(): - summary_from = recursive_get( - results_from, dataset, variant, group, scenario, - value={}) - if len(summary_from) > 0 and \ - summary_to["count"] != summary_from["count"] or \ - summary_to["num_workers"] != \ - summary_from["num_workers"]: + summary_from = recursive_get(results_from, dataset, variant, group, scenario, value={}) + if ( + len(summary_from) > 0 + and summary_to["count"] != summary_from["count"] + or summary_to["num_workers"] != summary_from["num_workers"] + ): raise Exception("Incompatible results!") - testcode = "/".join([dataset, variant, group, scenario, - "{:02d}".format( - summary_to["num_workers"])]) + testcode = "/".join( + [ + dataset, + variant, + group, + scenario, + "{:02d}".format(summary_to["num_workers"]), + ] + ) row = {} performance_changed = False for field in fields: key = field["name"] if key in summary_to: - row[key] = compute_diff( - summary_from.get(key, None), - summary_to[key]) + row[key] = compute_diff(summary_from.get(key, None), summary_to[key]) elif key in summary_to["database"]: row[key] = compute_diff( - recursive_get(summary_from, "database", key, - value=None), - summary_to["database"][key]) + recursive_get(summary_from, "database", key, value=None), + summary_to["database"][key], + ) else: row[key] = compute_diff( - recursive_get(summary_from, "metadata", key, - "average", value=None), - summary_to["metadata"][key]["average"]) - if "diff" not in row[key] or \ - ("diff_treshold" in field and - abs(row[key]["diff"]) >= - field["diff_treshold"]): + recursive_get(summary_from, "metadata", key, "average", value=None), + summary_to["metadata"][key]["average"], + ) + if "diff" not in row[key] or ( + "diff_treshold" in field and abs(row[key]["diff"]) >= field["diff_treshold"] + ): performance_changed = True if performance_changed: ret[testcode] = row @@ -130,8 +132,15 @@ def generate_remarkup(fields, data): ret += "\n" ret += " \n" ret += " \n" - ret += "\n".join(map(lambda x: " ".format( - x["name"].replace("_", " ").capitalize()), fields)) + "\n" + ret += ( + "\n".join( + map( + lambda x: " ".format(x["name"].replace("_", " ").capitalize()), + fields, + ) + ) + + "\n" + ) ret += " \n" for testcode in sorted(data.keys()): ret += " \n" @@ -147,12 +156,9 @@ def generate_remarkup(fields, data): else: color = "red" sign = "{{icon {} color={}}}".format(arrow, color) - ret += " \n".format( - value, field["unit"], diff, sign) + ret += " \n".format(value, field["unit"], diff, sign) else: - ret += " \n".format( - value, field["unit"]) + ret += " \n".format(value, field["unit"]) ret += " \n" ret += "
Testcode{}{}
{:.3f}{} //({:+.2%})// {}{:.3f}{} //({:+.2%})// {}{:.3f}{} //(new)// " \ - "{{icon plus color=blue}}{:.3f}{} //(new)// " "{{icon plus color=blue}}
\n" else: @@ -161,11 +167,14 @@ def generate_remarkup(fields, data): if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Compare results of multiple benchmark runs.") - parser.add_argument("--compare", action="append", nargs=2, - metavar=("from", "to"), - help="compare results between `from` and `to` files") + parser = argparse.ArgumentParser(description="Compare results of multiple benchmark runs.") + parser.add_argument( + "--compare", + action="append", + nargs=2, + metavar=("from", "to"), + help="compare results between `from` and `to` files", + ) parser.add_argument("--output", default="", help="output file name") args = parser.parse_args() diff --git a/tests/mgbench/datasets.py b/tests/mgbench/datasets.py index dbaaa2de9..9e9680b77 100644 --- a/tests/mgbench/datasets.py +++ b/tests/mgbench/datasets.py @@ -45,13 +45,10 @@ class Dataset: variant = self.DEFAULT_VARIANT if variant not in self.VARIANTS: raise ValueError("Invalid test variant!") - if (self.FILES and variant not in self.FILES) and \ - (self.URLS and variant not in self.URLS): - raise ValueError("The variant doesn't have a defined URL or " - "file path!") + if (self.FILES and variant not in self.FILES) and (self.URLS and variant not in self.URLS): + raise ValueError("The variant doesn't have a defined URL or " "file path!") if variant not in self.SIZES: - raise ValueError("The variant doesn't have a defined dataset " - "size!") + raise ValueError("The variant doesn't have a defined dataset " "size!") self._variant = variant if self.FILES is not None: self._file = self.FILES.get(variant, None) @@ -63,8 +60,7 @@ class Dataset: self._url = None self._size = self.SIZES[variant] if "vertices" not in self._size or "edges" not in self._size: - raise ValueError("The size defined for this variant doesn't " - "have the number of vertices and/or edges!") + raise ValueError("The size defined for this variant doesn't " "have the number of vertices and/or edges!") self._num_vertices = self._size["vertices"] self._num_edges = self._size["edges"] @@ -76,8 +72,7 @@ class Dataset: cached_input, exists = directory.get_file("dataset.cypher") if not exists: print("Downloading dataset file:", self._url) - downloaded_file = helpers.download_file( - self._url, directory.get_path()) + downloaded_file = helpers.download_file(self._url, directory.get_path()) print("Unpacking and caching file:", downloaded_file) helpers.unpack_and_move_file(downloaded_file, cached_input) print("Using cached dataset file:", cached_input) @@ -137,18 +132,20 @@ class Pokec(Dataset): # Arango benchmarks def benchmark__arango__single_vertex_read(self): - return ("MATCH (n:User {id : $id}) RETURN n", - {"id": self._get_random_vertex()}) + return ("MATCH (n:User {id : $id}) RETURN n", {"id": self._get_random_vertex()}) def benchmark__arango__single_vertex_write(self): - return ("CREATE (n:UserTemp {id : $id}) RETURN n", - {"id": random.randint(1, self._num_vertices * 10)}) + return ( + "CREATE (n:UserTemp {id : $id}) RETURN n", + {"id": random.randint(1, self._num_vertices * 10)}, + ) def benchmark__arango__single_edge_write(self): vertex_from, vertex_to = self._get_random_from_to() - return ("MATCH (n:User {id: $from}), (m:User {id: $to}) WITH n, m " - "CREATE (n)-[e:Temp]->(m) RETURN e", - {"from": vertex_from, "to": vertex_to}) + return ( + "MATCH (n:User {id: $from}), (m:User {id: $to}) WITH n, m " "CREATE (n)-[e:Temp]->(m) RETURN e", + {"from": vertex_from, "to": vertex_to}, + ) def benchmark__arango__aggregate(self): return ("MATCH (n:User) RETURN n.age, COUNT(*)", {}) @@ -157,92 +154,103 @@ class Pokec(Dataset): return ("MATCH (n:User) WHERE n.age >= 18 RETURN n.age, COUNT(*)", {}) def benchmark__arango__expansion_1(self): - return ("MATCH (s:User {id: $id})-->(n:User) " - "RETURN n.id", - {"id": self._get_random_vertex()}) + return ( + "MATCH (s:User {id: $id})-->(n:User) " "RETURN n.id", + {"id": self._get_random_vertex()}, + ) def benchmark__arango__expansion_1_with_filter(self): - return ("MATCH (s:User {id: $id})-->(n:User) " - "WHERE n.age >= 18 " - "RETURN n.id", - {"id": self._get_random_vertex()}) + return ( + "MATCH (s:User {id: $id})-->(n:User) " "WHERE n.age >= 18 " "RETURN n.id", + {"id": self._get_random_vertex()}, + ) def benchmark__arango__expansion_2(self): - return ("MATCH (s:User {id: $id})-->()-->(n:User) " - "RETURN DISTINCT n.id", - {"id": self._get_random_vertex()}) + return ( + "MATCH (s:User {id: $id})-->()-->(n:User) " "RETURN DISTINCT n.id", + {"id": self._get_random_vertex()}, + ) def benchmark__arango__expansion_2_with_filter(self): - return ("MATCH (s:User {id: $id})-->()-->(n:User) " - "WHERE n.age >= 18 " - "RETURN DISTINCT n.id", - {"id": self._get_random_vertex()}) + return ( + "MATCH (s:User {id: $id})-->()-->(n:User) " "WHERE n.age >= 18 " "RETURN DISTINCT n.id", + {"id": self._get_random_vertex()}, + ) def benchmark__arango__expansion_3(self): - return ("MATCH (s:User {id: $id})-->()-->()-->(n:User) " - "RETURN DISTINCT n.id", - {"id": self._get_random_vertex()}) + return ( + "MATCH (s:User {id: $id})-->()-->()-->(n:User) " "RETURN DISTINCT n.id", + {"id": self._get_random_vertex()}, + ) def benchmark__arango__expansion_3_with_filter(self): - return ("MATCH (s:User {id: $id})-->()-->()-->(n:User) " - "WHERE n.age >= 18 " - "RETURN DISTINCT n.id", - {"id": self._get_random_vertex()}) + return ( + "MATCH (s:User {id: $id})-->()-->()-->(n:User) " "WHERE n.age >= 18 " "RETURN DISTINCT n.id", + {"id": self._get_random_vertex()}, + ) def benchmark__arango__expansion_4(self): - return ("MATCH (s:User {id: $id})-->()-->()-->()-->(n:User) " - "RETURN DISTINCT n.id", - {"id": self._get_random_vertex()}) + return ( + "MATCH (s:User {id: $id})-->()-->()-->()-->(n:User) " "RETURN DISTINCT n.id", + {"id": self._get_random_vertex()}, + ) def benchmark__arango__expansion_4_with_filter(self): - return ("MATCH (s:User {id: $id})-->()-->()-->()-->(n:User) " - "WHERE n.age >= 18 " - "RETURN DISTINCT n.id", - {"id": self._get_random_vertex()}) + return ( + "MATCH (s:User {id: $id})-->()-->()-->()-->(n:User) " "WHERE n.age >= 18 " "RETURN DISTINCT n.id", + {"id": self._get_random_vertex()}, + ) def benchmark__arango__neighbours_2(self): - return ("MATCH (s:User {id: $id})-[*1..2]->(n:User) " - "RETURN DISTINCT n.id", - {"id": self._get_random_vertex()}) + return ( + "MATCH (s:User {id: $id})-[*1..2]->(n:User) " "RETURN DISTINCT n.id", + {"id": self._get_random_vertex()}, + ) def benchmark__arango__neighbours_2_with_filter(self): - return ("MATCH (s:User {id: $id})-[*1..2]->(n:User) " - "WHERE n.age >= 18 " - "RETURN DISTINCT n.id", - {"id": self._get_random_vertex()}) + return ( + "MATCH (s:User {id: $id})-[*1..2]->(n:User) " "WHERE n.age >= 18 " "RETURN DISTINCT n.id", + {"id": self._get_random_vertex()}, + ) def benchmark__arango__neighbours_2_with_data(self): - return ("MATCH (s:User {id: $id})-[*1..2]->(n:User) " - "RETURN DISTINCT n.id, n", - {"id": self._get_random_vertex()}) + return ( + "MATCH (s:User {id: $id})-[*1..2]->(n:User) " "RETURN DISTINCT n.id, n", + {"id": self._get_random_vertex()}, + ) def benchmark__arango__neighbours_2_with_data_and_filter(self): - return ("MATCH (s:User {id: $id})-[*1..2]->(n:User) " - "WHERE n.age >= 18 " - "RETURN DISTINCT n.id, n", - {"id": self._get_random_vertex()}) + return ( + "MATCH (s:User {id: $id})-[*1..2]->(n:User) " "WHERE n.age >= 18 " "RETURN DISTINCT n.id, n", + {"id": self._get_random_vertex()}, + ) def benchmark__arango__shortest_path(self): vertex_from, vertex_to = self._get_random_from_to() - return ("MATCH (n:User {id: $from}), (m:User {id: $to}) WITH n, m " - "MATCH p=(n)-[*bfs..15]->(m) " - "RETURN extract(n in nodes(p) | n.id) AS path", - {"from": vertex_from, "to": vertex_to}) + return ( + "MATCH (n:User {id: $from}), (m:User {id: $to}) WITH n, m " + "MATCH p=(n)-[*bfs..15]->(m) " + "RETURN extract(n in nodes(p) | n.id) AS path", + {"from": vertex_from, "to": vertex_to}, + ) def benchmark__arango__shortest_path_with_filter(self): vertex_from, vertex_to = self._get_random_from_to() - return ("MATCH (n:User {id: $from}), (m:User {id: $to}) WITH n, m " - "MATCH p=(n)-[*bfs..15 (e, n | n.age >= 18)]->(m) " - "RETURN extract(n in nodes(p) | n.id) AS path", - {"from": vertex_from, "to": vertex_to}) + return ( + "MATCH (n:User {id: $from}), (m:User {id: $to}) WITH n, m " + "MATCH p=(n)-[*bfs..15 (e, n | n.age >= 18)]->(m) " + "RETURN extract(n in nodes(p) | n.id) AS path", + {"from": vertex_from, "to": vertex_to}, + ) # Our benchmark queries def benchmark__create__edge(self): vertex_from, vertex_to = self._get_random_from_to() - return ("MATCH (a:User {id: $from}), (b:User {id: $to}) " - "CREATE (a)-[:TempEdge]->(b)", - {"from": vertex_from, "to": vertex_to}) + return ( + "MATCH (a:User {id: $from}), (b:User {id: $to}) " "CREATE (a)-[:TempEdge]->(b)", + {"from": vertex_from, "to": vertex_to}, + ) def benchmark__create__pattern(self): return ("CREATE ()-[:TempEdge]->()", {}) @@ -251,9 +259,12 @@ class Pokec(Dataset): return ("CREATE ()", {}) def benchmark__create__vertex_big(self): - return ("CREATE (:L1:L2:L3:L4:L5:L6:L7 {p1: true, p2: 42, " - "p3: \"Here is some text that is not extremely short\", " - "p4:\"Short text\", p5: 234.434, p6: 11.11, p7: false})", {}) + return ( + "CREATE (:L1:L2:L3:L4:L5:L6:L7 {p1: true, p2: 42, " + 'p3: "Here is some text that is not extremely short", ' + 'p4:"Short text", p5: 234.434, p6: 11.11, p7: false})', + {}, + ) def benchmark__aggregation__count(self): return ("MATCH (n) RETURN count(n), count(n.age)", {}) @@ -262,29 +273,31 @@ class Pokec(Dataset): return ("MATCH (n) RETURN min(n.age), max(n.age), avg(n.age)", {}) def benchmark__match__pattern_cycle(self): - return ("MATCH (n:User {id: $id})-[e1]->(m)-[e2]->(n) " - "RETURN e1, m, e2", - {"id": self._get_random_vertex()}) + return ( + "MATCH (n:User {id: $id})-[e1]->(m)-[e2]->(n) " "RETURN e1, m, e2", + {"id": self._get_random_vertex()}, + ) def benchmark__match__pattern_long(self): - return ("MATCH (n1:User {id: $id})-[e1]->(n2)-[e2]->" - "(n3)-[e3]->(n4)<-[e4]-(n5) " - "RETURN n5 LIMIT 1", - {"id": self._get_random_vertex()}) + return ( + "MATCH (n1:User {id: $id})-[e1]->(n2)-[e2]->" "(n3)-[e3]->(n4)<-[e4]-(n5) " "RETURN n5 LIMIT 1", + {"id": self._get_random_vertex()}, + ) def benchmark__match__pattern_short(self): - return ("MATCH (n:User {id: $id})-[e]->(m) " - "RETURN m LIMIT 1", - {"id": self._get_random_vertex()}) + return ( + "MATCH (n:User {id: $id})-[e]->(m) " "RETURN m LIMIT 1", + {"id": self._get_random_vertex()}, + ) def benchmark__match__vertex_on_label_property(self): - return ("MATCH (n:User) WITH n WHERE n.id = $id RETURN n", - {"id": self._get_random_vertex()}) + return ( + "MATCH (n:User) WITH n WHERE n.id = $id RETURN n", + {"id": self._get_random_vertex()}, + ) def benchmark__match__vertex_on_label_property_index(self): - return ("MATCH (n:User {id: $id}) RETURN n", - {"id": self._get_random_vertex()}) + return ("MATCH (n:User {id: $id}) RETURN n", {"id": self._get_random_vertex()}) def benchmark__match__vertex_on_property(self): - return ("MATCH (n {id: $id}) RETURN n", - {"id": self._get_random_vertex()}) + return ("MATCH (n {id: $id}) RETURN n", {"id": self._get_random_vertex()}) diff --git a/tests/mgbench/helpers.py b/tests/mgbench/helpers.py index 7488b1443..efd1dc9a8 100644 --- a/tests/mgbench/helpers.py +++ b/tests/mgbench/helpers.py @@ -28,18 +28,21 @@ def get_binary_path(path, base=""): def download_file(url, path): - ret = subprocess.run(["wget", "-nv", "--content-disposition", url], - stderr=subprocess.PIPE, cwd=path, check=True) + ret = subprocess.run( + ["wget", "-nv", "--content-disposition", url], + stderr=subprocess.PIPE, + cwd=path, + check=True, + ) data = ret.stderr.decode("utf-8") tmp = data.split("->")[1] - name = tmp[tmp.index('"') + 1:tmp.rindex('"')] + name = tmp[tmp.index('"') + 1 : tmp.rindex('"')] return os.path.join(path, name) def unpack_and_move_file(input_path, output_path): if input_path.endswith(".gz"): - subprocess.run(["gunzip", input_path], - stdout=subprocess.DEVNULL, check=True) + subprocess.run(["gunzip", input_path], stdout=subprocess.DEVNULL, check=True) input_path = input_path[:-3] os.rename(input_path, output_path) diff --git a/tests/mgbench/runners.py b/tests/mgbench/runners.py index 891a7cddd..64857038e 100644 --- a/tests/mgbench/runners.py +++ b/tests/mgbench/runners.py @@ -40,8 +40,7 @@ def _convert_args_to_flags(*args, **kwargs): def _get_usage(pid): total_cpu = 0 with open("/proc/{}/stat".format(pid)) as f: - total_cpu = (sum(map(int, f.read().split(")")[1].split()[11:15])) / - os.sysconf(os.sysconf_names["SC_CLK_TCK"])) + total_cpu = sum(map(int, f.read().split(")")[1].split()[11:15])) / os.sysconf(os.sysconf_names["SC_CLK_TCK"]) peak_rss = 0 with open("/proc/{}/status".format(pid)) as f: for row in f: @@ -60,10 +59,8 @@ class Memgraph: atexit.register(self._cleanup) # Determine Memgraph version - ret = subprocess.run([memgraph_binary, "--version"], - stdout=subprocess.PIPE, check=True) - version = re.search(r"[0-9]+\.[0-9]+\.[0-9]+", - ret.stdout.decode("utf-8")).group(0) + ret = subprocess.run([memgraph_binary, "--version"], stdout=subprocess.PIPE, check=True) + version = re.search(r"[0-9]+\.[0-9]+\.[0-9]+", ret.stdout.decode("utf-8")).group(0) self._memgraph_version = tuple(map(int, version.split("."))) def __del__(self): @@ -79,8 +76,7 @@ class Memgraph: if self._memgraph_version >= (0, 50, 0): kwargs["storage_properties_on_edges"] = self._properties_on_edges else: - assert self._properties_on_edges, \ - "Older versions of Memgraph can't disable properties on edges!" + assert self._properties_on_edges, "Older versions of Memgraph can't disable properties on edges!" return _convert_args_to_flags(self._memgraph_binary, **kwargs) def _start(self, **kwargs): @@ -94,8 +90,7 @@ class Memgraph: raise Exception("The database process died prematurely!") wait_for_server(7687) ret = self._proc_mg.poll() - assert ret is None, "The database process died prematurely " \ - "({})!".format(ret) + assert ret is None, "The database process died prematurely " "({})!".format(ret) def _cleanup(self): if self._proc_mg is None: @@ -121,8 +116,7 @@ class Memgraph: def stop(self): ret, usage = self._cleanup() - assert ret == 0, "The database process exited with a non-zero " \ - "status ({})!".format(ret) + assert ret == 0, "The database process exited with a non-zero " "status ({})!".format(ret) return usage @@ -135,8 +129,7 @@ class Client: return _convert_args_to_flags(self._client_binary, **kwargs) def execute(self, queries=None, file_path=None, num_workers=1): - if (queries is None and file_path is None) or \ - (queries is not None and file_path is not None): + if (queries is None and file_path is None) or (queries is not None and file_path is not None): raise ValueError("Either queries or input_path must be specified!") # TODO: check `file_path.endswith(".json")` to support advanced @@ -151,8 +144,7 @@ class Client: json.dump(query, f) f.write("\n") - args = self._get_args(input=file_path, num_workers=num_workers, - queries_json=queries_json) + args = self._get_args(input=file_path, num_workers=num_workers, queries_json=queries_json) ret = subprocess.run(args, stdout=subprocess.PIPE, check=True) data = ret.stdout.decode("utf-8").strip().split("\n") return list(map(json.loads, data)) diff --git a/tests/stress/bipartite.py b/tests/stress/bipartite.py index 3932906c4..cb7244aa9 100644 --- a/tests/stress/bipartite.py +++ b/tests/stress/bipartite.py @@ -12,44 +12,60 @@ # by the Apache License, Version 2.0, included in the file # licenses/APL.txt. -''' +""" Large bipartite graph stress test. -''' +""" import logging import multiprocessing import time import atexit -from common import connection_argument_parser, assert_equal, \ - OutputData, execute_till_success, \ - batch, render, SessionCache +from common import ( + connection_argument_parser, + assert_equal, + OutputData, + execute_till_success, + batch, + render, + SessionCache, +) def parse_args(): - ''' + """ Parses user arguments :return: parsed arguments - ''' + """ parser = connection_argument_parser() - parser.add_argument('--worker-count', type=int, - default=multiprocessing.cpu_count(), - help='Number of concurrent workers.') - parser.add_argument("--logging", default="INFO", - choices=["INFO", "DEBUG", "WARNING", "ERROR"], - help="Logging level") - parser.add_argument('--u-count', type=int, default=100, - help='Size of U set in the bipartite graph.') - parser.add_argument('--v-count', type=int, default=100, - help='Size of V set in the bipartite graph.') - parser.add_argument('--vertex-batch-size', type=int, default=100, - help="Create vertices in batches of this size.") - parser.add_argument('--edge-batching', action='store_true', - help='Create edges in batches.') - parser.add_argument('--edge-batch-size', type=int, default=100, - help='Number of edges in a batch when edges ' - 'are created in batches.') + parser.add_argument( + "--worker-count", + type=int, + default=multiprocessing.cpu_count(), + help="Number of concurrent workers.", + ) + parser.add_argument( + "--logging", + default="INFO", + choices=["INFO", "DEBUG", "WARNING", "ERROR"], + help="Logging level", + ) + parser.add_argument("--u-count", type=int, default=100, help="Size of U set in the bipartite graph.") + parser.add_argument("--v-count", type=int, default=100, help="Size of V set in the bipartite graph.") + parser.add_argument( + "--vertex-batch-size", + type=int, + default=100, + help="Create vertices in batches of this size.", + ) + parser.add_argument("--edge-batching", action="store_true", help="Create edges in batches.") + parser.add_argument( + "--edge-batch-size", + type=int, + default=100, + help="Number of edges in a batch when edges " "are created in batches.", + ) return parser.parse_args() @@ -62,18 +78,18 @@ atexit.register(SessionCache.cleanup) def create_u_v_edges(u): - ''' + """ Creates nodes and checks that all nodes were created. create edges from one vertex in U set to all vertex of V set :param worker_id: worker id :return: tuple (worker_id, create execution time, time unit) - ''' + """ start_time = time.time() session = SessionCache.argument_session(args) no_failures = 0 - match_u = 'MATCH (u:U {id: %d})' % u + match_u = "MATCH (u:U {id: %d})" % u if args.edge_batching: # TODO: try to randomize execution, the execution time should # be smaller, add randomize flag @@ -83,143 +99,126 @@ def create_u_v_edges(u): query = match_u + "".join(match_v) + "".join(create_u) no_failures += execute_till_success(session, query)[1] else: - no_failures += execute_till_success( - session, match_u + ' MATCH (v:V) CREATE (u)-[:R]->(v)')[1] + no_failures += execute_till_success(session, match_u + " MATCH (v:V) CREATE (u)-[:R]->(v)")[1] end_time = time.time() return u, end_time - start_time, "s", no_failures def traverse_from_u_worker(u): - ''' + """ Traverses edges starting from an element of U set. Traversed labels are: :U -> :V -> :U. - ''' + """ session = SessionCache.argument_session(args) start_time = time.time() assert_equal( args.u_count * args.v_count - args.v_count, # cypher morphism - session.run("MATCH (u1:U {id: %s})-[e1]->(v:V)<-[e2]-(u2:U) " - "RETURN count(v) AS cnt" % u).data()[0]['cnt'], - "Number of traversed edges started " - "from U(id:%s) is wrong!. " % u + - "Expected: %s Actual: %s") + session.run("MATCH (u1:U {id: %s})-[e1]->(v:V)<-[e2]-(u2:U) " "RETURN count(v) AS cnt" % u).data()[0]["cnt"], + "Number of traversed edges started " "from U(id:%s) is wrong!. " % u + "Expected: %s Actual: %s", + ) end_time = time.time() - return u, end_time - start_time, 's' + return u, end_time - start_time, "s" def traverse_from_v_worker(v): - ''' + """ Traverses edges starting from an element of V set. Traversed labels are: :V -> :U -> :V. - ''' + """ session = SessionCache.argument_session(args) start_time = time.time() assert_equal( - args.u_count * args.v_count - args.u_count, # cypher morphism - session.run("MATCH (v1:V {id: %s})<-[e1]-(u:U)-[e2]->(v2:V) " - "RETURN count(u) AS cnt" % v).data()[0]['cnt'], - "Number of traversed edges started " - "from V(id:%s) is wrong!. " % v + - "Expected: %s Actual: %s") + args.u_count * args.v_count - args.u_count, # cypher morphism + session.run("MATCH (v1:V {id: %s})<-[e1]-(u:U)-[e2]->(v2:V) " "RETURN count(u) AS cnt" % v).data()[0]["cnt"], + "Number of traversed edges started " "from V(id:%s) is wrong!. " % v + "Expected: %s Actual: %s", + ) end_time = time.time() - return v, end_time - start_time, 's' + return v, end_time - start_time, "s" def execution_handler(): - ''' + """ Initializes client processes, database and starts the execution. - ''' + """ # instance cleanup session = SessionCache.argument_session(args) start_time = time.time() # clean existing database - session.run('MATCH (n) DETACH DELETE n').consume() + session.run("MATCH (n) DETACH DELETE n").consume() cleanup_end_time = time.time() - output_data.add_measurement("cleanup_time", - cleanup_end_time - start_time) + output_data.add_measurement("cleanup_time", cleanup_end_time - start_time) log.info("Database is clean.") # create indices - session.run('CREATE INDEX ON :U').consume() - session.run('CREATE INDEX ON :V').consume() + session.run("CREATE INDEX ON :U").consume() + session.run("CREATE INDEX ON :V").consume() # create U vertices - for b in batch(render('CREATE (:U {{id: {}}})', range(args.u_count)), - args.vertex_batch_size): + for b in batch(render("CREATE (:U {{id: {}}})", range(args.u_count)), args.vertex_batch_size): session.run(" ".join(b)).consume() # create V vertices - for b in batch(render('CREATE (:V {{id: {}}})', range(args.v_count)), - args.vertex_batch_size): + for b in batch(render("CREATE (:V {{id: {}}})", range(args.v_count)), args.vertex_batch_size): session.run(" ".join(b)).consume() vertices_create_end_time = time.time() - output_data.add_measurement( - 'vertices_create_time', - vertices_create_end_time - cleanup_end_time) + output_data.add_measurement("vertices_create_time", vertices_create_end_time - cleanup_end_time) log.info("All nodes created.") # concurrent create execution & tests with multiprocessing.Pool(args.worker_count) as p: create_edges_start_time = time.time() - for worker_id, create_time, time_unit, no_failures in \ - p.map(create_u_v_edges, [i for i in range(args.u_count)]): - log.info('Worker ID: %s; Create time: %s%s Failures: %s' % - (worker_id, create_time, time_unit, no_failures)) + for worker_id, create_time, time_unit, no_failures in p.map(create_u_v_edges, [i for i in range(args.u_count)]): + log.info("Worker ID: %s; Create time: %s%s Failures: %s" % (worker_id, create_time, time_unit, no_failures)) create_edges_end_time = time.time() - output_data.add_measurement( - 'edges_create_time', - create_edges_end_time - create_edges_start_time) + output_data.add_measurement("edges_create_time", create_edges_end_time - create_edges_start_time) # check total number of edges assert_equal( args.v_count * args.u_count, - session.run( - 'MATCH ()-[r]->() ' - 'RETURN count(r) AS cnt').data()[0]['cnt'], - "Total number of edges isn't correct! Expected: %s Actual: %s") + session.run("MATCH ()-[r]->() " "RETURN count(r) AS cnt").data()[0]["cnt"], + "Total number of edges isn't correct! Expected: %s Actual: %s", + ) # check traversals starting from all elements of U traverse_from_u_start_time = time.time() - for u, traverse_u_time, time_unit in \ - p.map(traverse_from_u_worker, - [i for i in range(args.u_count)]): + for u, traverse_u_time, time_unit in p.map(traverse_from_u_worker, [i for i in range(args.u_count)]): log.info("U {id: %s} %s%s" % (u, traverse_u_time, time_unit)) traverse_from_u_end_time = time.time() output_data.add_measurement( - 'traverse_from_u_time', - traverse_from_u_end_time - traverse_from_u_start_time) + "traverse_from_u_time", + traverse_from_u_end_time - traverse_from_u_start_time, + ) # check traversals starting from all elements of V traverse_from_v_start_time = time.time() - for v, traverse_v_time, time_unit in \ - p.map(traverse_from_v_worker, - [i for i in range(args.v_count)]): + for v, traverse_v_time, time_unit in p.map(traverse_from_v_worker, [i for i in range(args.v_count)]): log.info("V {id: %s} %s%s" % (v, traverse_v_time, time_unit)) traverse_from_v_end_time = time.time() output_data.add_measurement( - 'traverse_from_v_time', - traverse_from_v_end_time - traverse_from_v_start_time) + "traverse_from_v_time", + traverse_from_v_end_time - traverse_from_v_start_time, + ) # check total number of vertices assert_equal( args.v_count + args.u_count, - session.run('MATCH (n) RETURN count(n) AS cnt').data()[0]['cnt'], - "Total number of vertices isn't correct! Expected: %s Actual: %s") + session.run("MATCH (n) RETURN count(n) AS cnt").data()[0]["cnt"], + "Total number of vertices isn't correct! Expected: %s Actual: %s", + ) # check total number of edges assert_equal( args.v_count * args.u_count, - session.run( - 'MATCH ()-[r]->() RETURN count(r) AS cnt').data()[0]['cnt'], - "Total number of edges isn't correct! Expected: %s Actual: %s") + session.run("MATCH ()-[r]->() RETURN count(r) AS cnt").data()[0]["cnt"], + "Total number of edges isn't correct! Expected: %s Actual: %s", + ) end_time = time.time() - output_data.add_measurement("total_execution_time", - end_time - start_time) + output_data.add_measurement("total_execution_time", end_time - start_time) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=args.logging) if args.logging != "DEBUG": logging.getLogger("neo4j").setLevel(logging.WARNING) diff --git a/tests/stress/common.py b/tests/stress/common.py index 3648d0e0a..f6aa48c57 100644 --- a/tests/stress/common.py +++ b/tests/stress/common.py @@ -11,12 +11,12 @@ # -*- coding: utf-8 -*- -''' +""" Common methods for writing graph database integration tests in python. Only Bolt communication protocol is supported. -''' +""" import contextlib import os @@ -28,9 +28,9 @@ from neo4j import GraphDatabase, TRUST_ALL_CERTIFICATES class OutputData: - ''' + """ Encapsulates results and info about the tests. - ''' + """ def __init__(self): # data in time format (name, time, unit) @@ -39,32 +39,32 @@ class OutputData: self._statuses = [] def add_measurement(self, name, time, unit="s"): - ''' + """ Stores measurement. :param name: str, name of measurement :param time: float, time value :param unit: str, time unit - ''' + """ self._measurements.append((name, time, unit)) def add_status(self, name, status): - ''' + """ Stores status data point. :param name: str, name of data point :param status: printable value - ''' + """ self._statuses.append((name, status)) def dump(self, print_f=print): - ''' + """ Dumps output using the given ouput function. Args: print_f - the function that consumes ouptput. Defaults to the 'print' function. - ''' + """ print_f("Output data:") for name, status in self._statuses: print_f(" %s: %s" % (name, status)) @@ -73,7 +73,7 @@ class OutputData: def execute_till_success(session, query, max_retries=1000): - ''' + """ Executes a query within Bolt session until the query is successfully executed against the database. @@ -86,7 +86,7 @@ def execute_till_success(session, query, max_retries=1000): :param query: query to execute :return: tuple (results_data_list, number_of_failures, result_summary) - ''' + """ no_failures = 0 while True: try: @@ -97,12 +97,11 @@ def execute_till_success(session, query, max_retries=1000): except Exception: no_failures += 1 if no_failures >= max_retries: - raise Exception("Query '%s' failed %d times, aborting" % - (query, max_retries)) + raise Exception("Query '%s' failed %d times, aborting" % (query, max_retries)) def batch(input, batch_size): - """ Batches the given input (must be iterable). + """Batches the given input (must be iterable). Supports input generators. Returns a generator. All is lazy. The last batch can contain less elements then `batch_size`, but is for sure more then zero. @@ -134,7 +133,7 @@ def render(template, iterable_arguments): def assert_equal(expected, actual, message): - ''' + """ Compares expected and actual values. If values are not the same terminate the execution. @@ -142,45 +141,41 @@ def assert_equal(expected, actual, message): :param actual: actual value :param message: str, message in case that the values are not equal, must contain two placeholders (%s) to print the values. - ''' + """ assert expected == actual, message % (expected, actual) def connection_argument_parser(): - ''' + """ Parses arguments related to establishing database connection like host, port, username, etc. :return: An instance of ArgumentParser - ''' + """ parser = ArgumentParser(description=__doc__) - parser.add_argument('--endpoint', type=str, default='127.0.0.1:7687', - help='DBMS instance endpoint. ' - 'Bolt protocol is the only option.') - parser.add_argument('--username', type=str, default='neo4j', - help='DBMS instance username.') - parser.add_argument('--password', type=int, default='1234', - help='DBMS instance password.') - parser.add_argument('--use-ssl', action='store_true', - help="Is SSL enabled?") + parser.add_argument( + "--endpoint", + type=str, + default="127.0.0.1:7687", + help="DBMS instance endpoint. " "Bolt protocol is the only option.", + ) + parser.add_argument("--username", type=str, default="neo4j", help="DBMS instance username.") + parser.add_argument("--password", type=int, default="1234", help="DBMS instance password.") + parser.add_argument("--use-ssl", action="store_true", help="Is SSL enabled?") return parser @contextlib.contextmanager def bolt_session(url, auth, ssl=False): - ''' + """ with wrapper around Bolt session. :param url: str, e.g. "bolt://127.0.0.1:7687" :param auth: auth method, goes directly to the Bolt driver constructor :param ssl: bool, is ssl enabled - ''' - driver = GraphDatabase.driver( - url, - auth=auth, - encrypted=ssl, - trust=TRUST_ALL_CERTIFICATES) + """ + driver = GraphDatabase.driver(url, auth=auth, encrypted=ssl, trust=TRUST_ALL_CERTIFICATES) session = driver.session() try: yield session @@ -192,19 +187,20 @@ def bolt_session(url, auth, ssl=False): # If you are using session with multiprocessing take a look at SesssionCache # in bipartite for an idea how to reuse sessions. def argument_session(args): - ''' + """ :return: Bolt session context manager based on program arguments - ''' - return bolt_session('bolt://' + args.endpoint, - (args.username, str(args.password)), - args.use_ssl) + """ + return bolt_session("bolt://" + args.endpoint, (args.username, str(args.password)), args.use_ssl) def argument_driver(args): return GraphDatabase.driver( - 'bolt://' + args.endpoint, + "bolt://" + args.endpoint, auth=(args.username, str(args.password)), - encrypted=args.use_ssl, trust=TRUST_ALL_CERTIFICATES) + encrypted=args.use_ssl, + trust=TRUST_ALL_CERTIFICATES, + ) + # This class is used to create and cache sessions. Session is cached by args # used to create it and process' pid in which it was created. This makes it @@ -219,8 +215,8 @@ class SessionCache: key = tuple(vars(args).items()) + (os.getpid(),) if key in SessionCache.cache: return SessionCache.cache[key][1] - driver = argument_driver(args) # | - session = driver.session() # V + driver = argument_driver(args) # | + session = driver.session() # V SessionCache.cache[key] = (driver, session) return session @@ -241,6 +237,7 @@ def periodically_execute(callable, args, interval, daemon=True): interval - time (in seconds) between two calls deamon - if the execution thread should be a daemon """ + def periodic_call(): while True: sleep(interval) diff --git a/tests/stress/create_match.py b/tests/stress/create_match.py index 70fc6cb29..434492cba 100644 --- a/tests/stress/create_match.py +++ b/tests/stress/create_match.py @@ -12,11 +12,11 @@ # by the Apache License, Version 2.0, included in the file # licenses/APL.txt. -''' +""" Large scale stress test. Tests only node creation. The idea is to run this test on machines with huge amount of memory e.g. 2TB. -''' +""" import logging import multiprocessing @@ -28,28 +28,41 @@ from common import connection_argument_parser, argument_session def parse_args(): - ''' + """ Parses user arguments :return: parsed arguments - ''' + """ parser = connection_argument_parser() # specific - parser.add_argument('--worker-count', type=int, - default=multiprocessing.cpu_count(), - help='Number of concurrent workers.') - parser.add_argument("--logging", default="INFO", - choices=["INFO", "DEBUG", "WARNING", "ERROR"], - help="Logging level") - parser.add_argument('--vertex-count', type=int, default=100, - help='Number of created vertices.') - parser.add_argument('--max-property-value', type=int, default=1000, - help='Maximum value of property - 1. A created node ' - 'will have a property with random value from 0 to ' - 'max_property_value - 1.') - parser.add_argument('--create-pack-size', type=int, default=1, - help='Number of CREATE clauses in a query') + parser.add_argument( + "--worker-count", + type=int, + default=multiprocessing.cpu_count(), + help="Number of concurrent workers.", + ) + parser.add_argument( + "--logging", + default="INFO", + choices=["INFO", "DEBUG", "WARNING", "ERROR"], + help="Logging level", + ) + parser.add_argument("--vertex-count", type=int, default=100, help="Number of created vertices.") + parser.add_argument( + "--max-property-value", + type=int, + default=1000, + help="Maximum value of property - 1. A created node " + "will have a property with random value from 0 to " + "max_property_value - 1.", + ) + parser.add_argument( + "--create-pack-size", + type=int, + default=1, + help="Number of CREATE clauses in a query", + ) return parser.parse_args() @@ -58,51 +71,49 @@ args = parse_args() def create_worker(worker_id): - ''' + """ Creates nodes and checks that all nodes were created. :param worker_id: worker id :return: tuple (worker_id, create execution time, time unit) - ''' - assert args.vertex_count > 0, 'Number of vertices has to be positive int' + """ + assert args.vertex_count > 0, "Number of vertices has to be positive int" generated_xs = defaultdict(int) - create_query = '' + create_query = "" with argument_session(args) as session: # create vertices start_time = time.time() for i in range(0, args.vertex_count): random_number = random.randint(0, args.max_property_value - 1) generated_xs[random_number] += 1 - create_query += 'CREATE (:Label_T%s {x: %s}) ' % \ - (worker_id, random_number) + create_query += "CREATE (:Label_T%s {x: %s}) " % (worker_id, random_number) # if full back or last item -> execute query - if (i + 1) % args.create_pack_size == 0 or \ - i == args.vertex_count - 1: + if (i + 1) % args.create_pack_size == 0 or i == args.vertex_count - 1: session.run(create_query).consume() - create_query = '' + create_query = "" create_time = time.time() # check total count - result_set = session.run('MATCH (n:Label_T%s) RETURN count(n) AS cnt' % - worker_id).data()[0] - assert result_set['cnt'] == args.vertex_count, \ - 'Create vertices Expected: %s Created: %s' % \ - (args.vertex_count, result_set['cnt']) + result_set = session.run("MATCH (n:Label_T%s) RETURN count(n) AS cnt" % worker_id).data()[0] + assert result_set["cnt"] == args.vertex_count, "Create vertices Expected: %s Created: %s" % ( + args.vertex_count, + result_set["cnt"], + ) # check count per property value for i, size in generated_xs.items(): - result_set = session.run('MATCH (n:Label_T%s {x: %s}) ' - 'RETURN count(n) AS cnt' - % (worker_id, i)).data()[0] - assert result_set['cnt'] == size, "Per x count isn't good " \ - "(Label: Label_T%s, prop x: %s" % (worker_id, i) + result_set = session.run("MATCH (n:Label_T%s {x: %s}) " "RETURN count(n) AS cnt" % (worker_id, i)).data()[0] + assert result_set["cnt"] == size, "Per x count isn't good " "(Label: Label_T%s, prop x: %s" % ( + worker_id, + i, + ) return (worker_id, create_time - start_time, "s") def create_handler(): - ''' + """ Initializes processes and starts the execution. - ''' + """ # instance cleanup with argument_session(args) as session: session.run("MATCH (n) DETACH DELETE n").consume() @@ -113,21 +124,19 @@ def create_handler(): # concurrent create execution & tests with multiprocessing.Pool(args.worker_count) as p: - for worker_id, create_time, time_unit in \ - p.map(create_worker, [i for i in range(args.worker_count)]): - log.info('Worker ID: %s; Create time: %s%s' % - (worker_id, create_time, time_unit)) + for worker_id, create_time, time_unit in p.map(create_worker, [i for i in range(args.worker_count)]): + log.info("Worker ID: %s; Create time: %s%s" % (worker_id, create_time, time_unit)) # check total count expected_total_count = args.worker_count * args.vertex_count - total_count = session.run( - 'MATCH (n) RETURN count(n) AS cnt').data()[0]['cnt'] - assert total_count == expected_total_count, \ - 'Total vertex number: %s Expected: %s' % \ - (total_count, expected_total_count) + total_count = session.run("MATCH (n) RETURN count(n) AS cnt").data()[0]["cnt"] + assert total_count == expected_total_count, "Total vertex number: %s Expected: %s" % ( + total_count, + expected_total_count, + ) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=args.logging) if args.logging != "DEBUG": logging.getLogger("neo4j").setLevel(logging.WARNING) diff --git a/tools/bench-graph-client/main.py b/tools/bench-graph-client/main.py index edb09ecde..41329a288 100755 --- a/tools/bench-graph-client/main.py +++ b/tools/bench-graph-client/main.py @@ -20,9 +20,7 @@ GITHUB_REPOSITORY = os.getenv("GITHUB_REPOSITORY", "") GITHUB_SHA = os.getenv("GITHUB_SHA", "") GITHUB_REF = os.getenv("GITHUB_REF", "") -BENCH_GRAPH_SERVER_ENDPOINT = os.getenv( - "BENCH_GRAPH_SERVER_ENDPOINT", - "http://bench-graph-api:9001") +BENCH_GRAPH_SERVER_ENDPOINT = os.getenv("BENCH_GRAPH_SERVER_ENDPOINT", "http://bench-graph-api:9001") log = logging.getLogger(__name__) @@ -52,12 +50,12 @@ def post_measurement(args): "github_run_id": args.github_run_id, "github_run_number": args.github_run_number, "results": data, - "git_branch": args.head_branch_name}, - timeout=1) - assert req.status_code == 200, \ - f"Uploading {args.benchmark_name} data failed." - log.info(f"{args.benchmark_name} data sent to " - f"{BENCH_GRAPH_SERVER_ENDPOINT}") + "git_branch": args.head_branch_name, + }, + timeout=1, + ) + assert req.status_code == 200, f"Uploading {args.benchmark_name} data failed." + log.info(f"{args.benchmark_name} data sent to " f"{BENCH_GRAPH_SERVER_ENDPOINT}") if __name__ == "__main__": diff --git a/tools/gdb-plugins/operator_tree.py b/tools/gdb-plugins/operator_tree.py index 0f72b4b65..505d74c03 100644 --- a/tools/gdb-plugins/operator_tree.py +++ b/tools/gdb-plugins/operator_tree.py @@ -4,14 +4,14 @@ import gdb def _logical_operator_type(): - '''Returns the LogicalOperator gdb.Type''' + """Returns the LogicalOperator gdb.Type""" # This is a function, because the type may appear during gdb runtime. # Therefore, we cannot assign it on import. - return gdb.lookup_type('memgraph::query::plan::LogicalOperator') + return gdb.lookup_type("memgraph::query::plan::LogicalOperator") def _iter_fields_and_base_classes(value): - '''Iterate all fields of value.type''' + """Iterate all fields of value.type""" types_to_process = [value.type] while types_to_process: for field in types_to_process.pop().fields(): @@ -21,31 +21,27 @@ def _iter_fields_and_base_classes(value): def _fields(value): - '''Return a list of value.type fields.''' - return [f for f in _iter_fields_and_base_classes(value) - if not f.is_base_class] + """Return a list of value.type fields.""" + return [f for f in _iter_fields_and_base_classes(value) if not f.is_base_class] def _has_field(value, field_name): - '''Return True if value.type has a field named field_name.''' + """Return True if value.type has a field named field_name.""" return field_name in [f.name for f in _fields(value)] def _base_classes(value): - '''Return a list of base classes for value.type.''' - return [f for f in _iter_fields_and_base_classes(value) - if f.is_base_class] + """Return a list of base classes for value.type.""" + return [f for f in _iter_fields_and_base_classes(value) if f.is_base_class] def _is_instance(value, type_): - '''Return True if value is an instance of type.''' - return value.type.unqualified() == type_ or \ - type_ in [base.type for base in _base_classes(value)] + """Return True if value is an instance of type.""" + return value.type.unqualified() == type_ or type_ in [base.type for base in _base_classes(value)] # Pattern for matching std::unique_ptr and std::shared_ptr -_SMART_PTR_TYPE_PATTERN = \ - re.compile('^std::(unique|shared)_ptr<(?P[\w:]*)') +_SMART_PTR_TYPE_PATTERN = re.compile("^std::(unique|shared)_ptr<(?P[\w:]*)") def _is_smart_ptr(maybe_smart_ptr, type_name=None): @@ -55,40 +51,39 @@ def _is_smart_ptr(maybe_smart_ptr, type_name=None): match = _SMART_PTR_TYPE_PATTERN.match(type_.name) if match is None or type_name is None: return bool(match) - return type_name == match.group('pointee_type') + return type_name == match.group("pointee_type") def _smart_ptr_pointee(smart_ptr): - '''Returns the pointer to object in shared_ptr/unique_ptr.''' + """Returns the pointer to object in shared_ptr/unique_ptr.""" # This function may not be needed when gdb adds dereferencing # shared_ptr/unique_ptr via Python API. - if _has_field(smart_ptr, '_M_ptr'): + if _has_field(smart_ptr, "_M_ptr"): # shared_ptr - return smart_ptr['_M_ptr'] - if _has_field(smart_ptr, '_M_t'): + return smart_ptr["_M_ptr"] + if _has_field(smart_ptr, "_M_t"): # unique_ptr - smart_ptr = smart_ptr['_M_t'] - if _has_field(smart_ptr, '_M_t'): + smart_ptr = smart_ptr["_M_t"] + if _has_field(smart_ptr, "_M_t"): # Check for one more level of _M_t - smart_ptr = smart_ptr['_M_t'] - if _has_field(smart_ptr, '_M_head_impl'): - return smart_ptr['_M_head_impl'] + smart_ptr = smart_ptr["_M_t"] + if _has_field(smart_ptr, "_M_head_impl"): + return smart_ptr["_M_head_impl"] def _get_operator_input(operator): - '''Returns the input operator of given operator, if it has any.''' - if not _has_field(operator, 'input_'): + """Returns the input operator of given operator, if it has any.""" + if not _has_field(operator, "input_"): return None - input_op = _smart_ptr_pointee(operator['input_']).dereference() + input_op = _smart_ptr_pointee(operator["input_"]).dereference() return input_op.cast(input_op.dynamic_type) class PrintOperatorTree(gdb.Command): - '''Print the tree of logical operators from the expression.''' + """Print the tree of logical operators from the expression.""" + def __init__(self): - super(PrintOperatorTree, self).__init__("print-operator-tree", - gdb.COMMAND_USER, - gdb.COMPLETE_EXPRESSION) + super(PrintOperatorTree, self).__init__("print-operator-tree", gdb.COMMAND_USER, gdb.COMPLETE_EXPRESSION) def invoke(self, argument, from_tty): try: @@ -98,17 +93,16 @@ class PrintOperatorTree(gdb.Command): logical_operator_type = _logical_operator_type() if operator.type.code in (gdb.TYPE_CODE_PTR, gdb.TYPE_CODE_REF): operator = operator.referenced_value() - if _is_smart_ptr(operator, 'memgraph::query::plan::LogicalOperator'): + if _is_smart_ptr(operator, "memgraph::query::plan::LogicalOperator"): operator = _smart_ptr_pointee(operator).dereference() if not _is_instance(operator, logical_operator_type): - raise gdb.GdbError("Expected a '%s', but got '%s'" % - (logical_operator_type, operator.type)) + raise gdb.GdbError("Expected a '%s', but got '%s'" % (logical_operator_type, operator.type)) next_op = operator.cast(operator.dynamic_type) tree = [] while next_op is not None: - tree.append('* %s <%s>' % (next_op.type.name, next_op.address)) + tree.append("* %s <%s>" % (next_op.type.name, next_op.address)) next_op = _get_operator_input(next_op) - print('\n'.join(tree)) + print("\n".join(tree)) PrintOperatorTree() diff --git a/tools/gdb-plugins/pretty_printers.py b/tools/gdb-plugins/pretty_printers.py index af75a9724..7f3d58840 100644 --- a/tools/gdb-plugins/pretty_printers.py +++ b/tools/gdb-plugins/pretty_printers.py @@ -3,43 +3,49 @@ import gdb.printing def build_memgraph_pretty_printers(): - '''Instantiate and return all memgraph pretty printer classes.''' - pp = gdb.printing.RegexpCollectionPrettyPrinter('memgraph') - pp.add_printer('memgraph::query::TypedValue', '^memgraph::query::TypedValue$', TypedValuePrinter) + """Instantiate and return all memgraph pretty printer classes.""" + pp = gdb.printing.RegexpCollectionPrettyPrinter("memgraph") + pp.add_printer( + "memgraph::query::TypedValue", + "^memgraph::query::TypedValue$", + TypedValuePrinter, + ) return pp class TypedValuePrinter(gdb.printing.PrettyPrinter): - '''Pretty printer for memgraph::query::TypedValue''' + """Pretty printer for memgraph::query::TypedValue""" + def __init__(self, val): - super(TypedValuePrinter, self).__init__('TypedValue') + super(TypedValuePrinter, self).__init__("TypedValue") self.val = val def to_string(self): def _to_str(val): - return '{%s %s}' % (value_type, self.val[val]) - value_type = str(self.val['type_']) - if value_type == 'memgraph::query::TypedValue::Type::Null': - return '{%s}' % value_type - elif value_type == 'memgraph::query::TypedValue::Type::Bool': - return _to_str('bool_v') - elif value_type == 'memgraph::query::TypedValue::Type::Int': - return _to_str('int_v') - elif value_type == 'memgraph::query::TypedValue::Type::Double': - return _to_str('double_v') - elif value_type == 'memgraph::query::TypedValue::Type::String': - return _to_str('string_v') - elif value_type == 'memgraph::query::TypedValue::Type::List': - return _to_str('list_v') - elif value_type == 'memgraph::query::TypedValue::Type::Map': - return _to_str('map_v') - elif value_type == 'memgraph::query::TypedValue::Type::Vertex': - return _to_str('vertex_v') - elif value_type == 'memgraph::query::TypedValue::Type::Edge': - return _to_str('edge_v') - elif value_type == 'memgraph::query::TypedValue::Type::Path': - return _to_str('path_v') - return '{%s}' % value_type + return "{%s %s}" % (value_type, self.val[val]) -gdb.printing.register_pretty_printer(None, build_memgraph_pretty_printers(), - replace=True) + value_type = str(self.val["type_"]) + if value_type == "memgraph::query::TypedValue::Type::Null": + return "{%s}" % value_type + elif value_type == "memgraph::query::TypedValue::Type::Bool": + return _to_str("bool_v") + elif value_type == "memgraph::query::TypedValue::Type::Int": + return _to_str("int_v") + elif value_type == "memgraph::query::TypedValue::Type::Double": + return _to_str("double_v") + elif value_type == "memgraph::query::TypedValue::Type::String": + return _to_str("string_v") + elif value_type == "memgraph::query::TypedValue::Type::List": + return _to_str("list_v") + elif value_type == "memgraph::query::TypedValue::Type::Map": + return _to_str("map_v") + elif value_type == "memgraph::query::TypedValue::Type::Vertex": + return _to_str("vertex_v") + elif value_type == "memgraph::query::TypedValue::Type::Edge": + return _to_str("edge_v") + elif value_type == "memgraph::query::TypedValue::Type::Path": + return _to_str("path_v") + return "{%s}" % value_type + + +gdb.printing.register_pretty_printer(None, build_memgraph_pretty_printers(), replace=True) diff --git a/tools/github/clang-tidy/clang-tidy-diff.py b/tools/github/clang-tidy/clang-tidy-diff.py index a20b1f1f4..852227f4f 100755 --- a/tools/github/clang-tidy/clang-tidy-diff.py +++ b/tools/github/clang-tidy/clang-tidy-diff.py @@ -1,12 +1,12 @@ #!/usr/bin/env python3 # -#===- clang-tidy-diff.py - ClangTidy Diff Checker -----------*- python -*--===# +# ===- clang-tidy-diff.py - ClangTidy Diff Checker -----------*- python -*--===# # # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # -#===-----------------------------------------------------------------------===# +# ===-----------------------------------------------------------------------===# r""" ClangTidy Diff Checker @@ -37,11 +37,11 @@ import threading import traceback try: - import yaml + import yaml except ImportError: - yaml = None + yaml = None -is_py2 = sys.version[0] == '2' +is_py2 = sys.version[0] == "2" if is_py2: import Queue as queue @@ -50,220 +50,242 @@ else: def run_tidy(task_queue, lock, timeout): - watchdog = None - while True: - command = task_queue.get() - try: - proc = subprocess.Popen(command, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + watchdog = None + while True: + command = task_queue.get() + try: + proc = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - if timeout is not None: - watchdog = threading.Timer(timeout, proc.kill) - watchdog.start() + if timeout is not None: + watchdog = threading.Timer(timeout, proc.kill) + watchdog.start() - stdout, stderr = proc.communicate() + stdout, stderr = proc.communicate() - with lock: - sys.stdout.write(stdout.decode('utf-8') + '\n') - sys.stdout.flush() - if stderr: - sys.stderr.write(stderr.decode('utf-8') + '\n') - sys.stderr.flush() - except Exception as e: - with lock: - sys.stderr.write('Failed: ' + str(e) + ': '.join(command) + '\n') - finally: - with lock: - if not (timeout is None or watchdog is None): - if not watchdog.is_alive(): - sys.stderr.write('Terminated by timeout: ' + - ' '.join(command) + '\n') - watchdog.cancel() - task_queue.task_done() + with lock: + sys.stdout.write(stdout.decode("utf-8") + "\n") + sys.stdout.flush() + if stderr: + sys.stderr.write(stderr.decode("utf-8") + "\n") + sys.stderr.flush() + except Exception as e: + with lock: + sys.stderr.write("Failed: " + str(e) + ": ".join(command) + "\n") + finally: + with lock: + if not (timeout is None or watchdog is None): + if not watchdog.is_alive(): + sys.stderr.write("Terminated by timeout: " + " ".join(command) + "\n") + watchdog.cancel() + task_queue.task_done() def start_workers(max_tasks, tidy_caller, task_queue, lock, timeout): - for _ in range(max_tasks): - t = threading.Thread(target=tidy_caller, args=(task_queue, lock, timeout)) - t.daemon = True - t.start() + for _ in range(max_tasks): + t = threading.Thread(target=tidy_caller, args=(task_queue, lock, timeout)) + t.daemon = True + t.start() def merge_replacement_files(tmpdir, mergefile): - """Merge all replacement files in a directory into a single file""" - # The fixes suggested by clang-tidy >= 4.0.0 are given under - # the top level key 'Diagnostics' in the output yaml files - mergekey = "Diagnostics" - merged = [] - for replacefile in glob.iglob(os.path.join(tmpdir, '*.yaml')): - content = yaml.safe_load(open(replacefile, 'r')) - if not content: - continue # Skip empty files. - merged.extend(content.get(mergekey, [])) + """Merge all replacement files in a directory into a single file""" + # The fixes suggested by clang-tidy >= 4.0.0 are given under + # the top level key 'Diagnostics' in the output yaml files + mergekey = "Diagnostics" + merged = [] + for replacefile in glob.iglob(os.path.join(tmpdir, "*.yaml")): + content = yaml.safe_load(open(replacefile, "r")) + if not content: + continue # Skip empty files. + merged.extend(content.get(mergekey, [])) - if merged: - # MainSourceFile: The key is required by the definition inside - # include/clang/Tooling/ReplacementsYaml.h, but the value - # is actually never used inside clang-apply-replacements, - # so we set it to '' here. - output = {'MainSourceFile': '', mergekey: merged} - with open(mergefile, 'w') as out: - yaml.safe_dump(output, out) - else: - # Empty the file: - open(mergefile, 'w').close() + if merged: + # MainSourceFile: The key is required by the definition inside + # include/clang/Tooling/ReplacementsYaml.h, but the value + # is actually never used inside clang-apply-replacements, + # so we set it to '' here. + output = {"MainSourceFile": "", mergekey: merged} + with open(mergefile, "w") as out: + yaml.safe_dump(output, out) + else: + # Empty the file: + open(mergefile, "w").close() def main(): - parser = argparse.ArgumentParser(description= - 'Run clang-tidy against changed files, and ' - 'output diagnostics only for modified ' - 'lines.') - parser.add_argument('-clang-tidy-binary', metavar='PATH', - default='clang-tidy', - help='path to clang-tidy binary') - parser.add_argument('-p', metavar='NUM', default=0, - help='strip the smallest prefix containing P slashes') - parser.add_argument('-regex', metavar='PATTERN', default=None, - help='custom pattern selecting file paths to check ' - '(case sensitive, overrides -iregex)') - parser.add_argument('-iregex', metavar='PATTERN', default= - r'.*\.(cpp|cc|c\+\+|cxx|c|cl|h|hpp|m|mm|inc)', - help='custom pattern selecting file paths to check ' - '(case insensitive, overridden by -regex)') - parser.add_argument('-j', type=int, default=1, - help='number of tidy instances to be run in parallel.') - parser.add_argument('-timeout', type=int, default=None, - help='timeout per each file in seconds.') - parser.add_argument('-fix', action='store_true', default=False, - help='apply suggested fixes') - parser.add_argument('-checks', - help='checks filter, when not specified, use clang-tidy ' - 'default', - default='') - parser.add_argument('-path', dest='build_path', - help='Path used to read a compile command database.') - if yaml: - parser.add_argument('-export-fixes', metavar='FILE', dest='export_fixes', - help='Create a yaml file to store suggested fixes in, ' - 'which can be applied with clang-apply-replacements.') - parser.add_argument('-extra-arg', dest='extra_arg', - action='append', default=[], - help='Additional argument to append to the compiler ' - 'command line.') - parser.add_argument('-extra-arg-before', dest='extra_arg_before', - action='append', default=[], - help='Additional argument to prepend to the compiler ' - 'command line.') - parser.add_argument('-quiet', action='store_true', default=False, - help='Run clang-tidy in quiet mode') - clang_tidy_args = [] - argv = sys.argv[1:] - if '--' in argv: - clang_tidy_args.extend(argv[argv.index('--'):]) - argv = argv[:argv.index('--')] + parser = argparse.ArgumentParser( + description="Run clang-tidy against changed files, and " "output diagnostics only for modified " "lines." + ) + parser.add_argument( + "-clang-tidy-binary", + metavar="PATH", + default="clang-tidy", + help="path to clang-tidy binary", + ) + parser.add_argument( + "-p", + metavar="NUM", + default=0, + help="strip the smallest prefix containing P slashes", + ) + parser.add_argument( + "-regex", + metavar="PATTERN", + default=None, + help="custom pattern selecting file paths to check " "(case sensitive, overrides -iregex)", + ) + parser.add_argument( + "-iregex", + metavar="PATTERN", + default=r".*\.(cpp|cc|c\+\+|cxx|c|cl|h|hpp|m|mm|inc)", + help="custom pattern selecting file paths to check " "(case insensitive, overridden by -regex)", + ) + parser.add_argument( + "-j", + type=int, + default=1, + help="number of tidy instances to be run in parallel.", + ) + parser.add_argument("-timeout", type=int, default=None, help="timeout per each file in seconds.") + parser.add_argument("-fix", action="store_true", default=False, help="apply suggested fixes") + parser.add_argument( + "-checks", + help="checks filter, when not specified, use clang-tidy " "default", + default="", + ) + parser.add_argument("-path", dest="build_path", help="Path used to read a compile command database.") + if yaml: + parser.add_argument( + "-export-fixes", + metavar="FILE", + dest="export_fixes", + help="Create a yaml file to store suggested fixes in, " + "which can be applied with clang-apply-replacements.", + ) + parser.add_argument( + "-extra-arg", + dest="extra_arg", + action="append", + default=[], + help="Additional argument to append to the compiler " "command line.", + ) + parser.add_argument( + "-extra-arg-before", + dest="extra_arg_before", + action="append", + default=[], + help="Additional argument to prepend to the compiler " "command line.", + ) + parser.add_argument( + "-quiet", + action="store_true", + default=False, + help="Run clang-tidy in quiet mode", + ) + clang_tidy_args = [] + argv = sys.argv[1:] + if "--" in argv: + clang_tidy_args.extend(argv[argv.index("--") :]) + argv = argv[: argv.index("--")] - args = parser.parse_args(argv) + args = parser.parse_args(argv) - # Extract changed lines for each file. - filename = None - lines_by_file = {} - for line in sys.stdin: - match = re.search('^\+\+\+\ \"?(.*?/){%s}([^ \t\n\"]*)' % args.p, line) - if match: - filename = match.group(2) - if filename is None: - continue + # Extract changed lines for each file. + filename = None + lines_by_file = {} + for line in sys.stdin: + match = re.search('^\+\+\+\ "?(.*?/){%s}([^ \t\n"]*)' % args.p, line) + if match: + filename = match.group(2) + if filename is None: + continue - if args.regex is not None: - if not re.match('^%s$' % args.regex, filename): - continue - else: - if not re.match('^%s$' % args.iregex, filename, re.IGNORECASE): - continue + if args.regex is not None: + if not re.match("^%s$" % args.regex, filename): + continue + else: + if not re.match("^%s$" % args.iregex, filename, re.IGNORECASE): + continue - match = re.search('^@@.*\+(\d+)(,(\d+))?', line) - if match: - start_line = int(match.group(1)) - line_count = 1 - if match.group(3): - line_count = int(match.group(3)) - if line_count == 0: - continue - end_line = start_line + line_count - 1 - lines_by_file.setdefault(filename, []).append([start_line, end_line]) + match = re.search("^@@.*\+(\d+)(,(\d+))?", line) + if match: + start_line = int(match.group(1)) + line_count = 1 + if match.group(3): + line_count = int(match.group(3)) + if line_count == 0: + continue + end_line = start_line + line_count - 1 + lines_by_file.setdefault(filename, []).append([start_line, end_line]) - if not any(lines_by_file): - print("No relevant changes found.") - sys.exit(0) + if not any(lines_by_file): + print("No relevant changes found.") + sys.exit(0) - max_task_count = args.j - if max_task_count == 0: - max_task_count = multiprocessing.cpu_count() - max_task_count = min(len(lines_by_file), max_task_count) + max_task_count = args.j + if max_task_count == 0: + max_task_count = multiprocessing.cpu_count() + max_task_count = min(len(lines_by_file), max_task_count) - tmpdir = None - if yaml and args.export_fixes: - tmpdir = tempfile.mkdtemp() - - # Tasks for clang-tidy. - task_queue = queue.Queue(max_task_count) - # A lock for console output. - lock = threading.Lock() - - # Run a pool of clang-tidy workers. - start_workers(max_task_count, run_tidy, task_queue, lock, args.timeout) - - # Form the common args list. - common_clang_tidy_args = [] - if args.fix: - common_clang_tidy_args.append('-fix') - if args.checks != '': - common_clang_tidy_args.append('-checks=' + args.checks) - if args.quiet: - common_clang_tidy_args.append('-quiet') - if args.build_path is not None: - common_clang_tidy_args.append('-p=%s' % args.build_path) - for arg in args.extra_arg: - common_clang_tidy_args.append('-extra-arg=%s' % arg) - for arg in args.extra_arg_before: - common_clang_tidy_args.append('-extra-arg-before=%s' % arg) - - for name in lines_by_file: - line_filter_json = json.dumps( - [{"name": name, "lines": lines_by_file[name]}], - separators=(',', ':')) - - # Run clang-tidy on files containing changes. - command = [args.clang_tidy_binary] - command.append('-line-filter=' + line_filter_json) + tmpdir = None if yaml and args.export_fixes: - # Get a temporary file. We immediately close the handle so clang-tidy can - # overwrite it. - (handle, tmp_name) = tempfile.mkstemp(suffix='.yaml', dir=tmpdir) - os.close(handle) - command.append('-export-fixes=' + tmp_name) - command.extend(common_clang_tidy_args) - command.append(name) - command.extend(clang_tidy_args) + tmpdir = tempfile.mkdtemp() - task_queue.put(command) + # Tasks for clang-tidy. + task_queue = queue.Queue(max_task_count) + # A lock for console output. + lock = threading.Lock() - # Wait for all threads to be done. - task_queue.join() + # Run a pool of clang-tidy workers. + start_workers(max_task_count, run_tidy, task_queue, lock, args.timeout) - if yaml and args.export_fixes: - print('Writing fixes to ' + args.export_fixes + ' ...') - try: - merge_replacement_files(tmpdir, args.export_fixes) - except: - sys.stderr.write('Error exporting fixes.\n') - traceback.print_exc() + # Form the common args list. + common_clang_tidy_args = [] + if args.fix: + common_clang_tidy_args.append("-fix") + if args.checks != "": + common_clang_tidy_args.append("-checks=" + args.checks) + if args.quiet: + common_clang_tidy_args.append("-quiet") + if args.build_path is not None: + common_clang_tidy_args.append("-p=%s" % args.build_path) + for arg in args.extra_arg: + common_clang_tidy_args.append("-extra-arg=%s" % arg) + for arg in args.extra_arg_before: + common_clang_tidy_args.append("-extra-arg-before=%s" % arg) - if tmpdir: - shutil.rmtree(tmpdir) + for name in lines_by_file: + line_filter_json = json.dumps([{"name": name, "lines": lines_by_file[name]}], separators=(",", ":")) + + # Run clang-tidy on files containing changes. + command = [args.clang_tidy_binary] + command.append("-line-filter=" + line_filter_json) + if yaml and args.export_fixes: + # Get a temporary file. We immediately close the handle so clang-tidy can + # overwrite it. + (handle, tmp_name) = tempfile.mkstemp(suffix=".yaml", dir=tmpdir) + os.close(handle) + command.append("-export-fixes=" + tmp_name) + command.extend(common_clang_tidy_args) + command.append(name) + command.extend(clang_tidy_args) + + task_queue.put(command) + + # Wait for all threads to be done. + task_queue.join() + + if yaml and args.export_fixes: + print("Writing fixes to " + args.export_fixes + " ...") + try: + merge_replacement_files(tmpdir, args.export_fixes) + except: + sys.stderr.write("Error exporting fixes.\n") + traceback.print_exc() + + if tmpdir: + shutil.rmtree(tmpdir) -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/tools/github/clang-tidy/run-clang-tidy.py b/tools/github/clang-tidy/run-clang-tidy.py index 0dbac0b25..c034cc68b 100755 --- a/tools/github/clang-tidy/run-clang-tidy.py +++ b/tools/github/clang-tidy/run-clang-tidy.py @@ -1,12 +1,12 @@ #!/usr/bin/env python3 # -#===- run-clang-tidy.py - Parallel clang-tidy runner --------*- python -*--===# +# ===- run-clang-tidy.py - Parallel clang-tidy runner --------*- python -*--===# # # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # -#===-----------------------------------------------------------------------===# +# ===-----------------------------------------------------------------------===# # FIXME: Integrate with clang-tidy-diff.py @@ -50,11 +50,11 @@ import threading import traceback try: - import yaml + import yaml except ImportError: - yaml = None + yaml = None -is_py2 = sys.version[0] == '2' +is_py2 = sys.version[0] == "2" if is_py2: import Queue as queue @@ -63,275 +63,327 @@ else: def find_compilation_database(path): - """Adjusts the directory until a compilation database is found.""" - result = './' - while not os.path.isfile(os.path.join(result, path)): - if os.path.realpath(result) == '/': - print('Error: could not find compilation database.') - sys.exit(1) - result += '../' - return os.path.realpath(result) + """Adjusts the directory until a compilation database is found.""" + result = "./" + while not os.path.isfile(os.path.join(result, path)): + if os.path.realpath(result) == "/": + print("Error: could not find compilation database.") + sys.exit(1) + result += "../" + return os.path.realpath(result) def make_absolute(f, directory): - if os.path.isabs(f): - return f - return os.path.normpath(os.path.join(directory, f)) + if os.path.isabs(f): + return f + return os.path.normpath(os.path.join(directory, f)) -def get_tidy_invocation(f, clang_tidy_binary, checks, tmpdir, build_path, - header_filter, allow_enabling_alpha_checkers, - extra_arg, extra_arg_before, quiet, config): - """Gets a command line for clang-tidy.""" - start = [clang_tidy_binary] - if allow_enabling_alpha_checkers: - start.append('-allow-enabling-analyzer-alpha-checkers') - if header_filter is not None: - start.append('-header-filter=' + header_filter) - if checks: - start.append('-checks=' + checks) - if tmpdir is not None: - start.append('-export-fixes') - # Get a temporary file. We immediately close the handle so clang-tidy can - # overwrite it. - (handle, name) = tempfile.mkstemp(suffix='.yaml', dir=tmpdir) - os.close(handle) - start.append(name) - for arg in extra_arg: - start.append('-extra-arg=%s' % arg) - for arg in extra_arg_before: - start.append('-extra-arg-before=%s' % arg) - start.append('-p=' + build_path) - if quiet: - start.append('-quiet') - if config: - start.append('-config=' + config) - start.append(f) - return start +def get_tidy_invocation( + f, + clang_tidy_binary, + checks, + tmpdir, + build_path, + header_filter, + allow_enabling_alpha_checkers, + extra_arg, + extra_arg_before, + quiet, + config, +): + """Gets a command line for clang-tidy.""" + start = [clang_tidy_binary] + if allow_enabling_alpha_checkers: + start.append("-allow-enabling-analyzer-alpha-checkers") + if header_filter is not None: + start.append("-header-filter=" + header_filter) + if checks: + start.append("-checks=" + checks) + if tmpdir is not None: + start.append("-export-fixes") + # Get a temporary file. We immediately close the handle so clang-tidy can + # overwrite it. + (handle, name) = tempfile.mkstemp(suffix=".yaml", dir=tmpdir) + os.close(handle) + start.append(name) + for arg in extra_arg: + start.append("-extra-arg=%s" % arg) + for arg in extra_arg_before: + start.append("-extra-arg-before=%s" % arg) + start.append("-p=" + build_path) + if quiet: + start.append("-quiet") + if config: + start.append("-config=" + config) + start.append(f) + return start def merge_replacement_files(tmpdir, mergefile): - """Merge all replacement files in a directory into a single file""" - # The fixes suggested by clang-tidy >= 4.0.0 are given under - # the top level key 'Diagnostics' in the output yaml files - mergekey = "Diagnostics" - merged=[] - for replacefile in glob.iglob(os.path.join(tmpdir, '*.yaml')): - content = yaml.safe_load(open(replacefile, 'r')) - if not content: - continue # Skip empty files. - merged.extend(content.get(mergekey, [])) + """Merge all replacement files in a directory into a single file""" + # The fixes suggested by clang-tidy >= 4.0.0 are given under + # the top level key 'Diagnostics' in the output yaml files + mergekey = "Diagnostics" + merged = [] + for replacefile in glob.iglob(os.path.join(tmpdir, "*.yaml")): + content = yaml.safe_load(open(replacefile, "r")) + if not content: + continue # Skip empty files. + merged.extend(content.get(mergekey, [])) - if merged: - # MainSourceFile: The key is required by the definition inside - # include/clang/Tooling/ReplacementsYaml.h, but the value - # is actually never used inside clang-apply-replacements, - # so we set it to '' here. - output = {'MainSourceFile': '', mergekey: merged} - with open(mergefile, 'w') as out: - yaml.safe_dump(output, out) - else: - # Empty the file: - open(mergefile, 'w').close() + if merged: + # MainSourceFile: The key is required by the definition inside + # include/clang/Tooling/ReplacementsYaml.h, but the value + # is actually never used inside clang-apply-replacements, + # so we set it to '' here. + output = {"MainSourceFile": "", mergekey: merged} + with open(mergefile, "w") as out: + yaml.safe_dump(output, out) + else: + # Empty the file: + open(mergefile, "w").close() def check_clang_apply_replacements_binary(args): - """Checks if invoking supplied clang-apply-replacements binary works.""" - try: - subprocess.check_call([args.clang_apply_replacements_binary, '--version']) - except: - print('Unable to run clang-apply-replacements. Is clang-apply-replacements ' - 'binary correctly specified?', file=sys.stderr) - traceback.print_exc() - sys.exit(1) + """Checks if invoking supplied clang-apply-replacements binary works.""" + try: + subprocess.check_call([args.clang_apply_replacements_binary, "--version"]) + except: + print( + "Unable to run clang-apply-replacements. Is clang-apply-replacements " "binary correctly specified?", + file=sys.stderr, + ) + traceback.print_exc() + sys.exit(1) def apply_fixes(args, tmpdir): - """Calls clang-apply-fixes on a given directory.""" - invocation = [args.clang_apply_replacements_binary] - if args.format: - invocation.append('-format') - if args.style: - invocation.append('-style=' + args.style) - invocation.append(tmpdir) - subprocess.call(invocation) + """Calls clang-apply-fixes on a given directory.""" + invocation = [args.clang_apply_replacements_binary] + if args.format: + invocation.append("-format") + if args.style: + invocation.append("-style=" + args.style) + invocation.append(tmpdir) + subprocess.call(invocation) def run_tidy(args, tmpdir, build_path, queue, lock, failed_files): - """Takes filenames out of queue and runs clang-tidy on them.""" - while True: - name = queue.get() - invocation = get_tidy_invocation(name, args.clang_tidy_binary, args.checks, - tmpdir, build_path, args.header_filter, - args.allow_enabling_alpha_checkers, - args.extra_arg, args.extra_arg_before, - args.quiet, args.config) + """Takes filenames out of queue and runs clang-tidy on them.""" + while True: + name = queue.get() + invocation = get_tidy_invocation( + name, + args.clang_tidy_binary, + args.checks, + tmpdir, + build_path, + args.header_filter, + args.allow_enabling_alpha_checkers, + args.extra_arg, + args.extra_arg_before, + args.quiet, + args.config, + ) - proc = subprocess.Popen(invocation, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - output, err = proc.communicate() - if proc.returncode != 0: - failed_files.append(name) - with lock: - sys.stdout.write(' '.join(invocation) + '\n' + output.decode('utf-8')) - if len(err) > 0: - sys.stdout.flush() - sys.stderr.write(err.decode('utf-8')) - queue.task_done() + proc = subprocess.Popen(invocation, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + output, err = proc.communicate() + if proc.returncode != 0: + failed_files.append(name) + with lock: + sys.stdout.write(" ".join(invocation) + "\n" + output.decode("utf-8")) + if len(err) > 0: + sys.stdout.flush() + sys.stderr.write(err.decode("utf-8")) + queue.task_done() def main(): - parser = argparse.ArgumentParser(description='Runs clang-tidy over all files ' - 'in a compilation database. Requires ' - 'clang-tidy and clang-apply-replacements in ' - '$PATH.') - parser.add_argument('-allow-enabling-alpha-checkers', - action='store_true', help='allow alpha checkers from ' - 'clang-analyzer.') - parser.add_argument('-clang-tidy-binary', metavar='PATH', - default='clang-tidy-11', - help='path to clang-tidy binary') - parser.add_argument('-clang-apply-replacements-binary', metavar='PATH', - default='clang-apply-replacements-11', - help='path to clang-apply-replacements binary') - parser.add_argument('-checks', default=None, - help='checks filter, when not specified, use clang-tidy ' - 'default') - parser.add_argument('-config', default=None, - help='Specifies a configuration in YAML/JSON format: ' - ' -config="{Checks: \'*\', ' - ' CheckOptions: [{key: x, ' - ' value: y}]}" ' - 'When the value is empty, clang-tidy will ' - 'attempt to find a file named .clang-tidy for ' - 'each source file in its parent directories.') - parser.add_argument('-header-filter', default=None, - help='regular expression matching the names of the ' - 'headers to output diagnostics from. Diagnostics from ' - 'the main file of each translation unit are always ' - 'displayed.') - if yaml: - parser.add_argument('-export-fixes', metavar='filename', dest='export_fixes', - help='Create a yaml file to store suggested fixes in, ' - 'which can be applied with clang-apply-replacements.') - parser.add_argument('-j', type=int, default=0, - help='number of tidy instances to be run in parallel.') - parser.add_argument('files', nargs='*', default=['.*'], - help='files to be processed (regex on path)') - parser.add_argument('-fix', action='store_true', help='apply fix-its') - parser.add_argument('-format', action='store_true', help='Reformat code ' - 'after applying fixes') - parser.add_argument('-style', default='file', help='The style of reformat ' - 'code after applying fixes') - parser.add_argument('-p', dest='build_path', - help='Path used to read a compile command database.') - parser.add_argument('-extra-arg', dest='extra_arg', - action='append', default=[], - help='Additional argument to append to the compiler ' - 'command line.') - parser.add_argument('-extra-arg-before', dest='extra_arg_before', - action='append', default=[], - help='Additional argument to prepend to the compiler ' - 'command line.') - parser.add_argument('-quiet', action='store_true', - help='Run clang-tidy in quiet mode') - args = parser.parse_args() + parser = argparse.ArgumentParser( + description="Runs clang-tidy over all files " + "in a compilation database. Requires " + "clang-tidy and clang-apply-replacements in " + "$PATH." + ) + parser.add_argument( + "-allow-enabling-alpha-checkers", + action="store_true", + help="allow alpha checkers from " "clang-analyzer.", + ) + parser.add_argument( + "-clang-tidy-binary", + metavar="PATH", + default="clang-tidy-11", + help="path to clang-tidy binary", + ) + parser.add_argument( + "-clang-apply-replacements-binary", + metavar="PATH", + default="clang-apply-replacements-11", + help="path to clang-apply-replacements binary", + ) + parser.add_argument( + "-checks", + default=None, + help="checks filter, when not specified, use clang-tidy " "default", + ) + parser.add_argument( + "-config", + default=None, + help="Specifies a configuration in YAML/JSON format: " + " -config=\"{Checks: '*', " + " CheckOptions: [{key: x, " + ' value: y}]}" ' + "When the value is empty, clang-tidy will " + "attempt to find a file named .clang-tidy for " + "each source file in its parent directories.", + ) + parser.add_argument( + "-header-filter", + default=None, + help="regular expression matching the names of the " + "headers to output diagnostics from. Diagnostics from " + "the main file of each translation unit are always " + "displayed.", + ) + if yaml: + parser.add_argument( + "-export-fixes", + metavar="filename", + dest="export_fixes", + help="Create a yaml file to store suggested fixes in, " + "which can be applied with clang-apply-replacements.", + ) + parser.add_argument( + "-j", + type=int, + default=0, + help="number of tidy instances to be run in parallel.", + ) + parser.add_argument("files", nargs="*", default=[".*"], help="files to be processed (regex on path)") + parser.add_argument("-fix", action="store_true", help="apply fix-its") + parser.add_argument("-format", action="store_true", help="Reformat code " "after applying fixes") + parser.add_argument( + "-style", + default="file", + help="The style of reformat " "code after applying fixes", + ) + parser.add_argument("-p", dest="build_path", help="Path used to read a compile command database.") + parser.add_argument( + "-extra-arg", + dest="extra_arg", + action="append", + default=[], + help="Additional argument to append to the compiler " "command line.", + ) + parser.add_argument( + "-extra-arg-before", + dest="extra_arg_before", + action="append", + default=[], + help="Additional argument to prepend to the compiler " "command line.", + ) + parser.add_argument("-quiet", action="store_true", help="Run clang-tidy in quiet mode") + args = parser.parse_args() - db_path = 'compile_commands.json' + db_path = "compile_commands.json" - if args.build_path is not None: - build_path = args.build_path - else: - # Find our database - build_path = find_compilation_database(db_path) - - try: - invocation = [args.clang_tidy_binary, '-list-checks'] - if args.allow_enabling_alpha_checkers: - invocation.append('-allow-enabling-analyzer-alpha-checkers') - invocation.append('-p=' + build_path) - if args.checks: - invocation.append('-checks=' + args.checks) - invocation.append('-') - if args.quiet: - # Even with -quiet we still want to check if we can call clang-tidy. - with open(os.devnull, 'w') as dev_null: - subprocess.check_call(invocation, stdout=dev_null) + if args.build_path is not None: + build_path = args.build_path else: - subprocess.check_call(invocation) - except: - print("Unable to run clang-tidy.", file=sys.stderr) - sys.exit(1) + # Find our database + build_path = find_compilation_database(db_path) - # Load the database and extract all files. - database = json.load(open(os.path.join(build_path, db_path))) - files = [make_absolute(entry['file'], entry['directory']) - for entry in database] + try: + invocation = [args.clang_tidy_binary, "-list-checks"] + if args.allow_enabling_alpha_checkers: + invocation.append("-allow-enabling-analyzer-alpha-checkers") + invocation.append("-p=" + build_path) + if args.checks: + invocation.append("-checks=" + args.checks) + invocation.append("-") + if args.quiet: + # Even with -quiet we still want to check if we can call clang-tidy. + with open(os.devnull, "w") as dev_null: + subprocess.check_call(invocation, stdout=dev_null) + else: + subprocess.check_call(invocation) + except: + print("Unable to run clang-tidy.", file=sys.stderr) + sys.exit(1) - max_task = args.j - if max_task == 0: - max_task = multiprocessing.cpu_count() + # Load the database and extract all files. + database = json.load(open(os.path.join(build_path, db_path))) + files = [make_absolute(entry["file"], entry["directory"]) for entry in database] - tmpdir = None - if args.fix or (yaml and args.export_fixes): - check_clang_apply_replacements_binary(args) - tmpdir = tempfile.mkdtemp() + max_task = args.j + if max_task == 0: + max_task = multiprocessing.cpu_count() - # Build up a big regexy filter from all command line arguments. - file_name_re = re.compile('|'.join(args.files)) + tmpdir = None + if args.fix or (yaml and args.export_fixes): + check_clang_apply_replacements_binary(args) + tmpdir = tempfile.mkdtemp() - return_code = 0 - try: - # Spin up a bunch of tidy-launching threads. - task_queue = queue.Queue(max_task) - # List of files with a non-zero return code. - failed_files = [] - lock = threading.Lock() - for _ in range(max_task): - t = threading.Thread(target=run_tidy, - args=(args, tmpdir, build_path, task_queue, lock, failed_files)) - t.daemon = True - t.start() + # Build up a big regexy filter from all command line arguments. + file_name_re = re.compile("|".join(args.files)) - # Fill the queue with files. - for name in files: - if file_name_re.search(name): - task_queue.put(name) + return_code = 0 + try: + # Spin up a bunch of tidy-launching threads. + task_queue = queue.Queue(max_task) + # List of files with a non-zero return code. + failed_files = [] + lock = threading.Lock() + for _ in range(max_task): + t = threading.Thread( + target=run_tidy, + args=(args, tmpdir, build_path, task_queue, lock, failed_files), + ) + t.daemon = True + t.start() - # Wait for all threads to be done. - task_queue.join() - if len(failed_files): - return_code = 1 + # Fill the queue with files. + for name in files: + if file_name_re.search(name): + task_queue.put(name) + + # Wait for all threads to be done. + task_queue.join() + if len(failed_files): + return_code = 1 + + except KeyboardInterrupt: + # This is a sad hack. Unfortunately subprocess goes + # bonkers with ctrl-c and we start forking merrily. + print("\nCtrl-C detected, goodbye.") + if tmpdir: + shutil.rmtree(tmpdir) + os.kill(0, 9) + + if yaml and args.export_fixes: + print("Writing fixes to " + args.export_fixes + " ...") + try: + merge_replacement_files(tmpdir, args.export_fixes) + except: + print("Error exporting fixes.\n", file=sys.stderr) + traceback.print_exc() + return_code = 1 + + if args.fix: + print("Applying fixes ...") + try: + apply_fixes(args, tmpdir) + except: + print("Error applying fixes.\n", file=sys.stderr) + traceback.print_exc() + return_code = 1 - except KeyboardInterrupt: - # This is a sad hack. Unfortunately subprocess goes - # bonkers with ctrl-c and we start forking merrily. - print('\nCtrl-C detected, goodbye.') if tmpdir: - shutil.rmtree(tmpdir) - os.kill(0, 9) - - if yaml and args.export_fixes: - print('Writing fixes to ' + args.export_fixes + ' ...') - try: - merge_replacement_files(tmpdir, args.export_fixes) - except: - print('Error exporting fixes.\n', file=sys.stderr) - traceback.print_exc() - return_code=1 - - if args.fix: - print('Applying fixes ...') - try: - apply_fixes(args, tmpdir) - except: - print('Error applying fixes.\n', file=sys.stderr) - traceback.print_exc() - return_code = 1 - - if tmpdir: - shutil.rmtree(tmpdir) - sys.exit(return_code) + shutil.rmtree(tmpdir) + sys.exit(return_code) -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() From 1abe8f8bfc39d882dd5549407e7d916ace934b7f Mon Sep 17 00:00:00 2001 From: niko4299 Date: Thu, 7 Jul 2022 13:15:01 +0200 Subject: [PATCH 3/6] Labels defined with colon --- src/memgraph.cpp | 2 +- src/query/frontend/opencypher/grammar/MemgraphCypher.g4 | 2 +- src/query/interpreter.cpp | 9 +++++---- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/memgraph.cpp b/src/memgraph.cpp index 955379024..e935c03ec 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -501,7 +501,7 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler { if (first_user) { spdlog::info("{} is first created user. Granting all privileges.", username); - GrantPrivilege(username, memgraph::query::kPrivilegesAll, {}); + GrantPrivilege(username, memgraph::query::kPrivilegesAll, {"*"}); } return user_added; diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 index 12d54a916..cb9cc2ba0 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 @@ -260,7 +260,7 @@ privilege : CREATE privilegeList : privilege ( ',' privilege )* ; -labelList : label ( ',' label )* ; +labelList : COLON label ( ',' COLON label )* ; label : ( '*' | symbolicName ) ; diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index f07d6e414..ec2eed347 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -295,10 +295,11 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Pa AuthQuery::Action::REVOKE_PRIVILEGE, AuthQuery::Action::SHOW_PRIVILEGES, AuthQuery::Action::SHOW_USERS_FOR_ROLE, AuthQuery::Action::SHOW_ROLE_FOR_USER}; - if (license_check_result.HasError() && enterprise_only_methods.contains(auth_query->action_)) { - throw utils::BasicException( - utils::license::LicenseCheckErrorToString(license_check_result.GetError(), "advanced authentication features")); - } + // if (license_check_result.HasError() && enterprise_only_methods.contains(auth_query->action_)) { + // throw utils::BasicException( + // utils::license::LicenseCheckErrorToString(license_check_result.GetError(), "advanced authentication + // features")); + // } switch (auth_query->action_) { case AuthQuery::Action::CREATE_USER: From b63db202d666914902bc1ff3a628c7ba0491e1d6 Mon Sep 17 00:00:00 2001 From: niko4299 Date: Thu, 7 Jul 2022 13:21:05 +0200 Subject: [PATCH 4/6] uncommented --- src/query/interpreter.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index ec2eed347..2299935d6 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -295,11 +295,11 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Pa AuthQuery::Action::REVOKE_PRIVILEGE, AuthQuery::Action::SHOW_PRIVILEGES, AuthQuery::Action::SHOW_USERS_FOR_ROLE, AuthQuery::Action::SHOW_ROLE_FOR_USER}; - // if (license_check_result.HasError() && enterprise_only_methods.contains(auth_query->action_)) { - // throw utils::BasicException( - // utils::license::LicenseCheckErrorToString(license_check_result.GetError(), "advanced authentication - // features")); - // } + if (license_check_result.HasError() && enterprise_only_methods.contains(auth_query->action_)) { + throw utils::BasicException( + utils::license::LicenseCheckErrorToString(license_check_result.GetError(), "advanced authentication + features")); + } switch (auth_query->action_) { case AuthQuery::Action::CREATE_USER: From c2a1328dcc473ff72e79cdb601369531867cc7f2 Mon Sep 17 00:00:00 2001 From: niko4299 Date: Thu, 7 Jul 2022 16:03:22 +0200 Subject: [PATCH 5/6] added Boris class --- src/auth/models.cpp | 72 ++++++++++++++++++++++++++++++++++++++++----- src/auth/models.hpp | 24 +++++++++------ 2 files changed, 80 insertions(+), 16 deletions(-) diff --git a/src/auth/models.cpp b/src/auth/models.cpp index 9780f84e6..85da82df8 100644 --- a/src/auth/models.cpp +++ b/src/auth/models.cpp @@ -185,18 +185,67 @@ bool operator==(const Permissions &first, const Permissions &second) { bool operator!=(const Permissions &first, const Permissions &second) { return !(first == second); } -LabelPermissions::LabelPermissions(const std::unordered_map &permissions) - : permissions_(permissions) {} +LabelPermissions::LabelPermissions(const std::unordered_set &grants, + const std::unordered_set &denies) + : grants_(grants), denies_(denies) {} -void LabelPermissions::Grant(const std::string &label) { permissions_[label] = 1; } +PermissionLevel LabelPermissions::Has(const std::string &permission) const { + if (denies_.find(permission) != denies_.end()) { + return PermissionLevel::DENY; + } -void LabelPermissions::Deny(const std::string &label) { permissions_[label] = 0; } + if (grants_.find(permission) != denies_.end()) { + return PermissionLevel::GRANT; + } -void LabelPermissions::Revoke(const std::string &label) { permissions_.erase(label); } + return PermissionLevel::NEUTRAL; +} + +void LabelPermissions::Grant(const std::string &permission) { + auto deniedPermissionIter = denies_.find(permission); + + if (deniedPermissionIter != denies_.end()) { + denies_.erase(deniedPermissionIter); + } + + if (grants_.find(permission) == grants_.end()) { + grants_.insert(permission); + } +} + +void LabelPermissions::Revoke(const std::string &permission) { + auto deniedPermissionIter = denies_.find(permission); + auto grantedPermissionIter = grants_.find(permission); + + if (deniedPermissionIter != denies_.end()) { + denies_.erase(deniedPermissionIter); + } + + if (grantedPermissionIter != grants_.end()) { + grants_.erase(grantedPermissionIter); + } +} + +void LabelPermissions::Deny(const std::string &permission) { + auto grantedPermissionIter = grants_.find(permission); + + if (grantedPermissionIter != grants_.end()) { + grants_.erase(grantedPermissionIter); + } + + if (denies_.find(permission) == denies_.end()) { + denies_.insert(permission); + } +} + +std::unordered_set LabelPermissions::GetGrants() const { return grants_; } + +std::unordered_set LabelPermissions::GetDenies() const { return denies_; } nlohmann::json LabelPermissions::Serialize() const { nlohmann::json data = nlohmann::json::object(); - data["labelPermissions"] = permissions_; + data["grants"] = grants_; + data["denies"] = denies_; return data; } @@ -205,9 +254,18 @@ LabelPermissions LabelPermissions::Deserialize(const nlohmann::json &data) { throw AuthException("Couldn't load permissions data!"); } - return {data["labelPermissions"]}; + return {LabelPermissions(data["grants"], data["denies"])}; } +std::unordered_set LabelPermissions::grants() const { return grants_; } +std::unordered_set LabelPermissions::denies() const { return denies_; } + +bool operator==(const LabelPermissions &first, const LabelPermissions &second) { + return first.grants() == second.grants() && first.denies() == second.denies(); +} + +bool operator!=(const LabelPermissions &first, const LabelPermissions &second) { return !(first == second); } + Role::Role(const std::string &rolename) : rolename_(utils::ToLowerCase(rolename)) {} Role::Role(const std::string &rolename, const Permissions &permissions) diff --git a/src/auth/models.hpp b/src/auth/models.hpp index 1003086c9..91e4e2174 100644 --- a/src/auth/models.hpp +++ b/src/auth/models.hpp @@ -12,6 +12,7 @@ #include #include +#include namespace memgraph::auth { // These permissions must have values that are applicable for usage in a @@ -91,29 +92,36 @@ bool operator!=(const Permissions &first, const Permissions &second); class LabelPermissions final { public: - LabelPermissions(const std::unordered_map &permissions_ = {}); + LabelPermissions(const std::unordered_set &grants = {}, + const std::unordered_set &denies = {}); - PermissionLevel Has(const std::string &label) const; + PermissionLevel Has(const std::string &permission) const; - void Grant(const std::string &label); + void Grant(const std::string &permission); - void Revoke(const std::string &label); + void Revoke(const std::string &permission); - void Deny(const std::string &label); + void Deny(const std::string &permission); + + std::unordered_set GetGrants() const; + std::unordered_set GetDenies() const; nlohmann::json Serialize() const; /// @throw AuthException if unable to deserialize. static LabelPermissions Deserialize(const nlohmann::json &data); - std::unordered_map permissions() const; + std::unordered_set grants() const; + std::unordered_set denies() const; private: - std::unordered_map permissions_; + std::unordered_set grants_{}; + std::unordered_set denies_{}; }; bool operator==(const LabelPermissions &first, const LabelPermissions &second); +bool operator!=(const LabelPermissions &first, const LabelPermissions &second); class Role final { public: Role(const std::string &rolename); @@ -192,5 +200,3 @@ class User final { bool operator==(const User &first, const User &second); } // namespace memgraph::auth - -// namespace memgraph::auth From 83aa71a29f3b3cdafed00c0427bfea65d3a91d17 Mon Sep 17 00:00:00 2001 From: niko4299 Date: Fri, 8 Jul 2022 10:59:54 +0200 Subject: [PATCH 6/6] interpreter.cpp 286 line convert vector label strings to vector LabelId, possible changes to LabelPermission --- src/query/interpreter.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index 2299935d6..2f28b1205 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -283,6 +283,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Pa std::string user_or_role = auth_query->user_or_role_; std::vector privileges = auth_query->privileges_; std::vector labels = auth_query->labels_; + // std::vector labels = NamesToLabels(labels, db_accessor); auto password = EvaluateOptionalExpression(auth_query->password_, &evaluator); Callback callback;