Skip to content

Commit

Permalink
Fix GPT4All bug w/ "n_ctx" param (#7093)
Browse files Browse the repository at this point in the history
Running `GPT4All` per the
[docs](https://python.langchain.com/docs/modules/model_io/models/llms/integrations/gpt4all),
I see:

```
$ from langchain.llms import GPT4All
$ model = GPT4All(model=local_path)
$ model("The capital of France is ", max_tokens=10)
TypeError: generate() got an unexpected keyword argument 'n_ctx'
```

It appears `n_ctx` is [no longer a supported
param](https://docs.gpt4all.io/gpt4all_python.html#gpt4all.gpt4all.GPT4All.generate)
in the GPT4All API from nomic-ai/gpt4all#1090.

It now uses `max_tokens`, so I set this.

And I also set other defaults used in GPT4All client
[here](https://github.com/nomic-ai/gpt4all/blob/main/gpt4all-bindings/python/gpt4all/gpt4all.py).

Confirm it now works:
```
$ from langchain.llms import GPT4All
$ model = GPT4All(model=local_path)
$ model("The capital of France is ", max_tokens=10)
< Model logging > 
"....Paris."
```

---------

Co-authored-by: R. Lance Martin <rlm@Rs-MacBook-Pro.local>
  • Loading branch information
rlancemartin and R. Lance Martin authored Jul 4, 2023
1 parent 6631fd5 commit 265c285
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 37 deletions.
41 changes: 18 additions & 23 deletions docs/extras/modules/model_io/models/llms/integrations/gpt4all.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {
"tags": []
},
Expand All @@ -45,7 +45,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {
"tags": []
},
Expand All @@ -64,13 +64,20 @@
"source": [
"### Specify Model\n",
"\n",
"To run locally, download a compatible ggml-formatted model. For more info, visit https://github.com/nomic-ai/gpt4all\n",
"To run locally, download a compatible ggml-formatted model. \n",
" \n",
"**Download option 1**: The [gpt4all page](https://gpt4all.io/index.html) has a useful `Model Explorer` section:\n",
"\n",
"For full installation instructions go [here](https://gpt4all.io/index.html).\n",
"* Select a model of interest\n",
"* Download using the UI and move the `.bin` to the `local_path` (noted below)\n",
"\n",
"The GPT4All Chat installer needs to decompress a 3GB LLM model during the installation process!\n",
"For more info, visit https://github.com/nomic-ai/gpt4all.\n",
"\n",
"Note that new models are uploaded regularly - check the link above for the most recent `.bin` URL"
"--- \n",
"\n",
"**Download option 2**: Uncomment the below block to download a model. \n",
"\n",
"* You may want to update `url` to a new version, whih can be browsed using the [gpt4all page](https://gpt4all.io/index.html)."
]
},
{
Expand All @@ -81,22 +88,8 @@
"source": [
"local_path = (\n",
" \"./models/ggml-gpt4all-l13b-snoozy.bin\" # replace with your desired local file path\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Uncomment the below block to download a model. You may want to update `url` to a new version."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
")\n",
"\n",
"# import requests\n",
"\n",
"# from pathlib import Path\n",
Expand Down Expand Up @@ -126,8 +119,10 @@
"source": [
"# Callbacks support token-wise streaming\n",
"callbacks = [StreamingStdOutCallbackHandler()]\n",
"\n",
"# Verbose is required to pass to the callback manager\n",
"llm = GPT4All(model=local_path, callbacks=callbacks, verbose=True)\n",
"\n",
"# If you want to use a custom model add the backend parameter\n",
"# Check https://docs.gpt4all.io/gpt4all_python.html for supported backends\n",
"llm = GPT4All(model=local_path, backend=\"gptj\", callbacks=callbacks, verbose=True)"
Expand Down Expand Up @@ -170,7 +165,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
"version": "3.9.16"
}
},
"nbformat": 4,
Expand Down
22 changes: 8 additions & 14 deletions langchain/llms/gpt4all.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class GPT4All(LLM):
.. code-block:: python
from langchain.llms import GPT4All
model = GPT4All(model="./models/gpt4all-model.bin", n_ctx=512, n_threads=8)
model = GPT4All(model="./models/gpt4all-model.bin", n_threads=8)
# Simplest invocation
response = model("Once upon a time, ")
Expand All @@ -30,7 +30,7 @@ class GPT4All(LLM):

backend: Optional[str] = Field(None, alias="backend")

n_ctx: int = Field(512, alias="n_ctx")
max_tokens: int = Field(200, alias="max_tokens")
"""Token context window."""

n_parts: int = Field(-1, alias="n_parts")
Expand Down Expand Up @@ -61,10 +61,10 @@ class GPT4All(LLM):
n_predict: Optional[int] = 256
"""The maximum number of tokens to generate."""

temp: Optional[float] = 0.8
temp: Optional[float] = 0.7
"""The temperature to use for sampling."""

top_p: Optional[float] = 0.95
top_p: Optional[float] = 0.1
"""The top-p value to use for sampling."""

top_k: Optional[int] = 40
Expand All @@ -79,19 +79,15 @@ class GPT4All(LLM):
repeat_last_n: Optional[int] = 64
"Last n tokens to penalize"

repeat_penalty: Optional[float] = 1.3
repeat_penalty: Optional[float] = 1.18
"""The penalty to apply to repeated tokens."""

n_batch: int = Field(1, alias="n_batch")
n_batch: int = Field(8, alias="n_batch")
"""Batch size for prompt processing."""

streaming: bool = False
"""Whether to stream the results or not."""

context_erase: float = 0.5
"""Leave (n_ctx * context_erase) tokens
starting from beginning if the context has run out."""

allow_download: bool = False
"""If model does not exist in ~/.cache/gpt4all/, download it."""

Expand All @@ -105,28 +101,26 @@ class Config:
@staticmethod
def _model_param_names() -> Set[str]:
return {
"n_ctx",
"max_tokens",
"n_predict",
"top_k",
"top_p",
"temp",
"n_batch",
"repeat_penalty",
"repeat_last_n",
"context_erase",
}

def _default_params(self) -> Dict[str, Any]:
return {
"n_ctx": self.n_ctx,
"max_tokens": self.max_tokens,
"n_predict": self.n_predict,
"top_k": self.top_k,
"top_p": self.top_p,
"temp": self.temp,
"n_batch": self.n_batch,
"repeat_penalty": self.repeat_penalty,
"repeat_last_n": self.repeat_last_n,
"context_erase": self.context_erase,
}

@root_validator()
Expand Down

0 comments on commit 265c285

Please sign in to comment.