Skip to content

Commit

Permalink
♻️ refactor(ai): 重构模型验证和选择逻辑
Browse files Browse the repository at this point in the history
- 【重构】将模型验证和选择逻辑从命令类中抽离到独立模块
- 【新增】创建 modelValidation 工具类统一处理模型验证逻辑
- 【优化】简化命令类中的模型配置处理流程
- 【移动】迁移 CodeReviewReportGenerator 到 services 目录
- 【文档】新增 AI 模块文档说明其功能和使用方式
  • Loading branch information
littleCareless committed Feb 5, 2025
1 parent 9a2b4cf commit 4bd0e55
Show file tree
Hide file tree
Showing 10 changed files with 286 additions and 179 deletions.
16 changes: 7 additions & 9 deletions src/ai/providers/BaseOpenAIProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import {
} from "../utils/generateHelper";

import { getWeeklyReportPrompt } from "../../prompt/weeklyReport";
import { CodeReviewReportGenerator } from "../../utils/review/CodeReviewReportGenerator";
import { CodeReviewReportGenerator } from "../../services/CodeReviewReportGenerator";
import { formatMessage } from "../../utils/i18n/LocalizationManager";

/**
Expand Down Expand Up @@ -193,10 +193,9 @@ export abstract class BaseOpenAIProvider implements AIProvider {
},
};
} catch (error) {
const message = formatMessage(
"codeReview.generation.failed",
[error instanceof Error ? error.message : String(error)]
);
const message = formatMessage("codeReview.generation.failed", [
error instanceof Error ? error.message : String(error),
]);
throw new Error(message);
}
},
Expand Down Expand Up @@ -244,10 +243,9 @@ export abstract class BaseOpenAIProvider implements AIProvider {
};
} catch (error) {
throw new Error(
formatMessage(
"weeklyReport.generation.failed",
[error instanceof Error ? error.message : String(error)]
)
formatMessage("weeklyReport.generation.failed", [
error instanceof Error ? error.message : String(error),
])
);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/ai/providers/VscodeProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { generateCommitMessageSystemPrompt } from "../../prompt/prompt";
import { getCodeReviewPrompt, getSystemPrompt } from "../utils/generateHelper";
import { getWeeklyReportPrompt } from "../../prompt/weeklyReport";
import { getMessage, formatMessage } from "../../utils/i18n";
import { CodeReviewReportGenerator } from "../../utils/review/CodeReviewReportGenerator";
import { CodeReviewReportGenerator } from "../../services/CodeReviewReportGenerator";

interface DiffBlock {
header: string;
Expand Down
86 changes: 26 additions & 60 deletions src/commands/BaseCommand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { SCMFactory } from "../scm/SCMProvider";
import { ModelPickerService } from "../services/ModelPickerService";
import { notify } from "../utils/notification/NotificationManager";
import { getMessage, formatMessage } from "../utils/i18n";
import { validateAndGetModel } from "../utils/ai/modelValidation";

/**
* 基础命令类,提供通用的命令执行功能
Expand Down Expand Up @@ -65,81 +66,46 @@ export abstract class BaseCommand {
let model = configuration.base.model;

if (!provider || !model) {
return this.selectAndUpdateModelConfiguration(provider, model);
return this.selectAndUpdateModelConfiguration(provider, model, true);
}

return { provider, model };
}

/**
* 获取模型并更新配置
* @param provider - AI提供商名称
* @param model - 模型名称
* @returns 更新后的提供商、模型和AI实例信息
* @throws Error 当无法获取模型列表或找不到指定模型时
*/
protected async getModelAndUpdateConfiguration(
provider = "Ollama",
model = "Ollama"
) {
let aiProvider = AIProviderFactory.getProvider(provider);
let models = await aiProvider.getModels();

if (!models || models.length === 0) {
const { provider: newProvider, model: newModel } =
await this.selectAndUpdateModelConfiguration(provider, model);
provider = newProvider;
model = newModel;

aiProvider = AIProviderFactory.getProvider(provider);
models = await aiProvider.getModels();

if (!models || models.length === 0) {
throw new Error(getMessage("model.list.empty"));
}
}

let selectedModel = models.find((m) => m.name === model);

if (!selectedModel) {
const { provider: newProvider, model: newModel } =
await this.selectAndUpdateModelConfiguration(provider, model);
provider = newProvider;
model = newModel;

aiProvider = AIProviderFactory.getProvider(provider);
models = await aiProvider.getModels();
selectedModel = models.find((m) => m.name === model);

if (!selectedModel) {
throw new Error(getMessage("model.notFound"));
}
}

return { provider, model, selectedModel, aiProvider };
}

/**
* 选择模型并更新配置
* @param provider - 当前AI提供商
* @param model - 当前模型名称
* @param throwError - 是否抛出错误,默认为false
* @returns 更新后的提供商和模型信息
*/
protected async selectAndUpdateModelConfiguration(
provider = "Ollama",
model = "Ollama"
model = "Ollama",
throwError = false
) {
const modelSelection = await this.showModelPicker(provider, model);
if (!modelSelection) {
return { provider, model };
try {
const result = await validateAndGetModel(provider, model);
return {
provider: result.provider,
model: result.model,
selectedModel: result.selectedModel,
aiProvider: result.aiProvider,
};
} catch (error: any) {
if (throwError) {
await notify.error(error.message);
throw error;
}
// 如果不抛出错误,返回原始值
const aiProvider = AIProviderFactory.getProvider(provider);
return {
provider,
model,
selectedModel: undefined,
aiProvider,
};
}

const config = ConfigurationManager.getInstance();
await config.updateAIConfiguration(
modelSelection.provider,
modelSelection.model
);
return { provider: modelSelection.provider, model: modelSelection.model };
}

/**
Expand Down
104 changes: 4 additions & 100 deletions src/commands/GenerateCommitCommand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,98 +8,12 @@ import { ModelPickerService } from "../services/ModelPickerService";
import { notify } from "../utils/notification";
import { getMessage, formatMessage } from "../utils/i18n";
import { ProgressHandler } from "../utils/notification/ProgressHandler";
import { validateAndGetModel } from "../utils/ai/modelValidation";

/**
* 提交信息生成命令类
*/
export class GenerateCommitCommand extends BaseCommand {
/**
* 获取模型并更新配置
* @param provider - 当前AI提供商
* @param model - 当前模型名称
* @returns 更新后的提供商、模型和AI实例信息
* @throws Error 当无法获取模型列表或找不到指定模型时
*/
protected async getModelAndUpdateConfiguration(
provider = "Ollama",
model = "Ollama"
) {
let aiProvider = AIProviderFactory.getProvider(provider);
// 获取模型列表
let models = await aiProvider.getModels();

// 如果模型为空或无法获取,直接让用户选择模型
if (!models || models.length === 0) {
const { provider: newProvider, model: newModel } =
await this.selectAndUpdateModelConfiguration(provider, model);
provider = newProvider;
model = newModel;

// 获取更新后的模型列表
aiProvider = AIProviderFactory.getProvider(provider);
models = await aiProvider.getModels();

// 如果新的模型列表仍然为空,则抛出错误
if (!models || models.length === 0) {
throw new Error(getMessage("model.list.empty"));
}
}

// 查找已选择的模型
let selectedModel = models.find((m) => m.name === model);

// 如果没有找到对应的模型,弹窗让用户重新选择
if (!selectedModel) {
const { provider: newProvider, model: newModel } =
await this.selectAndUpdateModelConfiguration(provider, model);
provider = newProvider;
model = newModel;

// 获取更新后的模型列表
aiProvider = AIProviderFactory.getProvider(provider);
models = await aiProvider.getModels();

// 选择有效的模型
selectedModel = models.find((m) => m.name === model);

// 如果依然没有找到对应的模型,抛出错误
if (!selectedModel) {
throw new Error(getMessage("model.notFound"));
}
}

return { provider, model, selectedModel, aiProvider };
}

/**
* 选择模型并更新配置
* @param provider - 当前AI提供商
* @param model - 当前模型名称
* @returns 更新后的提供商和模型信息
*/
protected async selectAndUpdateModelConfiguration(
provider = "Ollama",
model = "Ollama"
) {
// 获取模型选择
const modelSelection = await this.showModelPicker(provider, model);

// 如果没有选择模型,则直接返回当前的 provider 和 model
if (!modelSelection) {
return { provider, model };
}

const config = ConfigurationManager.getInstance();
// 使用新的封装方法更新配置
await config.updateAIConfiguration(
modelSelection.provider,
modelSelection.model
);

// 返回更新后的 provider 和 model
return { provider: modelSelection.provider, model: modelSelection.model };
}

/**
* 处理AI配置
* @returns AI提供商和模型信息
Expand Down Expand Up @@ -142,6 +56,7 @@ export class GenerateCommitCommand extends BaseCommand {
if (!configResult) {
return;
}
const { provider, model } = configResult;

try {
// 检测SCM提供程序
Expand All @@ -154,21 +69,10 @@ export class GenerateCommitCommand extends BaseCommand {
// 获取当前提交输入框内容
const currentInput = await scmProvider.getCommitInput();

// 获取配置信息
// 获取配置信息以用于后续操作
const config = ConfigurationManager.getInstance();
const configuration = config.getConfiguration();

// 获取或更新AI提供商和模型配置
let provider = configuration.base.provider;
let model = configuration.base.model;

if (!provider || !model) {
const { provider: newProvider, model: newModel } =
await this.selectAndUpdateModelConfiguration(provider, model);
provider = newProvider;
model = newModel;
}

// 使用进度提示生成提交信息
const response = await ProgressHandler.withProgress(
formatMessage("progress.generating.commit", [
Expand All @@ -191,7 +95,7 @@ export class GenerateCommitCommand extends BaseCommand {
model: newModel,
aiProvider,
selectedModel,
} = await this.getModelAndUpdateConfiguration(provider, model);
} = await this.selectAndUpdateModelConfiguration(provider, model);

// 生成提交信息
const result = await aiProvider.generateResponse({
Expand Down
11 changes: 5 additions & 6 deletions src/commands/ReviewCodeCommand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
withProgress,
} from "../utils/notification/NotificationManager";
import * as path from "path";
import { validateAndGetModel } from "../utils/ai/modelValidation";

/**
* 代码审查命令类
Expand Down Expand Up @@ -73,12 +74,10 @@ export class ReviewCodeCommand extends BaseCommand {
const { config, configuration } = this.getExtConfig();
let { provider, model } = configResult;

const {
provider: newProvider,
model: newModel,
aiProvider,
selectedModel,
} = await this.getModelAndUpdateConfiguration(provider, model);
const { aiProvider, selectedModel } = await validateAndGetModel(
provider,
model
);

await withProgress(getMessage("reviewing.code"), async (progress) => {
// 获取所有选中文件的差异
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { CodeReviewResult, CodeReviewIssue } from "../../ai/types";
import { CodeReviewResult, CodeReviewIssue } from "../ai/types";
import * as vscode from "vscode";
import { getMessage } from "../i18n";
import { getMessage } from "../utils/i18n";

/**
* 代码审查报告生成器,将代码审查结果转换为格式化的 Markdown 文档
Expand Down
Loading

0 comments on commit 4bd0e55

Please sign in to comment.