diff --git a/src/Ollama/src/OllamaChatModel.cs b/src/Ollama/src/OllamaChatModel.cs index 1ad7741..3524c8b 100644 --- a/src/Ollama/src/OllamaChatModel.cs +++ b/src/Ollama/src/OllamaChatModel.cs @@ -26,20 +26,7 @@ public override async IAsyncEnumerable GenerateAsync( { request = request ?? throw new ArgumentNullException(nameof(request)); - try - { - var runningModels = await Provider.Api.Models.ListRunningModelsAsync(cancellationToken).ConfigureAwait(false); - if (runningModels.Models != null && - runningModels.Models.All(x => x.Model?.Contains(Id) != true)) - { - await Provider.Api.Models.PullModelAsync(Id, cancellationToken: cancellationToken) - .EnsureSuccessAsync().ConfigureAwait(false); - } - } - catch (HttpRequestException) - { - // Ignore - } + await Provider.PullModelIfRequiredAndAllowedAsync(Id, cancellationToken).ConfigureAwait(false); var usedSettings = OllamaChatSettings.Calculate( requestSettings: settings, diff --git a/src/Ollama/src/OllamaEmbeddingModel.cs b/src/Ollama/src/OllamaEmbeddingModel.cs index b2134eb..6109051 100644 --- a/src/Ollama/src/OllamaEmbeddingModel.cs +++ b/src/Ollama/src/OllamaEmbeddingModel.cs @@ -25,21 +25,7 @@ public async Task CreateEmbeddingsAsync( { request = request ?? throw new ArgumentNullException(nameof(request)); - try - { - var runningModels = - await Provider.Api.Models.ListRunningModelsAsync(cancellationToken).ConfigureAwait(false); - if (runningModels.Models != null && - runningModels.Models.All(x => x.Model?.Contains(Id) != true)) - { - await Provider.Api.Models.PullModelAsync(Id, cancellationToken: cancellationToken) - .EnsureSuccessAsync().ConfigureAwait(false); - } - } - catch (HttpRequestException) - { - // Ignore - } + await Provider.PullModelIfRequiredAndAllowedAsync(Id, cancellationToken).ConfigureAwait(false); var results = new List>(capacity: request.Strings.Count); foreach (var prompt in request.Strings) diff --git a/src/Ollama/src/OllamaProvider.cs b/src/Ollama/src/OllamaProvider.cs index 7cc8d35..c12c9de 100644 --- a/src/Ollama/src/OllamaProvider.cs +++ b/src/Ollama/src/OllamaProvider.cs @@ -17,9 +17,40 @@ public sealed class OllamaProvider( { Timeout = TimeSpan.FromHours(1), }, baseUri: new Uri(url)); + + /// + /// OllamaChatModel and OllamaEmbeddingModel will pull models automatically if this is true. + /// + public bool CanPullModelsAutomatically { get; set; } = true; public void Dispose() { Api.Dispose(); } + + public async Task PullModelIfRequiredAndAllowedAsync( + string id, + CancellationToken cancellationToken = default) + { + if (!CanPullModelsAutomatically) + { + return; + } + + try + { + // Pull the model if it is not running + var runningModels = await Api.Models.ListRunningModelsAsync(cancellationToken).ConfigureAwait(false); + if (runningModels.Models != null && + runningModels.Models.All(x => x.Model?.Contains(id) != true)) + { + await Api.Models.PullModelAsync(id, cancellationToken: cancellationToken) + .EnsureSuccessAsync().ConfigureAwait(false); + } + } + catch (HttpRequestException) + { + // Ignore + } + } } \ No newline at end of file