Skip to content

Commit

Permalink
🔧 chore(ai): 增强 AI 提供商管理系统
Browse files Browse the repository at this point in the history
- 添加智谱、通义和豆包等新的 AI 模型提供商支持
- 改进 AI 提供商工厂的缓存管理机制,增加 30 分钟缓存清理
- 优化错误处理和模型信息展示
- 统一本地化消息的使用
- 重构 OpenAI Provider,提取基类以复用代码
- 改进各提供商的代码结构和异常处理
  • Loading branch information
littleCareless committed Dec 10, 2024
1 parent 5603b08 commit 1b36a48
Show file tree
Hide file tree
Showing 9 changed files with 620 additions and 295 deletions.
39 changes: 35 additions & 4 deletions src/ai/AIProviderFactory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,34 @@ import { AIProvider, ConfigKeys } from "../config/types";
import { ConfigurationManager } from "../config/ConfigurationManager";
import { VSCodeProvider } from "./providers/VscodeProvider";
import { LocalizationManager } from "../utils/LocalizationManager";
import { ZhipuAIProvider } from "./providers/ZhipuAIProvider";
import { DashScopeProvider } from "./providers/DashScopeProvider";
import { DoubaoProvider } from "./providers/DoubaoProvider";

export class AIProviderFactory {
private static providers: Map<string, AIProviderInterface> = new Map();
private static readonly PROVIDER_CACHE_TTL = 1000 * 60 * 30; // 30分钟缓存
private static providerTimestamps: Map<string, number> = new Map();

private static cleanStaleProviders() {
const now = Date.now();
for (const [id, timestamp] of this.providerTimestamps.entries()) {
if (now - timestamp > this.PROVIDER_CACHE_TTL) {
this.providers.delete(id);
this.providerTimestamps.delete(id);
}
}
}

public static getProvider(type?: string): AIProviderInterface {
// 如果未指定类型,使用默认提供商
this.cleanStaleProviders();
const providerType =
type ||
ConfigurationManager.getInstance().getConfig<string>("PROVIDER") ||
AIProvider.OPENAI;

let provider = this.providers.get(providerType);
console.log("provider", providerType);

if (!provider) {
switch (providerType.toLowerCase()) {
case AIProvider.OPENAI:
Expand All @@ -29,6 +44,15 @@ export class AIProviderFactory {
case AIProvider.VSCODE:
provider = new VSCodeProvider();
break;
case AIProvider.ZHIPU:
provider = new ZhipuAIProvider();
break;
case AIProvider.DASHSCOPE:
provider = new DashScopeProvider();
break;
case AIProvider.DOUBAO:
provider = new DoubaoProvider();
break;
default:
throw new Error(
LocalizationManager.getInstance().format(
Expand All @@ -38,14 +62,21 @@ export class AIProviderFactory {
);
}
this.providers.set(providerType, provider);
this.providerTimestamps.set(providerType, Date.now());
}

return provider;
}

public static getAllProviders(): AIProviderInterface[] {
// 返回所有可用的 AI Provider 实例
return [new OpenAIProvider(), new OllamaProvider(), new VSCodeProvider()];
return [
new OpenAIProvider(),
new OllamaProvider(),
new VSCodeProvider(),
new ZhipuAIProvider(),
new DashScopeProvider(),
new DoubaoProvider(),
];
}

public static reinitializeProvider(providerId: string): void {
Expand Down
100 changes: 100 additions & 0 deletions src/ai/providers/BaseOpenAIProvider.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import OpenAI from "openai";
import { ChatCompletionMessageParam } from "openai/resources";
import { AIProvider, AIRequestParams, AIResponse, AIModel } from "../types";
import { generateWithRetry, getSystemPrompt } from "../utils/generateHelper";

export interface OpenAIProviderConfig {
apiKey: string;
baseURL?: string;
apiVersion?: string;
providerId: string;
providerName: string;
defaultModel?: string;
models: AIModel[];
}

export abstract class BaseOpenAIProvider implements AIProvider {
protected openai: OpenAI;
protected config: OpenAIProviderConfig;
protected provider: { id: string; name: string };

constructor(config: OpenAIProviderConfig) {
this.config = config;
this.provider = {
id: config.providerId,
name: config.providerName,
};
this.openai = this.createClient();
}

protected createClient(): OpenAI {
const config: any = {
apiKey: this.config.apiKey,
};

if (this.config.baseURL) {
config.baseURL = this.config.baseURL;
if (this.config.apiKey) {
// config.defaultQuery = { "api-version": this.config.apiVersion };
config.defaultHeaders = { "api-key": this.config.apiKey };
}
}
console.log("config", config);

return new OpenAI(config);
}

async generateResponse(params: AIRequestParams): Promise<AIResponse> {
return generateWithRetry(
params,
async (truncatedDiff) => {
const messages: ChatCompletionMessageParam[] = [
{
role: "system",
content: getSystemPrompt(params),
},
{
role: "user",
content: truncatedDiff,
},
];

const completion = await this.openai.chat.completions.create({
model:
(params.model && params.model.id) ||
this.config.defaultModel ||
"gpt-3.5-turbo",
messages,
});

return {
content: completion.choices[0]?.message?.content || "",
usage: {
promptTokens: completion.usage?.prompt_tokens,
completionTokens: completion.usage?.completion_tokens,
totalTokens: completion.usage?.total_tokens,
},
};
},
{
initialMaxLength: params.model?.maxTokens?.input || 16385,
provider: this.getId(),
}
);
}

async getModels(): Promise<AIModel[]> {
return Promise.resolve(this.config.models);
}

getName(): string {
return this.provider.name;
}

getId(): string {
return this.provider.id;
}

abstract isAvailable(): Promise<boolean>;
abstract refreshModels(): Promise<string[]>;
}
145 changes: 145 additions & 0 deletions src/ai/providers/DashScopeProvider.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import { BaseOpenAIProvider } from "./BaseOpenAIProvider";
import { ConfigurationManager } from "../../config/ConfigurationManager";
import { AIModel } from "../types";

const dashscopeModels: AIModel[] = [
{
id: "qwen-max",
name: "Qwen Max (稳定版) - 旗舰模型: 强大的理解和生成能力",
maxTokens: { input: 30720, output: 8192 },
provider: { id: "dashscope", name: "DashScope" },
capabilities: {
streaming: true,
functionCalling: true,
},
cost: {
input: 0.02,
output: 0.06,
},
},
{
id: "qwen-max-latest",
name: "Qwen Max (最新版) - 旗舰实验版: 最新的模型改进和优化",
maxTokens: { input: 30720, output: 8192 },
provider: { id: "dashscope", name: "DashScope" },
capabilities: {
streaming: true,
functionCalling: true,
},
cost: {
input: 0.02,
output: 0.06,
},
},
{
id: "qwen-plus",
name: "Qwen Plus (稳定版) - 增强版: 性能与成本的最佳平衡",
maxTokens: { input: 129024, output: 8192 },
provider: { id: "dashscope", name: "DashScope" },
capabilities: {
streaming: true,
functionCalling: true,
},
cost: {
input: 0.0008,
output: 0.002,
},
},
{
id: "qwen-plus-latest",
name: "Qwen Plus (最新版) - 增强实验版: 新特性和优化的测试版本",
maxTokens: { input: 129024, output: 8192 },
provider: { id: "dashscope", name: "DashScope" },
capabilities: {
streaming: true,
functionCalling: true,
},
cost: {
input: 0.0008,
output: 0.002,
},
},
{
id: "qwen-turbo",
name: "Qwen Turbo (稳定版) - 快速版: 高性价比的日常对话模型",
maxTokens: { input: 129024, output: 8192 },
provider: { id: "dashscope", name: "DashScope" },
default: true,
capabilities: {
streaming: true,
functionCalling: true,
},
cost: {
input: 0.0003,
output: 0.0006,
},
},
{
id: "qwen-turbo-latest",
name: "Qwen Turbo (最新版) - 快速实验版: 优化推理速度的最新版本",
maxTokens: { input: 1000000, output: 8192 },
provider: { id: "dashscope", name: "DashScope" },
capabilities: {
streaming: true,
functionCalling: true,
},
cost: {
input: 0.0003,
output: 0.0006,
},
},
{
id: "qwen-coder-turbo",
name: "Qwen Coder Turbo (稳定版) - 编程专用: 代码生成和分析的专业模型",
maxTokens: { input: 129024, output: 8192 },
provider: { id: "dashscope", name: "DashScope" },
capabilities: {
streaming: true,
functionCalling: true,
},
cost: {
input: 0.002,
output: 0.006,
},
},
{
id: "qwen-coder-turbo-latest",
name: "Qwen Coder Turbo (最新版) - 编程实验版: 最新的代码辅助功能",
maxTokens: { input: 129024, output: 8192 },
provider: { id: "dashscope", name: "DashScope" },
capabilities: {
streaming: true,
functionCalling: true,
},
cost: {
input: 0.002,
output: 0.006,
},
},
];

export class DashScopeProvider extends BaseOpenAIProvider {
constructor() {
const configManager = ConfigurationManager.getInstance();
super({
apiKey: configManager.getConfig<string>("DASHSCOPE_API_KEY", false),
baseURL: "https://api.dashscope.com/v1/services/chat/completions",
providerId: "dashscope",
providerName: "DashScope",
models: dashscopeModels,
defaultModel: "qwen-turbo",
});
}

async isAvailable(): Promise<boolean> {
try {
return !!this.config.apiKey;
} catch {
return false;
}
}

async refreshModels(): Promise<string[]> {
return Promise.resolve(dashscopeModels.map((m) => m.id));
}
}
Loading

0 comments on commit 1b36a48

Please sign in to comment.