Add command NULLIF for identifying nulls in LOAD CSV (#914)

Add NULLIF command which turns all row values corresponding to the string to the nullif character sequence.
This commit is contained in:
Josipmrden 2023-06-21 14:50:46 +02:00 committed by GitHub
parent 63f8298033
commit 05cc35bf93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 124 additions and 13 deletions

View File

@ -3022,6 +3022,7 @@ class LoadCsv : public memgraph::query::Clause {
bool ignore_bad_;
memgraph::query::Expression *delimiter_{nullptr};
memgraph::query::Expression *quote_{nullptr};
memgraph::query::Expression *nullif_{nullptr};
memgraph::query::Identifier *row_var_{nullptr};
LoadCsv *Clone(AstStorage *storage) const override {
@ -3031,18 +3032,20 @@ class LoadCsv : public memgraph::query::Clause {
object->ignore_bad_ = ignore_bad_;
object->delimiter_ = delimiter_ ? delimiter_->Clone(storage) : nullptr;
object->quote_ = quote_ ? quote_->Clone(storage) : nullptr;
object->nullif_ = nullif_;
object->row_var_ = row_var_ ? row_var_->Clone(storage) : nullptr;
return object;
}
protected:
explicit LoadCsv(Expression *file, bool with_header, bool ignore_bad, Expression *delimiter, Expression *quote,
Identifier *row_var)
Expression *nullif, Identifier *row_var)
: file_(file),
with_header_(with_header),
ignore_bad_(ignore_bad),
delimiter_(delimiter),
quote_(quote),
nullif_(nullif),
row_var_(row_var) {
DMG_ASSERT(row_var, "LoadCsv cannot take nullptr for identifier");
}

View File

@ -362,6 +362,11 @@ antlrcpp::Any CypherMainVisitor::visitLoadCsv(MemgraphCypher::LoadCsvContext *ct
// handle skip bad row option
load_csv->ignore_bad_ = ctx->IGNORE() && ctx->BAD();
// handle character sequence which will correspond to nulls
if (ctx->NULLIF()) {
load_csv->nullif_ = std::any_cast<Expression *>(ctx->nullif()->accept(this));
}
// handle delimiter
if (ctx->DELIMITER()) {
if (ctx->delimiter()->literal()->StringLiteral()) {

View File

@ -59,6 +59,7 @@ memgraphCypherKeyword : cypherKeyword
| GRANT
| HEADER
| IDENTIFIED
| NULLIF
| ISOLATION
| IN_MEMORY_ANALYTICAL
| IN_MEMORY_TRANSACTIONAL
@ -224,6 +225,7 @@ loadCsv : LOAD CSV FROM csvFile ( WITH | NO ) HEADER
( IGNORE BAD ) ?
( DELIMITER delimiter ) ?
( QUOTE quote ) ?
( NULLIF nullif ) ?
AS rowVar ;
csvFile : literal ;
@ -232,6 +234,8 @@ delimiter : literal ;
quote : literal ;
nullif : literal ;
rowVar : variable ;
userOrRoleName : symbolicName ;

View File

@ -85,6 +85,7 @@ MODULE_WRITE : M O D U L E UNDERSCORE W R I T E ;
NEXT : N E X T ;
NO : N O ;
NOTHING : N O T H I N G ;
NULLIF : N U L L I F ;
PASSWORD : P A S S W O R D ;
PORT : P O R T ;
PRIVILEGES : P R I V I L E G E S ;

View File

@ -4637,13 +4637,14 @@ UniqueCursorPtr CallProcedure::MakeCursor(utils::MemoryResource *mem) const {
}
LoadCsv::LoadCsv(std::shared_ptr<LogicalOperator> input, Expression *file, bool with_header, bool ignore_bad,
Expression *delimiter, Expression *quote, Symbol row_var)
Expression *delimiter, Expression *quote, Expression *nullif, Symbol row_var)
: input_(input ? input : (std::make_shared<Once>())),
file_(file),
with_header_(with_header),
ignore_bad_(ignore_bad),
delimiter_(delimiter),
quote_(quote),
nullif_(nullif),
row_var_(row_var) {
MG_ASSERT(file_, "Something went wrong - '{}' member file_ shouldn't be a nullptr", __func__);
}
@ -4674,22 +4675,31 @@ auto ToOptionalString(ExpressionEvaluator *evaluator, Expression *expression) ->
return std::nullopt;
};
TypedValue CsvRowToTypedList(csv::Reader::Row &row) {
TypedValue CsvRowToTypedList(csv::Reader::Row &row, std::optional<utils::pmr::string> &nullif) {
auto *mem = row.get_allocator().GetMemoryResource();
auto typed_columns = utils::pmr::vector<TypedValue>(mem);
typed_columns.reserve(row.size());
for (auto &column : row) {
typed_columns.emplace_back(std::move(column));
if (!nullif.has_value() || column != nullif.value()) {
typed_columns.emplace_back(std::move(column));
} else {
typed_columns.emplace_back();
}
}
return {std::move(typed_columns), mem};
}
TypedValue CsvRowToTypedMap(csv::Reader::Row &row, csv::Reader::Header header) {
TypedValue CsvRowToTypedMap(csv::Reader::Row &row, csv::Reader::Header header,
std::optional<utils::pmr::string> &nullif) {
// a valid row has the same number of elements as the header
auto *mem = row.get_allocator().GetMemoryResource();
utils::pmr::map<utils::pmr::string, TypedValue> m(mem);
for (auto i = 0; i < row.size(); ++i) {
m.emplace(std::move(header[i]), std::move(row[i]));
if (!nullif.has_value() || row[i] != nullif.value()) {
m.emplace(std::move(header[i]), std::move(row[i]));
} else {
m.emplace(std::piecewise_construct, std::forward_as_tuple(std::move(header[i])), std::forward_as_tuple());
}
}
return {std::move(m), mem};
}
@ -4701,6 +4711,7 @@ class LoadCsvCursor : public Cursor {
const UniqueCursorPtr input_cursor_;
bool did_pull_;
std::optional<csv::Reader> reader_{};
std::optional<utils::pmr::string> nullif_;
public:
LoadCsvCursor(const LoadCsv *self, utils::MemoryResource *mem)
@ -4718,6 +4729,7 @@ class LoadCsvCursor : public Cursor {
// without massacring the code even worse than I did here
if (UNLIKELY(!reader_)) {
reader_ = MakeReader(&context.evaluation_context);
nullif_ = ParseNullif(&context.evaluation_context);
}
if (input_cursor_->Pull(frame, context)) {
@ -4733,10 +4745,10 @@ class LoadCsvCursor : public Cursor {
return false;
}
if (!reader_->HasHeader()) {
frame[self_->row_var_] = CsvRowToTypedList(*row);
frame[self_->row_var_] = CsvRowToTypedList(*row, nullif_);
} else {
frame[self_->row_var_] =
CsvRowToTypedMap(*row, csv::Reader::Header(reader_->GetHeader(), context.evaluation_context.memory));
CsvRowToTypedMap(*row, csv::Reader::Header(reader_->GetHeader(), context.evaluation_context.memory), nullif_);
}
if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(self_->row_var_.name())) {
context.frame_change_collector->ResetTrackingValue(self_->row_var_.name());
@ -4768,6 +4780,15 @@ class LoadCsvCursor : public Cursor {
csv::Reader::Config(self_->with_header_, self_->ignore_bad_, std::move(maybe_delim), std::move(maybe_quote)),
utils::NewDeleteResource());
}
std::optional<utils::pmr::string> ParseNullif(EvaluationContext *eval_context) {
Frame frame(0);
SymbolTable symbol_table;
DbAccessor *dba = nullptr;
auto evaluator = ExpressionEvaluator(&frame, symbol_table, *eval_context, dba, storage::View::OLD);
return ToOptionalString(&evaluator, self_->nullif_);
}
};
UniqueCursorPtr LoadCsv::MakeCursor(utils::MemoryResource *mem) const {

View File

@ -2227,7 +2227,7 @@ class LoadCsv : public memgraph::query::plan::LogicalOperator {
LoadCsv() = default;
LoadCsv(std::shared_ptr<LogicalOperator> input, Expression *file, bool with_header, bool ignore_bad,
Expression *delimiter, Expression *quote, Symbol row_var);
Expression *delimiter, Expression *quote, Expression *nullif, Symbol row_var);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override;
std::vector<Symbol> OutputSymbols(const SymbolTable &) const override;
@ -2243,6 +2243,7 @@ class LoadCsv : public memgraph::query::plan::LogicalOperator {
bool ignore_bad_;
Expression *delimiter_{nullptr};
Expression *quote_{nullptr};
Expression *nullif_{nullptr};
Symbol row_var_;
std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override {
@ -2253,6 +2254,7 @@ class LoadCsv : public memgraph::query::plan::LogicalOperator {
object->ignore_bad_ = ignore_bad_;
object->delimiter_ = delimiter_ ? delimiter_->Clone(storage) : nullptr;
object->quote_ = quote_ ? quote_->Clone(storage) : nullptr;
object->nullif_ = nullif_;
object->row_var_ = row_var_;
return object;
}

View File

@ -895,6 +895,10 @@ bool PlanToJsonVisitor::PreVisit(query::plan::LoadCsv &op) {
self["quote"] = ToJson(op.quote_);
}
if (op.nullif_) {
self["nullif"] = ToJson(op.nullif_);
}
self["row_variable"] = ToJson(op.row_var_);
op.input_->Accept(*this);

View File

@ -226,10 +226,9 @@ class RuleBasedPlanner {
const auto &row_sym = context.symbol_table->at(*load_csv->row_var_);
context.bound_symbols.insert(row_sym);
input_op =
std::make_unique<plan::LoadCsv>(std::move(input_op), load_csv->file_, load_csv->with_header_,
load_csv->ignore_bad_, load_csv->delimiter_, load_csv->quote_, row_sym);
input_op = std::make_unique<plan::LoadCsv>(std::move(input_op), load_csv->file_, load_csv->with_header_,
load_csv->ignore_bad_, load_csv->delimiter_, load_csv->quote_,
load_csv->nullif_, row_sym);
} else if (auto *foreach = utils::Downcast<query::Foreach>(clause)) {
context.is_write_query = true;
input_op = HandleForeachClause(foreach, std::move(input_op), *context.symbol_table, context.bound_symbols,

View File

@ -8,3 +8,6 @@ endfunction()
copy_load_csv_e2e_python_files(load_csv.py)
copy_load_csv_e2e_files(simple.csv)
copy_load_csv_e2e_python_files(load_csv_nullif.py)
copy_load_csv_e2e_files(nullif.csv)

View File

@ -0,0 +1,53 @@
# Copyright 2022 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import os
import sys
from pathlib import Path
import pytest
from gqlalchemy import Memgraph
NULLIF_CSV_FILE = "nullif.csv"
def get_file_path(file: str) -> str:
parent_path = Path(__file__).parent.absolute()
return os.path.join(parent_path, file)
def test_given_csv_when_nullif_then_all_identical_rows_are_null():
memgraph = Memgraph("localhost", 7687)
results = list(
memgraph.execute_and_fetch(
f"""LOAD CSV FROM '{get_file_path(NULLIF_CSV_FILE)}'
WITH HEADER NULLIF 'N/A' AS row
CREATE (n:Person {{name: row.name, age: row.age,
percentage: row.percentage, works_in_IT: row.works_in_IT}})
RETURN n
"""
)
)
expected_properties = [
{"age": "10", "percentage": "15.0", "works_in_IT": "false"},
{"name": "John", "percentage": "35.4", "works_in_IT": "false"},
{"name": "Milewa", "age": "34", "works_in_IT": "false"},
{"name": "Lucas", "age": "50", "percentage": "12.5"},
]
properties = [result["n"]._properties for result in results]
assert expected_properties == properties
if __name__ == "__main__":
sys.exit(pytest.main([__file__, "-rA"]))

View File

@ -0,0 +1,5 @@
name,age,percentage,works_in_IT
N/A,10,15.0,false
John,N/A,35.4,false
Milewa,34,N/A,false
Lucas,50,12.5,N/A
1 name age percentage works_in_IT
2 N/A 10 15.0 false
3 John N/A 35.4 false
4 Milewa 34 N/A false
5 Lucas 50 12.5 N/A

View File

@ -1,3 +1,10 @@
nullif_cluster: &nullif_cluster
cluster:
main:
args: ["--bolt-port", "7687", "--log-level=TRACE"]
log_file: "load_csv_log_file.txt"
validation_queries: []
load_csv_cluster: &load_csv_cluster
cluster:
main:
@ -9,6 +16,10 @@ load_csv_cluster: &load_csv_cluster
validation_queries: []
workloads:
- name: "LOAD CSV nullif"
binary: "tests/e2e/pytest_runner.sh"
args: ["load_csv/load_csv_nullif.py"]
<<: *nullif_cluster
- name: "MATCH + LOAD CSV"
binary: "tests/e2e/pytest_runner.sh"
args: ["load_csv/load_csv.py"]