Skip to content

Commit

Permalink
feat: Add tool calling support in ChatOllama
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz committed Jul 24, 2024
1 parent 1ffdb41 commit a6a9202
Show file tree
Hide file tree
Showing 8 changed files with 343 additions and 85 deletions.
3 changes: 2 additions & 1 deletion docs/modules/model_io/models/chat_models/how_to/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
> Tool calling is currently supported by:
> - [`ChatAnthropic`](/modules/model_io/models/chat_models/integrations/anthropic.md)
> - [`ChatOpenAI`](/modules/model_io/models/chat_models/integrations/openai.md)
> - [`ChatFirebaseVertexAI`](/modules/model_io/models/chat_models/integrations/firebase_vertex_ai.md)
> - [`ChatGoogleGenerativeAI`](/modules/model_io/models/chat_models/integrations/googleai.md)
> - [`ChatOllama`](/modules/model_io/models/chat_models/integrations/ollama.md)
> - [`ChatOpenAI`](/modules/model_io/models/chat_models/integrations/openai.md)
Tool calling allows a model to respond to a given prompt by generating output that matches a user-defined schema. While the name implies that the model is performing some action, this is actually not the case! The model is coming up with the arguments to a tool, and actually running the tool (or not) is up to the user - for example, if you want to extract output matching some schema from unstructured text, you could give the model an “extraction” tool that takes parameters matching the desired schema, then treat the generated output as your final result.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ await stream.forEach(print);

`ChatAnthropic` supports tool calling.

Check the [docs](https://langchaindart.dev/#/modules/model_io/models/chat_models/how_to/tools) for more information on how to use tools.
Check the [docs](/modules/model_io/models/chat_models/how_to/tools.md) for more information on how to use tools.

Example:
```dart
Expand All @@ -124,7 +124,7 @@ const tool = ToolSpec(
'properties': {
'location': {
'type': 'string',
'description': 'The city and state, e.g. San Francisco, CA',
'description': 'The city and country, e.g. San Francisco, US',
},
},
'required': ['location'],
Expand Down
51 changes: 51 additions & 0 deletions docs/modules/model_io/models/chat_models/integrations/ollama.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,57 @@ print(res.output.content);
// -> 'An Apple'
```

## Tool calling

`ChatOllama` supports tool calling.

Check the [docs](/modules/model_io/models/chat_models/how_to/tools.md) for more information on how to use tools.

**Notes:**
- Tool calling requires Ollama 0.2.8 or newer.
- Streaming tool calls is not supported at the moment.
- Not all models support tool calls. Check the Ollama catalogue for models that have the `Tools` tag (e.g. [`llama3.1`](https://ollama.com/library/llama3.1)).

```dart
const tool = ToolSpec(
name: 'get_current_weather',
description: 'Get the current weather in a given location',
inputJsonSchema: {
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description': 'The city and country, e.g. San Francisco, US',
},
},
'required': ['location'],
},
);
final chatModel = ChatOllama(
defaultOptions: ChatOllamaOptions(
model: 'llama3.1',
temperature: 0,
tools: [tool],
),
);
final res = await chatModel.invoke(
PromptValue.string('What’s the weather like in Boston and Madrid right now in celsius?'),
);
print(res.output.toolCalls);
// [AIChatMessageToolCall{
// id: a621064b-03b3-4ca6-8278-f37504901034,
// name: get_current_weather,
// arguments: {location: Boston, US},
// },
// AIChatMessageToolCall{
// id: f160d9ba-ae7d-4abc-a910-2b6cd503ec53,
// name: get_current_weather,
// arguments: {location: Madrid, ES},
// }]
```

## RAG (Retrieval-Augmented Generation) pipeline

We can easily create a fully local RAG pipeline using `OllamaEmbeddings` and `ChatOllama`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ void main(final List<String> arguments) async {
await _chatOllamaStreaming();
await _chatOllamaJsonMode();
await _chatOllamaMultimodal();
await _chatOllamaToolCalling();
await _rag();
}

Expand Down Expand Up @@ -94,6 +95,47 @@ Future<void> _chatOllamaJsonMode() async {
// {Spain: 46735727, The Netherlands: 17398435, France: 65273538}
}

Future<void> _chatOllamaToolCalling() async {
const tool = ToolSpec(
name: 'get_current_weather',
description: 'Get the current weather in a given location',
inputJsonSchema: {
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description': 'The city and country, e.g. San Francisco, US',
},
},
'required': ['location'],
},
);

final chatModel = ChatOllama(
defaultOptions: const ChatOllamaOptions(
model: 'llama3.1',
temperature: 0,
tools: [tool],
),
);

final res = await chatModel.invoke(
PromptValue.string(
'What’s the weather like in Boston and Madrid right now in celsius?'),
);
print(res.output.toolCalls);
// [AIChatMessageToolCall{
// id: a621064b-03b3-4ca6-8278-f37504901034,
// name: get_current_weather,
// arguments: {location: Boston, US},
// },
// AIChatMessageToolCall{
// id: f160d9ba-ae7d-4abc-a910-2b6cd503ec53,
// name: get_current_weather,
// arguments: {location: Madrid, ES},
// }]
}

Future<void> _chatOllamaMultimodal() async {
final chatModel = ChatOllama(
defaultOptions: const ChatOllamaOptions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@ import 'package:langchain_tiktoken/langchain_tiktoken.dart';
import 'package:ollama_dart/ollama_dart.dart';
import 'package:uuid/uuid.dart';

import '../../llms/mappers.dart';
import 'mappers.dart';
import 'types.dart';

/// Wrapper around [Ollama](https://ollama.ai) Chat API that enables
/// to interact with the LLMs in a chat-like fashion.
///
/// Ollama allows you to run open-source large language models,
/// such as Llama 3 or LLaVA, locally.
/// such as Llama 3.1, Gemma 2 or LLaVA, locally.
///
/// For a complete list of supported models and model variants, see the
/// [Ollama model library](https://ollama.ai/library).
Expand All @@ -37,7 +36,7 @@ import 'types.dart';
///
/// 1. Download and install [Ollama](https://ollama.ai)
/// 2. Fetch a model via `ollama pull <model family>`
/// * e.g., for Llama 3: `ollama pull llama3`
/// * e.g., for Llama 3: `ollama pull llama3.1`
///
/// ### Ollama base URL
///
Expand Down Expand Up @@ -188,9 +187,10 @@ class ChatOllama extends BaseChatModel<ChatOllamaOptions> {
}) async {
final id = _uuid.v4();
final completion = await _client.generateChatCompletion(
request: _generateCompletionRequest(
request: generateChatCompletionRequest(
input.toChatMessages(),
options: options,
defaultOptions: defaultOptions,
),
);
return completion.toChatResult(id);
Expand All @@ -204,65 +204,18 @@ class ChatOllama extends BaseChatModel<ChatOllamaOptions> {
final id = _uuid.v4();
return _client
.generateChatCompletionStream(
request: _generateCompletionRequest(
request: generateChatCompletionRequest(
input.toChatMessages(),
options: options,
defaultOptions: defaultOptions,
stream: true,
),
)
.map(
(final completion) => completion.toChatResult(id, streaming: true),
);
}

/// Creates a [GenerateChatCompletionRequest] from the given input.
GenerateChatCompletionRequest _generateCompletionRequest(
final List<ChatMessage> messages, {
final bool stream = false,
final ChatOllamaOptions? options,
}) {
return GenerateChatCompletionRequest(
model: options?.model ?? defaultOptions.model ?? defaultModel,
messages: messages.toMessages(),
format: (options?.format ?? defaultOptions.format)?.toResponseFormat(),
keepAlive: options?.keepAlive ?? defaultOptions.keepAlive,
stream: stream,
options: RequestOptions(
numKeep: options?.numKeep ?? defaultOptions.numKeep,
seed: options?.seed ?? defaultOptions.seed,
numPredict: options?.numPredict ?? defaultOptions.numPredict,
topK: options?.topK ?? defaultOptions.topK,
topP: options?.topP ?? defaultOptions.topP,
tfsZ: options?.tfsZ ?? defaultOptions.tfsZ,
typicalP: options?.typicalP ?? defaultOptions.typicalP,
repeatLastN: options?.repeatLastN ?? defaultOptions.repeatLastN,
temperature: options?.temperature ?? defaultOptions.temperature,
repeatPenalty: options?.repeatPenalty ?? defaultOptions.repeatPenalty,
presencePenalty:
options?.presencePenalty ?? defaultOptions.presencePenalty,
frequencyPenalty:
options?.frequencyPenalty ?? defaultOptions.frequencyPenalty,
mirostat: options?.mirostat ?? defaultOptions.mirostat,
mirostatTau: options?.mirostatTau ?? defaultOptions.mirostatTau,
mirostatEta: options?.mirostatEta ?? defaultOptions.mirostatEta,
penalizeNewline:
options?.penalizeNewline ?? defaultOptions.penalizeNewline,
stop: options?.stop ?? defaultOptions.stop,
numa: options?.numa ?? defaultOptions.numa,
numCtx: options?.numCtx ?? defaultOptions.numCtx,
numBatch: options?.numBatch ?? defaultOptions.numBatch,
numGpu: options?.numGpu ?? defaultOptions.numGpu,
mainGpu: options?.mainGpu ?? defaultOptions.mainGpu,
lowVram: options?.lowVram ?? defaultOptions.lowVram,
f16Kv: options?.f16KV ?? defaultOptions.f16KV,
logitsAll: options?.logitsAll ?? defaultOptions.logitsAll,
vocabOnly: options?.vocabOnly ?? defaultOptions.vocabOnly,
useMmap: options?.useMmap ?? defaultOptions.useMmap,
useMlock: options?.useMlock ?? defaultOptions.useMlock,
numThread: options?.numThread ?? defaultOptions.numThread,
),
);
}

/// Tokenizes the given prompt using tiktoken.
///
/// Currently Ollama does not provide a tokenizer for the models it supports.
Expand Down
Loading

0 comments on commit a6a9202

Please sign in to comment.