Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PuterBot-26: Add Guardrails validation to get_completion method in DemoReplayStrategy #33

Closed

Conversation

Mustaballer
Copy link
Collaborator

@Mustaballer Mustaballer commented Apr 23, 2023

Overview

<output>
    <object name="bank_run" format="length: 2">
        <string
            name="explanation"
            description="A paragraph about what a bank run is."
            format="length: 200 280"
            on-fail-length="reask"
        />
        <url
            name="follow_up_url"
            description="A web URL where I can read more about bank runs."
            format="valid-url"
            on-fail-valid-url="filter"
        />
    </object>
</output>

<prompt>
Explain what a bank run is in a tweet.

@xml_prefix_prompt

{output_schema}

@json_suffix_prompt_v2_wo_none
</prompt>
</rail>

Unit Tests Created

  • will do after review & iteration

Steps to QA

  • Test to see if current test cases pass, and look for print statements for Success and Fail if end to end testing

Questions

  • What format/structure do we want the output to be after getting the raw LLM output and feeding it to the guardrail?
  • Do we want a json output, with any keys or values? Right now it isn't useful since its use a sample RAIL configuration.

@atomicrichard
Copy link

Thank you for getting the ball rolling on this @Mustaballer! 🙏

Ultimately we want to be able to take the output and convert it into a sequence of InputEvents to be replayed. The most direct approach I can think of is to generate a dict that can be fed into an InputEvent constructor, e.g:

llm_completion = ...
input_event_dict = eval(llm_completion)
new_input_event = InputEvent(**input_event_dict)
play_input_event(input_event)
...

What do you think? Please let me know if anything is unclear!

@Mustaballer
Copy link
Collaborator Author

Mustaballer commented Apr 25, 2023

@abrichr To clarify, we want to convert the llm output(llm_completion) which is a string into a single InputEvent object(which can contain child InputEvents like a tree data structure in image below) using the guardrails, and return the parent InputEvent which has links to subsequent InputEvents. Also, I'm wondering what an example of the llm_completion text would look like for our purposes, just to have a clearer idea on how to parse and organize the RAIL configuration.

Thanks :)

image

@abrichr
Copy link
Member

abrichr commented Apr 25, 2023

Great questions!

To clarify, we want to convert the llm output(llm_completion) which is a string into a single InputEvent object(which can contain child InputEvents like a tree data structure in image below) using the guardrails, and return the parent InputEvent which has links to subsequent InputEvents.

I think ideally we would not want it to contain children, since the children are mostly just there for historical and debugging purposes. Ultimately what we want is a sequence of InputEvents. Whether we get it as a result of a single completion or multiple is unclear. I think we should test both approaches. Perhaps the number of InputEvents per completion should be a configurable parameter.

Also, I'm wondering what an example of the llm_completion text would look like for our purposes, just to have a clearer idea on how to parse and organize the RAIL configuration.

This is up for debate, but I think a useful starting point (and more robust than my earlier suggestion) is something like a json literal that can be parsed into a Python dict, which can then be passed into the InputEvent constructor as keyword arguments. For example:

prompt = self.get_prompt()  # TODO
completion = self.get_completion(prompt)  # e.g. "{'name': 'move', 'mouse_x': '1.23, 'mouse_y': '4.56', 'timestamp': '7.89', ...}"
event_dict = json.loads(completion)
input_event = InputEvent(**event_dict)
play_input_event(input_event)

(It's not clear to me whether we want to include the timestamp in the completion since we will likely need to override it.)

What do you think?

@Mustaballer
Copy link
Collaborator Author

Mustaballer commented Apr 25, 2023

@abrichr Thanks for explaining that! So, if I understand correctly, the intention of using guardrails is to organize the raw LLM output into a JSON structure that represents InputEvents. That makes sense to me. Therefore, I plan to use guardrails to create the JSON structure from the raw LLM output, and then parse the JSON into a dict using the method you provided.

Please let me know if this is what you had in mind and if you have any feedback or suggestions. Thank you!

@abrichr
Copy link
Member

abrichr commented Apr 26, 2023 via email

@abrichr abrichr marked this pull request as draft April 29, 2023 12:45
…d custom parser for llm output, and play input events
@Mustaballer
Copy link
Collaborator Author

I've been having some trouble using guardrails to structure the llm output. After looking at the logs, I realized that the ocr_text for the Screenshot is quite large and unreadable, and it includes some unrelated text. This could be the reason why I haven't been seeing any changes to the llm output. In the meantime, I created a custom parser function that converts the text into InputEvents based on the output from the prompt(which I changed to be more relevant), and it was successful in playing the input events.
Input_events

My next steps include writing unit tests for this function. I understand that the purpose of gisting is to compress the prompt. In this case, the screenshot text would get the most relevant info. Is it also important to compress the self.recording_input_events? Finally, my goal is to get the guardrails working to produce more structured output.

@abrichr
Copy link
Member

abrichr commented May 1, 2023

Excellent work @Mustaballer !

My next steps include writing unit tests for this function.

Please do! What should we test?

I understand that the purpose of gisting is to compress the prompt. In this case, the screenshot text would get the most relevant info. Is it also important to compress the self.recording_input_events?

Exactly right.

Finally, my goal is to get the guardrails working to produce more structured output.

Can you please clarify? What is next here?

What do you think about this as an MVP:

# run smoke test (a single test that simply instantiates the data structures and runs the critical paths (without necessarily checking the output)
python -m tests.puterbot.guardails

@abrichr
Copy link
Member

abrichr commented May 1, 2023

Also since this is such a great example, let's add step by step instructions for checking out and testing, and generating related screenshots (and code for reproducing them*) to the PR description and/or README.md, e.g.

git remote add Mustaballer https://github.com/Mustaballer/puterbot.git
git fetch Mustaballer
git checkout puterbot-16-guardrails
pip install -r requirements.txt
python -m puterbot.tests.<guardrails>  # run smoke test
python -m puterbot.validation.<guardrails>  # generate PR description output

*Ideally a single line e.g. python -m puterbot.validation.<module> that generates text output and/or graphs that are used to generate the PR.

@abrichr
Copy link
Member

abrichr commented May 1, 2023

@Mustaballer upon attempting to run pip install -r requirements.txt I get the following error:

ERROR: Could not find a version that satisfies the requirement pywin32==306 (from versions: none)
ERROR: No matching distribution found for pywin32==306

git diff shows:

diff --git a/requirements.txt b/requirements.txt
index de64780..84b56ea 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -12,7 +12,7 @@ mss==6.1.0
 pandas==2.0.0
 pygetwindow==0.0.9
 pyinstaller
-pywin32==306; sys_platform == 'win32'
+pywin32==306
 git+https://github.com/abrichr/pynput.git
 pytest==7.1.3
 rapidocr-onnxruntime==1.2.3
@@ -22,3 +22,4 @@ sqlalchemy==1.4.43
 torch==2.0.0
 tqdm==4.64.0
 transformers==4.28.1
+guardrails-ai

Can you please rebase your changes onto the latest changes in main:

git checkout main
git pull
git checkout puterbot-16-guardrails
git checkout -b puterbot-16-guardrails-backup  # just in case if you're not sure
git checkout puterbot-16-guardrails
git rebase main
git push

Thank you! 🙏

@Mustaballer
Copy link
Collaborator Author

@abrichr Thanks for the feedback. I will rebase my changes in main :) rn

@abrichr
Copy link
Member

abrichr commented May 1, 2023

Thank you @Mustaballer !

This is what I get when I check out your latest changes and run python -m puterbot.replay DemoReplayStrategy:

...
ress Key.tab, ', 'release Key.tab, ', 'press Key.tab, ', 'release Key.tab, ', 'press Key.tab, ', 'release Key.tab, ', 'release Key.cmd, ', 'release <0>, ', 'press Key.ctrl, ', 'press c, ']\n\nUsing the previously recorded input events, generate a sequence of input events to complete the task as a list: "
2023-05-01 15:13:59.200 | WARNING  | puterbot.strategies.llm_mixin:get_completion:44 - Truncating from len(prompt)=124946 to max_input_size=1024
Token indices sequence length is longer than the specified maximum sequence length for this model (47406 > 1024). Running this sequence through the model will result in indexing errors
Traceback (most recent call last):
  File "/usr/local/Cellar/python@3.10/3.10.11/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/local/Cellar/python@3.10/3.10.11/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/Users/abrichr/MLDSAI/src/puterbot/puterbot/replay.py", line 43, in <module>
    fire.Fire(replay)
  File "/Users/abrichr/MLDSAI/src/puterbot/.venv/lib/python3.10/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/Users/abrichr/MLDSAI/src/puterbot/.venv/lib/python3.10/site-packages/fire/core.py", line 466, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/Users/abrichr/MLDSAI/src/puterbot/.venv/lib/python3.10/site-packages/fire/core.py", line 681, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/Users/abrichr/MLDSAI/src/puterbot/puterbot/replay.py", line 39, in replay
    strategy.run()
  File "/Users/abrichr/MLDSAI/src/puterbot/puterbot/strategies/base.py", line 47, in run
    input_event = self.get_next_input_event(screenshot)
  File "/Users/abrichr/MLDSAI/src/puterbot/puterbot/strategies/demo.py", line 194, in get_next_input_event
    completion = self.get_completion(self.prompt, max_tokens)
  File "/Users/abrichr/MLDSAI/src/puterbot/puterbot/strategies/llm_mixin.py", line 52, in get_completion
    output_tokens = self.model.generate(
  File "/Users/abrichr/MLDSAI/src/puterbot/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/Users/abrichr/MLDSAI/src/puterbot/.venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 1437, in generate
    return self.greedy_search(
  File "/Users/abrichr/MLDSAI/src/puterbot/.venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 2248, in greedy_search
    outputs = self(
  File "/Users/abrichr/MLDSAI/src/puterbot/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/abrichr/MLDSAI/src/puterbot/.venv/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1075, in forward
    transformer_outputs = self.transformer(
  File "/Users/abrichr/MLDSAI/src/puterbot/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/abrichr/MLDSAI/src/puterbot/.venv/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 843, in forward
    position_embeds = self.wpe(position_ids)
  File "/Users/abrichr/MLDSAI/src/puterbot/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/abrichr/MLDSAI/src/puterbot/.venv/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 162, in forward
    return F.embedding(
  File "/Users/abrichr/MLDSAI/src/puterbot/.venv/lib/python3.10/site-packages/torch/nn/functional.py", line 2210, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
IndexError: index out of range in self

Any ideas?

@Mustaballer
Copy link
Collaborator Author

Would you mind trying to run the code again, but this time with a higher value for max_tokens? For instance, you could set it to 300.
max_tokens = 300 # for example
It appears that the problem lies in the fact that we're attempting to encode a sequence that is larger than the maximum sequence length that the model can accommodate. To address this, we can either truncate the sequence or modify the max_length parameter of the tokenizer. In more complex tasks, summarization or using gisting can be beneficial, as you previously mentioned. :)

@abrichr
Copy link
Member

abrichr commented May 2, 2023

@Mustaballer thank you for the suggestion!

Please define all constants at the top of the file in UPPER_SNAKE_CASE, e.g.:

MAX_TOKENS = 300
RAIL_STR = """
...
"""
...

@abrichr
Copy link
Member

abrichr commented May 2, 2023

@Mustaballer I implemented the following changes and am still getting the same error:

MAX_TOKENS = 300 
...
        max_tokens = MAX_TOKENS
...
        max_tokens = MAX_TOKENS
...

@abrichr
Copy link
Member

abrichr commented May 2, 2023

@Mustaballer thank you for the great work! Please grab some time on my calendar at your earliest convenience on Weds/Thurs this week: https://www.getclockwise.com/c/richard-abrich/quick-meeting

@abrichr
Copy link
Member

abrichr commented May 2, 2023

@Mustaballer I had a bug in my code that I had to fix. Here's thegit diff:

diff --git a/puterbot/strategies/llm_mixin.py b/puterbot/strategies/llm_mixin.py
index d88864e..4e6105a 100644
--- a/puterbot/strategies/llm_mixin.py
+++ b/puterbot/strategies/llm_mixin.py
@@ -44,7 +44,11 @@ class LLMReplayStrategyMixin(BaseReplayStrategy):
             logger.warning(
                 f"Truncating from {len(prompt)=} to {max_input_size=}"
             )
-            prompt = prompt[max_input_size:]
+            prompt = prompt[-max_input_size:]
+            logger.warning(
+                f"Truncated {len(prompt)=}"
+            )
+
         logger.debug(f"{prompt=} {max_tokens=}")
         input_tokens = self.tokenizer(prompt, return_tensors="pt")
         pad_token_id = self.tokenizer.eos_token_id

python -m puterbot.replay DemoReplayStrategy:

2023-05-02 02:36:49.530 | WARNING  | puterbot.strategies.llm_mixin:get_completion:44 - Truncating from len(prompt)=124946 to max_input_size=1024
2023-05-02 02:36:49.530 | WARNING  | puterbot.strategies.llm_mixin:get_completion:48 - Truncated len(prompt)=1024
2023-05-02 02:36:52.325 | INFO     | puterbot.strategies.demo:get_next_input_event:195 - completion="~~~\n\n{ 'key' : 'key', 'value' : 'key', 'value' : 'value', 'value' : 'value', 'value' : 'value', 'value' : 'value', 'value' : 'value', 'value' : 'value', 'value' : 'value', 'value' : 'value', 'value' : 'value', 'value' : 'value', 'value' : 'value', 'value' :"
Traceback (most recent call last):
  File "/usr/local/Cellar/python@3.10/3.10.11/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/local/Cellar/python@3.10/3.10.11/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/Users/abrichr/MLDSAI/src/puterbot/puterbot/replay.py", line 43, in <module>
    fire.Fire(replay)
  File "/Users/abrichr/MLDSAI/src/puterbot/.venv/lib/python3.10/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/Users/abrichr/MLDSAI/src/puterbot/.venv/lib/python3.10/site-packages/fire/core.py", line 466, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/Users/abrichr/MLDSAI/src/puterbot/.venv/lib/python3.10/site-packages/fire/core.py", line 681, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/Users/abrichr/MLDSAI/src/puterbot/puterbot/replay.py", line 39, in replay
    strategy.run()
  File "/Users/abrichr/MLDSAI/src/puterbot/puterbot/strategies/base.py", line 47, in run
    input_event = self.get_next_input_event(screenshot)
  File "/Users/abrichr/MLDSAI/src/puterbot/puterbot/strategies/demo.py", line 197, in get_next_input_event
    parsed_events = self.parse_input_event(completion)
  File "/Users/abrichr/MLDSAI/src/puterbot/puterbot/strategies/demo.py", line 99, in parse_input_event
    s = completion[1].replace(", ", "").replace(",", "").replace("'", "")
IndexError: list index out of range

Have you gotten this to run? Can you please provide the output of git status? Thank you! 🙏

@Mustaballer
Copy link
Collaborator Author

This is what I get when I do git status:
image

The completion that is generated starts with "~~~\n\n{ 'key' : 'key', 'value' : 'key', 'value' : 'value', 'value, ..."
whereas on my end it shows this as the completion.
image

I think one of the reasons could be because of the fact that the prompt itself is getting truncated a lot, where it might not get to the actual instruction, which is
image

To address this issue, I plan to improve the parser function to handle various string inputs and avoid throwing errors. Additionally, I will investigate gisting to preserve context data.

@abrichr
Copy link
Member

abrichr commented May 2, 2023

Hi @Mustaballer , thanks for the info.

What would be required for me to reproduce your results? I think the only thing missing is puterbot.db. There is an open issue here to support this: #46 (comment)

Can you please try running the linked code as the sender? You can send me the wormhole code via LinkedIn or another side channel. Please feel free to include this code (or a modified version) in your PR as well.

(Also please feel free to pick some time today after 2pm for our meeting if that works better for you! 🙏 )

@abrichr
Copy link
Member

abrichr commented May 28, 2023

@Mustaballer any update on this?

@Mustaballer
Copy link
Collaborator Author

I had temporarily paused work on this project, but I now plan to conduct additional testing with guardrails. My focus will be on enforcing strict standards for the language model completions. While I will be using the current models initially, I believe there is potential for guardrails to be employed in other replay strategies as well. I will provide a retrospective update on my progress and determine whether it is appropriate to close the project based on the results.

@Mustaballer
Copy link
Collaborator Author

Closing the pull request because I was unable to achieve the desired outcome while utilizing the "guardrails" library for rigorous parsing of Action Events.

@Mustaballer Mustaballer closed this Jun 5, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement Guardrails
3 participants