diff --git a/python-wrapper/src/neo4j_viz/neo4j.py b/python-wrapper/src/neo4j_viz/neo4j.py index 3a301ca..44867bc 100644 --- a/python-wrapper/src/neo4j_viz/neo4j.py +++ b/python-wrapper/src/neo4j_viz/neo4j.py @@ -4,7 +4,7 @@ from typing import Optional, Union import neo4j.graph -from neo4j import Driver, Result, RoutingControl +from neo4j import Driver, EagerResult, Result, RoutingControl from pydantic import BaseModel, ValidationError from neo4j_viz.colors import NEO4J_COLORS_DISCRETE, ColorSpace @@ -21,12 +21,65 @@ def _parse_validation_error(e: ValidationError, entity_type: type[BaseModel]) -> ) +def _collect_graph_entities( + value: object, + nodes: dict[str, neo4j.graph.Node], + rels: dict[str, neo4j.graph.Relationship], +) -> None: + """Recursively extract Node and Relationship objects from any record value.""" + if isinstance(value, neo4j.graph.Node): + nodes[value.element_id] = value + elif isinstance(value, neo4j.graph.Relationship): + rels[value.element_id] = value + elif isinstance(value, neo4j.graph.Path): + for node in value.nodes: + nodes[node.element_id] = node + for rel in value.relationships: + rels[rel.element_id] = rel + elif isinstance(value, list): + for item in value: + _collect_graph_entities(item, nodes, rels) + elif isinstance(value, dict): + for item in value.values(): + _collect_graph_entities(item, nodes, rels) + + +def _graph_from_eager_result(data: "EagerResult") -> neo4j.graph.Graph: + """Return the bolt hydration Graph shared by all entities in an EagerResult. + + Every Node/Relationship produced by the same query references the same + internal Graph that the driver built during bolt hydration — identical to + what Result.graph() returns. We find the first entity in the records and + return its .graph. If the result contains no graph entities at all we fall + back to walking the records manually. + """ + for record in data.records: + for value in record.values(): + if isinstance(value, (neo4j.graph.Node, neo4j.graph.Relationship)): + return value.graph + if isinstance(value, neo4j.graph.Path) and value.nodes: + return value.nodes[0].graph + + # Fallback: no direct entity columns — walk everything recursively and + # build a synthetic Graph so the rest of from_neo4j can stay uniform. + nodes_dict: dict[str, neo4j.graph.Node] = {} + rels_dict: dict[str, neo4j.graph.Relationship] = {} + for record in data.records: + for value in record.values(): + _collect_graph_entities(value, nodes_dict, rels_dict) + graph = neo4j.graph.Graph() + graph._nodes = nodes_dict + for rel in rels_dict.values(): + graph._relationships[rel.element_id] = rel + return graph + + def from_neo4j( - data: Union[neo4j.graph.Graph, Result, Driver], + data: Union[neo4j.graph.Graph, Result, EagerResult, Driver], row_limit: int = 10_000, ) -> VisualizationGraph: """ - Create a VisualizationGraph from a Neo4j `Graph`, Neo4j `Result` or Neo4j `Driver`. + Create a VisualizationGraph from a Neo4j `Graph`, Neo4j `Result`, Neo4j `EagerResult` or Neo4j `Driver`. By default: @@ -39,9 +92,10 @@ def from_neo4j( Parameters ---------- - data : Union[neo4j.graph.Graph, neo4j.Result, neo4j.Driver] - Either a query result in the shape of a `neo4j.graph.Graph` or `neo4j.Result`, or a `neo4j.Driver` in - which case a simple default query will be executed internally to retrieve the graph data. + data : Union[neo4j.graph.Graph, neo4j.Result, neo4j.EagerResult, neo4j.Driver] + Either a query result in the shape of a `neo4j.graph.Graph`, `neo4j.Result`, or `neo4j.EagerResult` + (as returned by `driver.execute_query()`), or a `neo4j.Driver` in which case a simple default query + will be executed internally to retrieve the graph data. row_limit : int, optional Maximum number of rows to return from the query, by default 10_000. This is only used if a `neo4j.Driver` is passed as `result` argument, otherwise the limit is ignored. @@ -49,8 +103,19 @@ def from_neo4j( if isinstance(data, Result): graph = data.graph() + raw_nodes = graph.nodes + raw_relationships = graph.relationships elif isinstance(data, neo4j.graph.Graph): - graph = data + raw_nodes = data.nodes + raw_relationships = data.relationships + elif isinstance(data, EagerResult): + # Every Node/Relationship hydrated from the same query shares one Graph + # object (the bolt hydration graph). Grabbing it from the first entity + # gives us the complete graph — including start/end nodes of + # relationships that were never returned as explicit columns. + graph = _graph_from_eager_result(data) + raw_nodes = graph.nodes + raw_relationships = graph.relationships elif isinstance(data, Driver): rel_count = data.execute_query( "MATCH ()-[r]->() RETURN count(r) as count", @@ -66,14 +131,18 @@ def from_neo4j( routing_=RoutingControl.READ, result_transformer_=Result.graph, ) + raw_nodes = graph.nodes + raw_relationships = graph.relationships else: - raise ValueError(f"Invalid input type `{type(data)}`. Expected `neo4j.Graph`, `neo4j.Result` or `neo4j.Driver`") + raise ValueError( + f"Invalid input type `{type(data)}`. Expected `neo4j.Graph`, `neo4j.Result`, `neo4j.EagerResult` or `neo4j.Driver`" + ) - nodes = [_map_node(node) for node in graph.nodes] + nodes = [_map_node(node) for node in raw_nodes] relationships = [] - for rel in graph.relationships: + for rel in raw_relationships: mapped_rel = _map_relationship(rel) if mapped_rel: relationships.append(mapped_rel) diff --git a/python-wrapper/tests/test_collect_graph_entities.py b/python-wrapper/tests/test_collect_graph_entities.py new file mode 100644 index 0000000..fb59e47 --- /dev/null +++ b/python-wrapper/tests/test_collect_graph_entities.py @@ -0,0 +1,131 @@ +import neo4j.graph + +from neo4j_viz.neo4j import _collect_graph_entities + + +def _make_graph() -> neo4j.graph.Graph: + return neo4j.graph.Graph() + + +def _make_node( + graph: neo4j.graph.Graph, element_id: str, labels: list[str], props: dict[str, object] +) -> neo4j.graph.Node: + return neo4j.graph.Node(graph, element_id, hash(element_id), labels, props) + + +def _make_rel( + graph: neo4j.graph.Graph, + element_id: str, + rel_type: str, + start: neo4j.graph.Node, + end: neo4j.graph.Node, + props: dict[str, object] | None = None, +) -> neo4j.graph.Relationship: + RelType = graph.relationship_type(rel_type) + rel = RelType.__new__(RelType) + rel.__dict__.update( + { + "_graph": graph, + "_element_id": element_id, + "_id": hash(element_id), + "_properties": props or {}, + "_start_node": start, + "_end_node": end, + } + ) + return rel + + +def test_plain_node() -> None: + g = _make_graph() + node = _make_node(g, "n1", ["A"], {"x": 1}) + nodes: dict[str, neo4j.graph.Node] = {} + rels: dict[str, neo4j.graph.Relationship] = {} + _collect_graph_entities(node, nodes, rels) + assert "n1" in nodes + assert rels == {} + + +def test_plain_relationship() -> None: + g = _make_graph() + a = _make_node(g, "a", ["A"], {}) + b = _make_node(g, "b", ["B"], {}) + rel = _make_rel(g, "r1", "KNOWS", a, b) + nodes: dict[str, neo4j.graph.Node] = {} + rels: dict[str, neo4j.graph.Relationship] = {} + _collect_graph_entities(rel, nodes, rels) + assert "r1" in rels + assert nodes == {} + + +def test_path() -> None: + g = _make_graph() + a = _make_node(g, "a", ["A"], {}) + b = _make_node(g, "b", ["B"], {}) + rel = _make_rel(g, "r1", "KNOWS", a, b) + path = neo4j.graph.Path(a, rel) + nodes: dict[str, neo4j.graph.Node] = {} + rels: dict[str, neo4j.graph.Relationship] = {} + _collect_graph_entities(path, nodes, rels) + assert set(nodes) == {"a", "b"} + assert set(rels) == {"r1"} + + +def test_list_of_nodes() -> None: + g = _make_graph() + a = _make_node(g, "a", ["A"], {}) + b = _make_node(g, "b", ["B"], {}) + nodes: dict[str, neo4j.graph.Node] = {} + rels: dict[str, neo4j.graph.Relationship] = {} + _collect_graph_entities([a, b], nodes, rels) + assert set(nodes) == {"a", "b"} + + +def test_nested_list() -> None: + g = _make_graph() + a = _make_node(g, "a", ["A"], {}) + nodes: dict[str, neo4j.graph.Node] = {} + rels: dict[str, neo4j.graph.Relationship] = {} + _collect_graph_entities([[a]], nodes, rels) + assert "a" in nodes + + +def test_dict_of_nodes() -> None: + g = _make_graph() + a = _make_node(g, "a", ["A"], {}) + nodes: dict[str, neo4j.graph.Node] = {} + rels: dict[str, neo4j.graph.Relationship] = {} + _collect_graph_entities({"key": a}, nodes, rels) + assert "a" in nodes + + +def test_deduplication() -> None: + g = _make_graph() + a = _make_node(g, "a", ["A"], {}) + nodes: dict[str, neo4j.graph.Node] = {} + rels: dict[str, neo4j.graph.Relationship] = {} + _collect_graph_entities([a, a], nodes, rels) + assert len(nodes) == 1 + + +def test_scalar_ignored() -> None: + nodes: dict[str, neo4j.graph.Node] = {} + rels: dict[str, neo4j.graph.Relationship] = {} + _collect_graph_entities("hello", nodes, rels) + _collect_graph_entities(42, nodes, rels) + _collect_graph_entities(None, nodes, rels) + assert nodes == {} and rels == {} + + +def test_mixed_list_with_path_and_node() -> None: + g = _make_graph() + a = _make_node(g, "a", ["A"], {}) + b = _make_node(g, "b", ["B"], {}) + c = _make_node(g, "c", ["C"], {}) + rel = _make_rel(g, "r1", "KNOWS", a, b) + path = neo4j.graph.Path(a, rel) + nodes: dict[str, neo4j.graph.Node] = {} + rels: dict[str, neo4j.graph.Relationship] = {} + _collect_graph_entities([path, c], nodes, rels) + assert set(nodes) == {"a", "b", "c"} + assert set(rels) == {"r1"} diff --git a/python-wrapper/tests/test_neo4j.py b/python-wrapper/tests/test_neo4j.py index 43fe8bc..461ae07 100644 --- a/python-wrapper/tests/test_neo4j.py +++ b/python-wrapper/tests/test_neo4j.py @@ -3,7 +3,7 @@ import neo4j import pytest -from neo4j import Driver, Session +from neo4j import Driver, EagerResult, Session from neo4j_viz.colors import NEO4J_COLORS_DISCRETE from neo4j_viz.neo4j import from_neo4j @@ -123,6 +123,58 @@ def test_from_neo4j_result(neo4j_session: Session) -> None: ] +@pytest.mark.requires_neo4j_and_gds +def test_from_neo4j_eager_result(neo4j_session: Session, neo4j_driver: Driver) -> None: + graph = neo4j_session.run("MATCH (a:_CI_A|_CI_B)-[r]->(b) RETURN a, b, r ORDER BY a").graph() + + eager_result: EagerResult = neo4j_driver.execute_query("MATCH (a:_CI_A|_CI_B)-[r]->(b) RETURN a, b, r ORDER BY a") + assert isinstance(eager_result, EagerResult) + + VG = from_neo4j(eager_result) + + sorted_nodes: list[neo4j.graph.Node] = sorted(graph.nodes, key=lambda x: dict(x.items())["name"]) + node_ids: list[str] = [node.element_id for node in sorted_nodes] + + expected_nodes = [ + Node( + id=node_ids[0], + caption="_CI_A", + color=NEO4J_COLORS_DISCRETE[0], + properties=dict( + labels=["_CI_A"], + name="Alice", + height=20, + id=42, + _id=1337, + caption="hello", + ), + ), + Node( + id=node_ids[1], + caption="_CI_A:_CI_B", + color=NEO4J_COLORS_DISCRETE[1], + properties=dict( + size=11, + labels=["_CI_A", "_CI_B"], + name="Bob", + height=10, + id=84, + __labels=[1, 2], + ), + ), + ] + + assert len(VG.nodes) == 2 + assert sorted(VG.nodes, key=lambda x: x.properties["name"]) == expected_nodes + + assert len(VG.relationships) == 2 + vg_rels = sorted([(e.source, e.target, e.caption) for e in VG.relationships], key=lambda x: x[2] if x[2] else "foo") + assert vg_rels == [ + (node_ids[0], node_ids[1], "KNOWS"), + (node_ids[1], node_ids[0], "RELATED"), + ] + + @pytest.mark.requires_neo4j_and_gds def test_from_neo4j_graph_driver(neo4j_session: Session, neo4j_driver: Driver) -> None: graph = neo4j_session.run("MATCH (a:_CI_A|_CI_B)-[r]->(b) RETURN a, b, r ORDER BY a").graph() diff --git a/python-wrapper/uv.lock b/python-wrapper/uv.lock index 433ed9d..ed1db82 100644 --- a/python-wrapper/uv.lock +++ b/python-wrapper/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.14'", @@ -2389,7 +2389,7 @@ wheels = [ [[package]] name = "neo4j-viz" -version = "1.3.0" +version = "1.4.0" source = { editable = "." } dependencies = [ { name = "anywidget" },