openai-compat-endpoint

Public

src / generator.ts

import { type Chat, type GeneratorController, type InferParsedConfig } from "@lmstudio/sdk";
import OpenAI from "openai";
import {
  type ChatCompletionMessageParam,
  type ChatCompletionMessageToolCall,
  type ChatCompletionTool,
  type ChatCompletionToolMessageParam,
} from "openai/resources/index";
import { configSchematics, globalConfigSchematics, placeholderModelValue } from "./config";
import { writeCachedState } from "./modelCache";

/* -------------------------------------------------------------------------- */
/*                                   Types                                    */
/* -------------------------------------------------------------------------- */

type ToolCallState = {
  id: string;
  name: string | null;
  index: number;
  arguments: string;
};

type ModelsResponse = {
  data?: Array<{ id?: string }>;
};

type ModelRefreshResult = {
  models: string[];
  didUpdate: boolean;
};

const modelsCache = {
  baseUrl: "",
  models: [] as string[],
  lastFetchedAt: 0,
  inFlight: null as Promise<void> | null,
};
let refreshNotified = false;

const modelsRefreshIntervalMs = 15_000;

function buildModelsUrl(overrideBaseUrl: string) {
  const trimmed = overrideBaseUrl.replace(/\/+$/g, "");
  if (trimmed.endsWith("/v1")) {
    return `${trimmed}/models`;
  }
  return `${trimmed}/v1/models`;
}

function sameModelList(a: string[], b: string[]) {
  if (a.length !== b.length) return false;
  for (let i = 0; i < a.length; i += 1) {
    if (a[i] !== b[i]) return false;
  }
  return true;
}

async function refreshModelOptionsIfNeeded(
  overrideBaseUrl: string,
  currentModel: string,
  forceRefresh: boolean,
): Promise<ModelRefreshResult> {
  if (!overrideBaseUrl) {
    console.info("Model refresh skipped: overrideBaseUrl is empty.");
    return { models: modelsCache.models, didUpdate: false };
  }
  if (typeof fetch !== "function") {
    console.warn("Fetch is not available; skipping model list refresh.");
    return { models: modelsCache.models, didUpdate: false };
  }

  const modelsUrl = buildModelsUrl(overrideBaseUrl);
  console.info(`Fetching model list from ${modelsUrl}`);
  const now = Date.now();
  if (modelsCache.inFlight) {
    await modelsCache.inFlight;
    return { models: modelsCache.models, didUpdate: false };
  }
  if (
    !forceRefresh &&
    modelsCache.baseUrl === modelsUrl &&
    now - modelsCache.lastFetchedAt < modelsRefreshIntervalMs
  ) {
    return { models: modelsCache.models, didUpdate: false };
  }

  let didUpdate = false;
  modelsCache.inFlight = (async () => {
    try {
      const response = await fetch(modelsUrl, { method: "GET" });
      if (!response.ok) {
        console.info(`Model refresh skipped (HTTP ${response.status}) from ${modelsUrl}`);
        return;
      }
      const payload = (await response.json()) as ModelsResponse;
      const rawModels = Array.isArray(payload?.data) ? payload.data : [];
      const modelIds = rawModels
        .map(model => model?.id)
        .filter((id): id is string => typeof id === "string" && id.length > 0);

      if (
        currentModel &&
        currentModel !== placeholderModelValue &&
        !modelIds.includes(currentModel)
      ) {
        modelIds.unshift(currentModel);
      }

      if (!modelIds.length) {
        return;
      }

      if (
        !forceRefresh &&
        modelsCache.baseUrl === modelsUrl &&
        sameModelList(modelsCache.models, modelIds)
      ) {
        modelsCache.lastFetchedAt = now;
        return;
      }

      modelsCache.baseUrl = modelsUrl;
      modelsCache.models = modelIds;
      modelsCache.lastFetchedAt = now;

      const lastSelected =
        currentModel && currentModel !== placeholderModelValue ? currentModel : undefined;
      await writeCachedState(modelIds, lastSelected);
      didUpdate = true;
      console.info(`Model options refreshed from ${modelsUrl}`);
    } catch (error) {
      console.warn("Failed to refresh model list:", error);
    } finally {
      modelsCache.inFlight = null;
    }
  })();

  await modelsCache.inFlight;
  return { models: modelsCache.models, didUpdate };
}

/* -------------------------------------------------------------------------- */
/*                               Build helpers                                */
/* -------------------------------------------------------------------------- */

/** Build a pre-configured OpenAI client. */
function createOpenAI(globalConfig: InferParsedConfig<typeof globalConfigSchematics>) {
  const overrideBaseUrl = globalConfig.get("overrideBaseUrl");

  // Use override URL if provided, otherwise auto-detect based on model type
  const baseURL = overrideBaseUrl || "https://api.openai.com/v1";

  const apiKey = globalConfig.get("openaiApiKey") || "local";

  return new OpenAI({
    apiKey,
    baseURL,
  });
}

/** Convert internal chat history to the format expected by OpenAI. */
function toOpenAIMessages(history: Chat): ChatCompletionMessageParam[] {
  const messages: ChatCompletionMessageParam[] = [];

  for (const message of history) {
    switch (message.getRole()) {
      case "system":
        messages.push({ role: "system", content: message.getText() });
        break;

      case "user":
        messages.push({ role: "user", content: message.getText() });
        break;

      case "assistant": {
        const toolCalls: ChatCompletionMessageToolCall[] = message
          .getToolCallRequests()
          .map(toolCall => ({
            id: toolCall.id ?? "",
            type: "function",
            function: {
              name: toolCall.name,
              arguments: JSON.stringify(toolCall.arguments ?? {}),
            },
          }));

        messages.push({
          role: "assistant",
          content: message.getText(),
          ...(toolCalls.length ? { tool_calls: toolCalls } : {}),
        });
        break;
      }

      case "tool": {
        message.getToolCallResults().forEach(toolCallResult => {
          messages.push({
            role: "tool",
            tool_call_id: toolCallResult.toolCallId ?? "",
            content: toolCallResult.content,
          } as ChatCompletionToolMessageParam);
        });
        break;
      }
    }
  }

  return messages;
}

/** Convert LM Studio tool definitions to OpenAI function-tool descriptors. */
function toOpenAITools(ctl: GeneratorController): ChatCompletionTool[] | undefined {
  const tools = ctl.getToolDefinitions().map<ChatCompletionTool>(t => ({
    type: "function",
    function: {
      name: t.function.name,
      description: t.function.description,
      parameters: t.function.parameters ?? {},
    },
  }));
  return tools.length ? tools : undefined;
}

/* -------------------------------------------------------------------------- */
/*                            Stream-handling utils                           */
/* -------------------------------------------------------------------------- */

function wireAbort(ctl: GeneratorController, stream: { controller: AbortController }) {
  ctl.onAborted(() => {
    console.info("Generation aborted by user.");
    stream.controller.abort();
  });
}

async function consumeStream(stream: AsyncIterable<any>, ctl: GeneratorController) {
  let current: ToolCallState | null = null;

  function maybeFlushCurrentToolCall() {
    if (current === null || current.name === null) {
      return;
    }
    ctl.toolCallGenerationEnded({
      type: "function",
      name: current.name,
      arguments: JSON.parse(current.arguments),
      id: current.id,
    });
    current = null;
  }

  for await (const chunk of stream) {
    console.info("Received chunk:", JSON.stringify(chunk));
    const delta = chunk.choices?.[0]?.delta as
      | {
          content?: string;
          tool_calls?: Array<{
            index: number;
            id?: string;
            function?: { name?: string; arguments?: string };
          }>;
        }
      | undefined;

    if (!delta) continue;

    /* Text streaming */
    if (delta.content) {
      ctl.fragmentGenerated(delta.content);
    }

    /* Tool-call streaming */
    for (const toolCall of delta.tool_calls ?? []) {
      if (toolCall.id !== undefined) {
        maybeFlushCurrentToolCall();
        current = { id: toolCall.id, name: null, index: toolCall.index, arguments: "" };
        ctl.toolCallGenerationStarted();
      }

      if (toolCall.function?.name && current) {
        current.name = toolCall.function.name;
        ctl.toolCallGenerationNameReceived(toolCall.function.name);
      }

      if (toolCall.function?.arguments && current) {
        current.arguments += toolCall.function.arguments;
        ctl.toolCallGenerationArgumentFragmentGenerated(toolCall.function.arguments);
      }
    }

    /* Finalize tool call */
    if (chunk.choices?.[0]?.finish_reason === "tool_calls" && current?.name) {
      maybeFlushCurrentToolCall();
    }
  }

  console.info("Generation completed.");
}

/* -------------------------------------------------------------------------- */
/*                                     API                                    */
/* -------------------------------------------------------------------------- */

export async function generate(ctl: GeneratorController, history: Chat) {
  const config = ctl.getPluginConfig(configSchematics);
  const globalConfig = ctl.getGlobalPluginConfig(globalConfigSchematics);
  const overrideBaseUrl = globalConfig.get("overrideBaseUrl");

  const selectedModel = config.get("model");
  console.info(
    `Generate: model=${selectedModel ?? ""} baseUrl=${overrideBaseUrl ?? ""} cached=${modelsCache.models.length}`,
  );
  if (!overrideBaseUrl && (!selectedModel || selectedModel === placeholderModelValue)) {
    ctl.fragmentGenerated("Set Override Base URL, then run once to load models.");
    return;
  }

  const shouldRefreshModels =
    !selectedModel || selectedModel === placeholderModelValue || modelsCache.models.length === 0;

  if (overrideBaseUrl && shouldRefreshModels) {
    const isPlaceholder = !selectedModel || selectedModel === placeholderModelValue;
    if (modelsCache.models.length > 0 && isPlaceholder) {
      if (!refreshNotified) {
        ctl.fragmentGenerated("Models cached. Restart LM Studio to update the dropdown.");
        refreshNotified = true;
      }
      return;
    }

    const refreshResult = await refreshModelOptionsIfNeeded(
      overrideBaseUrl,
      selectedModel,
      true,
    );

    if (isPlaceholder) {
      if (refreshResult.didUpdate) {
        if (!refreshNotified) {
          ctl.fragmentGenerated("Model list refreshed. Restart LM Studio to update the dropdown.");
          refreshNotified = true;
        }
      } else {
        ctl.fragmentGenerated("Model list not updated. Verify Base URL and try again.");
      }
      return;
    }
    // If a model is already selected, refresh silently.
  }
  const resolvedModel = selectedModel;

  if (!resolvedModel || resolvedModel === placeholderModelValue) {
    throw new Error("Model list not loaded yet. Set base URL, then retry to refresh models.");
  }

  /* 1. Setup client & payload */
  const openai = createOpenAI(globalConfig);
  const messages = toOpenAIMessages(history);
  const tools = toOpenAITools(ctl);

  /* 2. Kick off streaming completion */
  const stream = await openai.chat.completions.create({
    model: resolvedModel,
    messages,
    tools,
    stream: true,
  });

  /* 3. Abort wiring & stream processing */
  wireAbort(ctl, stream as any);
  await consumeStream(stream as any, ctl);
}