Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/sevn-ai/pyspur_dev into bug…
Browse files Browse the repository at this point in the history
…fix/minor-fixes
  • Loading branch information
preet-bhadra committed Feb 13, 2025
2 parents ae06dc1 + a8d7b6c commit 605cdd2
Show file tree
Hide file tree
Showing 23 changed files with 512 additions and 63 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<h1 align="center">
<img src="pyspur-logo.svg" alt="PySpur Logo" width="40" height="40" style="vertical-align: middle; filter: invert(1)" data-dark-mode="filter: invert(0)"> PySpur - Graph UI for building and testing AI Agents
PySpur - Graph UI for AI Agents
</h1>
<p align="center">
<a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-blue"></a>
Expand Down
8 changes: 0 additions & 8 deletions backend/README.MD

This file was deleted.

4 changes: 4 additions & 0 deletions backend/app/api/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from ..nodes.registry import NodeRegistry
NodeRegistry.discover_nodes()

from .node_management import router as node_management_router
from .workflow_management import router as workflow_management_router
Expand All @@ -16,9 +18,11 @@
from .rag_management import router as rag_management_router
from .file_management import router as file_management_router


load_dotenv()



app = FastAPI(root_path="/api")

# Add CORS middleware
Expand Down
4 changes: 4 additions & 0 deletions backend/app/execution/workflow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,10 @@ async def run(
print(
f"[WARNING]: Node {node_id} does not have an output_model defined: {e}\n skipping precomputed output"
)
except KeyError as e:
print(
f"[WARNING]: Node {node_id} not found in the predecessor workflow: {e}\n skipping precomputed output"
)

# Store input in initial inputs to be used by InputNode
input_node = next(
Expand Down
9 changes: 6 additions & 3 deletions backend/app/nodes/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from abc import ABC, abstractmethod
from hashlib import md5
from typing import Any, Dict, List, Optional, Type
import json

from pydantic import BaseModel, Field, create_model
from ..execution.workflow_execution_context import WorkflowExecutionContext
from ..schemas.workflow_schemas import WorkflowDefinitionSchema

from ..utils import pydantic_utils

class VisualTag(BaseModel):
"""
Expand Down Expand Up @@ -105,9 +106,11 @@ def setup(self) -> None:
For dynamic schema nodes, these can be created based on self.config.
"""
if self._config.has_fixed_output:
self.output_model = self.create_output_model_class(
self._config.output_schema
schema = json.loads(self._config.output_json_schema)
model = pydantic_utils.json_schema_to_model(
schema, model_class_name=self.name, base_class=BaseNodeOutput
)
self.output_model = model # type: ignore

def create_output_model_class(
self, output_schema: Dict[str, str]
Expand Down
71 changes: 51 additions & 20 deletions backend/app/nodes/factory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import importlib
from typing import Any, List, Dict
from typing import Any, Dict, List

from ..schemas.node_type_schemas import NodeTypeSchema
from .base import BaseNode
from .registry import NodeRegistry

from .node_types import (
SUPPORTED_NODE_TYPES,
Expand All @@ -14,50 +15,68 @@
class NodeFactory:
"""
Factory for creating node instances from a configuration.
Node type definitions are expected to be in the nodes package.
Supports both decorator-based registration and legacy configured registration.
Conventions:
- The node class should be named <NodeTypeName>Node
- The config model should be named <NodeTypeName>NodeConfig
- The input model should be named <NodeTypeName>NodeInput
- The output model should be named <NodeTypeName>NodeOutput
- There should be only one node type class per module
- The module name should be the snake_case version of the node type name
Example:
- Node type: Example
- Node class: ExampleNode
- Config model: ExampleNodeConfig
- Input model: ExampleNodeInput
- Output model: ExampleNodeOutput
- Module name: example
- Node type: MCTS
- Node class: MCTSNode
- Config model: MCTSNodeConfig
- Input model: MCTSNodeInput
- Output model: MCTSNodeOutput
- Module name: llm.mcts
Nodes can be registered in two ways:
1. Using the @NodeRegistry.register decorator (recommended)
2. Through the legacy configured SUPPORTED_NODE_TYPES in node_types.py
"""

@staticmethod
def get_all_node_types() -> Dict[str, List[NodeTypeSchema]]:
"""
Returns a dictionary of all available node types grouped by category.
Combines both decorator-registered and configured nodes.
"""
return get_all_node_types()
# Get nodes from both sources
configured_nodes = get_all_node_types()
registered_nodes = NodeRegistry.get_registered_nodes()

# Convert registered nodes to NodeTypeSchema
converted_nodes: Dict[str, List[NodeTypeSchema]] = {}
for category, nodes in registered_nodes.items():
if category not in converted_nodes:
converted_nodes[category] = []
for node in nodes:
schema = NodeTypeSchema(
node_type_name=node["node_type_name"],
module=node["module"],
class_name=node["class_name"]
)
converted_nodes[category].append(schema)

# Merge nodes, giving priority to configured ones
result = configured_nodes.copy()
for category, nodes in converted_nodes.items():
if category not in result:
result[category] = []
# Only add nodes that aren't already present
for node in nodes:
if not any(n.node_type_name == node.node_type_name for n in result[category]):
result[category].append(node)

return result

@staticmethod
def create_node(node_name: str, node_type_name: str, config: Any) -> BaseNode:
"""
Creates a node instance from a configuration.
Checks both registration methods for the node type.
"""
if not is_valid_node_type(node_type_name):
raise ValueError(f"Node type '{node_type_name}' is not valid.")

module_name = None
class_name = None
# Use the imported _SUPPORTED_NODE_TYPES

# First check configured nodes
for node_group in SUPPORTED_NODE_TYPES.values():
for node_type in node_group:
if node_type["node_type_name"] == node_type_name:
Expand All @@ -67,6 +86,18 @@ def create_node(node_name: str, node_type_name: str, config: Any) -> BaseNode:
if module_name and class_name:
break

# If not found, check registry
if not module_name or not class_name:
registered_nodes = NodeRegistry.get_registered_nodes()
for nodes in registered_nodes.values():
for node in nodes:
if node["node_type_name"] == node_type_name:
module_name = node["module"]
class_name = node["class_name"]
break
if module_name and class_name:
break

if not module_name or not class_name:
raise ValueError(f"Node type '{node_type_name}' not found.")

Expand Down
1 change: 0 additions & 1 deletion backend/app/nodes/integrations/firecrawl/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@

13 changes: 9 additions & 4 deletions backend/app/nodes/integrations/firecrawl/firecrawl_scrape.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ...base import BaseNode, BaseNodeConfig, BaseNodeInput, BaseNodeOutput
from firecrawl import FirecrawlApp # type: ignore
from ...utils.template_utils import render_template_or_get_first_string
from ...registry import NodeRegistry


class FirecrawlScrapeNodeInput(BaseNodeInput):
Expand Down Expand Up @@ -31,15 +32,19 @@ class FirecrawlScrapeNodeConfig(BaseNodeConfig):
)


@NodeRegistry.register(
category="Integrations",
display_name="Firecrawl Scrape",
logo="/images/firecrawl.png",
subcategory="Web Scraping",
position="after:FirecrawlCrawlNode"
)
class FirecrawlScrapeNode(BaseNode):
name = "firecrawl_scrape_node"
display_name = "Firecrawl Scrape"
logo = "/images/firecrawl.png"
category = "Firecrawl"

config_model = FirecrawlScrapeNodeConfig
input_model = FirecrawlScrapeNodeInput
output_model = FirecrawlScrapeNodeOutput
category = "Firecrawl" # This will be used by the frontend for subcategory grouping

async def run(self, input: BaseModel) -> BaseModel:
"""
Expand Down
23 changes: 17 additions & 6 deletions backend/app/nodes/node_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict, List

from ..schemas.node_type_schemas import NodeTypeSchema
from .registry import NodeRegistry

# Simple lists of supported and deprecated node types

Expand Down Expand Up @@ -118,11 +119,11 @@
"module": ".nodes.integrations.firecrawl.firecrawl_crawl",
"class_name": "FirecrawlCrawlNode",
},
{
"node_type_name": "FirecrawlScrapeNode",
"module": ".nodes.integrations.firecrawl.firecrawl_scrape",
"class_name": "FirecrawlScrapeNode",
},
# {
# "node_type_name": "FirecrawlScrapeNode",
# "module": ".nodes.integrations.firecrawl.firecrawl_scrape",
# "class_name": "FirecrawlScrapeNode",
# },
{
"node_type_name": "JinaReaderNode",
"module": ".nodes.integrations.jina.jina_reader",
Expand Down Expand Up @@ -207,13 +208,23 @@ def get_all_node_types() -> Dict[str, List[NodeTypeSchema]]:

def is_valid_node_type(node_type_name: str) -> bool:
"""
Checks if a node type is valid (supported or deprecated).
Checks if a node type is valid (supported, deprecated, or registered via decorator).
"""
# Check configured nodes first
for node_types in SUPPORTED_NODE_TYPES.values():
for node_type in node_types:
if node_type["node_type_name"] == node_type_name:
return True

for node_type in DEPRECATED_NODE_TYPES:
if node_type["node_type_name"] == node_type_name:
return True

# Check registry for decorator-registered nodes
registered_nodes = NodeRegistry.get_registered_nodes()
for nodes in registered_nodes.values():
for node in nodes:
if node["node_type_name"] == node_type_name:
return True

return False
Loading

0 comments on commit 605cdd2

Please sign in to comment.