Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create an ILMProvider interface and have our current implementation use it #17394

Merged
merged 23 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/actions/spelling/allow/allow.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ hyperlink
hyperlinking
hyperlinks
iconify
ILLM
ID
img
inlined
Expand All @@ -68,6 +69,7 @@ libuv
liga
lje
Llast
llm
llvm
Lmid
locl
Expand Down
194 changes: 194 additions & 0 deletions src/cascadia/QueryExtension/AzureLLMProvider.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include "pch.h"
#include "AzureLLMProvider.h"
#include "../../types/inc/utils.hpp"
#include "LibraryResources.h"

#include "AzureLLMProvider.g.cpp"
#include "AzureResponse.g.cpp"

using namespace winrt::Windows::Foundation;
using namespace winrt::Windows::Foundation::Collections;
using namespace winrt::Windows::UI::Core;
using namespace winrt::Windows::UI::Xaml;
using namespace winrt::Windows::UI::Xaml::Controls;
using namespace winrt::Windows::System;
namespace WWH = ::winrt::Windows::Web::Http;
namespace WSS = ::winrt::Windows::Storage::Streams;
namespace WDJ = ::winrt::Windows::Data::Json;

static constexpr std::wstring_view acceptedModel{ L"gpt-35-turbo" };
static constexpr std::wstring_view acceptedSeverityLevel{ L"safe" };

const std::wregex azureOpenAIEndpointRegex{ LR"(^https.*openai\.azure\.com)" };

namespace winrt::Microsoft::Terminal::Query::Extension::implementation
{
AzureLLMProvider::AzureLLMProvider(const winrt::hstring& endpoint, const winrt::hstring& key)
{
_AIEndpoint = endpoint;
_AIKey = key;
_httpClient = winrt::Windows::Web::Http::HttpClient{};
_httpClient.DefaultRequestHeaders().Accept().TryParseAdd(L"application/json");
_httpClient.DefaultRequestHeaders().Append(L"api-key", _AIKey);
}

void AzureLLMProvider::ClearMessageHistory()
{
_jsonMessages.Clear();
}

void AzureLLMProvider::SetSystemPrompt(const winrt::hstring& systemPrompt)
{
WDJ::JsonObject systemMessageObject;
winrt::hstring systemMessageContent{ systemPrompt };
systemMessageObject.Insert(L"role", WDJ::JsonValue::CreateStringValue(L"system"));
systemMessageObject.Insert(L"content", WDJ::JsonValue::CreateStringValue(systemMessageContent));
_jsonMessages.Append(systemMessageObject);
}

void AzureLLMProvider::SetContext(const Extension::IContext context)
{
_context = context;
}

winrt::Windows::Foundation::IAsyncOperation<Extension::IResponse> AzureLLMProvider::GetResponseAsync(const winrt::hstring& userPrompt)
{
// Use a flag for whether the response the user receives is an error message
// we pass this flag back to the caller so they can handle it appropriately (specifically, ExtensionPalette will send the correct telemetry event)
// there is only one case downstream from here that sets this flag to false, so start with it being true
bool isError{ true };
hstring message{};

// If the AI endpoint is not an azure open AI endpoint, return an error message
if (!std::regex_search(_AIEndpoint.c_str(), azureOpenAIEndpointRegex))
{
message = RS_(L"InvalidEndpointMessage");
}

// If we don't have a message string, that means the endpoint exists and matches the regex
// that we allow - now we can actually make the http request
if (message.empty())
{
// Make a copy of the prompt because we are switching threads
const auto promptCopy{ userPrompt };

// Make sure we are on the background thread for the http request
co_await winrt::resume_background();

WWH::HttpRequestMessage request{ WWH::HttpMethod::Post(), Uri{ _AIEndpoint } };
request.Headers().Accept().TryParseAdd(L"application/json");

WDJ::JsonObject jsonContent;
WDJ::JsonObject messageObject;

// _ActiveCommandline should be set already, we request for it the moment we become visible
winrt::hstring engineeredPrompt{ promptCopy };
if (_context && !_context.ActiveCommandline().empty())
{
engineeredPrompt = promptCopy + L". The shell I am running is " + _context.ActiveCommandline();
}
messageObject.Insert(L"role", WDJ::JsonValue::CreateStringValue(L"user"));
messageObject.Insert(L"content", WDJ::JsonValue::CreateStringValue(engineeredPrompt));
_jsonMessages.Append(messageObject);
jsonContent.SetNamedValue(L"messages", _jsonMessages);
jsonContent.SetNamedValue(L"max_tokens", WDJ::JsonValue::CreateNumberValue(800));
jsonContent.SetNamedValue(L"temperature", WDJ::JsonValue::CreateNumberValue(0.7));
jsonContent.SetNamedValue(L"frequency_penalty", WDJ::JsonValue::CreateNumberValue(0));
jsonContent.SetNamedValue(L"presence_penalty", WDJ::JsonValue::CreateNumberValue(0));
jsonContent.SetNamedValue(L"top_p", WDJ::JsonValue::CreateNumberValue(0.95));
jsonContent.SetNamedValue(L"stop", WDJ::JsonValue::CreateStringValue(L"None"));
const auto stringContent = jsonContent.ToString();
WWH::HttpStringContent requestContent{
stringContent,
WSS::UnicodeEncoding::Utf8,
L"application/json"
};

request.Content(requestContent);

// Send the request
try
{
const auto response = _httpClient.SendRequestAsync(request).get();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW all these get()s could use co_await in the future instead. That avoids the resume_background hassle.

// Parse out the suggestion from the response
const auto string{ response.Content().ReadAsStringAsync().get() };
const auto jsonResult{ WDJ::JsonObject::Parse(string) };
if (jsonResult.HasKey(L"error"))
{
const auto errorObject = jsonResult.GetNamedObject(L"error");
message = errorObject.GetNamedString(L"message");
}
else
{
if (_verifyModelIsValidHelper(jsonResult))
{
const auto choices = jsonResult.GetNamedArray(L"choices");
const auto firstChoice = choices.GetAt(0).GetObject();
const auto messageObject = firstChoice.GetNamedObject(L"message");
message = messageObject.GetNamedString(L"content");
isError = false;
}
else
{
message = RS_(L"InvalidModelMessage");
}
}
}
catch (...)
{
message = RS_(L"UnknownErrorMessage");
}
}

// Also make a new entry in our jsonMessages list, so the AI knows the full conversation so far
WDJ::JsonObject responseMessageObject;
responseMessageObject.Insert(L"role", WDJ::JsonValue::CreateStringValue(L"assistant"));
responseMessageObject.Insert(L"content", WDJ::JsonValue::CreateStringValue(message));
_jsonMessages.Append(responseMessageObject);

co_return winrt::make<AzureResponse>(message, isError);
}

bool AzureLLMProvider::_verifyModelIsValidHelper(const WDJ::JsonObject jsonResponse)
{
if (jsonResponse.GetNamedString(L"model") != acceptedModel)
{
return false;
}
WDJ::JsonObject contentFiltersObject;
// For some reason, sometimes the content filter results are in a key called "prompt_filter_results"
// and sometimes they are in a key called "prompt_annotations". Check for either.
if (jsonResponse.HasKey(L"prompt_filter_results"))
{
contentFiltersObject = jsonResponse.GetNamedArray(L"prompt_filter_results").GetObjectAt(0);
}
else if (jsonResponse.HasKey(L"prompt_annotations"))
{
contentFiltersObject = jsonResponse.GetNamedArray(L"prompt_annotations").GetObjectAt(0);
}
else
{
return false;
}
const auto contentFilters = contentFiltersObject.GetNamedObject(L"content_filter_results");
if (Feature_TerminalChatJailbreakFilter::IsEnabled() && !contentFilters.HasKey(L"jailbreak"))
{
return false;
}
for (const auto filterPair : contentFilters)
{
const auto filterLevel = filterPair.Value().GetObjectW();
if (filterLevel.HasKey(L"severity"))
{
if (filterLevel.GetNamedString(L"severity") != acceptedSeverityLevel)
{
return false;
}
}
}
return true;
}
}
51 changes: 51 additions & 0 deletions src/cascadia/QueryExtension/AzureLLMProvider.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#pragma once

#include "AzureLLMProvider.g.h"
#include "AzureResponse.g.h"

namespace winrt::Microsoft::Terminal::Query::Extension::implementation
{
struct AzureLLMProvider : AzureLLMProviderT<AzureLLMProvider>
{
AzureLLMProvider(const winrt::hstring& endpoint, const winrt::hstring& key);

void ClearMessageHistory();
void SetSystemPrompt(const winrt::hstring& systemPrompt);
void SetContext(const Extension::IContext context);

winrt::Windows::Foundation::IAsyncOperation<Extension::IResponse> GetResponseAsync(const winrt::hstring& userPrompt);

private:
winrt::hstring _AIEndpoint;
winrt::hstring _AIKey;
winrt::Windows::Web::Http::HttpClient _httpClient{ nullptr };

Extension::IContext _context;

winrt::Windows::Data::Json::JsonArray _jsonMessages;

bool _verifyModelIsValidHelper(const Windows::Data::Json::JsonObject jsonResponse);
};

struct AzureResponse : AzureResponseT<AzureResponse>
{
AzureResponse(const winrt::hstring& message, const bool isError) :
_message{ message },
_isError{ isError } {}
winrt::hstring Message() { return _message; };
bool IsError() { return _isError; };

private:
winrt::hstring _message;
bool _isError;
};
}

namespace winrt::Microsoft::Terminal::Query::Extension::factory_implementation
{
BASIC_FACTORY(AzureLLMProvider);
BASIC_FACTORY(AzureResponse);
}
17 changes: 17 additions & 0 deletions src/cascadia/QueryExtension/AzureLLMProvider.idl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

import "ILLMProvider.idl";

namespace Microsoft.Terminal.Query.Extension
{
[default_interface] runtimeclass AzureLLMProvider : ILLMProvider
{
AzureLLMProvider(String endpoint, String key);
}

[default_interface] runtimeclass AzureResponse : IResponse
{
AzureResponse(String message, Boolean isError);
}
}
Loading
Loading