diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 449fa866b..a922e92bf 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -4648,7 +4648,7 @@ LoadCsv::LoadCsv(std::shared_ptr input, Expression *file, bool MG_ASSERT(file_, "Something went wrong - '{}' member file_ shouldn't be a nullptr", __func__); } -bool LoadCsv::Accept(HierarchicalLogicalOperatorVisitor &visitor) { return false; }; +ACCEPT_WITH_INPUT(LoadCsv) class LoadCsvCursor; @@ -4699,14 +4699,12 @@ TypedValue CsvRowToTypedMap(csv::Reader::Row &row, csv::Reader::Header header) { class LoadCsvCursor : public Cursor { const LoadCsv *self_; const UniqueCursorPtr input_cursor_; - bool input_is_once_; + bool did_pull_; std::optional reader_{}; public: LoadCsvCursor(const LoadCsv *self, utils::MemoryResource *mem) - : self_(self), input_cursor_(self_->input_->MakeCursor(mem)) { - input_is_once_ = dynamic_cast(self_->input_.get()); - } + : self_(self), input_cursor_(self_->input_->MakeCursor(mem)), did_pull_{false} {} bool Pull(Frame &frame, ExecutionContext &context) override { SCOPED_PROFILE_OP("LoadCsv"); @@ -4722,14 +4720,14 @@ class LoadCsvCursor : public Cursor { reader_ = MakeReader(&context.evaluation_context); } - bool input_pulled = input_cursor_->Pull(frame, context); + if (input_cursor_->Pull(frame, context)) { + if (did_pull_) { + throw QueryRuntimeException( + "LOAD CSV can be executed only once, please check if the cardinality of the operator before LOAD CSV is 1"); + } + did_pull_ = true; + } - // If the input is Once, we have to keep going until we read all the rows, - // regardless of whether the pull on Once returned false. - // If we have e.g. MATCH(n) LOAD CSV ... AS x SET n.name = x.name, then we - // have to read at most cardinality(n) rows (but we can read less and stop - // pulling MATCH). - if (!input_is_once_ && !input_pulled) return false; auto row = reader_->GetNextRow(context.evaluation_context.memory); if (!row) { return false; diff --git a/src/query/plan/pretty_print.cpp b/src/query/plan/pretty_print.cpp index 1d0512d85..3b5c2303b 100644 --- a/src/query/plan/pretty_print.cpp +++ b/src/query/plan/pretty_print.cpp @@ -874,11 +874,27 @@ bool PlanToJsonVisitor::PreVisit(query::plan::CallProcedure &op) { bool PlanToJsonVisitor::PreVisit(query::plan::LoadCsv &op) { json self; self["name"] = "LoadCsv"; - self["file"] = ToJson(op.file_); - self["with_header"] = op.with_header_; - self["ignore_bad"] = op.ignore_bad_; - self["delimiter"] = ToJson(op.delimiter_); - self["quote"] = ToJson(op.quote_); + + if (op.file_) { + self["file"] = ToJson(op.file_); + } + + if (op.with_header_) { + self["with_header"] = op.with_header_; + } + + if (op.ignore_bad_) { + self["ignore_bad"] = op.ignore_bad_; + } + + if (op.delimiter_) { + self["delimiter"] = ToJson(op.delimiter_); + } + + if (op.quote_) { + self["quote"] = ToJson(op.quote_); + } + self["row_variable"] = ToJson(op.row_var_); op.input_->Accept(*this); diff --git a/src/query/plan/rewrite/index_lookup.hpp b/src/query/plan/rewrite/index_lookup.hpp index 1bcf2cb09..feac431fe 100644 --- a/src/query/plan/rewrite/index_lookup.hpp +++ b/src/query/plan/rewrite/index_lookup.hpp @@ -477,6 +477,16 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { return true; } + bool PreVisit(LoadCsv &op) override { + prev_ops_.push_back(&op); + return true; + } + + bool PostVisit(LoadCsv & /*op*/) override { + prev_ops_.pop_back(); + return true; + } + std::shared_ptr new_root_; private: diff --git a/src/utils/csv_parsing.cpp b/src/utils/csv_parsing.cpp index 49d8a0949..4744f2100 100644 --- a/src/utils/csv_parsing.cpp +++ b/src/utils/csv_parsing.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source diff --git a/src/utils/csv_parsing.hpp b/src/utils/csv_parsing.hpp index 37d438b41..928654ca8 100644 --- a/src/utils/csv_parsing.hpp +++ b/src/utils/csv_parsing.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source diff --git a/tests/e2e/CMakeLists.txt b/tests/e2e/CMakeLists.txt index e1820aa93..3162c7857 100644 --- a/tests/e2e/CMakeLists.txt +++ b/tests/e2e/CMakeLists.txt @@ -55,6 +55,7 @@ add_subdirectory(python_query_modules_reloading) add_subdirectory(analyze_graph) add_subdirectory(transaction_queue) add_subdirectory(mock_api) +add_subdirectory(load_csv) add_subdirectory(init_file_flags) copy_e2e_python_files(pytest_runner pytest_runner.sh "") diff --git a/tests/e2e/load_csv/CMakeLists.txt b/tests/e2e/load_csv/CMakeLists.txt new file mode 100644 index 000000000..06e6d6e33 --- /dev/null +++ b/tests/e2e/load_csv/CMakeLists.txt @@ -0,0 +1,10 @@ +function(copy_load_csv_e2e_python_files FILE_NAME) + copy_e2e_python_files(load_csv ${FILE_NAME}) +endfunction() + +function(copy_load_csv_e2e_files FILE_NAME) + copy_e2e_python_files(load_csv ${FILE_NAME}) +endfunction() + +copy_load_csv_e2e_python_files(load_csv.py) +copy_load_csv_e2e_files(simple.csv) diff --git a/tests/e2e/load_csv/load_csv.py b/tests/e2e/load_csv/load_csv.py new file mode 100644 index 000000000..5d91a6e1d --- /dev/null +++ b/tests/e2e/load_csv/load_csv.py @@ -0,0 +1,56 @@ +# Copyright 2022 Memgraph Ltd. +# +# Use of this software is governed by the Business Source License +# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +# License, and you may not use this file except in compliance with the Business Source License. +# +# As of the Change Date specified in that file, in accordance with +# the Business Source License, use of this software will be governed +# by the Apache License, Version 2.0, included in the file +# licenses/APL.txt. + +import os +import sys +from pathlib import Path + +import pytest +from gqlalchemy import Memgraph +from mgclient import DatabaseError + +SIMPLE_CSV_FILE = "simple.csv" + + +def get_file_path(file: str) -> str: + return os.path.join(Path(__file__).parent.absolute(), file) + + +def test_given_two_rows_in_db_when_load_csv_after_match_then_throw_exception(): + memgraph = Memgraph("localhost", 7687) + + with pytest.raises(DatabaseError): + next( + memgraph.execute_and_fetch( + f"""MATCH (n) LOAD CSV + FROM '{get_file_path(SIMPLE_CSV_FILE)}' WITH HEADER AS row + CREATE (:Person {{name: row.name}}) + """ + ) + ) + + +def test_given_one_row_in_db_when_load_csv_after_match_then_pass(): + memgraph = Memgraph("localhost", 7687) + + results = memgraph.execute_and_fetch( + f"""MATCH (n {{prop: 1}}) LOAD CSV + FROM '{get_file_path(SIMPLE_CSV_FILE)}' WITH HEADER AS row + CREATE (:Person {{name: row.name}}) + RETURN n + """ + ) + + assert len(list(results)) == 4 + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__, "-rA"])) diff --git a/tests/e2e/load_csv/simple.csv b/tests/e2e/load_csv/simple.csv new file mode 100644 index 000000000..42f9de93d --- /dev/null +++ b/tests/e2e/load_csv/simple.csv @@ -0,0 +1,5 @@ +id,name +1,Joseph +2,Peter +3,Ella +4,Joe diff --git a/tests/e2e/load_csv/workloads.yaml b/tests/e2e/load_csv/workloads.yaml new file mode 100644 index 000000000..e54fa728f --- /dev/null +++ b/tests/e2e/load_csv/workloads.yaml @@ -0,0 +1,15 @@ +load_csv_cluster: &load_csv_cluster + cluster: + main: + args: ["--bolt-port", "7687", "--log-level=TRACE"] + log_file: "load_csv_log_file.txt" + setup_queries: + - "CREATE (n {prop: 1});" + - "CREATE (n {prop: 2});" + validation_queries: [] + +workloads: + - name: "MATCH + LOAD CSV" + binary: "tests/e2e/pytest_runner.sh" + args: ["load_csv/load_csv.py"] + <<: *load_csv_cluster diff --git a/tests/setup.sh b/tests/setup.sh index 32545d68b..1630a6454 100755 --- a/tests/setup.sh +++ b/tests/setup.sh @@ -1,6 +1,7 @@ #!/bin/bash # shellcheck disable=1091 + set -Eeuo pipefail DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" diff --git a/tests/unit/utils_csv_parsing.cpp b/tests/unit/utils_csv_parsing.cpp index 3c852b171..9fef48af1 100644 --- a/tests/unit/utils_csv_parsing.cpp +++ b/tests/unit/utils_csv_parsing.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source