diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index 1de5e55ff..d8ce6df22 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -467,8 +467,10 @@ antlrcpp::Any CypherMainVisitor::visitLoadCsv(MemgraphCypher::LoadCsvContext *ct auto *load_csv = storage_->Create<LoadCsv>(); // handle file name - if (ctx->csvFile()->literal()->StringLiteral()) { + if (ctx->csvFile()->literal() && ctx->csvFile()->literal()->StringLiteral()) { load_csv->file_ = std::any_cast<Expression *>(ctx->csvFile()->accept(this)); + } else if (ctx->csvFile()->parameter()) { + load_csv->file_ = std::any_cast<ParameterLookup *>(ctx->csvFile()->accept(this)); } else { throw SemanticException("CSV file path should be a string literal"); } diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 index 0597967c7..892f1b1e3 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 @@ -265,7 +265,7 @@ loadCsv : LOAD CSV FROM csvFile ( WITH | NO ) HEADER ( NULLIF nullif ) ? AS rowVar ; -csvFile : literal ; +csvFile : literal | parameter ; delimiter : literal ; diff --git a/src/query/plan/rule_based_planner.hpp b/src/query/plan/rule_based_planner.hpp index 092710628..7fba3b623 100644 --- a/src/query/plan/rule_based_planner.hpp +++ b/src/query/plan/rule_based_planner.hpp @@ -232,7 +232,6 @@ class RuleBasedPlanner { } else if (auto *load_csv = utils::Downcast<query::LoadCsv>(clause)) { const auto &row_sym = context.symbol_table->at(*load_csv->row_var_); context.bound_symbols.insert(row_sym); - input_op = std::make_unique<plan::LoadCsv>(std::move(input_op), load_csv->file_, load_csv->with_header_, load_csv->ignore_bad_, load_csv->delimiter_, load_csv->quote_, load_csv->nullif_, row_sym); diff --git a/tests/e2e/load_csv/load_csv.py b/tests/e2e/load_csv/load_csv.py index 5d91a6e1d..371803ed1 100644 --- a/tests/e2e/load_csv/load_csv.py +++ b/tests/e2e/load_csv/load_csv.py @@ -16,6 +16,7 @@ from pathlib import Path import pytest from gqlalchemy import Memgraph from mgclient import DatabaseError +from neo4j import GraphDatabase SIMPLE_CSV_FILE = "simple.csv" @@ -52,5 +53,22 @@ def test_given_one_row_in_db_when_load_csv_after_match_then_pass(): assert len(list(results)) == 4 +def test_load_csv_with_parameters(): + memgraph = Memgraph("localhost", 7687) + URI = "bolt://localhost:7687" + AUTH = ("", "") + + with GraphDatabase.driver(URI, auth=AUTH) as client: + with client.session(database="memgraph") as session: + results = session.run( + f"""MATCH (n {{prop: 1}}) LOAD CSV + FROM $file WITH HEADER AS row + CREATE (:Person {{name: row.name}}) + RETURN n""", + file=get_file_path(SIMPLE_CSV_FILE), + ) + assert len(list(results)) == 4 + + if __name__ == "__main__": sys.exit(pytest.main([__file__, "-rA"])) diff --git a/tests/gql_behave/tests/memgraph_V1/features/functions.feature b/tests/gql_behave/tests/memgraph_V1/features/functions.feature index 19a7f2332..4a58cbfbe 100644 --- a/tests/gql_behave/tests/memgraph_V1/features/functions.feature +++ b/tests/gql_behave/tests/memgraph_V1/features/functions.feature @@ -64,7 +64,7 @@ Feature: Functions Given an empty graph And having executed """ - CREATE (:Node {prop: ToBoolean("t")}); + CREATE (:Node {prop: TOBOOLEAN("t")}); """ When executing query: """ @@ -74,11 +74,11 @@ Feature: Functions | n.prop | | true | - Scenario: ToBoolean test 03: + Scenario: ToBoolean test 04: Given an empty graph And having executed """ - CREATE (:Node {prop: ToBoolean("f")}); + CREATE (:Node {prop: TOBOOLEAN("f")}); """ When executing query: """