diff --git a/src/auth/models.cpp b/src/auth/models.cpp index 42b94c44d..92e19f841 100644 --- a/src/auth/models.cpp +++ b/src/auth/models.cpp @@ -42,6 +42,8 @@ std::string PermissionToString(Permission permission) { return "DUMP"; case Permission::REPLICATION: return "REPLICATION"; + case Permission::LOCK_PATH: + return "LOCK_PATH"; case Permission::AUTH: return "AUTH"; } diff --git a/src/auth/models.hpp b/src/auth/models.hpp index c03492a5f..fbc1df6f9 100644 --- a/src/auth/models.hpp +++ b/src/auth/models.hpp @@ -9,27 +9,31 @@ namespace auth { // These permissions must have values that are applicable for usage in a // bitmask. +// clang-format off enum class Permission : uint64_t { - MATCH = 0x00000001, - CREATE = 0x00000002, - MERGE = 0x00000004, - DELETE = 0x00000008, - SET = 0x00000010, - REMOVE = 0x00000020, - INDEX = 0x00000040, - STATS = 0x00000080, - CONSTRAINT = 0x00000100, - DUMP = 0x00000200, - REPLICATION = 0x00000400, - AUTH = 0x00010000, + MATCH = 1, + CREATE = 1U << 1U, + MERGE = 1U << 2U, + DELETE = 1U << 3U, + SET = 1U << 4U, + REMOVE = 1U << 5U, + INDEX = 1U << 6U, + STATS = 1U << 7U, + CONSTRAINT = 1U << 8U, + DUMP = 1U << 9U, + REPLICATION = 1U << 10U, + LOCK_PATH = 1U << 11U, + AUTH = 1U << 16U }; +// clang-format on // Constant list of all available permissions. const std::vector kPermissionsAll = { - Permission::MATCH, Permission::CREATE, Permission::MERGE, - Permission::DELETE, Permission::SET, Permission::REMOVE, - Permission::INDEX, Permission::STATS, Permission::CONSTRAINT, - Permission::DUMP, Permission::AUTH, Permission::REPLICATION}; + Permission::MATCH, Permission::CREATE, Permission::MERGE, + Permission::DELETE, Permission::SET, Permission::REMOVE, + Permission::INDEX, Permission::STATS, Permission::CONSTRAINT, + Permission::DUMP, Permission::AUTH, Permission::REPLICATION, + Permission::LOCK_PATH}; // Function that converts a permission to its string representation. std::string PermissionToString(Permission permission); diff --git a/src/glue/auth.cpp b/src/glue/auth.cpp index bb6a4f1ab..d5b8b706b 100644 --- a/src/glue/auth.cpp +++ b/src/glue/auth.cpp @@ -26,8 +26,10 @@ auth::Permission PrivilegeToPermission(query::AuthQuery::Privilege privilege) { return auth::Permission::DUMP; case query::AuthQuery::Privilege::REPLICATION: return auth::Permission::REPLICATION; + case query::AuthQuery::Privilege::LOCK_PATH: + return auth::Permission::LOCK_PATH; case query::AuthQuery::Privilege::AUTH: return auth::Permission::AUTH; } } -} +} // namespace glue diff --git a/src/query/exceptions.hpp b/src/query/exceptions.hpp index db61eae88..6e0727205 100644 --- a/src/query/exceptions.hpp +++ b/src/query/exceptions.hpp @@ -174,4 +174,11 @@ class ReplicationModificationInMulticommandTxException : public QueryException { "Replication clause not allowed in multicommand transactions.") {} }; +class LockPathModificationInMulticommandTxException : public QueryException { + public: + LockPathModificationInMulticommandTxException() + : QueryException( + "Lock path clause not allowed in multicommand transactions.") {} +}; + } // namespace query diff --git a/src/query/frontend/ast/ast.lcp b/src/query/frontend/ast/ast.lcp index 242fa7d23..1bc57aadc 100644 --- a/src/query/frontend/ast/ast.lcp +++ b/src/query/frontend/ast/ast.lcp @@ -2191,7 +2191,7 @@ cpp<# (:serialize)) (lcp:define-enum privilege (create delete match merge set remove index stats auth constraint - dump replication) + dump replication lock_path) (:serialize)) #>cpp AuthQuery() = default; @@ -2227,7 +2227,8 @@ const std::vector kPrivilegesAll = { AuthQuery::Privilege::INDEX, AuthQuery::Privilege::STATS, AuthQuery::Privilege::AUTH, AuthQuery::Privilege::CONSTRAINT, AuthQuery::Privilege::DUMP, - AuthQuery::Privilege::REPLICATION}; + AuthQuery::Privilege::REPLICATION, + AuthQuery::Privilege::LOCK_PATH}; cpp<# (lcp:define-class info-query (query) @@ -2333,4 +2334,23 @@ cpp<# (:serialize (:slk)) (:clone)) +(lcp:define-class lock-path-query (query) + ((action "Action" :scope :public)) + + (:public + (lcp:define-enum action + (lock-path unlock-path) + (:serialize)) + #>cpp + LockPathQuery() = default; + + DEFVISITABLE(QueryVisitor); + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + (lcp:pop-namespace) ;; namespace query diff --git a/src/query/frontend/ast/ast_visitor.hpp b/src/query/frontend/ast/ast_visitor.hpp index 3e5fa61d8..6eb7221a4 100644 --- a/src/query/frontend/ast/ast_visitor.hpp +++ b/src/query/frontend/ast/ast_visitor.hpp @@ -73,6 +73,7 @@ class ConstraintQuery; class RegexMatch; class DumpQuery; class ReplicationQuery; +class LockPathQuery; using TreeCompositeVisitor = ::utils::CompositeVisitor< SingleQuery, CypherUnion, NamedExpression, OrOperator, XorOperator, @@ -117,6 +118,6 @@ template class QueryVisitor : public ::utils::Visitor {}; + DumpQuery, ReplicationQuery, LockPathQuery> {}; } // namespace query diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index fe7c95bc4..4bfa088ba 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -285,6 +285,21 @@ antlrcpp::Any CypherMainVisitor::visitShowReplicas( return replication_query; } +antlrcpp::Any CypherMainVisitor::visitLockPathQuery( + MemgraphCypher::LockPathQueryContext *ctx) { + auto *lock_query = storage_->Create(); + if (ctx->LOCK()) { + lock_query->action_ = LockPathQuery::Action::LOCK_PATH; + } else if (ctx->UNLOCK()) { + lock_query->action_ = LockPathQuery::Action::UNLOCK_PATH; + } else { + throw SyntaxException("Expected LOCK or UNLOCK"); + } + + query_ = lock_query; + return lock_query; +} + antlrcpp::Any CypherMainVisitor::visitCypherUnion( MemgraphCypher::CypherUnionContext *ctx) { bool distinct = !ctx->ALL(); diff --git a/src/query/frontend/ast/cypher_main_visitor.hpp b/src/query/frontend/ast/cypher_main_visitor.hpp index 5c5da2e6a..95dbc9278 100644 --- a/src/query/frontend/ast/cypher_main_visitor.hpp +++ b/src/query/frontend/ast/cypher_main_visitor.hpp @@ -222,6 +222,12 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor { antlrcpp::Any visitShowReplicas( MemgraphCypher::ShowReplicasContext *ctx) override; + /** + * @return LockPathQuery* + */ + antlrcpp::Any visitLockPathQuery( + MemgraphCypher::LockPathQueryContext *ctx) override; + /** * @return CypherUnion* */ @@ -253,7 +259,8 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor { /** * @return CallProcedure* */ - antlrcpp::Any visitCallProcedure(MemgraphCypher::CallProcedureContext *ctx) override; + antlrcpp::Any visitCallProcedure( + MemgraphCypher::CallProcedureContext *ctx) override; /** * @return std::string diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 index a87f7febb..c3d3c75e7 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 @@ -19,6 +19,7 @@ memgraphCypherKeyword : cypherKeyword | FROM | GRANT | IDENTIFIED + | LOCK | MAIN | MODE | PASSWORD @@ -35,6 +36,7 @@ memgraphCypherKeyword : cypherKeyword | SYNC | TIMEOUT | TO + | UNLOCK | USER | USERS ; @@ -53,6 +55,7 @@ query : cypherQuery | authQuery | dumpQuery | replicationQuery + | lockPathQuery ; authQuery : createRole @@ -135,3 +138,6 @@ registerReplica : REGISTER REPLICA replicaName ( SYNC | ASYNC ) dropReplica : DROP REPLICA replicaName ; showReplicas : SHOW REPLICAS ; + +lockPathQuery : ( LOCK | UNLOCK ) DATA DIRECTORY ; + diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 index 340003e28..15923dcf2 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 @@ -14,8 +14,10 @@ ALTER : A L T E R ; ASYNC : A S Y N C ; AUTH : A U T H ; CLEAR : C L E A R ; +DATA : D A T A ; DATABASE : D A T A B A S E ; DENY : D E N Y ; +DIRECTORY : D I R E C T O R Y ; DROP : D R O P ; DUMP : D U M P ; FOR : F O R ; @@ -23,6 +25,7 @@ FROM : F R O M ; GRANT : G R A N T ; GRANTS : G R A N T S ; IDENTIFIED : I D E N T I F I E D ; +LOCK : L O C K ; MAIN : M A I N ; MODE : M O D E ; PASSWORD : P A S S W O R D ; @@ -39,5 +42,6 @@ STATS : S T A T S ; SYNC : S Y N C ; TIMEOUT : T I M E O U T ; TO : T O ; +UNLOCK : U N L O C K ; USER : U S E R ; USERS : U S E R S ; diff --git a/src/query/frontend/semantic/required_privileges.cpp b/src/query/frontend/semantic/required_privileges.cpp index 03b7bcdc0..b40549a9c 100644 --- a/src/query/frontend/semantic/required_privileges.cpp +++ b/src/query/frontend/semantic/required_privileges.cpp @@ -59,6 +59,10 @@ class PrivilegeExtractor : public QueryVisitor, AddPrivilege(AuthQuery::Privilege::DUMP); } + void Visit(LockPathQuery &lock_path_query) override { + AddPrivilege(AuthQuery::Privilege::LOCK_PATH); + } + void Visit(ReplicationQuery &replication_query) override { switch (replication_query.action_) { case ReplicationQuery::Action::SET_REPLICATION_ROLE: diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index cf563e800..19818d231 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -573,7 +573,7 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const auto &replicas = handler->ShowReplicas(); auto typed_replicas = std::vector>{}; typed_replicas.reserve(replicas.size()); - for (auto &replica : replicas) { + for (const auto &replica : replicas) { std::vector typed_replica; typed_replica.reserve(replica_nfields); @@ -1222,6 +1222,49 @@ PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, RWType::NONE}; } +PreparedQuery PrepareLockPathQuery(ParsedQuery parsed_query, + const bool in_explicit_transaction, + InterpreterContext *interpreter_context, + DbAccessor *dba) { + if (in_explicit_transaction) { + throw LockPathModificationInMulticommandTxException(); + } + + auto *lock_path_query = utils::Downcast(parsed_query.query); + + Frame frame(0); + SymbolTable symbol_table; + EvaluationContext evaluation_context; + evaluation_context.timestamp = + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); + evaluation_context.parameters = parsed_query.parameters; + ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, dba, + storage::View::OLD); + + Callback callback; + switch (lock_path_query->action_) { + case LockPathQuery::Action::LOCK_PATH: + if (!interpreter_context->db->LockPath()) { + throw QueryRuntimeException("Failed to lock the data directory"); + } + break; + case LockPathQuery::Action::UNLOCK_PATH: + if (!interpreter_context->db->UnlockPath()) { + throw QueryRuntimeException("Failed to unlock the data directory"); + } + break; + } + + return PreparedQuery{ + callback.header, std::move(parsed_query.required_privileges), + [](AnyStream *stream, + std::optional n) -> std::optional { + return QueryHandlerResult::COMMIT; + }}; +} + PreparedQuery PrepareInfoQuery( ParsedQuery parsed_query, bool in_explicit_transaction, std::map *summary, @@ -1601,6 +1644,10 @@ Interpreter::PrepareResult Interpreter::Prepare( prepared_query = PrepareReplicationQuery( std::move(parsed_query), in_explicit_transaction_, interpreter_context_, &*execution_db_accessor_); + } else if (utils::Downcast(parsed_query.query)) { + prepared_query = PrepareLockPathQuery( + std::move(parsed_query), in_explicit_transaction_, + interpreter_context_, &*execution_db_accessor_); } else { LOG(FATAL) << "Should not get here -- unknown query type!"; } diff --git a/src/storage/v2/replication/replication_client.cpp b/src/storage/v2/replication/replication_client.cpp index ac93bfbc4..0f3d9caed 100644 --- a/src/storage/v2/replication/replication_client.cpp +++ b/src/storage/v2/replication/replication_client.cpp @@ -425,7 +425,7 @@ Storage::ReplicationClient::GetRecoverySteps( recovery_steps.emplace_back(RecoveryCurrentWal{*current_wal_seq_num}); } else { CHECK(latest_snapshot); - locker_acc.AddFile(latest_snapshot->path); + locker_acc.AddPath(latest_snapshot->path); recovery_steps.emplace_back( RecoveryFinalSnapshot{latest_snapshot->start_timestamp}); } @@ -444,7 +444,7 @@ Storage::ReplicationClient::GetRecoverySteps( recovery_steps.emplace_back(RecoveryCurrentWal{*current_wal_seq_num}); } else { CHECK(latest_snapshot); - locker_acc.AddFile(latest_snapshot->path); + locker_acc.AddPath(latest_snapshot->path); recovery_steps.emplace_back( RecoveryFinalSnapshot{latest_snapshot->start_timestamp}); } @@ -475,7 +475,7 @@ Storage::ReplicationClient::GetRecoverySteps( // We need to lock these files and add them to the chain for (auto result_wal_it = wal_files->begin() + distance_from_first; result_wal_it != wal_files->end(); ++result_wal_it) { - locker_acc.AddFile(result_wal_it->path); + locker_acc.AddPath(result_wal_it->path); wal_chain.push_back(std::move(result_wal_it->path)); } @@ -494,7 +494,7 @@ Storage::ReplicationClient::GetRecoverySteps( CHECK(latest_snapshot) << "Invalid durability state, missing snapshot"; // We didn't manage to find a WAL chain, we need to send the latest snapshot // with its WALs - locker_acc.AddFile(latest_snapshot->path); + locker_acc.AddPath(latest_snapshot->path); recovery_steps.emplace_back(std::in_place_type_t{}, std::move(latest_snapshot->path)); @@ -514,13 +514,13 @@ Storage::ReplicationClient::GetRecoverySteps( } for (; wal_it != wal_files->end(); ++wal_it) { - locker_acc.AddFile(wal_it->path); + locker_acc.AddPath(wal_it->path); recovery_wal_files.push_back(std::move(wal_it->path)); } // We only have a WAL before the snapshot if (recovery_wal_files.empty()) { - locker_acc.AddFile(wal_files->back().path); + locker_acc.AddPath(wal_files->back().path); recovery_wal_files.push_back(std::move(wal_files->back().path)); } diff --git a/src/storage/v2/storage.cpp b/src/storage/v2/storage.cpp index 7056e6c99..2077bba09 100644 --- a/src/storage/v2/storage.cpp +++ b/src/storage/v2/storage.cpp @@ -1,5 +1,4 @@ #include "storage/v2/storage.hpp" - #include #include #include @@ -325,7 +324,8 @@ Storage::Storage(Config config) lock_file_path_(config_.durability.storage_directory / durability::kLockFile), uuid_(utils::GenerateUUID()), - epoch_id_(utils::GenerateUUID()) { + epoch_id_(utils::GenerateUUID()), + global_locker_(file_retainer_.AddLocker()) { if (config_.durability.snapshot_wal_mode != Config::Durability::SnapshotWalMode::DISABLED || config_.durability.snapshot_on_exit || @@ -1905,6 +1905,25 @@ void Storage::CreateSnapshot() { commit_log_.MarkFinished(transaction.start_timestamp); } +bool Storage::LockPath() { + auto locker_accessor = global_locker_.Access(); + return locker_accessor.AddPath(config_.durability.storage_directory); +} + +bool Storage::UnlockPath() { + { + auto locker_accessor = global_locker_.Access(); + if (!locker_accessor.RemovePath(config_.durability.storage_directory)) { + return false; + } + } + + // We use locker accessor in seperate scope so we don't produce deadlock + // after we call clean queue. + file_retainer_.CleanQueue(); + return true; +} + uint64_t Storage::CommitTimestamp( const std::optional desired_commit_timestamp) { #ifdef MG_ENTERPRISE diff --git a/src/storage/v2/storage.hpp b/src/storage/v2/storage.hpp index 3d1291e20..365821e5c 100644 --- a/src/storage/v2/storage.hpp +++ b/src/storage/v2/storage.hpp @@ -415,8 +415,10 @@ class Storage final { StorageInfo GetInfo() const; -#if MG_ENTERPRISE + bool LockPath(); + bool UnlockPath(); +#if MG_ENTERPRISE bool SetReplicaRole(io::network::Endpoint endpoint, const replication::ReplicationServerConfig &config = {}); @@ -568,6 +570,9 @@ class Storage final { utils::FileRetainer file_retainer_; + // Global locker that is used for clients file locking + utils::FileRetainer::FileLocker global_locker_; + // Replication #ifdef MG_ENTERPRISE // Last commited timestamp diff --git a/src/utils/file_locker.cpp b/src/utils/file_locker.cpp index 2ad635a4f..12f138c7b 100644 --- a/src/utils/file_locker.cpp +++ b/src/utils/file_locker.cpp @@ -1,4 +1,5 @@ #include "utils/file_locker.hpp" +#include namespace utils { @@ -12,18 +13,25 @@ void DeleteFromSystem(const std::filesystem::path &path) { ////// FileRetainer ////// void FileRetainer::DeleteFile(const std::filesystem::path &path) { + if (!std::filesystem::exists(path)) { + LOG(INFO) << "File " << path << " doesn't exist."; + return; + } + + auto absolute_path = std::filesystem::absolute(path); if (active_accessors_.load()) { - files_for_deletion_.WithLock([&](auto &files) { files.emplace(path); }); + files_for_deletion_.WithLock( + [&](auto &files) { files.emplace(std::move(absolute_path)); }); return; } std::unique_lock guard(main_lock_); - DeleteOrAddToQueue(path); + DeleteOrAddToQueue(absolute_path); } FileRetainer::FileLocker FileRetainer::AddLocker() { const size_t current_locker_id = next_locker_id_.fetch_add(1); lockers_.WithLock([&](auto &lockers) { - lockers.emplace(current_locker_id, std::set{}); + lockers.emplace(current_locker_id, LockerEntry{}); }); return FileLocker{this, current_locker_id}; } @@ -34,8 +42,8 @@ FileRetainer::~FileRetainer() { [[nodiscard]] bool FileRetainer::FileLocked(const std::filesystem::path &path) { return lockers_.WithLock([&](auto &lockers) { - for (const auto &[_, paths] : lockers) { - if (paths.count(path)) { + for (const auto &[_, locker_entry] : lockers) { + if (locker_entry.LocksFile(path)) { return true; } } @@ -52,6 +60,7 @@ void FileRetainer::DeleteOrAddToQueue(const std::filesystem::path &path) { } void FileRetainer::CleanQueue() { + std::unique_lock guard(main_lock_); files_for_deletion_.WithLock([&](auto &files) { for (auto it = files.cbegin(); it != files.cend();) { if (!FileLocked(*it)) { @@ -64,11 +73,57 @@ void FileRetainer::CleanQueue() { }); } +////// LockerEntry ////// +void FileRetainer::LockerEntry::LockPath(const std::filesystem::path &path) { + auto absolute_path = std::filesystem::absolute(path); + if (std::filesystem::is_directory(absolute_path)) { + directories_.emplace(std::move(absolute_path)); + return; + } + files_.emplace(std::move(absolute_path)); +} + +bool FileRetainer::LockerEntry::RemovePath(const std::filesystem::path &path) { + auto absolute_path = std::filesystem::absolute(path); + if (std::filesystem::is_directory(absolute_path)) { + return directories_.erase(absolute_path); + } + + return files_.erase(absolute_path); +} + +bool FileRetainer::LockerEntry::LocksFile( + const std::filesystem::path &path) const { + CHECK(path.is_absolute()) + << "Absolute path needed to check if the file is locked."; + + if (files_.count(path)) { + return true; + } + + for (const auto &directory : directories_) { + auto directory_path_it = directory.begin(); + auto path_it = path.begin(); + while (directory_path_it != directory.end() && path_it != path.end()) { + if (*directory_path_it != *path_it) { + break; + } + ++directory_path_it; + ++path_it; + } + + if (directory_path_it == directory.end()) { + return true; + } + } + + return false; +} + ////// FileLocker ////// FileRetainer::FileLocker::~FileLocker() { file_retainer_->lockers_.WithLock( [this](auto &lockers) { lockers.erase(locker_id_); }); - std::unique_lock guard(file_retainer_->main_lock_); file_retainer_->CleanQueue(); } @@ -85,14 +140,20 @@ FileRetainer::FileLockerAccessor::FileLockerAccessor(FileRetainer *retainer, file_retainer_->active_accessors_.fetch_add(1); } -bool FileRetainer::FileLockerAccessor::AddFile( +bool FileRetainer::FileLockerAccessor::AddPath( const std::filesystem::path &path) { if (!std::filesystem::exists(path)) return false; file_retainer_->lockers_.WithLock( - [&](auto &lockers) { lockers[locker_id_].emplace(path); }); + [&](auto &lockers) { lockers[locker_id_].LockPath(path); }); return true; } +bool FileRetainer::FileLockerAccessor::RemovePath( + const std::filesystem::path &path) { + return file_retainer_->lockers_.WithLock( + [&](auto &lockers) { return lockers[locker_id_].RemovePath(path); }); +} + FileRetainer::FileLockerAccessor::~FileLockerAccessor() { file_retainer_->active_accessors_.fetch_sub(1); } diff --git a/src/utils/file_locker.hpp b/src/utils/file_locker.hpp index 3645a2344..17fe7017b 100644 --- a/src/utils/file_locker.hpp +++ b/src/utils/file_locker.hpp @@ -2,10 +2,9 @@ #include #include #include -#include -#include #include #include +#include #include "utils/file.hpp" #include "utils/rw_lock.hpp" @@ -30,6 +29,8 @@ namespace utils { * - FileLockerAccessor prevents deletion of any file, so you can safely add * multiple files to the locker with no risk of having files deleted during * the process. + * - You can also add directories to the locker which prevents deletion + * of ANY files in that directory. * - After a FileLocker or FileLockerAccessor is destroyed, FileRetainer scans * the list of the files that wait to be deleted, and deletes all the files * that are not inside any of currently present lockers. @@ -49,8 +50,8 @@ namespace utils { * // Accesor prevents deletion of any files * // so you safely add multiple files in atomic way * auto accessor = locker.Access(); - * accessor.AddFile(file1); - * accessor.AddFile(file2); + * accessor.AddPath(file1); + * accessor.AddPath(file2); * } * // DO SOMETHING WITH THE FILES * } @@ -104,9 +105,14 @@ class FileRetainer { friend FileLocker; /** - * Add a single file to the current locker. + * Add a single path to the current locker. */ - bool AddFile(const std::filesystem::path &path); + bool AddPath(const std::filesystem::path &path); + + /** + * Remove a single path form the current locker. + */ + bool RemovePath(const std::filesystem::path &path); FileLockerAccessor(const FileLockerAccessor &) = delete; FileLockerAccessor(FileLockerAccessor &&) = default; @@ -136,6 +142,17 @@ class FileRetainer { */ FileLocker AddLocker(); + /** + * Delete the files that were queued for deletion. + * This is already called after a locker is destroyed. + * Call this only if you want to trigger cleaning of the + * queue before a locker is destroyed (e.g. a file was removed + * from a locker). + * This method CANNOT be called from a thread which has an active + * accessor as it will produce a deadlock. + */ + void CleanQueue(); + explicit FileRetainer() = default; FileRetainer(const FileRetainer &) = delete; FileRetainer(FileRetainer &&) = delete; @@ -147,16 +164,25 @@ class FileRetainer { private: [[nodiscard]] bool FileLocked(const std::filesystem::path &path); void DeleteOrAddToQueue(const std::filesystem::path &path); - void CleanQueue(); utils::RWLock main_lock_{RWLock::Priority::WRITE}; std::atomic active_accessors_{0}; std::atomic next_locker_id_{0}; - utils::Synchronized>, - utils::SpinLock> - lockers_; + class LockerEntry { + public: + void LockPath(const std::filesystem::path &path); + bool RemovePath(const std::filesystem::path &path); + [[nodiscard]] bool LocksFile(const std::filesystem::path &path) const; + + private: + std::set directories_; + std::set files_; + }; + + utils::Synchronized, utils::SpinLock> + lockers_; utils::Synchronized, utils::SpinLock> files_for_deletion_; }; diff --git a/tests/unit/cypher_main_visitor.cpp b/tests/unit/cypher_main_visitor.cpp index cf9ad37bd..d4ec70faf 100644 --- a/tests/unit/cypher_main_visitor.cpp +++ b/tests/unit/cypher_main_visitor.cpp @@ -3246,4 +3246,39 @@ TEST_P(CypherMainVisitorTest, IncorrectCallProcedure) { SyntaxException); } +TEST_P(CypherMainVisitorTest, TestLockPathQuery) { + auto &ast_generator = *GetParam(); + + const auto test_lock_path_query = [&](const std::string_view command, + const LockPathQuery::Action action) { + ASSERT_THROW(ast_generator.ParseQuery(command.data()), SyntaxException); + + { + const std::string query = fmt::format("{} ME", command); + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = fmt::format("{} DATA", command); + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = fmt::format("{} DATA STUFF", command); + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = fmt::format("{} DATA DIRECTORY", command); + auto *parsed_query = + dynamic_cast(ast_generator.ParseQuery(query)); + ASSERT_TRUE(parsed_query); + EXPECT_EQ(parsed_query->action_, action); + } + }; + + test_lock_path_query("LOCK", LockPathQuery::Action::LOCK_PATH); + test_lock_path_query("UNLOCK", LockPathQuery::Action::UNLOCK_PATH); +} + } // namespace diff --git a/tests/unit/utils_file_locker.cpp b/tests/unit/utils_file_locker.cpp index 17bdfbba1..ac262becc 100644 --- a/tests/unit/utils_file_locker.cpp +++ b/tests/unit/utils_file_locker.cpp @@ -27,7 +27,7 @@ class FileLockerTest : public ::testing::Test { std::filesystem::current_path(testing_directory); for (auto i = 1; i <= files_number; ++i) { - std::ofstream file(fmt::format("{}", i)); + std::ofstream(fmt::format("{}", i)); } std::filesystem::current_path(save_path); @@ -40,52 +40,158 @@ class FileLockerTest : public ::testing::Test { } }; -TEST_F(FileLockerTest, DeleteWhileLocking) { +// Test are parameterized based on the type of path used for locking and +// deleting. We test all of the combinations for absolute/relative paths for +// locking path and absolute/relative paths for deleting +// Parameter is represented by tuple (lock_absolute, delete_absolute). +class FileLockerParameterizedTest + : public FileLockerTest, + public ::testing::WithParamInterface> {}; + +TEST_P(FileLockerParameterizedTest, DeleteWhileLocking) { CreateFiles(1); utils::FileRetainer file_retainer; - auto t1 = std::thread([&]() { + const auto save_path = std::filesystem::current_path(); + std::filesystem::current_path(testing_directory); + const auto file = std::filesystem::path("1"); + const auto file_absolute = std::filesystem::absolute(file); + const auto [lock_absolute, delete_absolute] = GetParam(); + { auto locker = file_retainer.AddLocker(); { auto acc = locker.Access(); - std::this_thread::sleep_for(100ms); + file_retainer.DeleteFile(delete_absolute ? file_absolute : file); + ASSERT_TRUE(std::filesystem::exists(file)); } - }); - const auto file = testing_directory / "1"; - auto t2 = std::thread([&]() { - std::this_thread::sleep_for(50ms); - file_retainer.DeleteFile(file); - ASSERT_TRUE(std::filesystem::exists(file)); - }); - - t1.join(); - t2.join(); + } ASSERT_FALSE(std::filesystem::exists(file)); + + std::filesystem::current_path(save_path); } -TEST_F(FileLockerTest, DeleteWhileInLocker) { +TEST_P(FileLockerParameterizedTest, DeleteWhileInLocker) { CreateFiles(1); utils::FileRetainer file_retainer; - const auto file = testing_directory / "1"; - auto t1 = std::thread([&]() { + const auto save_path = std::filesystem::current_path(); + std::filesystem::current_path(testing_directory); + const auto file = std::filesystem::path("1"); + const auto file_absolute = std::filesystem::absolute(file); + const auto [lock_absolute, delete_absolute] = GetParam(); + { auto locker = file_retainer.AddLocker(); { auto acc = locker.Access(); - acc.AddFile(file); + acc.AddPath(lock_absolute ? file_absolute : file); } - std::this_thread::sleep_for(100ms); - }); - auto t2 = std::thread([&]() { - std::this_thread::sleep_for(50ms); - file_retainer.DeleteFile(file); + file_retainer.DeleteFile(delete_absolute ? file_absolute : file); ASSERT_TRUE(std::filesystem::exists(file)); - }); + } - t1.join(); - t2.join(); ASSERT_FALSE(std::filesystem::exists(file)); + std::filesystem::current_path(save_path); } +TEST_P(FileLockerParameterizedTest, DirectoryLock) { + utils::FileRetainer file_retainer; + // For this test we create the following file structure + // testing_directory + // 1 + // additional + // 2 + // We check 2 cases: + // - locking the subdirectory "additional", only "2" should be preserved + // - locking the directory testing_directory, all of the files shold be + // preserved + ASSERT_TRUE(std::filesystem::create_directory(testing_directory)); + const auto save_path = std::filesystem::current_path(); + std::filesystem::current_path(testing_directory); + + // Create additional directory inside the testing directory with a single file + const auto additional_directory = std::filesystem::path("additional"); + ASSERT_TRUE(std::filesystem::create_directory(additional_directory)); + + const auto nested_file = + std::filesystem::path(fmt::format("{}/2", additional_directory.string())); + const auto nested_file_absolute = std::filesystem::absolute(nested_file); + + const auto file = std::filesystem::path("1"); + const auto file_absolute = std::filesystem::absolute(file); + const auto directory_lock_test = [&](const bool lock_nested_directory) { + const auto directory_to_lock = + lock_nested_directory ? additional_directory : testing_directory; + const auto [lock_absolute, delete_absolute] = GetParam(); + std::ofstream(file.string()); + std::ofstream(nested_file.string()); + { + auto locker = file_retainer.AddLocker(); + { + auto acc = locker.Access(); + acc.AddPath(lock_absolute ? std::filesystem::absolute(directory_to_lock) + : directory_to_lock); + } + + file_retainer.DeleteFile(delete_absolute ? file_absolute : file); + ASSERT_NE(std::filesystem::exists(file), lock_nested_directory); + file_retainer.DeleteFile(delete_absolute ? nested_file_absolute + : nested_file); + ASSERT_TRUE(std::filesystem::exists(nested_file)); + } + ASSERT_FALSE(std::filesystem::exists(file)); + ASSERT_FALSE(std::filesystem::exists(nested_file)); + }; + + directory_lock_test(true); + directory_lock_test(false); + + std::filesystem::current_path(save_path); +} + +TEST_P(FileLockerParameterizedTest, RemovePath) { + utils::FileRetainer file_retainer; + ASSERT_TRUE(std::filesystem::create_directory(testing_directory)); + const auto save_path = std::filesystem::current_path(); + std::filesystem::current_path(testing_directory); + const auto file = std::filesystem::path("1"); + const auto file_absolute = std::filesystem::absolute(file); + auto remove_path_test = [&](const bool delete_explicitly_file) { + const auto [lock_absolute, delete_absolute] = GetParam(); + // Create the file + std::ofstream(file.string()); + auto locker = file_retainer.AddLocker(); + { + auto acc = locker.Access(); + acc.AddPath(lock_absolute ? file_absolute : file); + } + + file_retainer.DeleteFile(delete_absolute ? file_absolute : file); + ASSERT_TRUE(std::filesystem::exists(file)); + + { + auto acc = locker.Access(); + // If absolute was sent to AddPath method, use relative now + // to test those combinations. + acc.RemovePath(lock_absolute ? file : file_absolute); + } + if (delete_explicitly_file) { + file_retainer.DeleteFile(delete_absolute ? file_absolute : file); + } else { + file_retainer.CleanQueue(); + } + ASSERT_FALSE(std::filesystem::exists(file)); + }; + + remove_path_test(true); + remove_path_test(false); + std::filesystem::current_path(save_path); +} + +INSTANTIATE_TEST_CASE_P(FileLockerPathVariantTests, FileLockerParameterizedTest, + ::testing::Values(std::make_tuple(false, false), + std::make_tuple(false, true), + std::make_tuple(true, false), + std::make_tuple(true, true))); + TEST_F(FileLockerTest, MultipleLockers) { CreateFiles(3); utils::FileRetainer file_retainer; @@ -97,17 +203,18 @@ TEST_F(FileLockerTest, MultipleLockers) { auto locker = file_retainer.AddLocker(); { auto acc = locker.Access(); - acc.AddFile(file1); - acc.AddFile(common_file); + acc.AddPath(file1); + acc.AddPath(common_file); } + std::this_thread::sleep_for(200ms); }); auto t2 = std::thread([&]() { auto locker = file_retainer.AddLocker(); { auto acc = locker.Access(); - acc.AddFile(file2); - acc.AddFile(common_file); + acc.AddPath(file2); + acc.AddPath(common_file); } std::this_thread::sleep_for(200ms); }); @@ -117,7 +224,7 @@ TEST_F(FileLockerTest, MultipleLockers) { file_retainer.DeleteFile(file1); file_retainer.DeleteFile(file2); file_retainer.DeleteFile(common_file); - ASSERT_FALSE(std::filesystem::exists(file1)); + ASSERT_TRUE(std::filesystem::exists(file1)); ASSERT_TRUE(std::filesystem::exists(file2)); ASSERT_TRUE(std::filesystem::exists(common_file)); }); @@ -168,7 +275,7 @@ TEST_F(FileLockerTest, MultipleLockersAndDeleters) { auto acc = locker.Access(); for (auto i = 0; i < file_access_num; ++i) { auto file = random_file(); - if (acc.AddFile(file)) { + if (acc.AddPath(file)) { ASSERT_TRUE(std::filesystem::exists(file)); locked_files.emplace_back(std::move(file)); } else {