src / generator.ts
import { type Chat, type GeneratorController, type InferParsedConfig } from "@lmstudio/sdk";
import OpenAI from "openai";
import {
type ChatCompletionMessageParam,
type ChatCompletionMessageToolCall,
type ChatCompletionTool,
type ChatCompletionToolMessageParam,
} from "openai/resources/index";
import { configSchematics, globalConfigSchematics, placeholderModelValue } from "./config";
import { writeCachedState } from "./modelCache";
/* -------------------------------------------------------------------------- */
/* Types */
/* -------------------------------------------------------------------------- */
type ToolCallState = {
id: string;
name: string | null;
index: number;
arguments: string;
};
type ModelsResponse = {
data?: Array<{ id?: string }>;
};
type ModelRefreshResult = {
models: string[];
didUpdate: boolean;
};
const modelsCache = {
baseUrl: "",
models: [] as string[],
lastFetchedAt: 0,
inFlight: null as Promise<void> | null,
};
let refreshNotified = false;
const modelsRefreshIntervalMs = 15_000;
function buildModelsUrl(overrideBaseUrl: string) {
const trimmed = overrideBaseUrl.replace(/\/+$/g, "");
if (trimmed.endsWith("/v1")) {
return `${trimmed}/models`;
}
return `${trimmed}/v1/models`;
}
function sameModelList(a: string[], b: string[]) {
if (a.length !== b.length) return false;
for (let i = 0; i < a.length; i += 1) {
if (a[i] !== b[i]) return false;
}
return true;
}
async function refreshModelOptionsIfNeeded(
overrideBaseUrl: string,
currentModel: string,
forceRefresh: boolean,
): Promise<ModelRefreshResult> {
if (!overrideBaseUrl) {
console.info("Model refresh skipped: overrideBaseUrl is empty.");
return { models: modelsCache.models, didUpdate: false };
}
if (typeof fetch !== "function") {
console.warn("Fetch is not available; skipping model list refresh.");
return { models: modelsCache.models, didUpdate: false };
}
const modelsUrl = buildModelsUrl(overrideBaseUrl);
console.info(`Fetching model list from ${modelsUrl}`);
const now = Date.now();
if (modelsCache.inFlight) {
await modelsCache.inFlight;
return { models: modelsCache.models, didUpdate: false };
}
if (
!forceRefresh &&
modelsCache.baseUrl === modelsUrl &&
now - modelsCache.lastFetchedAt < modelsRefreshIntervalMs
) {
return { models: modelsCache.models, didUpdate: false };
}
let didUpdate = false;
modelsCache.inFlight = (async () => {
try {
const response = await fetch(modelsUrl, { method: "GET" });
if (!response.ok) {
console.info(`Model refresh skipped (HTTP ${response.status}) from ${modelsUrl}`);
return;
}
const payload = (await response.json()) as ModelsResponse;
const rawModels = Array.isArray(payload?.data) ? payload.data : [];
const modelIds = rawModels
.map(model => model?.id)
.filter((id): id is string => typeof id === "string" && id.length > 0);
if (
currentModel &&
currentModel !== placeholderModelValue &&
!modelIds.includes(currentModel)
) {
modelIds.unshift(currentModel);
}
if (!modelIds.length) {
return;
}
if (
!forceRefresh &&
modelsCache.baseUrl === modelsUrl &&
sameModelList(modelsCache.models, modelIds)
) {
modelsCache.lastFetchedAt = now;
return;
}
modelsCache.baseUrl = modelsUrl;
modelsCache.models = modelIds;
modelsCache.lastFetchedAt = now;
const lastSelected =
currentModel && currentModel !== placeholderModelValue ? currentModel : undefined;
await writeCachedState(modelIds, lastSelected);
didUpdate = true;
console.info(`Model options refreshed from ${modelsUrl}`);
} catch (error) {
console.warn("Failed to refresh model list:", error);
} finally {
modelsCache.inFlight = null;
}
})();
await modelsCache.inFlight;
return { models: modelsCache.models, didUpdate };
}
/* -------------------------------------------------------------------------- */
/* Build helpers */
/* -------------------------------------------------------------------------- */
/** Build a pre-configured OpenAI client. */
function createOpenAI(globalConfig: InferParsedConfig<typeof globalConfigSchematics>) {
const overrideBaseUrl = globalConfig.get("overrideBaseUrl");
// Use override URL if provided, otherwise auto-detect based on model type
const baseURL = overrideBaseUrl || "https://api.openai.com/v1";
const apiKey = globalConfig.get("openaiApiKey") || "local";
return new OpenAI({
apiKey,
baseURL,
});
}
/** Convert internal chat history to the format expected by OpenAI. */
function toOpenAIMessages(history: Chat): ChatCompletionMessageParam[] {
const messages: ChatCompletionMessageParam[] = [];
for (const message of history) {
switch (message.getRole()) {
case "system":
messages.push({ role: "system", content: message.getText() });
break;
case "user":
messages.push({ role: "user", content: message.getText() });
break;
case "assistant": {
const toolCalls: ChatCompletionMessageToolCall[] = message
.getToolCallRequests()
.map(toolCall => ({
id: toolCall.id ?? "",
type: "function",
function: {
name: toolCall.name,
arguments: JSON.stringify(toolCall.arguments ?? {}),
},
}));
messages.push({
role: "assistant",
content: message.getText(),
...(toolCalls.length ? { tool_calls: toolCalls } : {}),
});
break;
}
case "tool": {
message.getToolCallResults().forEach(toolCallResult => {
messages.push({
role: "tool",
tool_call_id: toolCallResult.toolCallId ?? "",
content: toolCallResult.content,
} as ChatCompletionToolMessageParam);
});
break;
}
}
}
return messages;
}
/** Convert LM Studio tool definitions to OpenAI function-tool descriptors. */
function toOpenAITools(ctl: GeneratorController): ChatCompletionTool[] | undefined {
const tools = ctl.getToolDefinitions().map<ChatCompletionTool>(t => ({
type: "function",
function: {
name: t.function.name,
description: t.function.description,
parameters: t.function.parameters ?? {},
},
}));
return tools.length ? tools : undefined;
}
/* -------------------------------------------------------------------------- */
/* Stream-handling utils */
/* -------------------------------------------------------------------------- */
function wireAbort(ctl: GeneratorController, stream: { controller: AbortController }) {
ctl.onAborted(() => {
console.info("Generation aborted by user.");
stream.controller.abort();
});
}
async function consumeStream(stream: AsyncIterable<any>, ctl: GeneratorController) {
let current: ToolCallState | null = null;
function maybeFlushCurrentToolCall() {
if (current === null || current.name === null) {
return;
}
ctl.toolCallGenerationEnded({
type: "function",
name: current.name,
arguments: JSON.parse(current.arguments),
id: current.id,
});
current = null;
}
for await (const chunk of stream) {
console.info("Received chunk:", JSON.stringify(chunk));
const delta = chunk.choices?.[0]?.delta as
| {
content?: string;
tool_calls?: Array<{
index: number;
id?: string;
function?: { name?: string; arguments?: string };
}>;
}
| undefined;
if (!delta) continue;
/* Text streaming */
if (delta.content) {
ctl.fragmentGenerated(delta.content);
}
/* Tool-call streaming */
for (const toolCall of delta.tool_calls ?? []) {
if (toolCall.id !== undefined) {
maybeFlushCurrentToolCall();
current = { id: toolCall.id, name: null, index: toolCall.index, arguments: "" };
ctl.toolCallGenerationStarted();
}
if (toolCall.function?.name && current) {
current.name = toolCall.function.name;
ctl.toolCallGenerationNameReceived(toolCall.function.name);
}
if (toolCall.function?.arguments && current) {
current.arguments += toolCall.function.arguments;
ctl.toolCallGenerationArgumentFragmentGenerated(toolCall.function.arguments);
}
}
/* Finalize tool call */
if (chunk.choices?.[0]?.finish_reason === "tool_calls" && current?.name) {
maybeFlushCurrentToolCall();
}
}
console.info("Generation completed.");
}
/* -------------------------------------------------------------------------- */
/* API */
/* -------------------------------------------------------------------------- */
export async function generate(ctl: GeneratorController, history: Chat) {
const config = ctl.getPluginConfig(configSchematics);
const globalConfig = ctl.getGlobalPluginConfig(globalConfigSchematics);
const overrideBaseUrl = globalConfig.get("overrideBaseUrl");
const selectedModel = config.get("model");
console.info(
`Generate: model=${selectedModel ?? ""} baseUrl=${overrideBaseUrl ?? ""} cached=${modelsCache.models.length}`,
);
if (!overrideBaseUrl && (!selectedModel || selectedModel === placeholderModelValue)) {
ctl.fragmentGenerated("Set Override Base URL, then run once to load models.");
return;
}
const shouldRefreshModels =
!selectedModel || selectedModel === placeholderModelValue || modelsCache.models.length === 0;
if (overrideBaseUrl && shouldRefreshModels) {
const isPlaceholder = !selectedModel || selectedModel === placeholderModelValue;
if (modelsCache.models.length > 0 && isPlaceholder) {
if (!refreshNotified) {
ctl.fragmentGenerated("Models cached. Restart LM Studio to update the dropdown.");
refreshNotified = true;
}
return;
}
const refreshResult = await refreshModelOptionsIfNeeded(
overrideBaseUrl,
selectedModel,
true,
);
if (isPlaceholder) {
if (refreshResult.didUpdate) {
if (!refreshNotified) {
ctl.fragmentGenerated("Model list refreshed. Restart LM Studio to update the dropdown.");
refreshNotified = true;
}
} else {
ctl.fragmentGenerated("Model list not updated. Verify Base URL and try again.");
}
return;
}
// If a model is already selected, refresh silently.
}
const resolvedModel = selectedModel;
if (!resolvedModel || resolvedModel === placeholderModelValue) {
throw new Error("Model list not loaded yet. Set base URL, then retry to refresh models.");
}
/* 1. Setup client & payload */
const openai = createOpenAI(globalConfig);
const messages = toOpenAIMessages(history);
const tools = toOpenAITools(ctl);
/* 2. Kick off streaming completion */
const stream = await openai.chat.completions.create({
model: resolvedModel,
messages,
tools,
stream: true,
});
/* 3. Abort wiring & stream processing */
wireAbort(ctl, stream as any);
await consumeStream(stream as any, ctl);
}