From c660a2d4ddd5f8dc4a85e3ea911c19b813095688 Mon Sep 17 00:00:00 2001 From: Xiao Ning Date: Wed, 5 Feb 2025 12:48:27 +0800 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor(ai):=20=E9=87=8D?= =?UTF-8?q?=E6=9E=84=E6=A8=A1=E5=9E=8B=E9=AA=8C=E8=AF=81=E5=92=8C=E9=80=89?= =?UTF-8?q?=E6=8B=A9=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 【重构】将模型验证和选择逻辑从命令类中抽离到独立模块 - 【新增】创建 modelValidation 工具类统一处理模型验证逻辑 - 【优化】简化命令类中的模型配置处理流程 - 【移动】迁移 CodeReviewReportGenerator 到 services 目录 - 【文档】新增 AI 模块文档说明其功能和使用方式 --- src/ai/providers/BaseOpenAIProvider.ts | 16 +-- src/ai/providers/VscodeProvider.ts | 2 +- src/commands/BaseCommand.ts | 86 ++++-------- src/commands/GenerateCommitCommand.ts | 104 +------------- src/commands/ReviewCodeCommand.ts | 11 +- .../CodeReviewReportGenerator.ts | 4 +- src/utils/ai/ai.md | 131 ++++++++++++++++++ src/utils/ai/index.ts | 1 + src/utils/ai/modelValidation.ts | 108 +++++++++++++++ src/utils/review/index.ts | 2 +- 10 files changed, 286 insertions(+), 179 deletions(-) rename src/{utils/review => services}/CodeReviewReportGenerator.ts (97%) create mode 100644 src/utils/ai/ai.md create mode 100644 src/utils/ai/index.ts create mode 100644 src/utils/ai/modelValidation.ts diff --git a/src/ai/providers/BaseOpenAIProvider.ts b/src/ai/providers/BaseOpenAIProvider.ts index 6182716..7b00ddc 100644 --- a/src/ai/providers/BaseOpenAIProvider.ts +++ b/src/ai/providers/BaseOpenAIProvider.ts @@ -15,7 +15,7 @@ import { } from "../utils/generateHelper"; import { getWeeklyReportPrompt } from "../../prompt/weeklyReport"; -import { CodeReviewReportGenerator } from "../../utils/review/CodeReviewReportGenerator"; +import { CodeReviewReportGenerator } from "../../services/CodeReviewReportGenerator"; import { formatMessage } from "../../utils/i18n/LocalizationManager"; /** @@ -193,10 +193,9 @@ export abstract class BaseOpenAIProvider implements AIProvider { }, }; } catch (error) { - const message = formatMessage( - "codeReview.generation.failed", - [error instanceof Error ? error.message : String(error)] - ); + const message = formatMessage("codeReview.generation.failed", [ + error instanceof Error ? error.message : String(error), + ]); throw new Error(message); } }, @@ -244,10 +243,9 @@ export abstract class BaseOpenAIProvider implements AIProvider { }; } catch (error) { throw new Error( - formatMessage( - "weeklyReport.generation.failed", - [error instanceof Error ? error.message : String(error)] - ) + formatMessage("weeklyReport.generation.failed", [ + error instanceof Error ? error.message : String(error), + ]) ); } } diff --git a/src/ai/providers/VscodeProvider.ts b/src/ai/providers/VscodeProvider.ts index 091f82b..5d0c866 100644 --- a/src/ai/providers/VscodeProvider.ts +++ b/src/ai/providers/VscodeProvider.ts @@ -11,7 +11,7 @@ import { generateCommitMessageSystemPrompt } from "../../prompt/prompt"; import { getCodeReviewPrompt, getSystemPrompt } from "../utils/generateHelper"; import { getWeeklyReportPrompt } from "../../prompt/weeklyReport"; import { getMessage, formatMessage } from "../../utils/i18n"; -import { CodeReviewReportGenerator } from "../../utils/review/CodeReviewReportGenerator"; +import { CodeReviewReportGenerator } from "../../services/CodeReviewReportGenerator"; interface DiffBlock { header: string; diff --git a/src/commands/BaseCommand.ts b/src/commands/BaseCommand.ts index a5846bc..c2c8faf 100644 --- a/src/commands/BaseCommand.ts +++ b/src/commands/BaseCommand.ts @@ -5,6 +5,7 @@ import { SCMFactory } from "../scm/SCMProvider"; import { ModelPickerService } from "../services/ModelPickerService"; import { notify } from "../utils/notification/NotificationManager"; import { getMessage, formatMessage } from "../utils/i18n"; +import { validateAndGetModel } from "../utils/ai/modelValidation"; /** * 基础命令类,提供通用的命令执行功能 @@ -65,81 +66,46 @@ export abstract class BaseCommand { let model = configuration.base.model; if (!provider || !model) { - return this.selectAndUpdateModelConfiguration(provider, model); + return this.selectAndUpdateModelConfiguration(provider, model, true); } return { provider, model }; } - /** - * 获取模型并更新配置 - * @param provider - AI提供商名称 - * @param model - 模型名称 - * @returns 更新后的提供商、模型和AI实例信息 - * @throws Error 当无法获取模型列表或找不到指定模型时 - */ - protected async getModelAndUpdateConfiguration( - provider = "Ollama", - model = "Ollama" - ) { - let aiProvider = AIProviderFactory.getProvider(provider); - let models = await aiProvider.getModels(); - - if (!models || models.length === 0) { - const { provider: newProvider, model: newModel } = - await this.selectAndUpdateModelConfiguration(provider, model); - provider = newProvider; - model = newModel; - - aiProvider = AIProviderFactory.getProvider(provider); - models = await aiProvider.getModels(); - - if (!models || models.length === 0) { - throw new Error(getMessage("model.list.empty")); - } - } - - let selectedModel = models.find((m) => m.name === model); - - if (!selectedModel) { - const { provider: newProvider, model: newModel } = - await this.selectAndUpdateModelConfiguration(provider, model); - provider = newProvider; - model = newModel; - - aiProvider = AIProviderFactory.getProvider(provider); - models = await aiProvider.getModels(); - selectedModel = models.find((m) => m.name === model); - - if (!selectedModel) { - throw new Error(getMessage("model.notFound")); - } - } - - return { provider, model, selectedModel, aiProvider }; - } - /** * 选择模型并更新配置 * @param provider - 当前AI提供商 * @param model - 当前模型名称 + * @param throwError - 是否抛出错误,默认为false * @returns 更新后的提供商和模型信息 */ protected async selectAndUpdateModelConfiguration( provider = "Ollama", - model = "Ollama" + model = "Ollama", + throwError = false ) { - const modelSelection = await this.showModelPicker(provider, model); - if (!modelSelection) { - return { provider, model }; + try { + const result = await validateAndGetModel(provider, model); + return { + provider: result.provider, + model: result.model, + selectedModel: result.selectedModel, + aiProvider: result.aiProvider, + }; + } catch (error: any) { + if (throwError) { + await notify.error(error.message); + throw error; + } + // 如果不抛出错误,返回原始值 + const aiProvider = AIProviderFactory.getProvider(provider); + return { + provider, + model, + selectedModel: undefined, + aiProvider, + }; } - - const config = ConfigurationManager.getInstance(); - await config.updateAIConfiguration( - modelSelection.provider, - modelSelection.model - ); - return { provider: modelSelection.provider, model: modelSelection.model }; } /** diff --git a/src/commands/GenerateCommitCommand.ts b/src/commands/GenerateCommitCommand.ts index 059f090..c04fea2 100644 --- a/src/commands/GenerateCommitCommand.ts +++ b/src/commands/GenerateCommitCommand.ts @@ -8,98 +8,12 @@ import { ModelPickerService } from "../services/ModelPickerService"; import { notify } from "../utils/notification"; import { getMessage, formatMessage } from "../utils/i18n"; import { ProgressHandler } from "../utils/notification/ProgressHandler"; +import { validateAndGetModel } from "../utils/ai/modelValidation"; /** * 提交信息生成命令类 */ export class GenerateCommitCommand extends BaseCommand { - /** - * 获取模型并更新配置 - * @param provider - 当前AI提供商 - * @param model - 当前模型名称 - * @returns 更新后的提供商、模型和AI实例信息 - * @throws Error 当无法获取模型列表或找不到指定模型时 - */ - protected async getModelAndUpdateConfiguration( - provider = "Ollama", - model = "Ollama" - ) { - let aiProvider = AIProviderFactory.getProvider(provider); - // 获取模型列表 - let models = await aiProvider.getModels(); - - // 如果模型为空或无法获取,直接让用户选择模型 - if (!models || models.length === 0) { - const { provider: newProvider, model: newModel } = - await this.selectAndUpdateModelConfiguration(provider, model); - provider = newProvider; - model = newModel; - - // 获取更新后的模型列表 - aiProvider = AIProviderFactory.getProvider(provider); - models = await aiProvider.getModels(); - - // 如果新的模型列表仍然为空,则抛出错误 - if (!models || models.length === 0) { - throw new Error(getMessage("model.list.empty")); - } - } - - // 查找已选择的模型 - let selectedModel = models.find((m) => m.name === model); - - // 如果没有找到对应的模型,弹窗让用户重新选择 - if (!selectedModel) { - const { provider: newProvider, model: newModel } = - await this.selectAndUpdateModelConfiguration(provider, model); - provider = newProvider; - model = newModel; - - // 获取更新后的模型列表 - aiProvider = AIProviderFactory.getProvider(provider); - models = await aiProvider.getModels(); - - // 选择有效的模型 - selectedModel = models.find((m) => m.name === model); - - // 如果依然没有找到对应的模型,抛出错误 - if (!selectedModel) { - throw new Error(getMessage("model.notFound")); - } - } - - return { provider, model, selectedModel, aiProvider }; - } - - /** - * 选择模型并更新配置 - * @param provider - 当前AI提供商 - * @param model - 当前模型名称 - * @returns 更新后的提供商和模型信息 - */ - protected async selectAndUpdateModelConfiguration( - provider = "Ollama", - model = "Ollama" - ) { - // 获取模型选择 - const modelSelection = await this.showModelPicker(provider, model); - - // 如果没有选择模型,则直接返回当前的 provider 和 model - if (!modelSelection) { - return { provider, model }; - } - - const config = ConfigurationManager.getInstance(); - // 使用新的封装方法更新配置 - await config.updateAIConfiguration( - modelSelection.provider, - modelSelection.model - ); - - // 返回更新后的 provider 和 model - return { provider: modelSelection.provider, model: modelSelection.model }; - } - /** * 处理AI配置 * @returns AI提供商和模型信息 @@ -142,6 +56,7 @@ export class GenerateCommitCommand extends BaseCommand { if (!configResult) { return; } + const { provider, model } = configResult; try { // 检测SCM提供程序 @@ -154,21 +69,10 @@ export class GenerateCommitCommand extends BaseCommand { // 获取当前提交输入框内容 const currentInput = await scmProvider.getCommitInput(); - // 获取配置信息 + // 获取配置信息以用于后续操作 const config = ConfigurationManager.getInstance(); const configuration = config.getConfiguration(); - // 获取或更新AI提供商和模型配置 - let provider = configuration.base.provider; - let model = configuration.base.model; - - if (!provider || !model) { - const { provider: newProvider, model: newModel } = - await this.selectAndUpdateModelConfiguration(provider, model); - provider = newProvider; - model = newModel; - } - // 使用进度提示生成提交信息 const response = await ProgressHandler.withProgress( formatMessage("progress.generating.commit", [ @@ -191,7 +95,7 @@ export class GenerateCommitCommand extends BaseCommand { model: newModel, aiProvider, selectedModel, - } = await this.getModelAndUpdateConfiguration(provider, model); + } = await this.selectAndUpdateModelConfiguration(provider, model); // 生成提交信息 const result = await aiProvider.generateResponse({ diff --git a/src/commands/ReviewCodeCommand.ts b/src/commands/ReviewCodeCommand.ts index cb30497..59481e2 100644 --- a/src/commands/ReviewCodeCommand.ts +++ b/src/commands/ReviewCodeCommand.ts @@ -6,6 +6,7 @@ import { withProgress, } from "../utils/notification/NotificationManager"; import * as path from "path"; +import { validateAndGetModel } from "../utils/ai/modelValidation"; /** * 代码审查命令类 @@ -73,12 +74,10 @@ export class ReviewCodeCommand extends BaseCommand { const { config, configuration } = this.getExtConfig(); let { provider, model } = configResult; - const { - provider: newProvider, - model: newModel, - aiProvider, - selectedModel, - } = await this.getModelAndUpdateConfiguration(provider, model); + const { aiProvider, selectedModel } = await validateAndGetModel( + provider, + model + ); await withProgress(getMessage("reviewing.code"), async (progress) => { // 获取所有选中文件的差异 diff --git a/src/utils/review/CodeReviewReportGenerator.ts b/src/services/CodeReviewReportGenerator.ts similarity index 97% rename from src/utils/review/CodeReviewReportGenerator.ts rename to src/services/CodeReviewReportGenerator.ts index c5054f9..f478991 100644 --- a/src/utils/review/CodeReviewReportGenerator.ts +++ b/src/services/CodeReviewReportGenerator.ts @@ -1,6 +1,6 @@ -import { CodeReviewResult, CodeReviewIssue } from "../../ai/types"; +import { CodeReviewResult, CodeReviewIssue } from "../ai/types"; import * as vscode from "vscode"; -import { getMessage } from "../i18n"; +import { getMessage } from "../utils/i18n"; /** * 代码审查报告生成器,将代码审查结果转换为格式化的 Markdown 文档 diff --git a/src/utils/ai/ai.md b/src/utils/ai/ai.md new file mode 100644 index 0000000..bfa900b --- /dev/null +++ b/src/utils/ai/ai.md @@ -0,0 +1,131 @@ +下面提供对该脚本中各个工具函数的详细文档说明,以帮助开发者更好地理解代码逻辑、参数说明以及返回结果等信息。 + +--- + +## 模块概述 + +本模块主要用于验证并获取 AI 模型的配置。在实际使用过程中,可能会遇到模型列表为空或选中的模型不存在的情况,此时会提示用户重新选择模型,并更新相关配置。模块中包含以下主要功能: + +- **验证并获取模型配置**:根据传入的 provider 与 model 名称,检查对应的 AI 模型是否存在;如果不存在则引导用户重新选择模型并更新配置。 +- **选择模型并更新配置**:通过调用 ModelPickerService 弹出模型选择器,让用户选择合适的模型,并将选择结果更新到全局配置中。 +- **提取 provider 与 selectedModel 信息**:在某些场景中只需要使用 AIProvider 与选中的模型,本工具函数用于简化返回值。 + +--- + +## 1. 函数:`validateAndGetModel` + +### 作用 + +验证并获取 AI 模型配置,确保指定的 provider 与 model 能够正确加载相应的 AI 模型。如果当前模型列表为空或找不到指定模型,则通过用户交互重新选择模型并更新配置。 + +### 参数 + +- `provider`(字符串,默认值 `"Ollama"`):初始的 AI 提供者名称。 +- `model`(字符串,默认值 `"Ollama"`):初始的模型名称。 + +### 返回值 + +返回一个 `Promise` 对象,其中 `ValidatedModelResult` 接口定义如下: + +```ts +interface ValidatedModelResult { + provider: string; // 最终确定的 AI 提供者名称 + model: string; // 最终选中的模型名称 + selectedModel: AIModel; // 选中的 AI 模型对象 + aiProvider: any; // 对应的 AI 提供者实例 +} +``` + +### 执行流程 + +1. **获取 AI 提供者实例** + 根据传入的 `provider` 参数,调用 `AIProviderFactory.getProvider` 获取对应的 AI 提供者实例 `aiProvider`。 + +2. **获取模型列表** + 调用 `aiProvider.getModels()` 获取所有可用的模型列表。 + +3. **判断模型列表是否为空** + + - 如果模型列表为空,则调用 `selectAndUpdateModel` 函数让用户重新选择模型。如果用户取消选择,则抛出错误 `"model.selection.cancelled"`;否则更新 `provider` 和 `model`,并重新获取模型列表。如果依然为空,则抛出错误 `"model.list.empty"`。 + +4. **查找选中的模型** + 在模型列表中查找名称与传入 `model` 参数相同的模型对象 `selectedModel`。 + +5. **判断选中的模型是否存在** + + - 如果找不到匹配的模型,同样调用 `selectAndUpdateModel` 进行重新选择。如果用户取消,则抛出错误 `"model.selection.cancelled"`。更新参数后再次查找模型;如果仍然找不到,则抛出错误 `"model.notFound"`。 + +6. **返回结果** + 返回包含 `provider`、`model`、`selectedModel` 和 `aiProvider` 的对象。 + +### 异常情况 + +- **用户取消模型选择**:抛出错误 `getMessage("model.selection.cancelled")`。 +- **模型列表为空**:抛出错误 `getMessage("model.list.empty")`。 +- **指定模型未找到**:抛出错误 `getMessage("model.notFound")`。 + +--- + +## 2. 函数:`selectAndUpdateModel` + +### 作用 + +引导用户选择模型,并更新全局配置中 AI 模型的设置。该函数主要用于在初始模型不可用或找不到指定模型时,通过交互方式获取正确的模型配置信息。 + +### 参数 + +- `provider`(字符串):当前的 AI 提供者名称。 +- `model`(字符串):当前的模型名称。 + +### 返回值 + +返回一个 `Promise` 对象,其值为用户选择的结果,格式同样为包含 `provider` 和 `model` 的对象。如果用户取消选择,则返回 `undefined`。 + +### 执行流程 + +1. **显示模型选择器** + 调用 `ModelPickerService.showModelPicker(provider, model)` 显示模型选择对话框,让用户选择可用的模型。 + +2. **判断选择结果** + 如果用户取消了选择(返回 `undefined` 或 `null`),则直接返回。 + +3. **更新配置** + 获取配置管理器单例 `ConfigurationManager.getInstance()`,并调用 `updateAIConfiguration` 方法,将选择的 `provider` 和 `model` 更新到全局配置中。 + +4. **返回选择结果** + 返回包含新的 `provider` 和 `model` 的对象。 + +--- + +## 3. 函数:`extractProviderAndModel` + +### 作用 + +用于在仅需要 AI 提供者实例 (`aiProvider`) 与选中的 AI 模型对象 (`selectedModel`) 的场景下,提取 `ValidatedModelResult` 对象中的这两个属性。 + +### 参数 + +- `result`(类型为 `ValidatedModelResult`):包含完整 AI 模型配置的对象。 + +### 返回值 + +返回一个包含以下两个属性的对象: + +- `aiProvider`:AI 提供者实例。 +- `selectedModel`:选中的 AI 模型对象。 + +### 使用场景 + +当调用 `validateAndGetModel` 获得完整配置后,在仅需要调用 AI 提供者实例与选中的模型时,可以通过该函数进行简化提取,避免重复访问整个配置对象。 + +--- + +## 总结 + +整个模块实现了 AI 模型配置的验证与选择过程,能够处理以下情况: + +- **初始模型列表为空**:自动引导用户重新选择模型,并更新配置。 +- **初始指定模型不存在**:通过用户交互获取新的模型选择,并确保返回有效的模型配置。 +- **简化调用**:通过 `extractProviderAndModel` 函数,方便其他模块仅提取所需的 provider 与 selectedModel 信息。 + +通过上述详细文档说明,开发者可以快速了解每个函数的职责、参数、返回值以及异常处理机制,从而在维护或扩展该模块时更加高效。 diff --git a/src/utils/ai/index.ts b/src/utils/ai/index.ts new file mode 100644 index 0000000..0f6732c --- /dev/null +++ b/src/utils/ai/index.ts @@ -0,0 +1 @@ +export * from "./modelValidation"; diff --git a/src/utils/ai/modelValidation.ts b/src/utils/ai/modelValidation.ts new file mode 100644 index 0000000..35a92ab --- /dev/null +++ b/src/utils/ai/modelValidation.ts @@ -0,0 +1,108 @@ +import { AIProviderFactory } from "../../ai/AIProviderFactory"; +import { ConfigurationManager } from "../../config/ConfigurationManager"; +import { ModelPickerService } from "../../services/ModelPickerService"; +import { getMessage } from "../i18n"; +import { AIModel } from "../../ai/types"; + +interface ValidatedModelResult { + provider: string; + model: string; + selectedModel: AIModel; + aiProvider: any; +} + +interface ProviderAndModels { + aiProvider: any; + models: AIModel[]; +} + +/** + * 获取provider实例和models列表 + */ +async function getProviderAndModels( + provider: string +): Promise { + const aiProvider = AIProviderFactory.getProvider(provider); + const models = await aiProvider.getModels(); + return { aiProvider, models }; +} + +/** + * 重新选择并验证模型 + */ +async function revalidateModel( + provider: string, + model: string +): Promise { + const result = await selectAndUpdateModel(provider, model); + if (!result) { + throw new Error(getMessage("model.selection.cancelled")); + } + + const { aiProvider, models } = await getProviderAndModels(result.provider); + + if (!models?.length) { + throw new Error(getMessage("model.list.empty")); + } + + const selectedModel = models.find((m) => m.name === result.model); + if (!selectedModel) { + throw new Error(getMessage("model.notFound")); + } + + return { + provider: result.provider, + model: result.model, + selectedModel, + aiProvider, + }; +} + +/** + * 验证并获取AI模型配置 + */ +export async function validateAndGetModel( + provider = "Ollama", + model = "Ollama" +): Promise { + let { aiProvider, models } = await getProviderAndModels(provider); + + if (!models?.length) { + return revalidateModel(provider, model); + } + + const selectedModel = models.find((m) => m.name === model); + if (!selectedModel) { + return revalidateModel(provider, model); + } + + return { + provider, + model, + selectedModel, + aiProvider, + }; +} + +/** + * 选择模型并更新配置 + */ +async function selectAndUpdateModel(provider: string, model: string) { + const selection = await ModelPickerService.showModelPicker(provider, model); + if (!selection) { + return; + } + + const config = ConfigurationManager.getInstance(); + await config.updateAIConfiguration(selection.provider, selection.model); + + return selection; +} + +// 用于仅需要provider和selectedModel的场景 +export function extractProviderAndModel(result: ValidatedModelResult) { + return { + aiProvider: result.aiProvider, + selectedModel: result.selectedModel, + }; +} diff --git a/src/utils/review/index.ts b/src/utils/review/index.ts index 746079d..54049ad 100644 --- a/src/utils/review/index.ts +++ b/src/utils/review/index.ts @@ -1 +1 @@ -export * from './CodeReviewReportGenerator'; \ No newline at end of file +export * from "../../services/CodeReviewReportGenerator";