diff --git a/src/shared/lib/ai.ts b/src/shared/lib/ai.ts index ec88e0a..9a1b26b 100644 --- a/src/shared/lib/ai.ts +++ b/src/shared/lib/ai.ts @@ -1,247 +1,11 @@ -import "server-only" - -import { createCipheriv, createDecipheriv, createHash, randomBytes } from "node:crypto" -import OpenAI from "openai" -import type { ChatCompletionMessageParam } from "openai/resources/chat/completions" - -import { env } from "@/env.mjs" -import { db } from "@/shared/db" -import { aiProviders } from "@/shared/db/schema" -import { desc, eq } from "drizzle-orm" - -type ChatRole = "system" | "user" | "assistant" - -type ChatMessage = { - role: ChatRole - content: string -} - -type AiChatRequest = { - messages: ChatCompletionMessageParam[] - model: string - temperature: number - maxTokens?: number - thinking?: Record - providerId?: string -} - -const isRecord = (v: unknown): v is Record => typeof v === "object" && v !== null - -const isChatMessage = (v: unknown): v is ChatMessage => { - if (!isRecord(v)) return false - const role = String(v.role ?? "") - if (role !== "system" && role !== "user" && role !== "assistant") return false - const content = String(v.content ?? "") - return content.trim().length > 0 -} - -const extractText = (value: unknown): string => { - if (typeof value === "string") return value.trim() - if (Array.isArray(value)) { - const joined = value.map((item) => extractText(item)).filter(Boolean).join("\n") - return joined.trim() - } - if (isRecord(value)) { - const candidates = ["text", "content", "output_text", "reasoning", "reasoning_content", "thinking"] - for (const key of candidates) { - const text = extractText(value[key]) - if (text) return text - } - } - return "" -} - -const extractMessageContent = (message: unknown): string => { - if (!isRecord(message)) return "" - const direct = extractText(message.content) - if (direct) return direct - const candidates = ["reasoning", "reasoning_content", "thinking", "output", "text"] - for (const key of candidates) { - const text = extractText(message[key]) - if (text) return text - } - for (const value of Object.values(message)) { - const text = extractText(value) - if (text) return text - } - return "" -} - -export const parseAiChatPayload = (body: unknown): AiChatRequest => { - if (!isRecord(body)) throw new Error("Invalid payload") - - const rawMessages = Array.isArray(body.messages) ? body.messages : [] - const messages = rawMessages - .filter(isChatMessage) - .map((m) => ({ role: m.role, content: m.content })) as ChatCompletionMessageParam[] - - if (messages.length === 0) throw new Error("Messages are required") - - const model = String(body.model ?? env.AI_MODEL ?? "gpt-4o-mini").trim() - const temperatureRaw = Number(body.temperature ?? 0.2) - const temperature = Number.isFinite(temperatureRaw) ? Math.min(Math.max(temperatureRaw, 0), 2) : 0.2 - const maxTokensRaw = Number(body.max_tokens ?? body.maxTokens ?? 0) - const maxTokens = Number.isFinite(maxTokensRaw) && maxTokensRaw > 0 ? Math.floor(maxTokensRaw) : undefined - const thinking = isRecord(body.thinking) ? body.thinking : undefined - const providerId = typeof body.providerId === "string" ? body.providerId.trim() : undefined - - return { - messages, - model, - temperature, - maxTokens, - thinking, - providerId: providerId && providerId.length > 0 ? providerId : undefined, - } -} - -const getEncryptionKey = () => { - const secret = String(env.NEXTAUTH_SECRET ?? "").trim() - if (!secret) throw new Error("AI encryption secret missing") - return createHash("sha256").update(secret).digest() -} - -export const encryptAiApiKey = (value: string) => { - const iv = randomBytes(12) - const key = getEncryptionKey() - const cipher = createCipheriv("aes-256-gcm", key, iv) - const encrypted = Buffer.concat([cipher.update(value, "utf8"), cipher.final()]) - const tag = cipher.getAuthTag() - return Buffer.concat([iv, tag, encrypted]).toString("base64") -} - -export const decryptAiApiKey = (value: string) => { - const raw = Buffer.from(value, "base64") - if (raw.length < 28) throw new Error("Invalid API key payload") - const iv = raw.subarray(0, 12) - const tag = raw.subarray(12, 28) - const encrypted = raw.subarray(28) - const key = getEncryptionKey() - const decipher = createDecipheriv("aes-256-gcm", key, iv) - decipher.setAuthTag(tag) - const decrypted = Buffer.concat([decipher.update(encrypted), decipher.final()]) - return decrypted.toString("utf8") -} - -const getAiProviderConfig = async (providerId?: string) => { - if (providerId) { - const [selected] = await db - .select({ - apiKeyEncrypted: aiProviders.apiKeyEncrypted, - baseUrl: aiProviders.baseUrl, - model: aiProviders.model, - }) - .from(aiProviders) - .where(eq(aiProviders.id, providerId)) - .limit(1) - if (!selected) throw new Error("AI provider not configured") - return { - apiKey: decryptAiApiKey(selected.apiKeyEncrypted), - baseUrl: selected.baseUrl ?? undefined, - model: selected.model, - } - } - - const [active] = await db - .select({ - apiKeyEncrypted: aiProviders.apiKeyEncrypted, - baseUrl: aiProviders.baseUrl, - model: aiProviders.model, - }) - .from(aiProviders) - .where(eq(aiProviders.isDefault, true)) - .orderBy(desc(aiProviders.updatedAt)) - .limit(1) - if (active) { - return { - apiKey: decryptAiApiKey(active.apiKeyEncrypted), - baseUrl: active.baseUrl ?? undefined, - model: active.model, - } - } - - const [fallback] = await db - .select({ - apiKeyEncrypted: aiProviders.apiKeyEncrypted, - baseUrl: aiProviders.baseUrl, - model: aiProviders.model, - }) - .from(aiProviders) - .orderBy(desc(aiProviders.updatedAt)) - .limit(1) - if (!fallback) throw new Error("AI provider not configured") - - return { - apiKey: decryptAiApiKey(fallback.apiKeyEncrypted), - baseUrl: fallback.baseUrl ?? undefined, - model: fallback.model, - } -} - -const getAiClient = async (config: { apiKey: string; baseUrl?: string }) => { - const baseUrl = String(config.baseUrl ?? "https://api.openai.com").replace(/\/+$/, "") - return new OpenAI({ - apiKey: config.apiKey, - baseURL: baseUrl.length ? baseUrl : undefined, - }) -} - -export const testAiProviderConfig = async (input: { apiKey: string; baseUrl?: string; model: string }) => { - const client = await getAiClient({ apiKey: input.apiKey, baseUrl: input.baseUrl }) - const result = await client.chat.completions.create({ - model: input.model, - messages: [{ role: "user", content: "ping" }], - temperature: 0, - max_tokens: 1, - } as Parameters[0]) - const hasChoices = "choices" in result && Array.isArray(result.choices) && result.choices.length > 0 - if (!hasChoices) throw new Error("Empty response from provider. Check API URL, model, and API key.") - return true -} - -export const testAiProviderById = async ( - providerId: string, - overrides?: { baseUrl?: string; model?: string } -) => { - const config = await getAiProviderConfig(providerId) - const client = await getAiClient({ apiKey: config.apiKey, baseUrl: overrides?.baseUrl ?? config.baseUrl }) - const result = await client.chat.completions.create({ - model: overrides?.model ?? config.model, - messages: [{ role: "user", content: "ping" }], - temperature: 0, - max_tokens: 1, - } as Parameters[0]) - const hasChoices = "choices" in result && Array.isArray(result.choices) && result.choices.length > 0 - if (!hasChoices) throw new Error("Empty response from provider. Check API URL, model, and API key.") - return true -} - -export const createAiChatCompletion = async (input: AiChatRequest) => { - const config = await getAiProviderConfig(input.providerId) - const client = await getAiClient(config) - const result = (await client.chat.completions.create({ - model: config.model || input.model, - messages: input.messages, - temperature: input.temperature, - ...(typeof input.maxTokens === "number" ? { max_tokens: input.maxTokens } : {}), - ...(input.thinking ? { thinking: input.thinking } : {}), - } as Parameters[0])) as Awaited< - ReturnType - > - - const hasChoices = "choices" in result && Array.isArray(result.choices) && result.choices.length > 0 - if (!hasChoices) throw new Error("Empty response from provider. Check API URL, model, and API key.") - - const content = extractMessageContent(result.choices?.[0]?.message) - if (!content.trim()) throw new Error("Empty response content. Check model output settings.") - - const usage = "usage" in result ? result.usage ?? null : null - return { content, usage } -} - -export const getAiErrorMessage = (v: unknown) => { - if (v instanceof Error) return v.message - if (!isRecord(v)) return "AI request failed" - const message = String(v.message ?? "") - return message.trim().length ? message : "AI request failed" -} +// 此文件为向后兼容的重导出入口,实际实现已按职责拆分到 ./ai/ 目录: +// - payload-parser.ts 请求负载解析 +// - api-key-crypto.ts API Key 加密/解密 +// - provider-config.ts Provider 配置查询 +// - client.ts AI 客户端创建与调用 +// - errors.ts 错误格式化 +export { encryptAiApiKey, decryptAiApiKey } from "./ai/api-key-crypto" +export { createAiChatCompletion, testAiProviderById, testAiProviderConfig } from "./ai/client" +export { getAiErrorMessage } from "./ai/errors" +export { parseAiChatPayload, isRecord } from "./ai/payload-parser" +export type { AiChatRequest, ChatMessage, ChatRole } from "./ai/payload-parser" diff --git a/src/shared/lib/ai/api-key-crypto.ts b/src/shared/lib/ai/api-key-crypto.ts new file mode 100644 index 0000000..c5066e0 --- /dev/null +++ b/src/shared/lib/ai/api-key-crypto.ts @@ -0,0 +1,33 @@ +import "server-only" + +import { createCipheriv, createDecipheriv, createHash, randomBytes } from "node:crypto" + +import { env } from "@/env.mjs" + +const getEncryptionKey = () => { + const secret = String(env.NEXTAUTH_SECRET ?? "").trim() + if (!secret) throw new Error("AI encryption secret missing") + return createHash("sha256").update(secret).digest() +} + +export const encryptAiApiKey = (value: string) => { + const iv = randomBytes(12) + const key = getEncryptionKey() + const cipher = createCipheriv("aes-256-gcm", key, iv) + const encrypted = Buffer.concat([cipher.update(value, "utf8"), cipher.final()]) + const tag = cipher.getAuthTag() + return Buffer.concat([iv, tag, encrypted]).toString("base64") +} + +export const decryptAiApiKey = (value: string) => { + const raw = Buffer.from(value, "base64") + if (raw.length < 28) throw new Error("Invalid API key payload") + const iv = raw.subarray(0, 12) + const tag = raw.subarray(12, 28) + const encrypted = raw.subarray(28) + const key = getEncryptionKey() + const decipher = createDecipheriv("aes-256-gcm", key, iv) + decipher.setAuthTag(tag) + const decrypted = Buffer.concat([decipher.update(encrypted), decipher.final()]) + return decrypted.toString("utf8") +} diff --git a/src/shared/lib/ai/client.ts b/src/shared/lib/ai/client.ts new file mode 100644 index 0000000..184c871 --- /dev/null +++ b/src/shared/lib/ai/client.ts @@ -0,0 +1,67 @@ +import "server-only" + +import OpenAI from "openai" + +import { extractMessageContent, type AiChatRequest } from "./payload-parser" +import { getAiProviderConfig } from "./provider-config" + +const getAiClient = async (config: { apiKey: string; baseUrl?: string }) => { + const baseUrl = String(config.baseUrl ?? "https://api.openai.com").replace(/\/+$/, "") + return new OpenAI({ + apiKey: config.apiKey, + baseURL: baseUrl.length ? baseUrl : undefined, + }) +} + +export const testAiProviderConfig = async (input: { apiKey: string; baseUrl?: string; model: string }) => { + const client = await getAiClient({ apiKey: input.apiKey, baseUrl: input.baseUrl }) + const result = await client.chat.completions.create({ + model: input.model, + messages: [{ role: "user", content: "ping" }], + temperature: 0, + max_tokens: 1, + } as Parameters[0]) + const hasChoices = "choices" in result && Array.isArray(result.choices) && result.choices.length > 0 + if (!hasChoices) throw new Error("Empty response from provider. Check API URL, model, and API key.") + return true +} + +export const testAiProviderById = async ( + providerId: string, + overrides?: { baseUrl?: string; model?: string } +) => { + const config = await getAiProviderConfig(providerId) + const client = await getAiClient({ apiKey: config.apiKey, baseUrl: overrides?.baseUrl ?? config.baseUrl }) + const result = await client.chat.completions.create({ + model: overrides?.model ?? config.model, + messages: [{ role: "user", content: "ping" }], + temperature: 0, + max_tokens: 1, + } as Parameters[0]) + const hasChoices = "choices" in result && Array.isArray(result.choices) && result.choices.length > 0 + if (!hasChoices) throw new Error("Empty response from provider. Check API URL, model, and API key.") + return true +} + +export const createAiChatCompletion = async (input: AiChatRequest) => { + const config = await getAiProviderConfig(input.providerId) + const client = await getAiClient(config) + const result = (await client.chat.completions.create({ + model: config.model || input.model, + messages: input.messages, + temperature: input.temperature, + ...(typeof input.maxTokens === "number" ? { max_tokens: input.maxTokens } : {}), + ...(input.thinking ? { thinking: input.thinking } : {}), + } as Parameters[0])) as Awaited< + ReturnType + > + + const hasChoices = "choices" in result && Array.isArray(result.choices) && result.choices.length > 0 + if (!hasChoices) throw new Error("Empty response from provider. Check API URL, model, and API key.") + + const content = extractMessageContent(result.choices?.[0]?.message) + if (!content.trim()) throw new Error("Empty response content. Check model output settings.") + + const usage = "usage" in result ? result.usage ?? null : null + return { content, usage } +} diff --git a/src/shared/lib/ai/errors.ts b/src/shared/lib/ai/errors.ts new file mode 100644 index 0000000..e9f0633 --- /dev/null +++ b/src/shared/lib/ai/errors.ts @@ -0,0 +1,10 @@ +import "server-only" + +import { isRecord } from "./payload-parser" + +export const getAiErrorMessage = (v: unknown) => { + if (v instanceof Error) return v.message + if (!isRecord(v)) return "AI request failed" + const message = String(v.message ?? "") + return message.trim().length ? message : "AI request failed" +} diff --git a/src/shared/lib/ai/index.ts b/src/shared/lib/ai/index.ts new file mode 100644 index 0000000..8acd4e7 --- /dev/null +++ b/src/shared/lib/ai/index.ts @@ -0,0 +1,5 @@ +export { encryptAiApiKey, decryptAiApiKey } from "./api-key-crypto" +export { createAiChatCompletion, testAiProviderById, testAiProviderConfig } from "./client" +export { getAiErrorMessage } from "./errors" +export { parseAiChatPayload, isRecord } from "./payload-parser" +export type { AiChatRequest, ChatMessage, ChatRole } from "./payload-parser" diff --git a/src/shared/lib/ai/payload-parser.ts b/src/shared/lib/ai/payload-parser.ts new file mode 100644 index 0000000..1027bc0 --- /dev/null +++ b/src/shared/lib/ai/payload-parser.ts @@ -0,0 +1,93 @@ +import "server-only" + +import type { ChatCompletionMessageParam } from "openai/resources/chat/completions" + +import { env } from "@/env.mjs" + +export type ChatRole = "system" | "user" | "assistant" + +export type ChatMessage = { + role: ChatRole + content: string +} + +export type AiChatRequest = { + messages: ChatCompletionMessageParam[] + model: string + temperature: number + maxTokens?: number + thinking?: Record + providerId?: string +} + +export const isRecord = (v: unknown): v is Record => typeof v === "object" && v !== null + +const isChatMessage = (v: unknown): v is ChatMessage => { + if (!isRecord(v)) return false + const role = String(v.role ?? "") + if (role !== "system" && role !== "user" && role !== "assistant") return false + const content = String(v.content ?? "") + return content.trim().length > 0 +} + +const extractText = (value: unknown): string => { + if (typeof value === "string") return value.trim() + if (Array.isArray(value)) { + const joined = value.map((item) => extractText(item)).filter(Boolean).join("\n") + return joined.trim() + } + if (isRecord(value)) { + const candidates = ["text", "content", "output_text", "reasoning", "reasoning_content", "thinking"] + for (const key of candidates) { + const text = extractText(value[key]) + if (text) return text + } + } + return "" +} + +const extractMessageContent = (message: unknown): string => { + if (!isRecord(message)) return "" + const direct = extractText(message.content) + if (direct) return direct + const candidates = ["reasoning", "reasoning_content", "thinking", "output", "text"] + for (const key of candidates) { + const text = extractText(message[key]) + if (text) return text + } + for (const value of Object.values(message)) { + const text = extractText(value) + if (text) return text + } + return "" +} + +export const parseAiChatPayload = (body: unknown): AiChatRequest => { + if (!isRecord(body)) throw new Error("Invalid payload") + + const rawMessages = Array.isArray(body.messages) ? body.messages : [] + const messages = rawMessages + .filter(isChatMessage) + .map((m) => ({ role: m.role, content: m.content })) as ChatCompletionMessageParam[] + + if (messages.length === 0) throw new Error("Messages are required") + + const model = String(body.model ?? env.AI_MODEL ?? "gpt-4o-mini").trim() + const temperatureRaw = Number(body.temperature ?? 0.2) + const temperature = Number.isFinite(temperatureRaw) ? Math.min(Math.max(temperatureRaw, 0), 2) : 0.2 + const maxTokensRaw = Number(body.max_tokens ?? body.maxTokens ?? 0) + const maxTokens = Number.isFinite(maxTokensRaw) && maxTokensRaw > 0 ? Math.floor(maxTokensRaw) : undefined + const thinking = isRecord(body.thinking) ? body.thinking : undefined + const providerId = typeof body.providerId === "string" ? body.providerId.trim() : undefined + + return { + messages, + model, + temperature, + maxTokens, + thinking, + providerId: providerId && providerId.length > 0 ? providerId : undefined, + } +} + +export { extractMessageContent } diff --git a/src/shared/lib/ai/provider-config.ts b/src/shared/lib/ai/provider-config.ts new file mode 100644 index 0000000..16192f7 --- /dev/null +++ b/src/shared/lib/ai/provider-config.ts @@ -0,0 +1,69 @@ +import "server-only" + +import { desc, eq } from "drizzle-orm" + +import { db } from "@/shared/db" +import { aiProviders } from "@/shared/db/schema" + +import { decryptAiApiKey } from "./api-key-crypto" + +export type AiProviderConfig = { + apiKey: string + baseUrl?: string + model: string +} + +export const getAiProviderConfig = async (providerId?: string): Promise => { + if (providerId) { + const [selected] = await db + .select({ + apiKeyEncrypted: aiProviders.apiKeyEncrypted, + baseUrl: aiProviders.baseUrl, + model: aiProviders.model, + }) + .from(aiProviders) + .where(eq(aiProviders.id, providerId)) + .limit(1) + if (!selected) throw new Error("AI provider not configured") + return { + apiKey: decryptAiApiKey(selected.apiKeyEncrypted), + baseUrl: selected.baseUrl ?? undefined, + model: selected.model, + } + } + + const [active] = await db + .select({ + apiKeyEncrypted: aiProviders.apiKeyEncrypted, + baseUrl: aiProviders.baseUrl, + model: aiProviders.model, + }) + .from(aiProviders) + .where(eq(aiProviders.isDefault, true)) + .orderBy(desc(aiProviders.updatedAt)) + .limit(1) + if (active) { + return { + apiKey: decryptAiApiKey(active.apiKeyEncrypted), + baseUrl: active.baseUrl ?? undefined, + model: active.model, + } + } + + const [fallback] = await db + .select({ + apiKeyEncrypted: aiProviders.apiKeyEncrypted, + baseUrl: aiProviders.baseUrl, + model: aiProviders.model, + }) + .from(aiProviders) + .orderBy(desc(aiProviders.updatedAt)) + .limit(1) + if (!fallback) throw new Error("AI provider not configured") + + return { + apiKey: decryptAiApiKey(fallback.apiKeyEncrypted), + baseUrl: fallback.baseUrl ?? undefined, + model: fallback.model, + } +}