Skip to content

Commit

Permalink
few fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
khushaljethava committed Nov 26, 2024
1 parent 4a40372 commit 51b33bd
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 19 deletions.
4 changes: 2 additions & 2 deletions graphviz2drawio/graphviz2drawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def convert(graph_to_convert: AGraph | str | IO, layout_prog: str = "dot") -> st
graph = graph_to_convert
else:
graph = AGraph(graph_to_convert)

graph_edges: dict[str, dict] = {
f"{e[0]}->{e[1]}-"
+ (e.attr.get("xlabel") or e.attr.get("label") or ""): e.attr.to_dict()
Expand All @@ -22,7 +22,7 @@ def convert(graph_to_convert: AGraph | str | IO, layout_prog: str = "dot") -> st
graph_nodes: dict[str, dict] = {n: n.attr.to_dict() for n in graph.nodes_iter()}

svg_graph = graph.draw(prog=layout_prog, format="svg")

nodes, edges, clusters = parse_nodes_edges_clusters(
svg_data=svg_graph,
is_directed=graph.directed,
Expand Down
38 changes: 38 additions & 0 deletions graphviz2drawio/models/SVG.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,38 @@
from xml.etree.ElementTree import Element
import xml.etree.ElementTree as ET
import re

NS_SVG = "{http://www.w3.org/2000/svg}"


def parse_svg_path(path_data):
# Extract all numbers from the path data
numbers = list(map(float, re.findall(r'-?\d+\.?\d*', path_data)))

# Separate x and y coordinates
x_coords = numbers[0::2] # Every other value starting from index 0
y_coords = numbers[1::2] # Every other value starting from index 1
# Calculate the bounding box
min_x = min(x_coords)
max_x = max(x_coords)
min_y = min(y_coords)
max_y = max(y_coords)

width = max_x - min_x
height = max_y - min_y

# Single x and y as the center of the bounding box
single_x = (min_x + max_x) / 2
single_y = (min_y + max_y) / 2

return single_x, single_y, width, height

def svg_tag(tag: str) -> str:
return f"{NS_SVG}{tag}"


def get_first(g: Element, tag: str) -> Element | None:

return g.find(f"./{NS_SVG}{tag}")


Expand All @@ -26,6 +51,19 @@ def get_text(g: Element) -> str | None:
return text_el.text
return None

def get_d(g: Element) -> str | None:
polygon_str = ET.tostring(g, encoding='unicode')
# Regular expression to find the <ns0:polygon> tag
polygon_elements = re.findall(r'<ns0:path[^>]*>', polygon_str)

# Join the results (in case there are multiple matches)
polygon_str = ''.join(polygon_elements)
match = re.search(r'd="([^"]+)"', polygon_str)
# If a match is found, format it as a dictionary
if match:
points = match.group(1)
x_coords, y_coords, width, height = parse_svg_path(points)
return x_coords, y_coords, width, height

def is_tag(g: Element, tag: str) -> bool:
return g.tag == svg_tag(tag)
4 changes: 2 additions & 2 deletions graphviz2drawio/models/SvgParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .commented_tree_builder import COMMENT, CommentedTreeBuilder
from .CoordsTranslate import CoordsTranslate
from .Errors import MissingTitleError
import xml.etree.ElementTree as ET


def parse_nodes_edges_clusters(
Expand All @@ -26,7 +27,6 @@ def parse_nodes_edges_clusters(
svg_data,
parser=ElementTree.XMLParser(target=CommentedTreeBuilder()),
)[0]

coords = CoordsTranslate.from_svg_transform(root.attrib["transform"])
node_factory = NodeFactory(coords)
edge_factory = EdgeFactory(coords=coords, is_directed=is_directed)
Expand All @@ -35,7 +35,6 @@ def parse_nodes_edges_clusters(
edges: OrderedDict[str, Edge] = OrderedDict()
clusters: OrderedDict[str, Node] = OrderedDict()
gradients = dict[str, Gradient]()

prev_comment = None
for g in root:
if g.tag == COMMENT:
Expand Down Expand Up @@ -68,6 +67,7 @@ def parse_nodes_edges_clusters(
else:
edges[edge.key_for_label] = edge
elif g.attrib["class"] == "cluster":

clusters[title] = node_factory.from_svg(
g,
labelloc="t",
Expand Down
24 changes: 14 additions & 10 deletions graphviz2drawio/mx/MxGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from graphviz2drawio.mx.Edge import Edge
from graphviz2drawio.mx.Node import Node
from graphviz2drawio.mx.Styles import Styles
from graphviz2drawio.models import SVG


class MxGraph:
Expand All @@ -16,14 +17,14 @@ def __init__(
clusters: OrderedDict[str, Node],
nodes: OrderedDict[str, Node],
edges: list[Edge],

) -> None:
self.nodes = nodes
self.edges = edges
self.graph = Element(MxConst.GRAPH, attrib={"grid": "0"})
self.root = SubElement(self.graph, MxConst.ROOT)
SubElement(self.root, MxConst.CELL, attrib={"id": "0"})
SubElement(self.root, MxConst.CELL, attrib={"id": "1", "parent": "0"})

# Add nodes first so edges are drawn on top
for cluster in clusters.values():
self.add_node(cluster)
Expand Down Expand Up @@ -103,18 +104,21 @@ def add_mx_geo(
attributes["as"] = "geometry"
SubElement(element, MxConst.GEO, attributes)
elif text_offset is not None:
# Calculate width and height based on text_offset
width = abs(text_offset.real) * 0.25 # Example calculation
height = abs(text_offset.imag) * 0.25 # Example calculation
x = str(text_offset.real / 2)
y = str(text_offset.imag / 2)
geo = SubElement(
element,
MxConst.GEO,
attrib={"as": "geometry", "relative": "1"},
)
SubElement(
geo,
MxConst.POINT,
attrib={
"x": str(text_offset.real),
"y": str(text_offset.imag),
"as": "offset",
"relative": "1",
"x": x,
"y": y,
"width": str(width),
"height": str(height),
"as": "geometry"
},
)
else:
Expand Down Expand Up @@ -167,4 +171,4 @@ def __str__(self) -> str:
return self.value()

def __repr__(self) -> str:
return self.value()
return self.value()
9 changes: 4 additions & 5 deletions graphviz2drawio/mx/NodeFactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .RectFactory import rect_from_ellipse_svg, rect_from_image, rect_from_svg_points
from .Text import Text
from .utils import adjust_color_opacity

from graphviz2drawio.models.Rect import Rect

class NodeFactory:
def __init__(self, coords: CoordsTranslate) -> None:
Expand All @@ -34,13 +34,12 @@ def from_svg(

if sid is None or gid is None:
raise MissingIdentifiersError(sid, gid)

if (inner_g := SVG.get_first(g, "g")) is not None:
if (inner_a := SVG.get_first(inner_g, "a")) is not None:
g = inner_a

if (polygon := SVG.get_first(g, "polygon")) is not None:
rect = rect_from_svg_points(self.coords, polygon.attrib["points"])
if (polygon := SVG.get_first(g, "path")) is not None:
x_coords, y_coords, width, height = SVG.get_d(polygon)
rect = Rect(x=x_coords,y=y_coords,width=width,height=height)
shape = Shape.RECT
fill = self._extract_fill(polygon, gradients)
stroke = self._extract_stroke(polygon)
Expand Down

0 comments on commit 51b33bd

Please sign in to comment.