-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #103 from aitomatic/dev
add Level-2 Planning & Reasoning intelligence capabilities
- Loading branch information
Showing
28 changed files
with
1,021 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Setups" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%load_ext autoreload\n", | ||
"%autoreload 2" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from pprint import pprint\n", | ||
"from IPython.display import display, Markdown, Pretty" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from dotenv import load_dotenv\n", | ||
"load_dotenv()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import nest_asyncio\n", | ||
"nest_asyncio.apply()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Imports" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from pathlib import Path\n", | ||
"\n", | ||
"from openssa import (Agent,\n", | ||
" HTP, AutoHTPlanner,\n", | ||
" OodaReasoner,\n", | ||
" FileResource)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Problems & Resources" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"PROBLEM = 'Does AMD have a healthy liquidity profile based on FY22 Quick Ratio?'" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"RESOURCE_PATH = Path() / '.FinanceBench' / 'docs' / 'AMD_2022_10K'\n", | ||
"assert RESOURCE_PATH.is_dir()\n", | ||
"\n", | ||
"resource = FileResource(RESOURCE_PATH)\n", | ||
"display(Markdown(resource.overview))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Agent with Planning & Reasoning" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"agent = Agent(planner=AutoHTPlanner(max_depth=3, max_subtasks_per_decomp=9),\n", | ||
" reasoner=OodaReasoner(),\n", | ||
" resources={resource})\n", | ||
"pprint(agent)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Problem-Solving with Automated Planner" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"auto_plan = agent.planner.plan(PROBLEM)\n", | ||
"pprint(auto_plan)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"solution_1 = agent.solve(PROBLEM)\n", | ||
"display(Markdown(solution_1))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Problem-Solving with Expert-Specified Plan" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"expert_plan = HTP.from_dict(\n", | ||
" {\n", | ||
" 'task': PROBLEM,\n", | ||
" 'sub-plans': [\n", | ||
" {\n", | ||
" 'task': 'retrieve data points needed for Quick Ratio',\n", | ||
" 'sub-plans': [\n", | ||
" {\n", | ||
" 'task': 'retrieve Cash & Cash Equivalents'\n", | ||
" },\n", | ||
" {\n", | ||
" 'task': 'retrieve Accounts Receivable'\n", | ||
" },\n", | ||
" {\n", | ||
" 'task': 'retrieve Short-Term Liabilities'\n", | ||
" },\n", | ||
" {\n", | ||
" 'task': 'retrieve Accounts Payable'\n", | ||
" },\n", | ||
" ]\n", | ||
" },\n", | ||
" {\n", | ||
" 'task': 'calculate Quick Ratio'\n", | ||
" },\n", | ||
" {\n", | ||
" 'task': 'see whether Quick Ratio is healthy, i.e. greater than 1'\n", | ||
" },\n", | ||
" ]\n", | ||
" }\n", | ||
")\n", | ||
"pprint(expert_plan)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"solution_2 = agent.solve(PROBLEM, plan=expert_plan)\n", | ||
"display(Markdown(solution_2))" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.8" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
"""Abstract agent with planning, reasoning & informational resources.""" | ||
|
||
|
||
from abc import ABC | ||
from dataclasses import dataclass, field | ||
from pprint import pprint | ||
|
||
from openssa.l2.planning.abstract import AbstractPlan, AbstractPlanner | ||
from openssa.l2.reasoning.abstract import AbstractReasoner | ||
from openssa.l2.reasoning.base import BaseReasoner | ||
from openssa.l2.resource.abstract import AbstractResource | ||
|
||
|
||
@dataclass(init=True, | ||
repr=True, | ||
eq=True, | ||
order=False, | ||
unsafe_hash=False, | ||
frozen=False, # mutable | ||
match_args=True, | ||
kw_only=False, | ||
slots=False, | ||
weakref_slot=False) | ||
class AbstractAgent(ABC): | ||
"""Abstract agent with planning, reasoning & informational resources.""" | ||
|
||
planner: AbstractPlanner | ||
reasoner: AbstractReasoner = field(default_factory=BaseReasoner) | ||
resources: set[AbstractResource] = field(default_factory=set, | ||
init=True, | ||
repr=True, | ||
hash=False, # mutable | ||
compare=True, | ||
metadata=None, | ||
kw_only=True) | ||
|
||
@property | ||
def resource_overviews(self) -> dict[str, str]: | ||
return {r.unique_name: r.overview for r in self.resources} | ||
|
||
def solve(self, problem: str, plan: AbstractPlan | None = None) -> str: | ||
"""Solve problem, with an automatically generated plan (default) or explicitly specified plan.""" | ||
plan: AbstractPlan = (self.planner.update_plan_resources(plan, resources=self.resources) | ||
if plan | ||
else self.planner.plan(problem, resources=self.resources)) | ||
pprint(plan) | ||
|
||
return plan.execute(reasoner=self.reasoner) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
"""Agent with planning, reasoning & informational resources.""" | ||
|
||
|
||
from dataclasses import dataclass, field | ||
|
||
from openssa.l2.planning.abstract import AbstractPlanner | ||
from openssa.l2.planning.hierarchical import AutoHTPlanner | ||
|
||
from .abstract import AbstractAgent | ||
|
||
|
||
@dataclass(init=True, | ||
repr=True, | ||
eq=True, | ||
order=False, | ||
unsafe_hash=False, | ||
frozen=False, # mutable | ||
match_args=True, | ||
kw_only=False, | ||
slots=False, | ||
weakref_slot=False) | ||
class Agent(AbstractAgent): | ||
"""Agent with planning, reasoning & informational resources.""" | ||
|
||
planner: AbstractPlanner = field(default_factory=AutoHTPlanner) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
"""Abstract fact.""" | ||
|
||
|
||
from abc import ABC | ||
|
||
|
||
class AbstractFact(ABC): # noqa: B024 | ||
"""Abstract fact.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
"""Abstract heuristic.""" | ||
|
||
|
||
from abc import ABC | ||
|
||
|
||
class AbstractHeuristic(ABC): # noqa: B024 | ||
"""Abstract heuristic.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
"""Abstract inference rule.""" | ||
|
||
|
||
from abc import ABC | ||
|
||
|
||
class AbstractInferenceRule(ABC): # noqa: B024 | ||
"""Abstract inference rule.""" |
Oops, something went wrong.