From 05cc35bf93b3c492c0ea1f0d2f469f5d504b1d4c Mon Sep 17 00:00:00 2001 From: Josipmrden Date: Wed, 21 Jun 2023 14:50:46 +0200 Subject: [PATCH] Add command NULLIF for identifying nulls in LOAD CSV (#914) Add NULLIF command which turns all row values corresponding to the string to the nullif character sequence. --- src/query/frontend/ast/ast.hpp | 5 +- .../frontend/ast/cypher_main_visitor.cpp | 5 ++ .../opencypher/grammar/MemgraphCypher.g4 | 4 ++ .../opencypher/grammar/MemgraphCypherLexer.g4 | 1 + src/query/plan/operator.cpp | 35 +++++++++--- src/query/plan/operator.hpp | 4 +- src/query/plan/pretty_print.cpp | 4 ++ src/query/plan/rule_based_planner.hpp | 7 ++- tests/e2e/load_csv/CMakeLists.txt | 3 ++ tests/e2e/load_csv/load_csv_nullif.py | 53 +++++++++++++++++++ tests/e2e/load_csv/nullif.csv | 5 ++ tests/e2e/load_csv/workloads.yaml | 11 ++++ 12 files changed, 124 insertions(+), 13 deletions(-) create mode 100644 tests/e2e/load_csv/load_csv_nullif.py create mode 100644 tests/e2e/load_csv/nullif.csv diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index 3bdf704ce..caee103a2 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -3022,6 +3022,7 @@ class LoadCsv : public memgraph::query::Clause { bool ignore_bad_; memgraph::query::Expression *delimiter_{nullptr}; memgraph::query::Expression *quote_{nullptr}; + memgraph::query::Expression *nullif_{nullptr}; memgraph::query::Identifier *row_var_{nullptr}; LoadCsv *Clone(AstStorage *storage) const override { @@ -3031,18 +3032,20 @@ class LoadCsv : public memgraph::query::Clause { object->ignore_bad_ = ignore_bad_; object->delimiter_ = delimiter_ ? delimiter_->Clone(storage) : nullptr; object->quote_ = quote_ ? quote_->Clone(storage) : nullptr; + object->nullif_ = nullif_; object->row_var_ = row_var_ ? row_var_->Clone(storage) : nullptr; return object; } protected: explicit LoadCsv(Expression *file, bool with_header, bool ignore_bad, Expression *delimiter, Expression *quote, - Identifier *row_var) + Expression *nullif, Identifier *row_var) : file_(file), with_header_(with_header), ignore_bad_(ignore_bad), delimiter_(delimiter), quote_(quote), + nullif_(nullif), row_var_(row_var) { DMG_ASSERT(row_var, "LoadCsv cannot take nullptr for identifier"); } diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index 8b772ca0d..f2e037172 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -362,6 +362,11 @@ antlrcpp::Any CypherMainVisitor::visitLoadCsv(MemgraphCypher::LoadCsvContext *ct // handle skip bad row option load_csv->ignore_bad_ = ctx->IGNORE() && ctx->BAD(); + // handle character sequence which will correspond to nulls + if (ctx->NULLIF()) { + load_csv->nullif_ = std::any_cast(ctx->nullif()->accept(this)); + } + // handle delimiter if (ctx->DELIMITER()) { if (ctx->delimiter()->literal()->StringLiteral()) { diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 index 9bb1bfabc..f7ffe8803 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 @@ -59,6 +59,7 @@ memgraphCypherKeyword : cypherKeyword | GRANT | HEADER | IDENTIFIED + | NULLIF | ISOLATION | IN_MEMORY_ANALYTICAL | IN_MEMORY_TRANSACTIONAL @@ -224,6 +225,7 @@ loadCsv : LOAD CSV FROM csvFile ( WITH | NO ) HEADER ( IGNORE BAD ) ? ( DELIMITER delimiter ) ? ( QUOTE quote ) ? + ( NULLIF nullif ) ? AS rowVar ; csvFile : literal ; @@ -232,6 +234,8 @@ delimiter : literal ; quote : literal ; +nullif : literal ; + rowVar : variable ; userOrRoleName : symbolicName ; diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 index 862682d6e..674a5f61d 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 @@ -85,6 +85,7 @@ MODULE_WRITE : M O D U L E UNDERSCORE W R I T E ; NEXT : N E X T ; NO : N O ; NOTHING : N O T H I N G ; +NULLIF : N U L L I F ; PASSWORD : P A S S W O R D ; PORT : P O R T ; PRIVILEGES : P R I V I L E G E S ; diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index a922e92bf..d50ef68e4 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -4637,13 +4637,14 @@ UniqueCursorPtr CallProcedure::MakeCursor(utils::MemoryResource *mem) const { } LoadCsv::LoadCsv(std::shared_ptr input, Expression *file, bool with_header, bool ignore_bad, - Expression *delimiter, Expression *quote, Symbol row_var) + Expression *delimiter, Expression *quote, Expression *nullif, Symbol row_var) : input_(input ? input : (std::make_shared())), file_(file), with_header_(with_header), ignore_bad_(ignore_bad), delimiter_(delimiter), quote_(quote), + nullif_(nullif), row_var_(row_var) { MG_ASSERT(file_, "Something went wrong - '{}' member file_ shouldn't be a nullptr", __func__); } @@ -4674,22 +4675,31 @@ auto ToOptionalString(ExpressionEvaluator *evaluator, Expression *expression) -> return std::nullopt; }; -TypedValue CsvRowToTypedList(csv::Reader::Row &row) { +TypedValue CsvRowToTypedList(csv::Reader::Row &row, std::optional &nullif) { auto *mem = row.get_allocator().GetMemoryResource(); auto typed_columns = utils::pmr::vector(mem); typed_columns.reserve(row.size()); for (auto &column : row) { - typed_columns.emplace_back(std::move(column)); + if (!nullif.has_value() || column != nullif.value()) { + typed_columns.emplace_back(std::move(column)); + } else { + typed_columns.emplace_back(); + } } return {std::move(typed_columns), mem}; } -TypedValue CsvRowToTypedMap(csv::Reader::Row &row, csv::Reader::Header header) { +TypedValue CsvRowToTypedMap(csv::Reader::Row &row, csv::Reader::Header header, + std::optional &nullif) { // a valid row has the same number of elements as the header auto *mem = row.get_allocator().GetMemoryResource(); utils::pmr::map m(mem); for (auto i = 0; i < row.size(); ++i) { - m.emplace(std::move(header[i]), std::move(row[i])); + if (!nullif.has_value() || row[i] != nullif.value()) { + m.emplace(std::move(header[i]), std::move(row[i])); + } else { + m.emplace(std::piecewise_construct, std::forward_as_tuple(std::move(header[i])), std::forward_as_tuple()); + } } return {std::move(m), mem}; } @@ -4701,6 +4711,7 @@ class LoadCsvCursor : public Cursor { const UniqueCursorPtr input_cursor_; bool did_pull_; std::optional reader_{}; + std::optional nullif_; public: LoadCsvCursor(const LoadCsv *self, utils::MemoryResource *mem) @@ -4718,6 +4729,7 @@ class LoadCsvCursor : public Cursor { // without massacring the code even worse than I did here if (UNLIKELY(!reader_)) { reader_ = MakeReader(&context.evaluation_context); + nullif_ = ParseNullif(&context.evaluation_context); } if (input_cursor_->Pull(frame, context)) { @@ -4733,10 +4745,10 @@ class LoadCsvCursor : public Cursor { return false; } if (!reader_->HasHeader()) { - frame[self_->row_var_] = CsvRowToTypedList(*row); + frame[self_->row_var_] = CsvRowToTypedList(*row, nullif_); } else { frame[self_->row_var_] = - CsvRowToTypedMap(*row, csv::Reader::Header(reader_->GetHeader(), context.evaluation_context.memory)); + CsvRowToTypedMap(*row, csv::Reader::Header(reader_->GetHeader(), context.evaluation_context.memory), nullif_); } if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(self_->row_var_.name())) { context.frame_change_collector->ResetTrackingValue(self_->row_var_.name()); @@ -4768,6 +4780,15 @@ class LoadCsvCursor : public Cursor { csv::Reader::Config(self_->with_header_, self_->ignore_bad_, std::move(maybe_delim), std::move(maybe_quote)), utils::NewDeleteResource()); } + + std::optional ParseNullif(EvaluationContext *eval_context) { + Frame frame(0); + SymbolTable symbol_table; + DbAccessor *dba = nullptr; + auto evaluator = ExpressionEvaluator(&frame, symbol_table, *eval_context, dba, storage::View::OLD); + + return ToOptionalString(&evaluator, self_->nullif_); + } }; UniqueCursorPtr LoadCsv::MakeCursor(utils::MemoryResource *mem) const { diff --git a/src/query/plan/operator.hpp b/src/query/plan/operator.hpp index 9c0a0c831..8cbbccc9c 100644 --- a/src/query/plan/operator.hpp +++ b/src/query/plan/operator.hpp @@ -2227,7 +2227,7 @@ class LoadCsv : public memgraph::query::plan::LogicalOperator { LoadCsv() = default; LoadCsv(std::shared_ptr input, Expression *file, bool with_header, bool ignore_bad, - Expression *delimiter, Expression *quote, Symbol row_var); + Expression *delimiter, Expression *quote, Expression *nullif, Symbol row_var); bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; std::vector OutputSymbols(const SymbolTable &) const override; @@ -2243,6 +2243,7 @@ class LoadCsv : public memgraph::query::plan::LogicalOperator { bool ignore_bad_; Expression *delimiter_{nullptr}; Expression *quote_{nullptr}; + Expression *nullif_{nullptr}; Symbol row_var_; std::unique_ptr Clone(AstStorage *storage) const override { @@ -2253,6 +2254,7 @@ class LoadCsv : public memgraph::query::plan::LogicalOperator { object->ignore_bad_ = ignore_bad_; object->delimiter_ = delimiter_ ? delimiter_->Clone(storage) : nullptr; object->quote_ = quote_ ? quote_->Clone(storage) : nullptr; + object->nullif_ = nullif_; object->row_var_ = row_var_; return object; } diff --git a/src/query/plan/pretty_print.cpp b/src/query/plan/pretty_print.cpp index 3b5c2303b..3c23a2e8c 100644 --- a/src/query/plan/pretty_print.cpp +++ b/src/query/plan/pretty_print.cpp @@ -895,6 +895,10 @@ bool PlanToJsonVisitor::PreVisit(query::plan::LoadCsv &op) { self["quote"] = ToJson(op.quote_); } + if (op.nullif_) { + self["nullif"] = ToJson(op.nullif_); + } + self["row_variable"] = ToJson(op.row_var_); op.input_->Accept(*this); diff --git a/src/query/plan/rule_based_planner.hpp b/src/query/plan/rule_based_planner.hpp index b05b8d06e..09a53cf29 100644 --- a/src/query/plan/rule_based_planner.hpp +++ b/src/query/plan/rule_based_planner.hpp @@ -226,10 +226,9 @@ class RuleBasedPlanner { const auto &row_sym = context.symbol_table->at(*load_csv->row_var_); context.bound_symbols.insert(row_sym); - input_op = - std::make_unique(std::move(input_op), load_csv->file_, load_csv->with_header_, - load_csv->ignore_bad_, load_csv->delimiter_, load_csv->quote_, row_sym); - + input_op = std::make_unique(std::move(input_op), load_csv->file_, load_csv->with_header_, + load_csv->ignore_bad_, load_csv->delimiter_, load_csv->quote_, + load_csv->nullif_, row_sym); } else if (auto *foreach = utils::Downcast(clause)) { context.is_write_query = true; input_op = HandleForeachClause(foreach, std::move(input_op), *context.symbol_table, context.bound_symbols, diff --git a/tests/e2e/load_csv/CMakeLists.txt b/tests/e2e/load_csv/CMakeLists.txt index 06e6d6e33..368915dbe 100644 --- a/tests/e2e/load_csv/CMakeLists.txt +++ b/tests/e2e/load_csv/CMakeLists.txt @@ -8,3 +8,6 @@ endfunction() copy_load_csv_e2e_python_files(load_csv.py) copy_load_csv_e2e_files(simple.csv) + +copy_load_csv_e2e_python_files(load_csv_nullif.py) +copy_load_csv_e2e_files(nullif.csv) diff --git a/tests/e2e/load_csv/load_csv_nullif.py b/tests/e2e/load_csv/load_csv_nullif.py new file mode 100644 index 000000000..018781683 --- /dev/null +++ b/tests/e2e/load_csv/load_csv_nullif.py @@ -0,0 +1,53 @@ +# Copyright 2022 Memgraph Ltd. +# +# Use of this software is governed by the Business Source License +# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +# License, and you may not use this file except in compliance with the Business Source License. +# +# As of the Change Date specified in that file, in accordance with +# the Business Source License, use of this software will be governed +# by the Apache License, Version 2.0, included in the file +# licenses/APL.txt. + +import os +import sys +from pathlib import Path + +import pytest +from gqlalchemy import Memgraph + +NULLIF_CSV_FILE = "nullif.csv" + + +def get_file_path(file: str) -> str: + parent_path = Path(__file__).parent.absolute() + return os.path.join(parent_path, file) + + +def test_given_csv_when_nullif_then_all_identical_rows_are_null(): + memgraph = Memgraph("localhost", 7687) + + results = list( + memgraph.execute_and_fetch( + f"""LOAD CSV FROM '{get_file_path(NULLIF_CSV_FILE)}' + WITH HEADER NULLIF 'N/A' AS row + CREATE (n:Person {{name: row.name, age: row.age, + percentage: row.percentage, works_in_IT: row.works_in_IT}}) + RETURN n + """ + ) + ) + + expected_properties = [ + {"age": "10", "percentage": "15.0", "works_in_IT": "false"}, + {"name": "John", "percentage": "35.4", "works_in_IT": "false"}, + {"name": "Milewa", "age": "34", "works_in_IT": "false"}, + {"name": "Lucas", "age": "50", "percentage": "12.5"}, + ] + properties = [result["n"]._properties for result in results] + + assert expected_properties == properties + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__, "-rA"])) diff --git a/tests/e2e/load_csv/nullif.csv b/tests/e2e/load_csv/nullif.csv new file mode 100644 index 000000000..3a38bf775 --- /dev/null +++ b/tests/e2e/load_csv/nullif.csv @@ -0,0 +1,5 @@ +name,age,percentage,works_in_IT +N/A,10,15.0,false +John,N/A,35.4,false +Milewa,34,N/A,false +Lucas,50,12.5,N/A diff --git a/tests/e2e/load_csv/workloads.yaml b/tests/e2e/load_csv/workloads.yaml index e54fa728f..d07609699 100644 --- a/tests/e2e/load_csv/workloads.yaml +++ b/tests/e2e/load_csv/workloads.yaml @@ -1,3 +1,10 @@ +nullif_cluster: &nullif_cluster + cluster: + main: + args: ["--bolt-port", "7687", "--log-level=TRACE"] + log_file: "load_csv_log_file.txt" + validation_queries: [] + load_csv_cluster: &load_csv_cluster cluster: main: @@ -9,6 +16,10 @@ load_csv_cluster: &load_csv_cluster validation_queries: [] workloads: + - name: "LOAD CSV nullif" + binary: "tests/e2e/pytest_runner.sh" + args: ["load_csv/load_csv_nullif.py"] + <<: *nullif_cluster - name: "MATCH + LOAD CSV" binary: "tests/e2e/pytest_runner.sh" args: ["load_csv/load_csv.py"]