diff --git a/query_modules/louvain/src/data_structures/graph.cpp b/query_modules/louvain/src/data_structures/graph.cpp index c645784a0..23508234b 100644 --- a/query_modules/louvain/src/data_structures/graph.cpp +++ b/query_modules/louvain/src/data_structures/graph.cpp @@ -78,14 +78,23 @@ double Graph::Modularity() const { if (total_w_ == 0) return 0; + std::unordered_map weight_c; + std::unordered_map degree_c; + for (uint32_t i = 0; i < n_nodes_; ++i) { + degree_c[Community(i)] += + static_cast(IncidentWeight(i)); for (const auto &neigh : adj_list_[i]) { uint32_t j = neigh.dest; double w = neigh.weight; if (Community(i) != Community(j)) continue; - ret += w - (IncidentWeight(i) * IncidentWeight(j) / (2.0 * total_w_)); + weight_c[Community(i)] += w; } } + + for (const auto &p : degree_c) + ret += weight_c[p.first] - (p.second * p.second) / (2 * total_w_); + ret /= 2 * total_w_; return ret; } diff --git a/query_modules/louvain/test/unit/graph.cpp b/query_modules/louvain/test/unit/graph.cpp index b80a98773..3cfba6bb4 100644 --- a/query_modules/louvain/test/unit/graph.cpp +++ b/query_modules/louvain/test/unit/graph.cpp @@ -273,7 +273,7 @@ TEST(Graph, Modularity) { {3, 4, 4.2}}); std::vector c = {0, 1, 1, 2, 2}; SetCommunities(&graph, c); - EXPECT_NEAR(graph.Modularity(), 0.37452886332076973, 1e-6); + EXPECT_NEAR(graph.Modularity(), 0.036798254314620096, 1e-6); // Tree // (0)--(3) @@ -289,7 +289,7 @@ TEST(Graph, Modularity) { {2, 6, 0.7}}); c = {0, 0, 1, 0, 0, 1, 2}; SetCommunities(&graph, c); - EXPECT_NEAR(graph.Modularity(), 0.6945087219651122, 1e-6); + EXPECT_NEAR(graph.Modularity(), 0.4424617301530794, 1e-6); // Graph without self-loops // (0)--(1) @@ -305,7 +305,7 @@ TEST(Graph, Modularity) { {3, 4, 0.7}}); c = {0, 1, 1, 1, 1}; SetCommunities(&graph, c); - EXPECT_NEAR(graph.Modularity(), 0.32653061224489793, 1e-6); + EXPECT_NEAR(graph.Modularity(), -0.022959183673469507, 1e-6); // Graph with self loop [*nodes have self loops] // (0)--(1*) @@ -324,5 +324,42 @@ TEST(Graph, Modularity) { {4, 4, 1}}); c = {0, 0, 0, 0, 1}; SetCommunities(&graph, c); - EXPECT_NEAR(graph.Modularity(), 0.2754545454545455, 1e-6); + EXPECT_NEAR(graph.Modularity(), 0.188842975206611, 1e-6); + + // Neo4j example graph + // (0)--(1)---(3)--(4) + // \ / \ / + // (2) (5) + graph = BuildGraph(6, {{0, 1, 1}, + {1, 2, 1}, + {0, 2, 1}, + {1, 3, 1}, + {3, 5, 1}, + {5, 4, 1}, + {3, 4, 1}}); + c = {0, 0, 0, 1, 1, 1}; + SetCommunities(&graph, c); + EXPECT_NEAR(graph.Modularity(), 0.3571428571428571, 1e-6); + + // Example graph from wikipedia + // (0)--(1)--(3)--(4)--(5) + // \ / | \ / + // (2) (7) (6) + // / \ + // (8)--(9) + graph = BuildGraph(10, {{0, 1, 1}, + {1, 2, 1}, + {0, 2, 1}, + {1, 3, 1}, + {3, 4, 1}, + {4, 5, 1}, + {5, 6, 1}, + {6, 4, 1}, + {3, 7, 1}, + {7, 8, 1}, + {7, 9, 1}, + {8, 9, 1}}); + c = {0, 0, 0, 0, 1, 1, 1, 2, 2, 2}; + SetCommunities(&graph, c); + EXPECT_NEAR(graph.Modularity(), 0.4896, 1e-4); }