diff --git a/src/matgl/graph/compute.py b/src/matgl/graph/compute.py index 8af2a857..a2ecf546 100644 --- a/src/matgl/graph/compute.py +++ b/src/matgl/graph/compute.py @@ -94,7 +94,7 @@ def create_line_graph(g: dgl.DGLGraph, threebody_cutoff: float, directed: bool = def ensure_line_graph_compatibility( - graph: dgl.DGLGraph, line_graph: dgl.DGLGraph, threebody_cutoff: float, directed: bool = False, tol: float = 5e-7 + graph: dgl.DGLGraph, line_graph: dgl.DGLGraph, threebody_cutoff: float, directed: bool = False, tol: float = 5e-6 ) -> dgl.DGLGraph: """Ensure that line graph is compatible with graph. @@ -306,7 +306,7 @@ def _ensure_3body_line_graph_compatibility(graph: dgl.DGLGraph, line_graph: dgl. def _ensure_directed_line_graph_compatibility( - graph: dgl.DGLGraph, line_graph: dgl.DGLGraph, threebody_cutoff: float, tol: float = 5e-7 + graph: dgl.DGLGraph, line_graph: dgl.DGLGraph, threebody_cutoff: float, tol: float = 5e-6 ) -> dgl.DGLGraph: """Ensure that line graph is compatible with graph. diff --git a/src/matgl/graph/data.py b/src/matgl/graph/data.py index cf3ed746..76529f47 100644 --- a/src/matgl/graph/data.py +++ b/src/matgl/graph/data.py @@ -267,7 +267,11 @@ def __getitem__(self, idx: int): self.graphs[idx], self.lattices[idx], self.state_attr[idx], - {k: torch.tensor(v[idx], dtype=matgl.float_th) for k, v in self.labels.items()}, + { + k: torch.tensor(v[idx], dtype=matgl.float_th) + for k, v in self.labels.items() + if not isinstance(v[idx], str) + }, ] if self.include_line_graph: items.insert(2, self.line_graphs[idx])