diff --git a/tests/stress/long_running.py b/tests/stress/long_running.py index a3ad4051e..7461cb16a 100755 --- a/tests/stress/long_running.py +++ b/tests/stress/long_running.py @@ -10,140 +10,29 @@ the graph state oscilates. """ import logging +import multiprocessing +import neo4j.exceptions import random import time -from uuid import uuid4 -from threading import Lock, Thread -from contextlib import contextmanager from collections import defaultdict import common + log = logging.getLogger(__name__) -# the label in the database that is indexed -# used for matching vertices faster -INDEXED_LABEL = "indexed_label" + +INDEX_FORMAT = "indexed_label{}" -def rint(upper_exclusive): - return random.randint(0, upper_exclusive - 1) +def random_element(lst): + return lst[random.randint(0, len(lst) - 1)] def bernoulli(p): return random.random() < p -def random_id(): - return str(uuid4()) - - -class QueryExecutionSynchronizer(): - """ - Fascilitates running a query with not other queries being - concurrently executed. - - Exposes a count of how many queries in total have been - executed through `count_total`. - """ - - def __init__(self, sleep_time=0.2): - """ - Args: - sleep_time - Sleep time while awaiting execution rights - """ - self.count_total = 0 - - self._lock = Lock() - self._count = 0 - self._can_run = True - self._sleep_time = sleep_time - - @contextmanager - def run(self): - """ - Provides a context for running a query without isolation. - Isolated queries can't be executed while such a context exists. - """ - while True: - with self._lock: - if self._can_run: - self._count += 1 - self.count_total += 1 - break - time.sleep(self._sleep_time) - - try: - yield - finally: - with self._lock: - self._count -= 1 - - @contextmanager - def run_isolated(self): - """ - Provides a context for runnig a query with isolation. Prevents - new queries from executing. Waits till the currently executing - queries are done. Once this context exits execution can - continue. - """ - with self._lock: - self._can_run = False - - while True: - with self._lock: - if self._count == 0: - break - time.sleep(self._sleep_time) - - with self._lock: - try: - yield - finally: - self._can_run = True - - -class LabelCounter(): - """ Encapsulates a label and a thread-safe counter """ - - def __init__(self, label): - self.label = label - self._count = 0 - self._lock = Lock() - - def increment(self): - with self._lock: - self._count += 1 - - def decrement(self): - with self._lock: - self._count -= 1 - - -class ThreadSafeList(): - """ Provides a thread-safe access to a list for a few functionalities. """ - - def __init__(self): - self._list = [] - self._lock = Lock() - - def append(self, element): - with self._lock: - self._list.append(element) - - def remove(self, element): - with self._lock: - self._list.remove(element) - - def random(self): - with self._lock: - return self._list[rint(len(self._list))] - - def __len__(self): - with self._lock: - return len(self._list) - - class Graph(): """ Exposes functions for working on a graph, and tracks some @@ -162,25 +51,20 @@ class Graph(): self.vertex_count = vertex_count self.edge_count = edge_count - self.query_execution_synchronizer = QueryExecutionSynchronizer() - # storage - self.edges = ThreadSafeList() - self.vertices = ThreadSafeList() - self.labels = {"label%d" % i: ThreadSafeList() for i in range(labels)} + self.edges = [] + self.vertices = [] + self.labels = {"label%d" % i: [] for i in range(labels)} # info about query failures, maps exception string representations into # occurence counts self._query_failure_counts = defaultdict(int) - self._query_failure_counts_lock = Lock() def add_query_failure(self, reason): - with self._query_failure_counts_lock: - self._query_failure_counts[reason] += 1 + self._query_failure_counts[reason] += 1 def query_failures(self): - with self._query_failure_counts_lock: - return dict(self._query_failure_counts) + return dict(self._query_failure_counts) class GraphSession(): @@ -190,9 +74,19 @@ class GraphSession(): verification function. """ - def __init__(self, graph, session): + def __init__(self, sid, graph, session): + self.sid = sid + + # the label in the database that is indexed + # used for matching vertices faster + self.indexed_label = INDEX_FORMAT.format(sid) + + self.vertex_id = 1 + self.edge_id = 1 + self.graph = graph self.session = session + self.executed_queries = 0 self._start_time = time.time() @property @@ -203,28 +97,33 @@ class GraphSession(): def e(self): return self.graph.edges - def execute_basic(self, query): - log.debug("Executing query: %s", query) + def execute(self, query): + log.debug("Runner %d executing query: %s", self.sid, query) + self.executed_queries += 1 try: return self.session.run(query).data() + except neo4j.exceptions.ServiceUnavailable as e: + raise e except Exception as e: self.graph.add_query_failure(str(e)) return None - def execute(self, query): - with self.graph.query_execution_synchronizer.run(): - return self.execute_basic(query) - - def create_vertex(self): - vertex_id = random_id() - self.execute("CREATE (:%s {id: %r})" % (INDEXED_LABEL, vertex_id)) - self.v.append(vertex_id) + def create_vertices(self, vertices_count): + query = "" + if vertices_count == 0: return + for _ in range(vertices_count): + query += "CREATE (:%s {id: %r}) " % (self.indexed_label, + self.vertex_id) + self.v.append(self.vertex_id) + self.vertex_id += 1 + self.execute(query) def remove_vertex(self): - vertex_id = self.v.random() + vertex_id = random_element(self.v) result = self.execute( "MATCH (n:%s {id: %r}) OPTIONAL MATCH (n)-[r]-() " - "DETACH DELETE n RETURN n.id, labels(n), r.id" % (INDEXED_LABEL, vertex_id)) + "DETACH DELETE n RETURN n.id, labels(n), r.id" % + (self.indexed_label, vertex_id)) if result: process_vertex_ids = set() for row in result: @@ -234,84 +133,126 @@ class GraphSession(): process_vertex_ids.add(vertex_id) self.v.remove(vertex_id) for label in row['labels(n)']: - if (label != INDEXED_LABEL): + if (label != self.indexed_label): self.graph.labels[label].remove(vertex_id) # remove edge edge_id = row['r.id'] - if edge_id: + if edge_id != None: self.e.remove(edge_id) def create_edge(self): - eid = random_id() creation = self.execute( "MATCH (from:%s {id: %r}), (to:%s {id: %r}) " "CREATE (from)-[e:EdgeType {id: %r}]->(to) RETURN e" % ( - INDEXED_LABEL, self.v.random(), INDEXED_LABEL, self.v.random(), eid)) + self.indexed_label, random_element(self.v), self.indexed_label, + random_element(self.v), self.edge_id)) if creation: - self.e.append(eid) + self.e.append(self.edge_id) + self.edge_id += 1 def remove_edge(self): - edge_id = self.e.random() - result = self.execute( - "MATCH ()-[e {id: %r}]->() DELETE e RETURN e.id" % edge_id) + edge_id = random_element(self.e) + result = self.execute("MATCH (:%s)-[e {id: %r}]->(:%s) DELETE e " + "RETURN e.id" % (self.indexed_label, edge_id, + self.indexed_label)) if result: self.e.remove(edge_id) def add_label(self): - vertex_id = self.v.random() + vertex_id = random_element(self.v) label = random.choice(list(self.graph.labels.keys())) # add a label on a vertex that didn't have that label # yet (we need that for book-keeping) - result = self.execute( - "MATCH (v {id: %r}) WHERE not v:%s SET v:%s RETURN v.id" % ( - vertex_id, label, label)) + result = self.execute("MATCH (v:%s {id: %r}) WHERE not v:%s SET v:%s " + "RETURN v.id" % (self.indexed_label, vertex_id, + label, label)) if result: self.graph.labels[label].append(vertex_id) + def update_global_vertices(self): + lo = random.randint(0, self.vertex_id) + hi = lo + int(self.vertex_id * 0.01) + num = random.randint(0, 2 ** 20) + self.execute("MATCH (n) WHERE n.id > %d AND n.id < %d " + "SET n.value = %d" % (lo, hi, num)) + + def update_global_edges(self): + lo = random.randint(0, self.edge_id) + hi = lo + int(self.edge_id * 0.01) + num = random.randint(0, 2 ** 20) + self.execute("MATCH ()-[e]->() WHERE e.id > %d AND e.id < %d " + "SET e.value = %d" % (lo, hi, num)) + def verify_graph(self): """ Checks if the local info corresponds to DB state """ - def test(a, b, message): - assert set(a) == set(b), message % (len(a), len(b)) + def test(obj, length, message): + assert len(obj) == length, message % (len(obj), length) def get(query, key): - return [row[key] for row in self.execute_basic(query)] + ret = self.execute(query) + assert ret != None, "Query '{}' returned 'None'!".format(query) + return [row[key] for row in ret] - # graph state verification must be run in isolation - with self.graph.query_execution_synchronizer.run_isolated(): - test(self.v._list, get("MATCH (n) RETURN n.id", "n.id"), - "Expected %d vertices, found %d") - test(self.e._list, get("MATCH ()-[r]->() RETURN r.id", "r.id"), - "Expected %d edges, found %d") - for lab, exp in self.graph.labels.items(): - test(get("MATCH (n:%s) RETURN n.id" % lab, "n.id"), exp._list, - "Expected %d vertices with label '{}', found %d".format( - lab)) + test(self.v, get("MATCH (n:{}) RETURN count(n)".format( + self.indexed_label), "count(n)")[0], + "Expected %d vertices, found %d") + test(self.e, get("MATCH (:{0})-[r]->(:{0}) RETURN count(r)".format( + self.indexed_label), "count(r)")[0], + "Expected %d edges, found %d") + for lab, exp in self.graph.labels.items(): + test(exp, get("MATCH (n:%s:%s) RETURN count(n)" % ( + self.indexed_label, lab), "count(n)")[0], + "Expected %d vertices with label '{}', found %d".format( + lab)) - log.info("Graph verification success:") - log.info("\tExecuted %d queries in %.2f seconds", - self.graph.query_execution_synchronizer.count_total, - time.time() - self._start_time) - log.info("\tGraph has %d vertices and %d edges", - len(self.v), len(self.e)) - for label in sorted(self.graph.labels.keys()): - log.info("\tVertices with label '%s': %d", - label, len(self.graph.labels[label])) - failures = self.graph.query_failures() - if failures: - log.info("\tQuery failed (reason: count)") - for reason, count in failures.items(): - log.info("\t\t'%s': %d", reason, count) + log.info("Runner %d graph verification success:", self.sid) + log.info("\tExecuted %d queries in %.2f seconds", + self.executed_queries, time.time() - self._start_time) + log.info("\tGraph has %d vertices and %d edges", + len(self.v), len(self.e)) + for label in sorted(self.graph.labels.keys()): + log.info("\tVertices with label '%s': %d", + label, len(self.graph.labels[label])) + failures = self.graph.query_failures() + if failures: + log.info("\tQuery failed (reason: count)") + for reason, count in failures.items(): + log.info("\t\t'%s': %d", reason, count) - def run_loop(self, query_count, max_time): - start_time = time.time() - for _ in range(query_count): - if (time.time() - start_time) / 60 > max_time: + def run_loop(self, vertex_batch, query_count, max_time, verify): + # start the test + start_time = last_verify = time.time() + + # initial batched vertex creation + for _ in range(self.graph.vertex_count // vertex_batch): + if (time.time() - start_time) / 60 > max_time \ + or self.executed_queries > query_count: break + self.create_vertices(vertex_batch) + self.create_vertices(self.graph.vertex_count % vertex_batch) + + # run rest + while self.executed_queries < query_count: + now_time = time.time() + if (now_time - start_time) / 60 > max_time: + break + + if verify > 0 and (now_time - last_verify) > verify: + self.verify_graph() + last_verify = now_time ratio_e = len(self.e) / self.graph.edge_count ratio_v = len(self.v) / self.graph.vertex_count + # try to edit vertices globally + if bernoulli(0.01): + self.update_global_vertices() + + # try to edit edges globally + if bernoulli(0.01): + self.update_global_edges() + # prefer adding/removing edges whenever there is an edge # disbalance and there is enough vertices if ratio_v > 0.5 and abs(1 - ratio_e) > 0.2: @@ -330,7 +271,20 @@ class GraphSession(): if bernoulli(ratio_v / 2.0): self.remove_vertex() else: - self.create_vertex() + self.create_vertices(1) + + +def runner(params): + num, args = params + driver = common.argument_driver(args) + graph = Graph(args.vertex_count // args.thread_count, + args.edge_count // args.thread_count) + log.info("Starting query runner process") + session = GraphSession(num, graph, driver.session()) + session.run_loop(args.vertex_batch, args.max_queries // args.thread_count, + args.max_time, args.verify) + log.info("Runner %d executed %d queries", num, session.executed_queries) + driver.close() def parse_args(): @@ -342,6 +296,9 @@ def parse_args(): help="The average number of vertices in the graph") argp.add_argument("--edge-count", type=int, required=True, help="The average number of edges in the graph") + argp.add_argument("--vertex-batch", type=int, default=200, + help="The number of vertices to be created " + "simultaneously") argp.add_argument("--prop-count", type=int, default=5, help="The max number of properties on a node") argp.add_argument("--max-queries", type=int, default=2 ** 30, @@ -351,7 +308,7 @@ def parse_args(): argp.add_argument("--verify", type=int, default=0, help="Interval (seconds) between checking local info") argp.add_argument("--thread-count", type=int, default=1, - help="The number of threads that operate on the graph" + help="The number of threads that operate on the graph " "independently") return argp.parse_args() @@ -363,33 +320,18 @@ def main(): logging.getLogger("neo4j").setLevel(logging.WARNING) log.info("Starting Memgraph long running test") - graph = Graph(args.vertex_count, args.edge_count) + # cleanup and create indexes driver = common.argument_driver(args) - - # cleanup driver.session().run("MATCH (n) DETACH DELETE n").consume() - driver.session().run("CREATE INDEX ON :%s(id)" % INDEXED_LABEL).consume() - - if args.verify > 0: - log.info("Creating veification session") - verififaction_session = GraphSession(graph, driver.session()) - common.periodically_execute(verififaction_session.verify_graph, (), - args.verify) - # TODO better verification failure handling - - threads = [] - for _ in range(args.thread_count): - log.info("Creating query runner thread") - session = GraphSession(graph, driver.session()) - thread = Thread(target=session.run_loop, - args=(args.max_queries // args.thread_count, - args.max_time), - daemon=True) - threads.append(thread) - list(map(Thread.start, threads)) - - list(map(Thread.join, threads)) + for i in range(args.thread_count): + label = INDEX_FORMAT.format(i) + driver.session().run("CREATE INDEX ON :%s(id)" % label).consume() driver.close() + + params = [(i, args) for i in range(args.thread_count)] + with multiprocessing.Pool(args.thread_count) as p: + p.map(runner, params, 1) + log.info("All query runners done")