feat: support mlx plugin

# Conflicts:
#	Makefile
#	web-app/src/routes/settings/providers/$providerName.tsx
This commit is contained in:
Louis
2026-01-29 15:09:54 +07:00
parent 0055fbc442
commit b16b519f4e
58 changed files with 4527 additions and 89 deletions

2
.gitignore vendored
View File

@@ -63,3 +63,5 @@ src-tauri/resources/
test-data
llm-docs
.claude/agents
mlx-server/.build
mlx-server/.swiftpm

View File

@@ -118,6 +118,18 @@ endif
cargo test --manifest-path src-tauri/plugins/tauri-plugin-llamacpp/Cargo.toml
cargo test --manifest-path src-tauri/utils/Cargo.toml
# Build MLX server (macOS Apple Silicon only)
build-mlx-server:
ifeq ($(shell uname -s),Darwin)
@echo "Building MLX server for Apple Silicon..."
# cd mlx-server && swift build -c release
cd mlx-server && xcodebuild build -scheme mlx-server -destination 'platform=OS X'
# -configuration Release
@echo "MLX server built successfully"
else
@echo "Skipping MLX server build (macOS only)"
endif
# Build
build: install-and-build install-rust-targets
yarn build

View File

@@ -0,0 +1,40 @@
{
"name": "@janhq/mlx-extension",
"productName": "MLX Inference Engine",
"version": "1.0.0",
"description": "This extension enables MLX-Swift inference on Apple Silicon Macs",
"main": "dist/index.js",
"module": "dist/module.js",
"engine": "mlx",
"author": "Jan <service@jan.ai>",
"license": "AGPL-3.0",
"scripts": {
"build": "rolldown -c rolldown.config.mjs",
"build:publish": "rimraf *.tgz --glob || true && yarn build && npm pack && cpx *.tgz ../../pre-install"
},
"devDependencies": {
"cpx": "1.5.0",
"rimraf": "3.0.2",
"rolldown": "1.0.0-beta.1",
"typescript": "5.9.2"
},
"dependencies": {
"@janhq/core": "../../core/package.tgz",
"@janhq/tauri-plugin-llamacpp-api": "link:../../src-tauri/plugins/tauri-plugin-llamacpp",
"@janhq/tauri-plugin-mlx-api": "link:../../src-tauri/plugins/tauri-plugin-mlx",
"@tauri-apps/api": "2.8.0",
"@tauri-apps/plugin-http": "2.5.0",
"@tauri-apps/plugin-log": "^2.6.0"
},
"engines": {
"node": ">=18.0.0"
},
"files": [
"dist/*",
"package.json"
],
"installConfig": {
"hoistingLimits": "workspaces"
},
"packageManager": "yarn@4.5.3"
}

View File

@@ -0,0 +1,21 @@
import { defineConfig } from 'rolldown'
import pkgJson from './package.json' with { type: 'json' }
import settingJson from './settings.json' with { type: 'json' }
export default defineConfig({
input: 'src/index.ts',
output: {
format: 'esm',
file: 'dist/index.js',
},
platform: 'browser',
define: {
SETTINGS: JSON.stringify(settingJson),
ENGINE: JSON.stringify(pkgJson.engine),
IS_MAC: JSON.stringify(process.platform === 'darwin'),
},
inject: process.env.IS_DEV ? {} : {
fetch: ['@tauri-apps/plugin-http', 'fetch'],
},
})

View File

@@ -0,0 +1,33 @@
[
{
"key": "ctx_size",
"title": "Context Size",
"description": "Context window size for MLX inference",
"controllerType": "input",
"controllerProps": {
"value": 4096,
"placeholder": "4096",
"type": "number",
"textAlign": "right"
}
},
{
"key": "auto_unload",
"title": "Auto unload model",
"description": "Automatically unload other models when loading a new one",
"controllerType": "checkbox",
"controllerProps": { "value": true }
},
{
"key": "timeout",
"title": "Timeout (seconds)",
"description": "Maximum time to wait for model to load",
"controllerType": "input",
"controllerProps": {
"value": 600,
"placeholder": "600",
"type": "number",
"textAlign": "right"
}
}
]

5
extensions/mlx-extension/src/env.d.ts vendored Normal file
View File

@@ -0,0 +1,5 @@
declare const SETTINGS: SettingComponentProps[]
declare const ENGINE: string
declare const IS_WINDOWS: boolean
declare const IS_MAC: boolean
declare const IS_LINUX: boolean

View File

@@ -0,0 +1,825 @@
/**
* MLX Extension - Inference engine for Apple Silicon Macs using MLX-Swift
*
* This extension provides an alternative to llama.cpp for running GGUF models
* locally on Apple Silicon using the MLX framework with Metal GPU acceleration.
*
* It shares the same model directory as llamacpp-extension so users can
* switch between engines without re-downloading models.
*/
import {
AIEngine,
getJanDataFolderPath,
fs,
joinPath,
modelInfo,
SessionInfo,
UnloadResult,
chatCompletion,
chatCompletionChunk,
ImportOptions,
chatCompletionRequest,
events,
AppEvent,
DownloadEvent,
} from '@janhq/core'
import { info, warn, error as logError } from '@tauri-apps/plugin-log'
import { invoke } from '@tauri-apps/api/core'
import {
loadMlxModel,
unloadMlxModel,
MlxConfig,
} from '@janhq/tauri-plugin-mlx-api'
import { readGgufMetadata, ModelConfig } from '@janhq/tauri-plugin-llamacpp-api'
// Error message constant
const OUT_OF_CONTEXT_SIZE = 'the request exceeds the available context size.'
const logger = {
info: function (...args: any[]) {
console.log(...args)
info(args.map((arg) => ` ${arg}`).join(` `))
},
warn: function (...args: any[]) {
console.warn(...args)
warn(args.map((arg) => ` ${arg}`).join(` `))
},
error: function (...args: any[]) {
console.error(...args)
logError(args.map((arg) => ` ${arg}`).join(` `))
},
}
export default class mlx_extension extends AIEngine {
provider: string = 'mlx'
autoUnload: boolean = true
timeout: number = 600
readonly providerId: string = 'mlx'
private config: any = {}
private providerPath!: string
private apiSecret: string = 'JanMLX'
private loadingModels = new Map<string, Promise<SessionInfo>>()
override async onLoad(): Promise<void> {
super.onLoad()
let settings = structuredClone(SETTINGS)
this.registerSettings(settings)
let loadedConfig: any = {}
for (const item of settings) {
const defaultValue = item.controllerProps.value
loadedConfig[item.key] = await this.getSetting<typeof defaultValue>(
item.key,
defaultValue
)
}
this.config = loadedConfig
this.autoUnload = this.config.auto_unload ?? true
this.timeout = this.config.timeout ?? 600
this.getProviderPath()
}
async getProviderPath(): Promise<string> {
if (!this.providerPath) {
// Use mlx folder for models
this.providerPath = await joinPath([
await getJanDataFolderPath(),
'mlx',
])
}
return this.providerPath
}
override async onUnload(): Promise<void> {
// Cleanup handled by Tauri plugin on app exit
}
onSettingUpdate<T>(key: string, value: T): void {
this.config[key] = value
if (key === 'auto_unload') {
this.autoUnload = value as boolean
} else if (key === 'timeout') {
this.timeout = value as number
}
}
private async generateApiKey(
modelId: string,
port: string
): Promise<string> {
// Reuse the llamacpp plugin's API key generation
const hash = await invoke<string>('plugin:llamacpp|generate_api_key', {
modelId: modelId + port,
apiSecret: this.apiSecret,
})
return hash
}
override async get(modelId: string): Promise<modelInfo | undefined> {
const modelPath = await joinPath([
await this.getProviderPath(),
'models',
modelId,
])
const path = await joinPath([modelPath, 'model.yml'])
if (!(await fs.existsSync(path))) return undefined
const modelConfig = await invoke<ModelConfig>('read_yaml', { path })
return {
id: modelId,
name: modelConfig.name ?? modelId,
providerId: this.provider,
port: 0,
sizeBytes: modelConfig.size_bytes ?? 0,
embedding: modelConfig.embedding ?? false,
} as modelInfo
}
override async list(): Promise<modelInfo[]> {
const modelsDir = await joinPath([await this.getProviderPath(), 'models'])
if (!(await fs.existsSync(modelsDir))) {
await fs.mkdir(modelsDir)
}
let modelIds: string[] = []
// DFS to find all model.yml files
let stack = [modelsDir]
while (stack.length > 0) {
const currentDir = stack.pop()
const modelConfigPath = await joinPath([currentDir, 'model.yml'])
if (await fs.existsSync(modelConfigPath)) {
modelIds.push(currentDir.slice(modelsDir.length + 1))
continue
}
const children = await fs.readdirSync(currentDir)
for (const child of children) {
const dirInfo = await fs.fileStat(child)
if (!dirInfo.isDirectory) continue
stack.push(child)
}
}
let modelInfos: modelInfo[] = []
for (const modelId of modelIds) {
const path = await joinPath([modelsDir, modelId, 'model.yml'])
const modelConfig = await invoke<ModelConfig>('read_yaml', { path })
const capabilities: string[] = []
if (modelConfig.mmproj_path) {
capabilities.push('vision')
}
// Check for tool support
try {
if (await this.isToolSupported(modelId)) {
capabilities.push('tools')
}
} catch (e) {
logger.warn(`Failed to check tool support for ${modelId}: ${e}`)
}
modelInfos.push({
id: modelId,
name: modelConfig.name ?? modelId,
providerId: this.provider,
port: 0,
sizeBytes: modelConfig.size_bytes ?? 0,
embedding: modelConfig.embedding ?? false,
capabilities: capabilities.length > 0 ? capabilities : undefined,
} as modelInfo)
}
return modelInfos
}
private async getRandomPort(): Promise<number> {
try {
return await invoke<number>('plugin:mlx|get_mlx_random_port')
} catch {
logger.error('Unable to find a suitable port for MLX server')
throw new Error('Unable to find a suitable port for MLX model')
}
}
override async load(
modelId: string,
overrideSettings?: any,
isEmbedding: boolean = false
): Promise<SessionInfo> {
const sInfo = await this.findSessionByModel(modelId)
if (sInfo) {
throw new Error('Model already loaded!')
}
if (this.loadingModels.has(modelId)) {
return this.loadingModels.get(modelId)!
}
const loadingPromise = this.performLoad(
modelId,
overrideSettings,
isEmbedding
)
this.loadingModels.set(modelId, loadingPromise)
try {
return await loadingPromise
} finally {
this.loadingModels.delete(modelId)
}
}
private async performLoad(
modelId: string,
overrideSettings?: any,
isEmbedding: boolean = false
): Promise<SessionInfo> {
const loadedModels = await this.getLoadedModels()
// Auto-unload other models if needed
const otherLoadingPromises = Array.from(this.loadingModels.entries())
.filter(([id, _]) => id !== modelId)
.map(([_, promise]) => promise)
if (
this.autoUnload &&
!isEmbedding &&
(loadedModels.length > 0 || otherLoadingPromises.length > 0)
) {
if (otherLoadingPromises.length > 0) {
await Promise.all(otherLoadingPromises)
}
const allLoadedModels = await this.getLoadedModels()
if (allLoadedModels.length > 0) {
await Promise.all(allLoadedModels.map((id) => this.unload(id)))
}
}
const cfg = { ...this.config, ...(overrideSettings ?? {}) }
const janDataFolderPath = await getJanDataFolderPath()
const modelConfigPath = await joinPath([
this.providerPath,
'models',
modelId,
'model.yml',
])
const modelConfig = await invoke<ModelConfig>('read_yaml', {
path: modelConfigPath,
})
const port = await this.getRandomPort()
const api_key = await this.generateApiKey(modelId, String(port))
const envs: Record<string, string> = {
MLX_API_KEY: api_key,
}
// Resolve model path - could be absolute or relative
let modelPath: string
if (modelConfig.model_path.startsWith('/') || modelConfig.model_path.includes(':')) {
// Absolute path
modelPath = modelConfig.model_path
} else {
// Relative path - resolve from Jan data folder
modelPath = await joinPath([
janDataFolderPath,
modelConfig.model_path,
])
}
// Resolve the MLX server binary path
const mlxServerPath = await this.getMlxServerBinaryPath()
const mlxConfig: MlxConfig = {
ctx_size: cfg.ctx_size ?? 4096,
n_predict: cfg.n_predict ?? 0,
threads: cfg.threads ?? 0,
chat_template: cfg.chat_template ?? '',
}
logger.info(
'Loading MLX model:',
modelId,
'with config:',
JSON.stringify(mlxConfig)
)
try {
const sInfo = await loadMlxModel(
mlxServerPath,
modelId,
modelPath,
port,
mlxConfig,
envs,
isEmbedding,
Number(this.timeout)
)
return sInfo
} catch (error) {
logger.error('Error loading MLX model:', error)
throw error
}
}
private async getMlxServerBinaryPath(): Promise<string> {
const janDataFolderPath = await getJanDataFolderPath()
// Look for the MLX server binary in the Jan data folder
const binaryPath = await joinPath([
janDataFolderPath,
'mlx',
'mlx-server',
])
if (await fs.existsSync(binaryPath)) {
return binaryPath
}
// Fallback: check in the app resources
throw new Error(
'MLX server binary not found. Please ensure mlx-server is installed at ' +
binaryPath
)
}
override async unload(modelId: string): Promise<UnloadResult> {
const sInfo = await this.findSessionByModel(modelId)
if (!sInfo) {
throw new Error(`No active MLX session found for model: ${modelId}`)
}
try {
const result = await unloadMlxModel(sInfo.pid)
if (result.success) {
logger.info(`Successfully unloaded MLX model with PID ${sInfo.pid}`)
} else {
logger.warn(`Failed to unload MLX model: ${result.error}`)
}
return result
} catch (error) {
logger.error('Error unloading MLX model:', error)
return {
success: false,
error: `Failed to unload model: ${error}`,
}
}
}
private async findSessionByModel(modelId: string): Promise<SessionInfo> {
try {
return await invoke<SessionInfo>(
'plugin:mlx|find_mlx_session_by_model',
{ modelId }
)
} catch (e) {
logger.error(e)
throw new Error(String(e))
}
}
override async chat(
opts: chatCompletionRequest,
abortController?: AbortController
): Promise<chatCompletion | AsyncIterable<chatCompletionChunk>> {
const sessionInfo = await this.findSessionByModel(opts.model)
if (!sessionInfo) {
throw new Error(`No active MLX session found for model: ${opts.model}`)
}
// Check if the process is alive
const isAlive = await invoke<boolean>('plugin:mlx|is_mlx_process_running', {
pid: sessionInfo.pid,
})
if (isAlive) {
try {
await fetch(`http://localhost:${sessionInfo.port}/health`)
} catch (e) {
this.unload(sessionInfo.model_id)
throw new Error('MLX model appears to have crashed! Please reload!')
}
} else {
throw new Error('MLX model has crashed! Please reload!')
}
const baseUrl = `http://localhost:${sessionInfo.port}/v1`
const url = `${baseUrl}/chat/completions`
const headers = {
'Content-Type': 'application/json',
Authorization: `Bearer ${sessionInfo.api_key}`,
}
const body = JSON.stringify(opts)
if (opts.stream) {
return this.handleStreamingResponse(url, headers, body, abortController)
}
const response = await fetch(url, {
method: 'POST',
headers,
body,
signal: abortController?.signal,
})
if (!response.ok) {
const errorData = await response.json().catch(() => null)
throw new Error(
`MLX API request failed with status ${response.status}: ${JSON.stringify(errorData)}`
)
}
const completionResponse = (await response.json()) as chatCompletion
if (completionResponse.choices?.[0]?.finish_reason === 'length') {
throw new Error(OUT_OF_CONTEXT_SIZE)
}
return completionResponse
}
private async *handleStreamingResponse(
url: string,
headers: HeadersInit,
body: string,
abortController?: AbortController
): AsyncIterable<chatCompletionChunk> {
const response = await fetch(url, {
method: 'POST',
headers,
body,
signal: AbortSignal.any([
AbortSignal.timeout(this.timeout * 1000),
abortController?.signal,
]),
})
if (!response.ok) {
const errorData = await response.json().catch(() => null)
throw new Error(
`MLX API request failed with status ${response.status}: ${JSON.stringify(errorData)}`
)
}
if (!response.body) {
throw new Error('Response body is null')
}
const reader = response.body.getReader()
const decoder = new TextDecoder('utf-8')
let buffer = ''
try {
while (true) {
const { done, value } = await reader.read()
if (done) break
buffer += decoder.decode(value, { stream: true })
const lines = buffer.split('\n')
buffer = lines.pop() || ''
for (const line of lines) {
const trimmedLine = line.trim()
if (!trimmedLine || trimmedLine === 'data: [DONE]') continue
if (trimmedLine.startsWith('data: ')) {
const jsonStr = trimmedLine.slice(6)
try {
const data = JSON.parse(jsonStr) as chatCompletionChunk
if (data.choices?.[0]?.finish_reason === 'length') {
throw new Error(OUT_OF_CONTEXT_SIZE)
}
yield data
} catch (e) {
logger.error('Error parsing MLX stream JSON:', e)
throw e
}
} else if (trimmedLine.startsWith('error: ')) {
const jsonStr = trimmedLine.slice(7)
const error = JSON.parse(jsonStr)
throw new Error(error.message)
}
}
}
} finally {
reader.releaseLock()
}
}
override async delete(modelId: string): Promise<void> {
const modelDir = await joinPath([
await this.getProviderPath(),
'models',
modelId,
])
const modelConfigPath = await joinPath([modelDir, 'model.yml'])
if (!(await fs.existsSync(modelConfigPath))) {
throw new Error(`Model ${modelId} does not exist`)
}
const modelConfig = await invoke<ModelConfig>('read_yaml', {
path: modelConfigPath,
})
// Check if model_path is a relative path within mlx folder
if (!modelConfig.model_path.startsWith('/') && !modelConfig.model_path.includes(':')) {
// Model file is at {janDataFolder}/{model_path}
// Delete the parent folder containing the actual model file
const janDataFolderPath = await getJanDataFolderPath()
const modelPath = await joinPath([janDataFolderPath, modelConfig.model_path])
const parentDir = modelPath.substring(0, modelPath.lastIndexOf('/'))
// Only delete if it's different from modelDir (i.e., not the same folder)
if (parentDir !== modelDir) {
await fs.rm(parentDir)
}
}
// Always delete the model.yml folder
await fs.rm(modelDir)
}
override async update(
modelId: string,
model: Partial<modelInfo>
): Promise<void> {
// Delegate to the same logic as llamacpp since they share the model dir
const modelFolderPath = await joinPath([
await this.getProviderPath(),
'models',
modelId,
])
const modelConfig = await invoke<ModelConfig>('read_yaml', {
path: await joinPath([modelFolderPath, 'model.yml']),
})
const newFolderPath = await joinPath([
await this.getProviderPath(),
'models',
model.id,
])
if (await fs.existsSync(newFolderPath)) {
throw new Error(`Model with ID ${model.id} already exists`)
}
const newModelConfigPath = await joinPath([newFolderPath, 'model.yml'])
await fs.mv(modelFolderPath, newFolderPath).then(() =>
invoke('write_yaml', {
data: {
...modelConfig,
model_path: modelConfig?.model_path?.replace(
`mlx/models/${modelId}`,
`mlx/models/${model.id}`
),
},
savePath: newModelConfigPath,
})
)
}
override async import(modelId: string, opts: ImportOptions): Promise<void> {
const isValidModelId = (id: string) => {
// only allow alphanumeric, underscore, hyphen, and dot characters in modelId
if (!/^[a-zA-Z0-9/_\-\.]+$/.test(id)) return false
// check for empty parts or path traversal
const parts = id.split('/')
return parts.every((s) => s !== '' && s !== '.' && s !== '..')
}
if (!isValidModelId(modelId))
throw new Error(
`Invalid modelId: ${modelId}. Only alphanumeric and / _ - . characters are allowed.`
)
const configPath = await joinPath([
await this.getProviderPath(),
'models',
modelId,
'model.yml',
])
if (await fs.existsSync(configPath))
throw new Error(`Model ${modelId} already exists`)
const sourcePath = opts.modelPath
if (sourcePath.startsWith('https://')) {
// Download from URL to mlx models folder
const janDataFolderPath = await getJanDataFolderPath()
const modelDir = await joinPath([janDataFolderPath, 'mlx', 'models', modelId])
const localPath = await joinPath([modelDir, 'model.safetensors'])
const downloadManager = window.core.extensionManager.getByName(
'@janhq/download-extension'
)
await downloadManager.downloadFiles(
[
{
url: sourcePath,
save_path: localPath,
sha256: opts.modelSha256,
size: opts.modelSize,
model_id: modelId,
},
],
`mlx/${modelId}`,
(transferred: number, total: number) => {
events.emit(DownloadEvent.onFileDownloadUpdate, {
modelId,
percent: transferred / total,
size: { transferred, total },
downloadType: 'Model',
})
}
)
// Create model.yml with relative path
const modelConfig = {
model_path: `mlx/models/${modelId}/model.safetensors`,
name: modelId,
size_bytes: opts.modelSize ?? 0,
}
await fs.mkdir(modelDir)
await invoke<void>('write_yaml', {
data: modelConfig,
savePath: configPath,
})
events.emit(AppEvent.onModelImported, {
modelId,
modelPath: modelConfig.model_path,
size_bytes: modelConfig.size_bytes,
})
} else {
// Local file - use absolute path directly
if (!(await fs.existsSync(sourcePath))) {
throw new Error(`File not found: ${sourcePath}`)
}
// Get file size
const stat = await fs.fileStat(sourcePath)
const size_bytes = stat.size
// Create model.yml with absolute path
const modelConfig = {
model_path: sourcePath,
name: modelId,
size_bytes,
}
// Create model folder for model.yml only (no copying of safetensors)
const modelDir = await joinPath([
await this.getProviderPath(),
'models',
modelId,
])
await fs.mkdir(modelDir)
await invoke<void>('write_yaml', {
data: modelConfig,
savePath: configPath,
})
events.emit(AppEvent.onModelImported, {
modelId,
modelPath: sourcePath,
size_bytes,
})
}
}
override async abortImport(modelId: string): Promise<void> {
// Not applicable for MLX - imports go through llamacpp extension
}
override async getLoadedModels(): Promise<string[]> {
try {
return await invoke<string[]>('plugin:mlx|get_mlx_loaded_models')
} catch (e) {
logger.error(e)
throw new Error(e)
}
}
async isToolSupported(modelId: string): Promise<boolean> {
// Check GGUF/safetensors metadata for tool support
const modelConfigPath = await joinPath([
this.providerPath,
'models',
modelId,
'model.yml',
])
const modelConfig = await invoke<ModelConfig>('read_yaml', {
path: modelConfigPath,
})
// model_path could be absolute or relative
let modelPath: string
if (modelConfig.model_path.startsWith('/') || modelConfig.model_path.includes(':')) {
// Absolute path
modelPath = modelConfig.model_path
} else {
// Relative path - resolve from Jan data folder
const janDataFolderPath = await getJanDataFolderPath()
modelPath = await joinPath([janDataFolderPath, modelConfig.model_path])
}
// Check if model is safetensors or GGUF
const isSafetensors = modelPath.endsWith('.safetensors')
const modelDir = modelPath.substring(0, modelPath.lastIndexOf('/'))
// For safetensors models, check multiple sources for tool support
if (isSafetensors) {
// Check 1: tokenizer_config.json (common for tool-capable models)
const tokenizerConfigPath = await joinPath([modelDir, 'tokenizer_config.json'])
if (await fs.existsSync(tokenizerConfigPath)) {
try {
const tokenizerConfigContent = await invoke<string>('read_file_sync', {
args: [tokenizerConfigPath],
})
// Check for tool/function calling indicators
const tcLower = tokenizerConfigContent.toLowerCase()
if (tcLower.includes('function_call') ||
tcLower.includes('tool_use') ||
tcLower.includes('tools') ||
tcLower.includes('assistant')) {
logger.info(`Tool support detected from tokenizer_config.json for ${modelId}`)
return true
}
} catch (e) {
logger.warn(`Failed to read tokenizer_config.json: ${e}`)
}
}
// Check 2: chat_template.jinja for tool patterns
const chatTemplatePath = await joinPath([modelDir, 'chat_template.jinja'])
if (await fs.existsSync(chatTemplatePath)) {
try {
const chatTemplateContent = await invoke<string>('read_file_sync', {
args: [chatTemplatePath],
})
// Common tool/function calling template patterns
const ctLower = chatTemplateContent.toLowerCase()
const toolPatterns = [
/\{\%.*tool.*\%\}/, // {% tool ... %}
/\{\%.*function.*\%\}/, // {% function ... %}
/\{\%.*tool_call/,
/\{\%.*tools\./,
/\{[-]?#.*tool/,
/\{[-]?%.*tool/,
/"tool_calls"/, // "tool_calls" JSON key
/'tool_calls'/, // 'tool_calls' JSON key
/function_call/,
/tool_use/,
]
for (const pattern of toolPatterns) {
if (pattern.test(chatTemplateContent)) {
logger.info(`Tool support detected from chat_template.jinja for ${modelId}`)
return true
}
}
} catch (e) {
logger.warn(`Failed to read chat_template.jinja: ${e}`)
}
}
// Check 3: Look for tool-related files
const toolFiles = ['tools.jinja', 'tool_use.jinja', 'function_calling.jinja']
for (const toolFile of toolFiles) {
const toolPath = await joinPath([modelDir, toolFile])
if (await fs.existsSync(toolPath)) {
logger.info(`Tool support detected from ${toolFile} for ${modelId}`)
return true
}
}
logger.info(`No tool support detected for safetensors model ${modelId}`)
return false
} else {
// For GGUF models, check metadata
try {
const metadata = await readGgufMetadata(modelPath)
const chatTemplate = metadata.metadata?.['tokenizer.chat_template']
return chatTemplate?.includes('tools') ?? false
} catch (e) {
logger.warn(`Failed to read GGUF metadata: ${e}`)
return false
}
}
}
}

View File

@@ -0,0 +1,15 @@
{
"compilerOptions": {
"target": "es2018",
"module": "ES6",
"moduleResolution": "node",
"outDir": "./dist",
"esModuleInterop": true,
"forceConsistentCasingInFileNames": true,
"strict": false,
"skipLibCheck": true,
"rootDir": "./src"
},
"include": ["./src"],
"exclude": ["**/*.test.ts"]
}

266
mlx-server/Package.resolved Normal file
View File

@@ -0,0 +1,266 @@
{
"pins" : [
{
"identity" : "async-http-client",
"kind" : "remoteSourceControl",
"location" : "https://github.com/swift-server/async-http-client.git",
"state" : {
"revision" : "4b99975677236d13f0754339864e5360142ff5a1",
"version" : "1.30.3"
}
},
{
"identity" : "hummingbird",
"kind" : "remoteSourceControl",
"location" : "https://github.com/hummingbird-project/hummingbird",
"state" : {
"revision" : "daf66bfd4b46c1f3f080a1f3438d8fbecee7ace5",
"version" : "2.19.0"
}
},
{
"identity" : "mlx-swift",
"kind" : "remoteSourceControl",
"location" : "https://github.com/ml-explore/mlx-swift",
"state" : {
"revision" : "4dccaeda1d83cf8697f235d2786c2d72ad4bb925",
"version" : "0.30.3"
}
},
{
"identity" : "mlx-swift-lm",
"kind" : "remoteSourceControl",
"location" : "https://github.com/ml-explore/mlx-swift-lm",
"state" : {
"branch" : "main",
"revision" : "2c700546340c37f275d23302163701b77c4dcbd9"
}
},
{
"identity" : "swift-algorithms",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-algorithms.git",
"state" : {
"revision" : "87e50f483c54e6efd60e885f7f5aa946cee68023",
"version" : "1.2.1"
}
},
{
"identity" : "swift-argument-parser",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-argument-parser",
"state" : {
"revision" : "c5d11a805e765f52ba34ec7284bd4fcd6ba68615",
"version" : "1.7.0"
}
},
{
"identity" : "swift-asn1",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-asn1.git",
"state" : {
"revision" : "810496cf121e525d660cd0ea89a758740476b85f",
"version" : "1.5.1"
}
},
{
"identity" : "swift-async-algorithms",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-async-algorithms.git",
"state" : {
"revision" : "6c050d5ef8e1aa6342528460db614e9770d7f804",
"version" : "1.1.1"
}
},
{
"identity" : "swift-atomics",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-atomics.git",
"state" : {
"revision" : "b601256eab081c0f92f059e12818ac1d4f178ff7",
"version" : "1.3.0"
}
},
{
"identity" : "swift-certificates",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-certificates.git",
"state" : {
"revision" : "7d5f6124c91a2d06fb63a811695a3400d15a100e",
"version" : "1.17.1"
}
},
{
"identity" : "swift-collections",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-collections.git",
"state" : {
"revision" : "7b847a3b7008b2dc2f47ca3110d8c782fb2e5c7e",
"version" : "1.3.0"
}
},
{
"identity" : "swift-configuration",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-configuration.git",
"state" : {
"revision" : "6ffef195ed4ba98ee98029970c94db7edc60d4c6",
"version" : "1.0.1"
}
},
{
"identity" : "swift-crypto",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-crypto.git",
"state" : {
"revision" : "6f70fa9eab24c1fd982af18c281c4525d05e3095",
"version" : "4.2.0"
}
},
{
"identity" : "swift-distributed-tracing",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-distributed-tracing.git",
"state" : {
"revision" : "baa932c1336f7894145cbaafcd34ce2dd0b77c97",
"version" : "1.3.1"
}
},
{
"identity" : "swift-http-structured-headers",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-http-structured-headers.git",
"state" : {
"revision" : "76d7627bd88b47bf5a0f8497dd244885960dde0b",
"version" : "1.6.0"
}
},
{
"identity" : "swift-http-types",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-http-types.git",
"state" : {
"revision" : "45eb0224913ea070ec4fba17291b9e7ecf4749ca",
"version" : "1.5.1"
}
},
{
"identity" : "swift-jinja",
"kind" : "remoteSourceControl",
"location" : "https://github.com/huggingface/swift-jinja.git",
"state" : {
"revision" : "d81197f35f41445bc10e94600795e68c6f5e94b0",
"version" : "2.3.1"
}
},
{
"identity" : "swift-log",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-log.git",
"state" : {
"revision" : "2778fd4e5a12a8aaa30a3ee8285f4ce54c5f3181",
"version" : "1.9.1"
}
},
{
"identity" : "swift-metrics",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-metrics.git",
"state" : {
"revision" : "0743a9364382629da3bf5677b46a2c4b1ce5d2a6",
"version" : "2.7.1"
}
},
{
"identity" : "swift-nio",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-nio.git",
"state" : {
"revision" : "5e72fc102906ebe75a3487595a653e6f43725552",
"version" : "2.94.0"
}
},
{
"identity" : "swift-nio-extras",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-nio-extras.git",
"state" : {
"revision" : "3df009d563dc9f21a5c85b33d8c2e34d2e4f8c3b",
"version" : "1.32.1"
}
},
{
"identity" : "swift-nio-http2",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-nio-http2.git",
"state" : {
"revision" : "c2ba4cfbb83f307c66f5a6df6bb43e3c88dfbf80",
"version" : "1.39.0"
}
},
{
"identity" : "swift-nio-ssl",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-nio-ssl.git",
"state" : {
"revision" : "173cc69a058623525a58ae6710e2f5727c663793",
"version" : "2.36.0"
}
},
{
"identity" : "swift-nio-transport-services",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-nio-transport-services.git",
"state" : {
"revision" : "60c3e187154421171721c1a38e800b390680fb5d",
"version" : "1.26.0"
}
},
{
"identity" : "swift-numerics",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-numerics",
"state" : {
"revision" : "0c0290ff6b24942dadb83a929ffaaa1481df04a2",
"version" : "1.1.1"
}
},
{
"identity" : "swift-service-context",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-service-context.git",
"state" : {
"revision" : "1983448fefc717a2bc2ebde5490fe99873c5b8a6",
"version" : "1.2.1"
}
},
{
"identity" : "swift-service-lifecycle",
"kind" : "remoteSourceControl",
"location" : "https://github.com/swift-server/swift-service-lifecycle.git",
"state" : {
"revision" : "1de37290c0ab3c5a96028e0f02911b672fd42348",
"version" : "2.9.1"
}
},
{
"identity" : "swift-system",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-system",
"state" : {
"revision" : "7c6ad0fc39d0763e0b699210e4124afd5041c5df",
"version" : "1.6.4"
}
},
{
"identity" : "swift-transformers",
"kind" : "remoteSourceControl",
"location" : "https://github.com/huggingface/swift-transformers",
"state" : {
"revision" : "573e5c9036c2f136b3a8a071da8e8907322403d0",
"version" : "1.1.6"
}
}
],
"version" : 2
}

36
mlx-server/Package.swift Normal file
View File

@@ -0,0 +1,36 @@
// swift-tools-version: 5.9
import PackageDescription
let package = Package(
name: "mlx-server",
platforms: [
.macOS(.v14)
],
products: [
.executable(name: "mlx-server", targets: ["MLXServer"])
],
dependencies: [
.package(url: "https://github.com/ml-explore/mlx-swift", from: "0.30.3"),
.package(url: "https://github.com/ml-explore/mlx-swift-lm", branch: "main"),
.package(url: "https://github.com/apple/swift-argument-parser", from: "1.7.0"),
.package(url: "https://github.com/hummingbird-project/hummingbird", from: "2.19.0"),
],
targets: [
.executableTarget(
name: "MLXServer",
dependencies: [
.product(name: "MLX", package: "mlx-swift"),
.product(name: "MLXNN", package: "mlx-swift"),
.product(name: "MLXOptimizers", package: "mlx-swift"),
.product(name: "MLXRandom", package: "mlx-swift"),
.product(name: "MLXLLM", package: "mlx-swift-lm"),
.product(name: "MLXVLM", package: "mlx-swift-lm"),
.product(name: "MLXLMCommon", package: "mlx-swift-lm"),
.product(name: "ArgumentParser", package: "swift-argument-parser"),
.product(name: "Hummingbird", package: "hummingbird"),
],
path: "Sources/MLXServer"
)
]
)

View File

@@ -0,0 +1,67 @@
import ArgumentParser
import Foundation
import Hummingbird
@main
struct MLXServerCommand: AsyncParsableCommand {
static let configuration = CommandConfiguration(
commandName: "mlx-server",
abstract: "MLX-Swift inference server with OpenAI-compatible API"
)
@Option(name: [.long, .short], help: "Path to the GGUF model file")
var model: String
@Option(name: .long, help: "Port to listen on")
var port: Int = 8080
@Option(name: .long, help: "Context window size")
var ctxSize: Int = 4096
@Option(name: .long, help: "API key for authentication (optional)")
var apiKey: String = ""
@Option(name: .long, help: "Chat template to use (optional)")
var chatTemplate: String = ""
@Flag(name: .long, help: "Run in embedding mode")
var embedding: Bool = false
func run() async throws {
// Print startup info
print("[mlx] MLX-Swift Server starting...")
print("[mlx] Model path: \(model)")
print("[mlx] Port: \(port)")
print("[mlx] Context size: \(ctxSize)")
// Extract model ID from path
let modelURL = URL(fileURLWithPath: model)
let modelId = modelURL.deletingPathExtension().lastPathComponent
// Load the model
let modelRunner = ModelRunner()
do {
try await modelRunner.load(modelPath: model, modelId: modelId)
} catch {
print("[mlx] Failed to load model: \(error)")
throw error
}
// Set up the HTTP server
let server = MLXHTTPServer(
modelRunner: modelRunner,
modelId: modelId,
apiKey: apiKey
)
let router = server.buildRouter()
let app = Application(router: router, configuration: .init(address: .hostname("127.0.0.1", port: port)))
// Print readiness signal (monitored by Tauri plugin)
print("[mlx] http server listening on http://127.0.0.1:\(port)")
print("[mlx] server is listening on 127.0.0.1:\(port)")
try await app.run()
}
}

View File

@@ -0,0 +1,364 @@
import Foundation
import MLX
import MLXLLM
import MLXLMCommon
import MLXRandom
import MLXVLM
/// Manages loading and running inference with MLX models
actor ModelRunner {
private var container: ModelContainer?
private var modelId: String = ""
var isLoaded: Bool {
container != nil
}
var currentModelId: String {
modelId
}
/// Load a model from the given path, trying LLM first then VLM
func load(modelPath: String, modelId: String) async throws {
print("[mlx] Loading model from: \(modelPath)")
let modelURL = URL(fileURLWithPath: modelPath)
let modelDir = modelURL.deletingLastPathComponent()
let configuration = ModelConfiguration(directory: modelDir, defaultPrompt: "")
// Try LLM factory first, fall back to VLM factory
do {
self.container = try await LLMModelFactory.shared.loadContainer(
configuration: configuration
) { progress in
print("[mlx] Loading progress: \(Int(progress.fractionCompleted * 100))%")
}
print("[mlx] Model loaded as LLM: \(modelId)")
} catch {
print("[mlx] LLM loading failed (\(error.localizedDescription)), trying VLM factory...")
do {
self.container = try await VLMModelFactory.shared.loadContainer(
configuration: configuration
) { progress in
print("[mlx] Loading progress: \(Int(progress.fractionCompleted * 100))%")
}
print("[mlx] Model loaded as VLM: \(modelId)")
} catch {
print("[mlx] Error: Failed to load model with both LLM and VLM factories: \(error.localizedDescription)")
throw error
}
}
self.modelId = modelId
print("[mlx] Model ready: \(modelId)")
}
/// Build Chat.Message array from ChatMessages, including images and videos
private func buildChat(from messages: [ChatMessage]) -> [Chat.Message] {
messages.map { message in
let role: Chat.Message.Role =
switch message.role {
case "assistant":
.assistant
case "user":
.user
case "system":
.system
case "tool":
.tool
default:
.user
}
let images: [UserInput.Image] = (message.images ?? []).compactMap { urlString in
guard let url = URL(string: urlString) else {
print("[mlx] Warning: Invalid image URL: \(urlString)")
return nil
}
return .url(url)
}
let videos: [UserInput.Video] = (message.videos ?? []).compactMap { urlString in
guard let url = URL(string: urlString) else {
print("[mlx] Warning: Invalid video URL: \(urlString)")
return nil
}
return .url(url)
}
if !images.isEmpty {
print("[mlx] Message has \(images.count) image(s)")
}
if !videos.isEmpty {
print("[mlx] Message has \(videos.count) video(s)")
}
return Chat.Message(
role: role, content: message.content ?? "", images: images, videos: videos)
}
}
/// Convert AnyCodable tools array to ToolSpec format
private func buildToolSpecs(from tools: [AnyCodable]?) -> [[String: any Sendable]]? {
guard let tools = tools, !tools.isEmpty else { return nil }
let specs = tools.map { tool in
tool.toSendable() as! [String: any Sendable]
}
print("[mlx] Tools provided: \(specs.count)")
return specs
}
/// Generate a chat completion (non-streaming)
func generate(
messages: [ChatMessage],
temperature: Float = 0.7,
topP: Float = 1.0,
maxTokens: Int = 2048,
repetitionPenalty: Float = 1.0,
stop: [String] = [],
tools: [AnyCodable]? = nil
) async throws -> (String, [ToolCallInfo], UsageInfo) {
guard let container = container else {
print("[mlx] Error: generate() called but no model is loaded")
throw MLXServerError.modelNotLoaded
}
print("[mlx] Generate: \(messages.count) messages, temp=\(temperature), topP=\(topP), maxTokens=\(maxTokens)")
let chat = buildChat(from: messages)
let toolSpecs = buildToolSpecs(from: tools)
let generateParameters = GenerateParameters(
maxTokens: maxTokens,
temperature: temperature,
topP: topP,
repetitionPenalty: repetitionPenalty
)
let userInput = UserInput(
chat: chat,
processing: .init(resize: .init(width: 1024, height: 1024)),
tools: toolSpecs
)
let result: (String, [ToolCallInfo], UsageInfo) = try await container.perform { context in
let input = try await context.processor.prepare(input: userInput)
let promptTokenCount = input.text.tokens.size
var output = ""
var completionTokenCount = 0
var collectedToolCalls: [ToolCallInfo] = []
var completionInfo: GenerateCompletionInfo?
do {
for await generation in try MLXLMCommon.generate(
input: input, parameters: generateParameters, context: context
) {
switch generation {
case .chunk(let chunk):
output += chunk
completionTokenCount += 1
// Check stop sequences
var hitStop = false
for s in stop where output.hasSuffix(s) {
output = String(output.dropLast(s.count))
hitStop = true
print("[mlx] Hit stop sequence: \"\(s)\"")
break
}
if hitStop { break }
case .info(let info):
completionInfo = info
print("[mlx] Generation info: \(info.promptTokenCount) prompt tokens, \(info.generationTokenCount) generated tokens")
print("[mlx] Prompt: \(String(format: "%.1f", info.promptTokensPerSecond)) tokens/sec")
print("[mlx] Generation: \(String(format: "%.1f", info.tokensPerSecond)) tokens/sec")
case .toolCall(let toolCall):
let argsData = try JSONSerialization.data(
withJSONObject: toolCall.function.arguments.mapValues { $0.anyValue },
options: [.sortedKeys]
)
let argsString = String(data: argsData, encoding: .utf8) ?? "{}"
let info = ToolCallInfo(
id: generateToolCallId(),
type: "function",
function: FunctionCall(
name: toolCall.function.name,
arguments: argsString
)
)
collectedToolCalls.append(info)
print("[mlx] Tool call: \(toolCall.function.name)(\(argsString))")
}
}
} catch {
print("[mlx] Error during generation: \(error.localizedDescription)")
throw error
}
let usage: UsageInfo
if let info = completionInfo {
usage = UsageInfo(
prompt_tokens: info.promptTokenCount,
completion_tokens: info.generationTokenCount,
total_tokens: info.promptTokenCount + info.generationTokenCount
)
} else {
usage = UsageInfo(
prompt_tokens: promptTokenCount,
completion_tokens: completionTokenCount,
total_tokens: promptTokenCount + completionTokenCount
)
}
print("[mlx] Generate complete: \(output.count) chars, \(collectedToolCalls.count) tool call(s)")
return (output, collectedToolCalls, usage)
}
return result
}
/// Generate a streaming chat completion
func generateStream(
messages: [ChatMessage],
temperature: Float = 0.7,
topP: Float = 1.0,
maxTokens: Int = 2048,
repetitionPenalty: Float = 1.0,
stop: [String] = [],
tools: [AnyCodable]? = nil
) -> AsyncThrowingStream<StreamEvent, Error> {
AsyncThrowingStream { continuation in
Task {
guard let container = self.container else {
print("[mlx] Error: generateStream() called but no model is loaded")
continuation.finish(throwing: MLXServerError.modelNotLoaded)
return
}
print("[mlx] Stream generate: \(messages.count) messages, temp=\(temperature), topP=\(topP), maxTokens=\(maxTokens)")
let chat = self.buildChat(from: messages)
let toolSpecs = self.buildToolSpecs(from: tools)
let userInput = UserInput(
chat: chat,
processing: .init(resize: .init(width: 1024, height: 1024)),
tools: toolSpecs
)
do {
try await container.perform { context in
let generateParameters = GenerateParameters(
maxTokens: maxTokens,
temperature: temperature,
topP: topP,
repetitionPenalty: repetitionPenalty
)
let input = try await context.processor.prepare(input: userInput)
var completionTokenCount = 0
var accumulated = ""
var hasToolCalls = false
do {
for await generation in try MLXLMCommon.generate(
input: input, parameters: generateParameters, context: context
) {
switch generation {
case .chunk(let chunk):
accumulated += chunk
completionTokenCount += 1
continuation.yield(.chunk(chunk))
// Check stop sequences
var hitStop = false
for s in stop where accumulated.hasSuffix(s) {
hitStop = true
print("[mlx] Hit stop sequence: \"\(s)\"")
break
}
if hitStop { break }
case .info(let info):
print("[mlx] Stream generation info: \(info.promptTokenCount) prompt tokens, \(info.generationTokenCount) generated tokens")
print("[mlx] Prompt: \(String(format: "%.1f", info.promptTokensPerSecond)) tokens/sec")
print("[mlx] Generation: \(String(format: "%.1f", info.tokensPerSecond)) tokens/sec")
let usage = UsageInfo(
prompt_tokens: info.promptTokenCount,
completion_tokens: info.generationTokenCount,
total_tokens: info.promptTokenCount + info.generationTokenCount
)
let timings = TimingsInfo(
prompt_n: info.promptTokenCount,
predicted_n: info.generationTokenCount,
predicted_per_second: info.tokensPerSecond,
prompt_per_second: info.promptTokensPerSecond
)
continuation.yield(.done(usage: usage, timings: timings, hasToolCalls: hasToolCalls))
case .toolCall(let toolCall):
hasToolCalls = true
let argsData = try JSONSerialization.data(
withJSONObject: toolCall.function.arguments.mapValues { $0.anyValue },
options: [.sortedKeys]
)
let argsString = String(data: argsData, encoding: .utf8) ?? "{}"
let info = ToolCallInfo(
id: generateToolCallId(),
type: "function",
function: FunctionCall(
name: toolCall.function.name,
arguments: argsString
)
)
print("[mlx] Stream tool call: \(toolCall.function.name)(\(argsString))")
continuation.yield(.toolCall(info))
}
}
} catch {
print("[mlx] Error during stream generation: \(error.localizedDescription)")
throw error
}
// If no .info was received, send done with fallback usage
// The .info case already yields .done, so only send if we haven't
print("[mlx] Stream complete: \(accumulated.count) chars")
continuation.finish()
}
} catch {
print("[mlx] Error in stream: \(error.localizedDescription)")
continuation.finish(throwing: error)
}
}
}
}
}
/// Events emitted during streaming generation
enum StreamEvent {
/// A text chunk
case chunk(String)
/// A tool call from the model
case toolCall(ToolCallInfo)
/// Generation complete with usage and timing info
case done(usage: UsageInfo, timings: TimingsInfo?, hasToolCalls: Bool)
}
enum MLXServerError: Error, LocalizedError {
case modelNotLoaded
case invalidRequest(String)
var errorDescription: String? {
switch self {
case .modelNotLoaded:
return "No model is currently loaded"
case .invalidRequest(let msg):
return "Invalid request: \(msg)"
}
}
}

View File

@@ -0,0 +1,225 @@
import Foundation
// MARK: - Chat Completion Request
struct ChatCompletionRequest: Codable {
let model: String
let messages: [ChatMessage]
var temperature: Float?
var top_p: Float?
var max_tokens: Int?
var stream: Bool?
var stop: [String]?
var n_predict: Int?
var repetition_penalty: Float?
var tools: [AnyCodable]?
}
struct AnyCodable: Codable, @unchecked Sendable {
let value: Any
init(_ value: Any) {
self.value = value
}
func encode(to encoder: Encoder) throws {
var container = encoder.singleValueContainer()
if let string = value as? String {
try container.encode(string)
} else if let int = value as? Int {
try container.encode(int)
} else if let double = value as? Double {
try container.encode(double)
} else if let bool = value as? Bool {
try container.encode(bool)
} else if let array = value as? [Any] {
try container.encode(array.map { AnyCodable($0) })
} else if let dict = value as? [String: Any] {
try container.encode(dict.mapValues { AnyCodable($0) })
} else {
throw EncodingError.invalidValue(value, EncodingError.Context(codingPath: encoder.codingPath, debugDescription: "Unsupported type"))
}
}
init(from decoder: Decoder) throws {
let container = try decoder.singleValueContainer()
if let string = try? container.decode(String.self) {
value = string
} else if let int = try? container.decode(Int.self) {
value = int
} else if let double = try? container.decode(Double.self) {
value = double
} else if let bool = try? container.decode(Bool.self) {
value = bool
} else if let array = try? container.decode([AnyCodable].self) {
value = array.map { $0.value }
} else if let dict = try? container.decode([String: AnyCodable].self) {
value = dict.mapValues { $0.value }
} else {
throw DecodingError.dataCorrupted(DecodingError.Context(codingPath: decoder.codingPath, debugDescription: "Unsupported type"))
}
}
/// Recursively convert the underlying value to `[String: any Sendable]` or primitive Sendable types
func toSendable() -> any Sendable {
switch value {
case let string as String:
return string
case let int as Int:
return int
case let double as Double:
return double
case let bool as Bool:
return bool
case let array as [Any]:
return array.map { AnyCodable($0).toSendable() }
case let dict as [String: Any]:
return dict.mapValues { AnyCodable($0).toSendable() }
default:
return String(describing: value)
}
}
}
struct ChatMessage: Codable {
let role: String
var content: String?
var images: [String]?
var videos: [String]?
var tool_calls: [ToolCallInfo]?
var tool_call_id: String?
var name: String?
}
// MARK: - Tool Call Types (OpenAI-compatible)
struct ToolCallInfo: Codable {
let id: String
let type: String
let function: FunctionCall
}
struct FunctionCall: Codable {
let name: String
let arguments: String
}
struct ToolCallDelta: Codable {
let index: Int
var id: String?
var type: String?
var function: FunctionCallDelta?
}
struct FunctionCallDelta: Codable {
var name: String?
var arguments: String?
}
// MARK: - Chat Completion Response (non-streaming)
struct ChatCompletionResponse: Codable {
let id: String
let object: String
let created: Int
let model: String
let choices: [ChatChoice]
var usage: UsageInfo?
}
struct ChatChoice: Codable {
let index: Int
let message: ChatMessage
let finish_reason: String?
}
struct UsageInfo: Codable {
let prompt_tokens: Int
let completion_tokens: Int
let total_tokens: Int
}
// MARK: - Chat Completion Chunk (streaming)
struct ChatCompletionChunk: Codable {
let id: String
let object: String
let created: Int
let model: String
let choices: [ChatChunkChoice]
var usage: UsageInfo?
var timings: TimingsInfo?
}
struct ChatChunkChoice: Codable {
let index: Int
let delta: ChatDelta
let finish_reason: String?
}
struct ChatDelta: Codable {
var role: String?
var content: String?
var tool_calls: [ToolCallDelta]?
}
struct TimingsInfo: Codable {
var prompt_n: Int?
var predicted_n: Int?
var predicted_per_second: Double?
var prompt_per_second: Double?
}
// MARK: - Models List Response
struct ModelsResponse: Codable {
let object: String
let data: [ModelInfo]
}
struct ModelInfo: Codable {
let id: String
let object: String
let created: Int
let owned_by: String
}
// MARK: - Health Response
struct HealthResponse: Codable {
let status: String
}
// MARK: - Error Response
struct ErrorResponse: Codable {
let error: ErrorDetail
}
struct ErrorDetail: Codable {
let message: String
let type_name: String
let code: String?
enum CodingKeys: String, CodingKey {
case message
case type_name = "type"
case code
}
}
// MARK: - Helpers
func generateResponseId() -> String {
"chatcmpl-\(UUID().uuidString.prefix(12))"
}
func generateToolCallId() -> String {
"call_\(UUID().uuidString.replacingOccurrences(of: "-", with: "").prefix(24))"
}
func currentTimestamp() -> Int {
Int(Date().timeIntervalSince1970)
}

View File

@@ -0,0 +1,340 @@
import Foundation
import Hummingbird
/// HTTP server that exposes an OpenAI-compatible API backed by MLX
struct MLXHTTPServer {
let modelRunner: ModelRunner
let modelId: String
let apiKey: String
func buildRouter() -> Router<BasicRequestContext> {
let router = Router()
// Health check
router.get("/health") { _, _ in
let response = HealthResponse(status: "ok")
return try encodeJSON(response)
}
// List models
router.get("/v1/models") { _, _ in
let response = ModelsResponse(
object: "list",
data: [
ModelInfo(
id: self.modelId,
object: "model",
created: currentTimestamp(),
owned_by: "mlx"
)
]
)
return try encodeJSON(response)
}
// Chat completions
router.post("/v1/chat/completions") { request, context in
// Validate API key if set
if !self.apiKey.isEmpty {
let authHeader =
request.headers[.authorization]
let expectedAuth = "Bearer \(self.apiKey)"
if authHeader != expectedAuth {
let error = ErrorResponse(
error: ErrorDetail(
message: "Unauthorized",
type_name: "authentication_error",
code: "unauthorized"
)
)
let response = try Response(
status: .unauthorized,
headers: [.contentType: "application/json"],
body: .init(byteBuffer: encodeJSONBuffer(error))
)
return response
}
}
// Parse request body
let body = try await request.body.collect(upTo: 10 * 1024 * 1024) // 10MB max
let chatRequest = try JSONDecoder().decode(ChatCompletionRequest.self, from: body)
let temperature = chatRequest.temperature ?? 0.7
let topP = chatRequest.top_p ?? 1.0
let maxTokens = chatRequest.max_tokens ?? chatRequest.n_predict ?? 2048
let repetitionPenalty = chatRequest.repetition_penalty ?? 1.0
let stop = chatRequest.stop ?? []
let isStreaming = chatRequest.stream ?? false
let tools = chatRequest.tools
print("[mlx] Request: model=\(chatRequest.model), messages=\(chatRequest.messages.count), stream=\(isStreaming), tools=\(tools?.count ?? 0)")
if isStreaming {
return try await self.handleStreamingRequest(
chatRequest: chatRequest,
temperature: temperature,
topP: topP,
maxTokens: maxTokens,
repetitionPenalty: repetitionPenalty,
stop: stop,
tools: tools
)
} else {
return try await self.handleNonStreamingRequest(
chatRequest: chatRequest,
temperature: temperature,
topP: topP,
maxTokens: maxTokens,
repetitionPenalty: repetitionPenalty,
stop: stop,
tools: tools
)
}
}
return router
}
private func handleNonStreamingRequest(
chatRequest: ChatCompletionRequest,
temperature: Float,
topP: Float,
maxTokens: Int,
repetitionPenalty: Float,
stop: [String],
tools: [AnyCodable]? = nil
) async throws -> Response {
let (text, toolCalls, usage) = try await modelRunner.generate(
messages: chatRequest.messages,
temperature: temperature,
topP: topP,
maxTokens: maxTokens,
repetitionPenalty: repetitionPenalty,
stop: stop,
tools: tools
)
let finishReason = toolCalls.isEmpty ? "stop" : "tool_calls"
let message = ChatMessage(
role: "assistant",
content: text.isEmpty && !toolCalls.isEmpty ? nil : text,
tool_calls: toolCalls.isEmpty ? nil : toolCalls
)
let response = ChatCompletionResponse(
id: generateResponseId(),
object: "chat.completion",
created: currentTimestamp(),
model: chatRequest.model,
choices: [
ChatChoice(
index: 0,
message: message,
finish_reason: finishReason
)
],
usage: usage
)
print("[mlx] Response: \(text.count) chars, \(toolCalls.count) tool call(s), finish=\(finishReason)")
return try Response(
status: .ok,
headers: [.contentType: "application/json"],
body: .init(byteBuffer: encodeJSONBuffer(response))
)
}
private func handleStreamingRequest(
chatRequest: ChatCompletionRequest,
temperature: Float,
topP: Float,
maxTokens: Int,
repetitionPenalty: Float,
stop: [String],
tools: [AnyCodable]? = nil
) async throws -> Response {
let responseId = generateResponseId()
let created = currentTimestamp()
let model = chatRequest.model
let stream = await modelRunner.generateStream(
messages: chatRequest.messages,
temperature: temperature,
topP: topP,
maxTokens: maxTokens,
repetitionPenalty: repetitionPenalty,
stop: stop,
tools: tools
)
// Build SSE response body
let responseStream = AsyncStream<ByteBuffer> { continuation in
Task {
// Send initial role chunk
let initialChunk = ChatCompletionChunk(
id: responseId,
object: "chat.completion.chunk",
created: created,
model: model,
choices: [
ChatChunkChoice(
index: 0,
delta: ChatDelta(role: "assistant", content: nil),
finish_reason: nil
)
]
)
if let data = try? encodeJSONData(initialChunk) {
var buffer = ByteBufferAllocator().buffer(capacity: data.count + 8)
buffer.writeString("data: ")
buffer.writeBytes(data)
buffer.writeString("\n\n")
continuation.yield(buffer)
}
do {
for try await event in stream {
switch event {
case .chunk(let token):
let chunk = ChatCompletionChunk(
id: responseId,
object: "chat.completion.chunk",
created: created,
model: model,
choices: [
ChatChunkChoice(
index: 0,
delta: ChatDelta(role: nil, content: token),
finish_reason: nil
)
]
)
if let data = try? encodeJSONData(chunk) {
var buffer = ByteBufferAllocator().buffer(
capacity: data.count + 8)
buffer.writeString("data: ")
buffer.writeBytes(data)
buffer.writeString("\n\n")
continuation.yield(buffer)
}
case .toolCall(let toolCallInfo):
let chunk = ChatCompletionChunk(
id: responseId,
object: "chat.completion.chunk",
created: created,
model: model,
choices: [
ChatChunkChoice(
index: 0,
delta: ChatDelta(
role: nil,
content: nil,
tool_calls: [
ToolCallDelta(
index: 0,
id: toolCallInfo.id,
type: toolCallInfo.type,
function: FunctionCallDelta(
name: toolCallInfo.function.name,
arguments: toolCallInfo.function.arguments
)
)
]
),
finish_reason: nil
)
]
)
if let data = try? encodeJSONData(chunk) {
var buffer = ByteBufferAllocator().buffer(
capacity: data.count + 8)
buffer.writeString("data: ")
buffer.writeBytes(data)
buffer.writeString("\n\n")
continuation.yield(buffer)
}
case .done(let usage, let timings, let hasToolCalls):
let finishReason = hasToolCalls ? "tool_calls" : "stop"
// Final chunk with finish_reason
let finalChunk = ChatCompletionChunk(
id: responseId,
object: "chat.completion.chunk",
created: created,
model: model,
choices: [
ChatChunkChoice(
index: 0,
delta: ChatDelta(role: nil, content: nil),
finish_reason: finishReason
)
],
usage: usage,
timings: timings
)
if let data = try? encodeJSONData(finalChunk) {
var buffer = ByteBufferAllocator().buffer(
capacity: data.count + 8)
buffer.writeString("data: ")
buffer.writeBytes(data)
buffer.writeString("\n\n")
continuation.yield(buffer)
}
// Send [DONE]
var doneBuffer = ByteBufferAllocator().buffer(capacity: 16)
doneBuffer.writeString("data: [DONE]\n\n")
continuation.yield(doneBuffer)
}
}
} catch {
print("[mlx] Error in SSE stream: \(error.localizedDescription)")
// Send error as SSE event
var buffer = ByteBufferAllocator().buffer(capacity: 256)
buffer.writeString(
"error: {\"message\":\"\(error.localizedDescription)\"}\n\n")
continuation.yield(buffer)
}
continuation.finish()
}
}
return Response(
status: .ok,
headers: [
.contentType: "text/event-stream",
.init("Cache-Control")!: "no-cache",
.init("Connection")!: "keep-alive",
],
body: .init(asyncSequence: responseStream)
)
}
}
// MARK: - JSON Encoding Helpers
private func encodeJSON<T: Encodable>(_ value: T) throws -> Response {
let data = try JSONEncoder().encode(value)
var buffer = ByteBufferAllocator().buffer(capacity: data.count)
buffer.writeBytes(data)
return Response(
status: .ok,
headers: [.contentType: "application/json"],
body: .init(byteBuffer: buffer)
)
}
private func encodeJSONBuffer<T: Encodable>(_ value: T) throws -> ByteBuffer {
let data = try JSONEncoder().encode(value)
var buffer = ByteBufferAllocator().buffer(capacity: data.count)
buffer.writeBytes(data)
return buffer
}
private func encodeJSONData<T: Encodable>(_ value: T) throws -> Data {
try JSONEncoder().encode(value)
}

16
src-tauri/Cargo.lock generated
View File

@@ -41,6 +41,7 @@ dependencies = [
"tauri-plugin-http",
"tauri-plugin-llamacpp",
"tauri-plugin-log",
"tauri-plugin-mlx",
"tauri-plugin-opener",
"tauri-plugin-os",
"tauri-plugin-rag",
@@ -6311,6 +6312,21 @@ dependencies = [
"time",
]
[[package]]
name = "tauri-plugin-mlx"
version = "0.6.599"
dependencies = [
"jan-utils",
"log",
"nix",
"serde",
"sysinfo",
"tauri",
"tauri-plugin",
"thiserror 2.0.17",
"tokio",
]
[[package]]
name = "tauri-plugin-opener"
version = "2.5.0"

View File

@@ -27,9 +27,11 @@ default = [
]
hardware = ["dep:tauri-plugin-hardware"]
deep-link = ["dep:tauri-plugin-deep-link"]
mlx = ["dep:tauri-plugin-mlx"]
desktop = [
"deep-link",
"hardware"
"hardware",
"mlx"
]
mobile = [
"tauri/protocol-asset",
@@ -82,6 +84,7 @@ zip = "0.6"
tauri-plugin-deep-link = { version = "2", optional = true }
tauri-plugin-hardware = { path = "./plugins/tauri-plugin-hardware", optional = true }
tauri-plugin-llamacpp = { path = "./plugins/tauri-plugin-llamacpp" }
tauri-plugin-mlx = { path = "./plugins/tauri-plugin-mlx", optional = true }
tauri-plugin-vector-db = { path = "./plugins/tauri-plugin-vector-db" }
tauri-plugin-rag = { path = "./plugins/tauri-plugin-rag" }
tauri-plugin-http = { version = "2", features = ["unsafe-headers"] }

View File

@@ -26,6 +26,7 @@
"hardware:default",
"deep-link:default",
"llamacpp:default",
"mlx:default",
"updater:default",
"updater:allow-check",
{

View File

@@ -27,6 +27,7 @@
"vector-db:default",
"rag:default",
"llamacpp:default",
"mlx:default",
"deep-link:default",
"hardware:default",
{

View File

@@ -79,6 +79,7 @@ export interface DownloadItem {
proxy?: Record<string, string | string[] | boolean>
sha256?: string
size?: number
model_id?: string
}
export interface ModelConfig {
@@ -90,6 +91,7 @@ export interface ModelConfig {
sha256?: string
mmproj_sha256?: string
mmproj_size_bytes?: number
embedding?: boolean
}
export interface EmbeddingResponse {

View File

@@ -0,0 +1,27 @@
[package]
name = "tauri-plugin-mlx"
version = "0.6.599"
authors = ["Jan <service@jan.ai>"]
description = "Tauri plugin for managing MLX-Swift server processes and model loading on Apple Silicon"
license = "MIT"
repository = "https://github.com/janhq/jan"
edition = "2021"
rust-version = "1.77.2"
exclude = ["/examples", "/dist-js", "/guest-js", "/node_modules"]
links = "tauri-plugin-mlx"
[dependencies]
log = "0.4"
serde = { version = "1.0", features = ["derive"] }
sysinfo = "0.34.2"
tauri = { version = "2.5.0", default-features = false, features = [] }
thiserror = "2.0.12"
tokio = { version = "1", features = ["full"] }
jan-utils = { path = "../../utils" }
# Unix-specific dependencies (macOS)
[target.'cfg(unix)'.dependencies]
nix = { version = "=0.30.1", features = ["signal", "process"] }
[build-dependencies]
tauri-plugin = { version = "2.3.1", features = ["build"] }

View File

@@ -0,0 +1,14 @@
const COMMANDS: &[&str] = &[
"cleanup_mlx_processes",
"load_mlx_model",
"unload_mlx_model",
"is_mlx_process_running",
"get_mlx_random_port",
"find_mlx_session_by_model",
"get_mlx_loaded_models",
"get_mlx_all_sessions",
];
fn main() {
tauri_plugin::Builder::new(COMMANDS).build();
}

View File

@@ -0,0 +1,63 @@
'use strict';
var core = require('@tauri-apps/api/core');
function asNumber(v, defaultValue = 0) {
if (v === '' || v === null || v === undefined)
return defaultValue;
const n = Number(v);
return isFinite(n) ? n : defaultValue;
}
function asString(v, defaultValue = '') {
if (v === '' || v === null || v === undefined)
return defaultValue;
return String(v);
}
function normalizeMlxConfig(config) {
return {
ctx_size: asNumber(config.ctx_size),
n_predict: asNumber(config.n_predict),
threads: asNumber(config.threads),
chat_template: asString(config.chat_template),
};
}
async function loadMlxModel(binaryPath, modelId, modelPath, port, cfg, envs, isEmbedding = false, timeout = 600) {
const config = normalizeMlxConfig(cfg);
return await core.invoke('plugin:mlx|load_mlx_model', {
binaryPath,
modelId,
modelPath,
port,
config,
envs,
isEmbedding,
timeout,
});
}
async function unloadMlxModel(pid) {
return await core.invoke('plugin:mlx|unload_mlx_model', { pid });
}
async function isMlxProcessRunning(pid) {
return await core.invoke('plugin:mlx|is_mlx_process_running', { pid });
}
async function getMlxRandomPort() {
return await core.invoke('plugin:mlx|get_mlx_random_port');
}
async function findMlxSessionByModel(modelId) {
return await core.invoke('plugin:mlx|find_mlx_session_by_model', { modelId });
}
async function getMlxLoadedModels() {
return await core.invoke('plugin:mlx|get_mlx_loaded_models');
}
async function getMlxAllSessions() {
return await core.invoke('plugin:mlx|get_mlx_all_sessions');
}
exports.findMlxSessionByModel = findMlxSessionByModel;
exports.getMlxAllSessions = getMlxAllSessions;
exports.getMlxLoadedModels = getMlxLoadedModels;
exports.getMlxRandomPort = getMlxRandomPort;
exports.isMlxProcessRunning = isMlxProcessRunning;
exports.loadMlxModel = loadMlxModel;
exports.normalizeMlxConfig = normalizeMlxConfig;
exports.unloadMlxModel = unloadMlxModel;

View File

@@ -0,0 +1,10 @@
import { SessionInfo, UnloadResult, MlxConfig } from './types';
export { SessionInfo, UnloadResult, MlxConfig } from './types';
export declare function normalizeMlxConfig(config: any): MlxConfig;
export declare function loadMlxModel(binaryPath: string, modelId: string, modelPath: string, port: number, cfg: MlxConfig, envs: Record<string, string>, isEmbedding?: boolean, timeout?: number): Promise<SessionInfo>;
export declare function unloadMlxModel(pid: number): Promise<UnloadResult>;
export declare function isMlxProcessRunning(pid: number): Promise<boolean>;
export declare function getMlxRandomPort(): Promise<number>;
export declare function findMlxSessionByModel(modelId: string): Promise<SessionInfo | null>;
export declare function getMlxLoadedModels(): Promise<string[]>;
export declare function getMlxAllSessions(): Promise<SessionInfo[]>;

View File

@@ -0,0 +1,54 @@
import { invoke } from '@tauri-apps/api/core';
function asNumber(v, defaultValue = 0) {
if (v === '' || v === null || v === undefined)
return defaultValue;
const n = Number(v);
return isFinite(n) ? n : defaultValue;
}
function asString(v, defaultValue = '') {
if (v === '' || v === null || v === undefined)
return defaultValue;
return String(v);
}
function normalizeMlxConfig(config) {
return {
ctx_size: asNumber(config.ctx_size),
n_predict: asNumber(config.n_predict),
threads: asNumber(config.threads),
chat_template: asString(config.chat_template),
};
}
async function loadMlxModel(binaryPath, modelId, modelPath, port, cfg, envs, isEmbedding = false, timeout = 600) {
const config = normalizeMlxConfig(cfg);
return await invoke('plugin:mlx|load_mlx_model', {
binaryPath,
modelId,
modelPath,
port,
config,
envs,
isEmbedding,
timeout,
});
}
async function unloadMlxModel(pid) {
return await invoke('plugin:mlx|unload_mlx_model', { pid });
}
async function isMlxProcessRunning(pid) {
return await invoke('plugin:mlx|is_mlx_process_running', { pid });
}
async function getMlxRandomPort() {
return await invoke('plugin:mlx|get_mlx_random_port');
}
async function findMlxSessionByModel(modelId) {
return await invoke('plugin:mlx|find_mlx_session_by_model', { modelId });
}
async function getMlxLoadedModels() {
return await invoke('plugin:mlx|get_mlx_loaded_models');
}
async function getMlxAllSessions() {
return await invoke('plugin:mlx|get_mlx_all_sessions');
}
export { findMlxSessionByModel, getMlxAllSessions, getMlxLoadedModels, getMlxRandomPort, isMlxProcessRunning, loadMlxModel, normalizeMlxConfig, unloadMlxModel };

View File

@@ -0,0 +1,18 @@
export interface SessionInfo {
pid: number;
port: number;
model_id: string;
model_path: string;
is_embedding: boolean;
api_key: string;
}
export interface UnloadResult {
success: boolean;
error?: string;
}
export type MlxConfig = {
ctx_size: number;
n_predict: number;
threads: number;
chat_template: string;
};

View File

@@ -0,0 +1,73 @@
import { invoke } from '@tauri-apps/api/core'
import { SessionInfo, UnloadResult, MlxConfig } from './types'
export { SessionInfo, UnloadResult, MlxConfig } from './types'
function asNumber(v: any, defaultValue = 0): number {
if (v === '' || v === null || v === undefined) return defaultValue
const n = Number(v)
return isFinite(n) ? n : defaultValue
}
function asString(v: any, defaultValue = ''): string {
if (v === '' || v === null || v === undefined) return defaultValue
return String(v)
}
export function normalizeMlxConfig(config: any): MlxConfig {
return {
ctx_size: asNumber(config.ctx_size),
n_predict: asNumber(config.n_predict),
threads: asNumber(config.threads),
chat_template: asString(config.chat_template),
}
}
export async function loadMlxModel(
binaryPath: string,
modelId: string,
modelPath: string,
port: number,
cfg: MlxConfig,
envs: Record<string, string>,
isEmbedding: boolean = false,
timeout: number = 600
): Promise<SessionInfo> {
const config = normalizeMlxConfig(cfg)
return await invoke('plugin:mlx|load_mlx_model', {
binaryPath,
modelId,
modelPath,
port,
config,
envs,
isEmbedding,
timeout,
})
}
export async function unloadMlxModel(pid: number): Promise<UnloadResult> {
return await invoke('plugin:mlx|unload_mlx_model', { pid })
}
export async function isMlxProcessRunning(pid: number): Promise<boolean> {
return await invoke('plugin:mlx|is_mlx_process_running', { pid })
}
export async function getMlxRandomPort(): Promise<number> {
return await invoke('plugin:mlx|get_mlx_random_port')
}
export async function findMlxSessionByModel(
modelId: string
): Promise<SessionInfo | null> {
return await invoke('plugin:mlx|find_mlx_session_by_model', { modelId })
}
export async function getMlxLoadedModels(): Promise<string[]> {
return await invoke('plugin:mlx|get_mlx_loaded_models')
}
export async function getMlxAllSessions(): Promise<SessionInfo[]> {
return await invoke('plugin:mlx|get_mlx_all_sessions')
}

View File

@@ -0,0 +1,20 @@
export interface SessionInfo {
pid: number
port: number
model_id: string
model_path: string
is_embedding: boolean
api_key: string
}
export interface UnloadResult {
success: boolean
error?: string
}
export type MlxConfig = {
ctx_size: number
n_predict: number
threads: number
chat_template: string
}

View File

@@ -0,0 +1,33 @@
{
"name": "@janhq/tauri-plugin-mlx-api",
"version": "0.6.6",
"private": true,
"description": "Tauri plugin API for MLX-Swift inference on Apple Silicon",
"type": "module",
"types": "./dist-js/index.d.ts",
"main": "./dist-js/index.cjs",
"module": "./dist-js/index.js",
"exports": {
"types": "./dist-js/index.d.ts",
"import": "./dist-js/index.js",
"require": "./dist-js/index.cjs"
},
"files": [
"dist-js",
"README.md"
],
"scripts": {
"build": "rollup -c",
"prepublishOnly": "yarn build",
"pretest": "yarn build"
},
"dependencies": {
"@tauri-apps/api": ">=2.0.0-beta.6"
},
"devDependencies": {
"@rollup/plugin-typescript": "^12.0.0",
"rollup": "^4.9.6",
"tslib": "^2.6.2",
"typescript": "^5.3.3"
}
}

View File

@@ -0,0 +1,13 @@
# Automatically generated - DO NOT EDIT!
"$schema" = "../../schemas/schema.json"
[[permission]]
identifier = "allow-cleanup-mlx-processes"
description = "Enables the cleanup_mlx_processes command without any pre-configured scope."
commands.allow = ["cleanup_mlx_processes"]
[[permission]]
identifier = "deny-cleanup-mlx-processes"
description = "Denies the cleanup_mlx_processes command without any pre-configured scope."
commands.deny = ["cleanup_mlx_processes"]

View File

@@ -0,0 +1,13 @@
# Automatically generated - DO NOT EDIT!
"$schema" = "../../schemas/schema.json"
[[permission]]
identifier = "allow-find-mlx-session-by-model"
description = "Enables the find_mlx_session_by_model command without any pre-configured scope."
commands.allow = ["find_mlx_session_by_model"]
[[permission]]
identifier = "deny-find-mlx-session-by-model"
description = "Denies the find_mlx_session_by_model command without any pre-configured scope."
commands.deny = ["find_mlx_session_by_model"]

View File

@@ -0,0 +1,13 @@
# Automatically generated - DO NOT EDIT!
"$schema" = "../../schemas/schema.json"
[[permission]]
identifier = "allow-get-mlx-all-sessions"
description = "Enables the get_mlx_all_sessions command without any pre-configured scope."
commands.allow = ["get_mlx_all_sessions"]
[[permission]]
identifier = "deny-get-mlx-all-sessions"
description = "Denies the get_mlx_all_sessions command without any pre-configured scope."
commands.deny = ["get_mlx_all_sessions"]

View File

@@ -0,0 +1,13 @@
# Automatically generated - DO NOT EDIT!
"$schema" = "../../schemas/schema.json"
[[permission]]
identifier = "allow-get-mlx-loaded-models"
description = "Enables the get_mlx_loaded_models command without any pre-configured scope."
commands.allow = ["get_mlx_loaded_models"]
[[permission]]
identifier = "deny-get-mlx-loaded-models"
description = "Denies the get_mlx_loaded_models command without any pre-configured scope."
commands.deny = ["get_mlx_loaded_models"]

View File

@@ -0,0 +1,13 @@
# Automatically generated - DO NOT EDIT!
"$schema" = "../../schemas/schema.json"
[[permission]]
identifier = "allow-get-mlx-random-port"
description = "Enables the get_mlx_random_port command without any pre-configured scope."
commands.allow = ["get_mlx_random_port"]
[[permission]]
identifier = "deny-get-mlx-random-port"
description = "Denies the get_mlx_random_port command without any pre-configured scope."
commands.deny = ["get_mlx_random_port"]

View File

@@ -0,0 +1,13 @@
# Automatically generated - DO NOT EDIT!
"$schema" = "../../schemas/schema.json"
[[permission]]
identifier = "allow-is-mlx-process-running"
description = "Enables the is_mlx_process_running command without any pre-configured scope."
commands.allow = ["is_mlx_process_running"]
[[permission]]
identifier = "deny-is-mlx-process-running"
description = "Denies the is_mlx_process_running command without any pre-configured scope."
commands.deny = ["is_mlx_process_running"]

View File

@@ -0,0 +1,13 @@
# Automatically generated - DO NOT EDIT!
"$schema" = "../../schemas/schema.json"
[[permission]]
identifier = "allow-load-mlx-model"
description = "Enables the load_mlx_model command without any pre-configured scope."
commands.allow = ["load_mlx_model"]
[[permission]]
identifier = "deny-load-mlx-model"
description = "Denies the load_mlx_model command without any pre-configured scope."
commands.deny = ["load_mlx_model"]

View File

@@ -0,0 +1,13 @@
# Automatically generated - DO NOT EDIT!
"$schema" = "../../schemas/schema.json"
[[permission]]
identifier = "allow-unload-mlx-model"
description = "Enables the unload_mlx_model command without any pre-configured scope."
commands.allow = ["unload_mlx_model"]
[[permission]]
identifier = "deny-unload-mlx-model"
description = "Denies the unload_mlx_model command without any pre-configured scope."
commands.deny = ["unload_mlx_model"]

View File

@@ -0,0 +1,232 @@
## Default Permission
Default permissions for the MLX plugin
#### This default permission set includes the following:
- `allow-cleanup-mlx-processes`
- `allow-load-mlx-model`
- `allow-unload-mlx-model`
- `allow-is-mlx-process-running`
- `allow-get-mlx-random-port`
- `allow-find-mlx-session-by-model`
- `allow-get-mlx-loaded-models`
- `allow-get-mlx-all-sessions`
## Permission Table
<table>
<tr>
<th>Identifier</th>
<th>Description</th>
</tr>
<tr>
<td>
`mlx:allow-cleanup-mlx-processes`
</td>
<td>
Enables the cleanup_mlx_processes command without any pre-configured scope.
</td>
</tr>
<tr>
<td>
`mlx:deny-cleanup-mlx-processes`
</td>
<td>
Denies the cleanup_mlx_processes command without any pre-configured scope.
</td>
</tr>
<tr>
<td>
`mlx:allow-find-mlx-session-by-model`
</td>
<td>
Enables the find_mlx_session_by_model command without any pre-configured scope.
</td>
</tr>
<tr>
<td>
`mlx:deny-find-mlx-session-by-model`
</td>
<td>
Denies the find_mlx_session_by_model command without any pre-configured scope.
</td>
</tr>
<tr>
<td>
`mlx:allow-get-mlx-all-sessions`
</td>
<td>
Enables the get_mlx_all_sessions command without any pre-configured scope.
</td>
</tr>
<tr>
<td>
`mlx:deny-get-mlx-all-sessions`
</td>
<td>
Denies the get_mlx_all_sessions command without any pre-configured scope.
</td>
</tr>
<tr>
<td>
`mlx:allow-get-mlx-loaded-models`
</td>
<td>
Enables the get_mlx_loaded_models command without any pre-configured scope.
</td>
</tr>
<tr>
<td>
`mlx:deny-get-mlx-loaded-models`
</td>
<td>
Denies the get_mlx_loaded_models command without any pre-configured scope.
</td>
</tr>
<tr>
<td>
`mlx:allow-get-mlx-random-port`
</td>
<td>
Enables the get_mlx_random_port command without any pre-configured scope.
</td>
</tr>
<tr>
<td>
`mlx:deny-get-mlx-random-port`
</td>
<td>
Denies the get_mlx_random_port command without any pre-configured scope.
</td>
</tr>
<tr>
<td>
`mlx:allow-is-mlx-process-running`
</td>
<td>
Enables the is_mlx_process_running command without any pre-configured scope.
</td>
</tr>
<tr>
<td>
`mlx:deny-is-mlx-process-running`
</td>
<td>
Denies the is_mlx_process_running command without any pre-configured scope.
</td>
</tr>
<tr>
<td>
`mlx:allow-load-mlx-model`
</td>
<td>
Enables the load_mlx_model command without any pre-configured scope.
</td>
</tr>
<tr>
<td>
`mlx:deny-load-mlx-model`
</td>
<td>
Denies the load_mlx_model command without any pre-configured scope.
</td>
</tr>
<tr>
<td>
`mlx:allow-unload-mlx-model`
</td>
<td>
Enables the unload_mlx_model command without any pre-configured scope.
</td>
</tr>
<tr>
<td>
`mlx:deny-unload-mlx-model`
</td>
<td>
Denies the unload_mlx_model command without any pre-configured scope.
</td>
</tr>
</table>

View File

@@ -0,0 +1,12 @@
[default]
description = "Default permissions for the MLX plugin"
permissions = [
"allow-cleanup-mlx-processes",
"allow-load-mlx-model",
"allow-unload-mlx-model",
"allow-is-mlx-process-running",
"allow-get-mlx-random-port",
"allow-find-mlx-session-by-model",
"allow-get-mlx-loaded-models",
"allow-get-mlx-all-sessions",
]

View File

@@ -0,0 +1,402 @@
{
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "PermissionFile",
"description": "Permission file that can define a default permission, a set of permissions or a list of inlined permissions.",
"type": "object",
"properties": {
"default": {
"description": "The default permission set for the plugin",
"anyOf": [
{
"$ref": "#/definitions/DefaultPermission"
},
{
"type": "null"
}
]
},
"set": {
"description": "A list of permissions sets defined",
"type": "array",
"items": {
"$ref": "#/definitions/PermissionSet"
}
},
"permission": {
"description": "A list of inlined permissions",
"default": [],
"type": "array",
"items": {
"$ref": "#/definitions/Permission"
}
}
},
"definitions": {
"DefaultPermission": {
"description": "The default permission set of the plugin.\n\nWorks similarly to a permission with the \"default\" identifier.",
"type": "object",
"required": [
"permissions"
],
"properties": {
"version": {
"description": "The version of the permission.",
"type": [
"integer",
"null"
],
"format": "uint64",
"minimum": 1.0
},
"description": {
"description": "Human-readable description of what the permission does. Tauri convention is to use `<h4>` headings in markdown content for Tauri documentation generation purposes.",
"type": [
"string",
"null"
]
},
"permissions": {
"description": "All permissions this set contains.",
"type": "array",
"items": {
"type": "string"
}
}
}
},
"PermissionSet": {
"description": "A set of direct permissions grouped together under a new name.",
"type": "object",
"required": [
"description",
"identifier",
"permissions"
],
"properties": {
"identifier": {
"description": "A unique identifier for the permission.",
"type": "string"
},
"description": {
"description": "Human-readable description of what the permission does.",
"type": "string"
},
"permissions": {
"description": "All permissions this set contains.",
"type": "array",
"items": {
"$ref": "#/definitions/PermissionKind"
}
}
}
},
"Permission": {
"description": "Descriptions of explicit privileges of commands.\n\nIt can enable commands to be accessible in the frontend of the application.\n\nIf the scope is defined it can be used to fine grain control the access of individual or multiple commands.",
"type": "object",
"required": [
"identifier"
],
"properties": {
"version": {
"description": "The version of the permission.",
"type": [
"integer",
"null"
],
"format": "uint64",
"minimum": 1.0
},
"identifier": {
"description": "A unique identifier for the permission.",
"type": "string"
},
"description": {
"description": "Human-readable description of what the permission does. Tauri internal convention is to use `<h4>` headings in markdown content for Tauri documentation generation purposes.",
"type": [
"string",
"null"
]
},
"commands": {
"description": "Allowed or denied commands when using this permission.",
"default": {
"allow": [],
"deny": []
},
"allOf": [
{
"$ref": "#/definitions/Commands"
}
]
},
"scope": {
"description": "Allowed or denied scoped when using this permission.",
"allOf": [
{
"$ref": "#/definitions/Scopes"
}
]
},
"platforms": {
"description": "Target platforms this permission applies. By default all platforms are affected by this permission.",
"type": [
"array",
"null"
],
"items": {
"$ref": "#/definitions/Target"
}
}
}
},
"Commands": {
"description": "Allowed and denied commands inside a permission.\n\nIf two commands clash inside of `allow` and `deny`, it should be denied by default.",
"type": "object",
"properties": {
"allow": {
"description": "Allowed command.",
"default": [],
"type": "array",
"items": {
"type": "string"
}
},
"deny": {
"description": "Denied command, which takes priority.",
"default": [],
"type": "array",
"items": {
"type": "string"
}
}
}
},
"Scopes": {
"description": "An argument for fine grained behavior control of Tauri commands.\n\nIt can be of any serde serializable type and is used to allow or prevent certain actions inside a Tauri command. The configured scope is passed to the command and will be enforced by the command implementation.\n\n## Example\n\n```json { \"allow\": [{ \"path\": \"$HOME/**\" }], \"deny\": [{ \"path\": \"$HOME/secret.txt\" }] } ```",
"type": "object",
"properties": {
"allow": {
"description": "Data that defines what is allowed by the scope.",
"type": [
"array",
"null"
],
"items": {
"$ref": "#/definitions/Value"
}
},
"deny": {
"description": "Data that defines what is denied by the scope. This should be prioritized by validation logic.",
"type": [
"array",
"null"
],
"items": {
"$ref": "#/definitions/Value"
}
}
}
},
"Value": {
"description": "All supported ACL values.",
"anyOf": [
{
"description": "Represents a null JSON value.",
"type": "null"
},
{
"description": "Represents a [`bool`].",
"type": "boolean"
},
{
"description": "Represents a valid ACL [`Number`].",
"allOf": [
{
"$ref": "#/definitions/Number"
}
]
},
{
"description": "Represents a [`String`].",
"type": "string"
},
{
"description": "Represents a list of other [`Value`]s.",
"type": "array",
"items": {
"$ref": "#/definitions/Value"
}
},
{
"description": "Represents a map of [`String`] keys to [`Value`]s.",
"type": "object",
"additionalProperties": {
"$ref": "#/definitions/Value"
}
}
]
},
"Number": {
"description": "A valid ACL number.",
"anyOf": [
{
"description": "Represents an [`i64`].",
"type": "integer",
"format": "int64"
},
{
"description": "Represents a [`f64`].",
"type": "number",
"format": "double"
}
]
},
"Target": {
"description": "Platform target.",
"oneOf": [
{
"description": "MacOS.",
"type": "string",
"enum": [
"macOS"
]
},
{
"description": "Windows.",
"type": "string",
"enum": [
"windows"
]
},
{
"description": "Linux.",
"type": "string",
"enum": [
"linux"
]
},
{
"description": "Android.",
"type": "string",
"enum": [
"android"
]
},
{
"description": "iOS.",
"type": "string",
"enum": [
"iOS"
]
}
]
},
"PermissionKind": {
"type": "string",
"oneOf": [
{
"description": "Enables the cleanup_mlx_processes command without any pre-configured scope.",
"type": "string",
"const": "allow-cleanup-mlx-processes",
"markdownDescription": "Enables the cleanup_mlx_processes command without any pre-configured scope."
},
{
"description": "Denies the cleanup_mlx_processes command without any pre-configured scope.",
"type": "string",
"const": "deny-cleanup-mlx-processes",
"markdownDescription": "Denies the cleanup_mlx_processes command without any pre-configured scope."
},
{
"description": "Enables the find_mlx_session_by_model command without any pre-configured scope.",
"type": "string",
"const": "allow-find-mlx-session-by-model",
"markdownDescription": "Enables the find_mlx_session_by_model command without any pre-configured scope."
},
{
"description": "Denies the find_mlx_session_by_model command without any pre-configured scope.",
"type": "string",
"const": "deny-find-mlx-session-by-model",
"markdownDescription": "Denies the find_mlx_session_by_model command without any pre-configured scope."
},
{
"description": "Enables the get_mlx_all_sessions command without any pre-configured scope.",
"type": "string",
"const": "allow-get-mlx-all-sessions",
"markdownDescription": "Enables the get_mlx_all_sessions command without any pre-configured scope."
},
{
"description": "Denies the get_mlx_all_sessions command without any pre-configured scope.",
"type": "string",
"const": "deny-get-mlx-all-sessions",
"markdownDescription": "Denies the get_mlx_all_sessions command without any pre-configured scope."
},
{
"description": "Enables the get_mlx_loaded_models command without any pre-configured scope.",
"type": "string",
"const": "allow-get-mlx-loaded-models",
"markdownDescription": "Enables the get_mlx_loaded_models command without any pre-configured scope."
},
{
"description": "Denies the get_mlx_loaded_models command without any pre-configured scope.",
"type": "string",
"const": "deny-get-mlx-loaded-models",
"markdownDescription": "Denies the get_mlx_loaded_models command without any pre-configured scope."
},
{
"description": "Enables the get_mlx_random_port command without any pre-configured scope.",
"type": "string",
"const": "allow-get-mlx-random-port",
"markdownDescription": "Enables the get_mlx_random_port command without any pre-configured scope."
},
{
"description": "Denies the get_mlx_random_port command without any pre-configured scope.",
"type": "string",
"const": "deny-get-mlx-random-port",
"markdownDescription": "Denies the get_mlx_random_port command without any pre-configured scope."
},
{
"description": "Enables the is_mlx_process_running command without any pre-configured scope.",
"type": "string",
"const": "allow-is-mlx-process-running",
"markdownDescription": "Enables the is_mlx_process_running command without any pre-configured scope."
},
{
"description": "Denies the is_mlx_process_running command without any pre-configured scope.",
"type": "string",
"const": "deny-is-mlx-process-running",
"markdownDescription": "Denies the is_mlx_process_running command without any pre-configured scope."
},
{
"description": "Enables the load_mlx_model command without any pre-configured scope.",
"type": "string",
"const": "allow-load-mlx-model",
"markdownDescription": "Enables the load_mlx_model command without any pre-configured scope."
},
{
"description": "Denies the load_mlx_model command without any pre-configured scope.",
"type": "string",
"const": "deny-load-mlx-model",
"markdownDescription": "Denies the load_mlx_model command without any pre-configured scope."
},
{
"description": "Enables the unload_mlx_model command without any pre-configured scope.",
"type": "string",
"const": "allow-unload-mlx-model",
"markdownDescription": "Enables the unload_mlx_model command without any pre-configured scope."
},
{
"description": "Denies the unload_mlx_model command without any pre-configured scope.",
"type": "string",
"const": "deny-unload-mlx-model",
"markdownDescription": "Denies the unload_mlx_model command without any pre-configured scope."
},
{
"description": "Default permissions for the MLX plugin\n#### This default permission set includes:\n\n- `allow-cleanup-mlx-processes`\n- `allow-load-mlx-model`\n- `allow-unload-mlx-model`\n- `allow-is-mlx-process-running`\n- `allow-get-mlx-random-port`\n- `allow-find-mlx-session-by-model`\n- `allow-get-mlx-loaded-models`\n- `allow-get-mlx-all-sessions`",
"type": "string",
"const": "default",
"markdownDescription": "Default permissions for the MLX plugin\n#### This default permission set includes:\n\n- `allow-cleanup-mlx-processes`\n- `allow-load-mlx-model`\n- `allow-unload-mlx-model`\n- `allow-is-mlx-process-running`\n- `allow-get-mlx-random-port`\n- `allow-find-mlx-session-by-model`\n- `allow-get-mlx-loaded-models`\n- `allow-get-mlx-all-sessions`"
}
]
}
}
}

View File

@@ -0,0 +1,31 @@
import { readFileSync } from 'node:fs'
import { dirname, join } from 'node:path'
import { cwd } from 'node:process'
import typescript from '@rollup/plugin-typescript'
const pkg = JSON.parse(readFileSync(join(cwd(), 'package.json'), 'utf8'))
export default {
input: 'guest-js/index.ts',
output: [
{
file: pkg.exports.import,
format: 'esm'
},
{
file: pkg.exports.require,
format: 'cjs'
}
],
plugins: [
typescript({
declaration: true,
declarationDir: dirname(pkg.exports.import)
})
],
external: [
/^@tauri-apps\/api/,
...Object.keys(pkg.dependencies || {}),
...Object.keys(pkg.peerDependencies || {})
]
}

View File

@@ -0,0 +1,59 @@
use tauri::{Manager, Runtime};
pub async fn cleanup_processes<R: Runtime>(app_handle: &tauri::AppHandle<R>) {
let app_state = match app_handle.try_state::<crate::state::MlxState>() {
Some(state) => state,
None => {
log::warn!("MlxState not found in app_handle");
return;
}
};
let mut map = app_state.mlx_server_process.lock().await;
let pids: Vec<i32> = map.keys().cloned().collect();
for pid in pids {
if let Some(session) = map.remove(&pid) {
let mut child = session.child;
#[cfg(unix)]
{
use nix::sys::signal::{kill, Signal};
use nix::unistd::Pid;
use tokio::time::{timeout, Duration};
if let Some(raw_pid) = child.id() {
let raw_pid = raw_pid as i32;
log::info!("Sending SIGTERM to MLX PID {} during shutdown", raw_pid);
let _ = kill(Pid::from_raw(raw_pid), Signal::SIGTERM);
match timeout(Duration::from_secs(2), child.wait()).await {
Ok(Ok(status)) => {
log::info!("MLX process {} exited gracefully: {}", raw_pid, status)
}
Ok(Err(e)) => {
log::error!(
"Error waiting after SIGTERM for MLX process {}: {}",
raw_pid,
e
)
}
Err(_) => {
log::warn!(
"SIGTERM timed out for MLX PID {}; sending SIGKILL",
raw_pid
);
let _ = kill(Pid::from_raw(raw_pid), Signal::SIGKILL);
let _ = child.wait().await;
}
}
}
}
}
}
}
#[tauri::command]
pub async fn cleanup_mlx_processes<R: Runtime>(
app_handle: tauri::AppHandle<R>,
) -> Result<(), String> {
cleanup_processes(&app_handle).await;
Ok(())
}

View File

@@ -0,0 +1,360 @@
use std::collections::HashMap;
use std::path::PathBuf;
use std::process::Stdio;
use std::time::Duration;
use tauri::{Manager, Runtime, State};
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::Command;
use tokio::sync::mpsc;
use tokio::time::Instant;
use crate::error::{ErrorCode, MlxError, ServerError, ServerResult};
use crate::process::{
find_session_by_model_id, get_all_active_sessions, get_all_loaded_model_ids,
get_random_available_port, is_process_running_by_pid,
};
use crate::state::{MlxBackendSession, MlxState, SessionInfo};
#[cfg(unix)]
use crate::process::graceful_terminate_process;
#[derive(serde::Serialize, serde::Deserialize)]
pub struct UnloadResult {
success: bool,
error: Option<String>,
}
/// MLX server configuration passed from the frontend
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct MlxConfig {
#[serde(default)]
pub ctx_size: i32,
#[serde(default)]
pub n_predict: i32,
#[serde(default)]
pub threads: i32,
#[serde(default)]
pub chat_template: String,
}
/// Load a model using the MLX server binary
#[tauri::command]
pub async fn load_mlx_model<R: Runtime>(
app_handle: tauri::AppHandle<R>,
binary_path: String,
model_id: String,
model_path: String,
port: u16,
config: MlxConfig,
envs: HashMap<String, String>,
is_embedding: bool,
timeout: u64,
) -> ServerResult<SessionInfo> {
let state: State<MlxState> = app_handle.state();
let mut process_map = state.mlx_server_process.lock().await;
log::info!("Attempting to launch MLX server at path: {:?}", binary_path);
log::info!("Using MLX configuration: {:?}", config);
// Validate binary path
let bin_path = PathBuf::from(&binary_path);
if !bin_path.exists() {
return Err(MlxError::new(
ErrorCode::BinaryNotFound,
format!("MLX server binary not found at: {}", binary_path),
None,
)
.into());
}
// Validate model path
let model_path_pb = PathBuf::from(&model_path);
if !model_path_pb.exists() {
return Err(MlxError::new(
ErrorCode::ModelFileNotFound,
format!("Model file not found at: {}", model_path),
None,
)
.into());
}
let api_key: String = envs
.get("MLX_API_KEY")
.map(|s| s.to_string())
.unwrap_or_else(|| {
log::warn!("API key not provided for MLX server");
String::new()
});
// Build command arguments
let mut args: Vec<String> = vec![
"--model".to_string(),
model_path.clone(),
"--port".to_string(),
port.to_string(),
];
if config.ctx_size > 0 {
args.push("--ctx-size".to_string());
args.push(config.ctx_size.to_string());
}
if !api_key.is_empty() {
args.push("--api-key".to_string());
args.push(api_key.clone());
}
if !config.chat_template.is_empty() {
args.push("--chat-template".to_string());
args.push(config.chat_template.clone());
}
if is_embedding {
args.push("--embedding".to_string());
}
log::info!("MLX server arguments: {:?}", args);
// Configure the command
let mut command = Command::new(&bin_path);
command.args(&args);
command.envs(envs);
command.stdout(Stdio::piped());
command.stderr(Stdio::piped());
// Spawn the child process
let mut child = command.spawn().map_err(ServerError::Io)?;
let stderr = child.stderr.take().expect("stderr was piped");
let stdout = child.stdout.take().expect("stdout was piped");
// Create channels for communication between tasks
let (ready_tx, mut ready_rx) = mpsc::channel::<bool>(1);
// Spawn task to monitor stdout for readiness
let stdout_ready_tx = ready_tx.clone();
let _stdout_task = tokio::spawn(async move {
let mut reader = BufReader::new(stdout);
let mut byte_buffer = Vec::new();
loop {
byte_buffer.clear();
match reader.read_until(b'\n', &mut byte_buffer).await {
Ok(0) => break,
Ok(_) => {
let line = String::from_utf8_lossy(&byte_buffer);
let line = line.trim_end();
if !line.is_empty() {
log::info!("[mlx stdout] {}", line);
}
let line_lower = line.to_lowercase();
if line_lower.contains("http server listening")
|| line_lower.contains("server is listening")
|| line_lower.contains("server started")
|| line_lower.contains("ready to accept")
|| line_lower.contains("server started and listening on")
{
log::info!(
"MLX server appears to be ready based on stdout: '{}'",
line
);
let _ = stdout_ready_tx.send(true).await;
}
}
Err(e) => {
log::error!("Error reading MLX stdout: {}", e);
break;
}
}
}
});
// Spawn task to capture stderr and monitor for errors
let stderr_task = tokio::spawn(async move {
let mut reader = BufReader::new(stderr);
let mut byte_buffer = Vec::new();
let mut stderr_buffer = String::new();
loop {
byte_buffer.clear();
match reader.read_until(b'\n', &mut byte_buffer).await {
Ok(0) => break,
Ok(_) => {
let line = String::from_utf8_lossy(&byte_buffer);
let line = line.trim_end();
if !line.is_empty() {
stderr_buffer.push_str(line);
stderr_buffer.push('\n');
log::info!("[mlx] {}", line);
let line_lower = line.to_lowercase();
if line_lower.contains("server is listening")
|| line_lower.contains("server listening on")
|| line_lower.contains("server started and listening on")
{
log::info!(
"MLX model appears to be ready based on logs: '{}'",
line
);
let _ = ready_tx.send(true).await;
}
}
}
Err(e) => {
log::error!("Error reading MLX logs: {}", e);
break;
}
}
}
stderr_buffer
});
// Check if process exited early
if let Some(status) = child.try_wait()? {
if !status.success() {
let stderr_output = stderr_task.await.unwrap_or_default();
log::error!("MLX server failed early with code {:?}", status);
log::error!("{}", stderr_output);
return Err(MlxError::from_stderr(&stderr_output).into());
}
}
// Wait for server to be ready or timeout
let timeout_duration = Duration::from_secs(timeout);
let start_time = Instant::now();
log::info!("Waiting for MLX model session to be ready...");
loop {
tokio::select! {
Some(true) = ready_rx.recv() => {
log::info!("MLX model is ready to accept requests!");
break;
}
_ = tokio::time::sleep(Duration::from_millis(50)) => {
if let Some(status) = child.try_wait()? {
let stderr_output = stderr_task.await.unwrap_or_default();
if !status.success() {
log::error!("MLX server exited with error code {:?}", status);
return Err(MlxError::from_stderr(&stderr_output).into());
} else {
log::error!("MLX server exited successfully but without ready signal");
return Err(MlxError::from_stderr(&stderr_output).into());
}
}
if start_time.elapsed() > timeout_duration {
log::error!("Timeout waiting for MLX server to be ready");
let _ = child.kill().await;
let stderr_output = stderr_task.await.unwrap_or_default();
return Err(MlxError::new(
ErrorCode::ModelLoadTimedOut,
"The MLX model took too long to load and timed out.".into(),
Some(format!(
"Timeout: {}s\n\nStderr:\n{}",
timeout_duration.as_secs(),
stderr_output
)),
)
.into());
}
}
}
}
let pid = child.id().map(|id| id as i32).unwrap_or(-1);
log::info!("MLX server process started with PID: {} and is ready", pid);
let session_info = SessionInfo {
pid,
port: port.into(),
model_id,
model_path: model_path_pb.display().to_string(),
is_embedding,
api_key,
};
process_map.insert(
pid,
MlxBackendSession {
child,
info: session_info.clone(),
},
);
Ok(session_info)
}
/// Unload an MLX model by terminating its process
#[tauri::command]
pub async fn unload_mlx_model<R: Runtime>(
app_handle: tauri::AppHandle<R>,
pid: i32,
) -> ServerResult<UnloadResult> {
let state: State<MlxState> = app_handle.state();
let mut map = state.mlx_server_process.lock().await;
if let Some(session) = map.remove(&pid) {
let mut child = session.child;
#[cfg(unix)]
{
graceful_terminate_process(&mut child).await;
}
Ok(UnloadResult {
success: true,
error: None,
})
} else {
log::warn!("No MLX server with PID '{}' found", pid);
Ok(UnloadResult {
success: true,
error: None,
})
}
}
/// Check if a process is still running
#[tauri::command]
pub async fn is_mlx_process_running<R: Runtime>(
app_handle: tauri::AppHandle<R>,
pid: i32,
) -> Result<bool, String> {
is_process_running_by_pid(app_handle, pid).await
}
/// Get a random available port
#[tauri::command]
pub async fn get_mlx_random_port<R: Runtime>(
app_handle: tauri::AppHandle<R>,
) -> Result<u16, String> {
get_random_available_port(app_handle).await
}
/// Find session information by model ID
#[tauri::command]
pub async fn find_mlx_session_by_model<R: Runtime>(
app_handle: tauri::AppHandle<R>,
model_id: String,
) -> Result<Option<SessionInfo>, String> {
find_session_by_model_id(app_handle, &model_id).await
}
/// Get all loaded model IDs
#[tauri::command]
pub async fn get_mlx_loaded_models<R: Runtime>(
app_handle: tauri::AppHandle<R>,
) -> Result<Vec<String>, String> {
get_all_loaded_model_ids(app_handle).await
}
/// Get all active sessions
#[tauri::command]
pub async fn get_mlx_all_sessions<R: Runtime>(
app_handle: tauri::AppHandle<R>,
) -> Result<Vec<SessionInfo>, String> {
get_all_active_sessions(app_handle).await
}

View File

@@ -0,0 +1,91 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum ErrorCode {
BinaryNotFound,
ModelFileNotFound,
ModelLoadFailed,
ModelLoadTimedOut,
OutOfMemory,
MlxProcessError,
IoError,
InternalError,
}
#[derive(Debug, Clone, Serialize, thiserror::Error)]
#[error("MlxError {{ code: {code:?}, message: \"{message}\" }}")]
pub struct MlxError {
pub code: ErrorCode,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<String>,
}
impl MlxError {
pub fn new(code: ErrorCode, message: String, details: Option<String>) -> Self {
Self {
code,
message,
details,
}
}
/// Parses stderr from the MLX server and creates a specific MlxError.
pub fn from_stderr(stderr: &str) -> Self {
let lower_stderr = stderr.to_lowercase();
if lower_stderr.contains("out of memory")
|| lower_stderr.contains("failed to allocate")
|| lower_stderr.contains("insufficient memory")
{
return Self::new(
ErrorCode::OutOfMemory,
"Out of memory. The model requires more RAM than available.".into(),
Some(stderr.into()),
);
}
Self::new(
ErrorCode::MlxProcessError,
"The MLX model process encountered an unexpected error.".into(),
Some(stderr.into()),
)
}
}
#[derive(Debug, thiserror::Error)]
pub enum ServerError {
#[error(transparent)]
Mlx(#[from] MlxError),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Tauri error: {0}")]
Tauri(#[from] tauri::Error),
}
impl serde::Serialize for ServerError {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let error_to_serialize: MlxError = match self {
ServerError::Mlx(err) => err.clone(),
ServerError::Io(e) => MlxError::new(
ErrorCode::IoError,
"An input/output error occurred.".into(),
Some(e.to_string()),
),
ServerError::Tauri(e) => MlxError::new(
ErrorCode::InternalError,
"An internal application error occurred.".into(),
Some(e.to_string()),
),
};
error_to_serialize.serialize(serializer)
}
}
pub type ServerResult<T> = Result<T, ServerError>;

View File

@@ -0,0 +1,32 @@
use tauri::{
plugin::{Builder, TauriPlugin},
Manager, Runtime,
};
pub mod cleanup;
mod commands;
mod error;
mod process;
pub mod state;
pub use cleanup::cleanup_mlx_processes;
/// Initializes the MLX plugin.
pub fn init<R: Runtime>() -> TauriPlugin<R> {
Builder::new("mlx")
.invoke_handler(tauri::generate_handler![
cleanup::cleanup_mlx_processes,
commands::load_mlx_model,
commands::unload_mlx_model,
commands::is_mlx_process_running,
commands::get_mlx_random_port,
commands::find_mlx_session_by_model,
commands::get_mlx_loaded_models,
commands::get_mlx_all_sessions,
])
.setup(|app, _api| {
app.manage(state::MlxState::new());
Ok(())
})
.build()
}

View File

@@ -0,0 +1,120 @@
use std::collections::HashSet;
use sysinfo::{Pid, System};
use tauri::{Manager, Runtime, State};
use crate::state::{MlxState, SessionInfo};
use jan_utils::generate_random_port;
/// Check if a process is running by PID
pub async fn is_process_running_by_pid<R: Runtime>(
app_handle: tauri::AppHandle<R>,
pid: i32,
) -> Result<bool, String> {
let mut system = System::new();
system.refresh_processes(sysinfo::ProcessesToUpdate::All, true);
let process_pid = Pid::from(pid as usize);
let alive = system.process(process_pid).is_some();
if !alive {
let state: State<MlxState> = app_handle.state();
let mut map = state.mlx_server_process.lock().await;
map.remove(&pid);
}
Ok(alive)
}
/// Get a random available port, avoiding ports used by existing sessions
pub async fn get_random_available_port<R: Runtime>(
app_handle: tauri::AppHandle<R>,
) -> Result<u16, String> {
let state: State<MlxState> = app_handle.state();
let map = state.mlx_server_process.lock().await;
let used_ports: HashSet<u16> = map
.values()
.filter_map(|session| {
if session.info.port > 0 && session.info.port <= u16::MAX as i32 {
Some(session.info.port as u16)
} else {
None
}
})
.collect();
drop(map);
generate_random_port(&used_ports)
}
/// Gracefully terminate a process on Unix systems (macOS)
#[cfg(unix)]
pub async fn graceful_terminate_process(child: &mut tokio::process::Child) {
use nix::sys::signal::{kill, Signal};
use nix::unistd::Pid;
use std::time::Duration;
use tokio::time::timeout;
if let Some(raw_pid) = child.id() {
let raw_pid = raw_pid as i32;
log::info!("Sending SIGTERM to MLX process PID {}", raw_pid);
let _ = kill(Pid::from_raw(raw_pid), Signal::SIGTERM);
match timeout(Duration::from_secs(5), child.wait()).await {
Ok(Ok(status)) => log::info!("MLX process exited gracefully: {}", status),
Ok(Err(e)) => log::error!("Error waiting after SIGTERM for MLX process: {}", e),
Err(_) => {
log::warn!(
"SIGTERM timed out for MLX PID {}; sending SIGKILL",
raw_pid
);
let _ = kill(Pid::from_raw(raw_pid), Signal::SIGKILL);
match child.wait().await {
Ok(s) => log::info!("Force-killed MLX process exited: {}", s),
Err(e) => log::error!("Error waiting after SIGKILL for MLX process: {}", e),
}
}
}
}
}
/// Find a session by model ID
pub async fn find_session_by_model_id<R: Runtime>(
app_handle: tauri::AppHandle<R>,
model_id: &str,
) -> Result<Option<SessionInfo>, String> {
let state: State<MlxState> = app_handle.state();
let map = state.mlx_server_process.lock().await;
let session_info = map
.values()
.find(|backend_session| backend_session.info.model_id == model_id)
.map(|backend_session| backend_session.info.clone());
Ok(session_info)
}
/// Get all loaded model IDs
pub async fn get_all_loaded_model_ids<R: Runtime>(
app_handle: tauri::AppHandle<R>,
) -> Result<Vec<String>, String> {
let state: State<MlxState> = app_handle.state();
let map = state.mlx_server_process.lock().await;
let model_ids = map
.values()
.map(|backend_session| backend_session.info.model_id.clone())
.collect();
Ok(model_ids)
}
/// Get all active sessions
pub async fn get_all_active_sessions<R: Runtime>(
app_handle: tauri::AppHandle<R>,
) -> Result<Vec<SessionInfo>, String> {
let state: State<MlxState> = app_handle.state();
let map = state.mlx_server_process.lock().await;
let sessions: Vec<SessionInfo> = map.values().map(|s| s.info.clone()).collect();
Ok(sessions)
}

View File

@@ -0,0 +1,39 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::process::Child;
use tokio::sync::Mutex;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionInfo {
pub pid: i32,
pub port: i32,
pub model_id: String,
pub model_path: String,
pub is_embedding: bool,
pub api_key: String,
}
pub struct MlxBackendSession {
pub child: Child,
pub info: SessionInfo,
}
/// MLX plugin state
pub struct MlxState {
pub mlx_server_process: Arc<Mutex<HashMap<i32, MlxBackendSession>>>,
}
impl Default for MlxState {
fn default() -> Self {
Self {
mlx_server_process: Arc::new(Mutex::new(HashMap::new())),
}
}
}
impl MlxState {
pub fn new() -> Self {
Self::default()
}
}

View File

@@ -0,0 +1,14 @@
{
"compilerOptions": {
"target": "es2021",
"module": "esnext",
"moduleResolution": "bundler",
"skipLibCheck": true,
"strict": true,
"noUnusedLocals": true,
"noImplicitAny": true,
"noEmit": true
},
"include": ["guest-js/*.ts"],
"exclude": ["dist-js", "node_modules"]
}

View File

@@ -52,6 +52,18 @@ __metadata:
languageName: unknown
linkType: soft
"@janhq/tauri-plugin-mlx-api@workspace:tauri-plugin-mlx":
version: 0.0.0-use.local
resolution: "@janhq/tauri-plugin-mlx-api@workspace:tauri-plugin-mlx"
dependencies:
"@rollup/plugin-typescript": "npm:^12.0.0"
"@tauri-apps/api": "npm:>=2.0.0-beta.6"
rollup: "npm:^4.9.6"
tslib: "npm:^2.6.2"
typescript: "npm:^5.3.3"
languageName: unknown
linkType: soft
"@janhq/tauri-plugin-rag-api@workspace:tauri-plugin-rag":
version: 0.0.0-use.local
resolution: "@janhq/tauri-plugin-rag-api@workspace:tauri-plugin-rag"

View File

@@ -41,6 +41,11 @@ pub fn run() {
app_builder = app_builder.plugin(tauri_plugin_deep_link::init());
}
#[cfg(feature = "mlx")]
{
app_builder = app_builder.plugin(tauri_plugin_mlx::init());
}
#[cfg(not(any(target_os = "android", target_os = "ios")))]
{
app_builder = app_builder.plugin(tauri_plugin_hardware::init());
@@ -328,6 +333,17 @@ pub fn run() {
} else {
log::info!("Llama processes cleaned up successfully");
}
#[cfg(feature = "mlx")]
{
use tauri_plugin_mlx::cleanup_mlx_processes;
if let Err(e) = cleanup_mlx_processes(app_handle.clone()).await {
log::warn!("Failed to cleanup MLX processes: {}", e);
} else {
log::info!("MLX processes cleaned up successfully");
}
}
log::info!("App cleanup completed");
});
});

Binary file not shown.

After

Width:  |  Height:  |  Size: 63 KiB

View File

@@ -41,7 +41,7 @@ export const DialogDeleteModel = ({
deleteModelCache(selectedModelId)
serviceHub
.models()
.deleteModel(selectedModelId)
.deleteModel(selectedModelId, provider.provider)
.then(() => {
serviceHub
.providers()

View File

@@ -0,0 +1,264 @@
import {
Dialog,
DialogContent,
DialogDescription,
DialogHeader,
DialogTitle,
DialogTrigger,
} from '@/components/ui/dialog'
import { Button } from '@/components/ui/button'
import { useServiceHub } from '@/hooks/useServiceHub'
import { useState } from 'react'
import { toast } from 'sonner'
import {
IconLoader2,
IconCheck,
} from '@tabler/icons-react'
import { ExtensionManager } from '@/lib/extension'
type ImportMlxModelDialogProps = {
provider: ModelProvider
trigger?: React.ReactNode
onSuccess?: (importedModelName?: string) => void
}
export const ImportMlxModelDialog = ({
provider,
trigger,
onSuccess,
}: ImportMlxModelDialogProps) => {
const serviceHub = useServiceHub()
const [open, setOpen] = useState(false)
const [importing, setImporting] = useState(false)
const [selectedPath, setSelectedPath] = useState<string | null>(null)
const [modelName, setModelName] = useState('')
const handleFileSelect = async () => {
const result = await serviceHub.dialog().open({
multiple: false,
directory: false,
filters: [
{
name: 'Safetensor Files',
extensions: ['safetensors'],
},
{
name: 'All Files',
extensions: ['*'],
},
],
})
if (result && typeof result === 'string') {
setSelectedPath(result)
// Extract model name from path
const pathParts = result.split(/[\\/]/)
const nameFromPath = pathParts[pathParts.length - 1] || 'mlx-model'
const sanitizedName = nameFromPath
.replace(/\s/g, '-')
.replace(/[^a-zA-Z0-9/_.\-]/g, '')
setModelName(sanitizedName)
}
}
const handleImport = async () => {
if (!selectedPath) {
toast.error('Please select a safetensor file or folder')
return
}
if (!modelName) {
toast.error('Please enter a model name')
return
}
// Validate model name - only allow alphanumeric, underscore, hyphen, and dot
if (!/^[a-zA-Z0-9/_.\-]+$/.test(modelName)) {
toast.error('Invalid model name. Only alphanumeric and _ - . characters are allowed.')
return
}
// Check if model already exists
const modelExists = provider.models.some(
(model) => model.id === modelName
)
if (modelExists) {
toast.error('Model already exists', {
description: `${modelName} already imported`,
})
return
}
setImporting(true)
try {
console.log('[MLX Import] Starting import:', { modelName, selectedPath })
// Get the MLX engine and call its import method
const engine = ExtensionManager.getInstance().getEngine('mlx')
if (!engine) {
throw new Error('MLX engine not found')
}
console.log('[MLX Import] Calling engine.import()...')
await engine.import(modelName, {
modelPath: selectedPath,
})
console.log('[MLX Import] Import completed')
toast.success('Model imported successfully', {
description: `${modelName} has been imported`,
})
// Reset form and close dialog
setSelectedPath(null)
setModelName('')
setOpen(false)
onSuccess?.(modelName)
} catch (error) {
console.error('[MLX Import] Import model error:', error)
toast.error('Failed to import model', {
description:
error instanceof Error ? error.message : String(error),
})
} finally {
setImporting(false)
}
}
const resetForm = () => {
setSelectedPath(null)
setModelName('')
}
const handleOpenChange = (newOpen: boolean) => {
if (!importing) {
setOpen(newOpen)
if (!newOpen) {
resetForm()
}
}
}
const displayPath = selectedPath
? selectedPath.split(/[\\/]/).pop() || selectedPath
: null
return (
<Dialog open={open} onOpenChange={handleOpenChange}>
<DialogTrigger asChild>{trigger}</DialogTrigger>
<DialogContent
onInteractOutside={(e) => {
e.preventDefault()
}}
>
<DialogHeader>
<DialogTitle className="flex items-center gap-2">
Import MLX Model
</DialogTitle>
<DialogDescription>
Import a safetensor model file or folder for use with MLX. MLX models
are typically downloaded from HuggingFace and use the safetensors format.
</DialogDescription>
</DialogHeader>
<div className="space-y-6">
{/* Model Name Input */}
<div className="space-y-2">
<label className="text-sm font-medium">
Model Name
</label>
<input
type="text"
value={modelName}
onChange={(e) => setModelName(e.target.value)}
placeholder="my-mlx-model"
className="w-full px-3 py-2 bg-background border rounded-lg focus:outline-none focus:ring-2 focus:ring-accent/50"
/>
<p className="text-xs text-muted-foreground">
Only alphanumeric and _ - . characters are allowed
</p>
</div>
{/* File Selection Area */}
<div className="border rounded-lg p-4 space-y-3">
<div className="flex items-center gap-2">
<h3 className="font-medium">
Safetensor File or Folder
</h3>
<span className="text-xs bg-secondary px-2 py-1 rounded-sm">
Required
</span>
</div>
{displayPath ? (
<div className="bg-accent/10 border rounded-lg p-3">
<div className="flex items-center justify-between">
<div className="flex items-center gap-2">
<IconCheck size={16} className="text-accent" />
<span className="text-sm font-medium">
{displayPath}
</span>
</div>
<Button
variant="secondary"
size="sm"
onClick={handleFileSelect}
disabled={importing}
>
Change
</Button>
</div>
</div>
) : (
<Button
type="button"
variant="link"
onClick={handleFileSelect}
disabled={importing}
className="w-full h-12 border border-dashed text-muted-foreground"
>
Select Safetensor File or Folder
</Button>
)}
</div>
{/* Preview */}
{modelName && (
<div className="rounded-lg p-3">
<div className="flex items-center gap-2">
<span className="text-sm font-medium text-muted-foreground">
Model will be saved as:
</span>
</div>
<p className="text-sm font-mono mt-1">
mlx/models/{modelName}/
</p>
</div>
)}
</div>
<div className="flex gap-2 pt-4 justify-end">
<Button
variant="ghost"
size="sm"
onClick={() => handleOpenChange(false)}
disabled={importing}
>
Cancel
</Button>
<Button
onClick={handleImport}
size="sm"
disabled={importing || !selectedPath || !modelName}
>
{importing && <IconLoader2 className="mr-2 size-4 animate-spin" />}
{importing ? 'Importing...' : 'Import Model'}
</Button>
</div>
</DialogContent>
</Dialog>
)
}

View File

@@ -55,10 +55,7 @@ export class CustomChatTransport implements ChatTransport<UIMessage> {
private serviceHub: ServiceHub | null
private threadId?: string
constructor(
systemMessage?: string,
threadId?: string
) {
constructor(systemMessage?: string, threadId?: string) {
this.systemMessage = systemMessage
this.threadId = threadId
this.serviceHub = useServiceStore.getState().serviceHub
@@ -248,17 +245,17 @@ export class CustomChatTransport implements ChatTransport<UIMessage> {
return result.toUIMessageStream({
messageMetadata: ({ part }) => {
if (!streamStartTime) {
// Track stream start time on start
if (part.type === 'start' && !streamStartTime) {
streamStartTime = Date.now()
}
// Track stream start time on first text delta
if (part.type === 'text-delta') {
// Count text deltas as a rough token approximation
// Each delta typically represents one token in streaming
textDeltaCount++
// Report streaming token speed in real-time
if (this.onStreamingTokenSpeed) {
if (this.onStreamingTokenSpeed && streamStartTime) {
const elapsedMs = Date.now() - streamStartTime
this.onStreamingTokenSpeed(textDeltaCount, elapsedMs)
}
@@ -279,22 +276,18 @@ export class CustomChatTransport implements ChatTransport<UIMessage> {
}
}
const usage = finishPart.totalUsage
const llamacppMeta = finishPart.providerMetadata?.llamacpp
const durationMs = streamStartTime ? Date.now() - streamStartTime : 0
const durationSec = durationMs / 1000
// Use provider's outputTokens, or llama.cpp completionTokens, or fall back to text delta count
const outputTokens =
usage?.outputTokens ??
llamacppMeta?.completionTokens ??
textDeltaCount
const inputTokens = usage?.inputTokens ?? llamacppMeta?.promptTokens
const inputTokens = usage?.inputTokens
// Use llama.cpp's tokens per second if available, otherwise calculate from duration
let tokenSpeed: number
if (llamacppMeta?.tokensPerSecond != null) {
tokenSpeed = llamacppMeta.tokensPerSecond
} else if (durationSec > 0 && outputTokens > 0) {
if (durationSec > 0 && outputTokens > 0) {
tokenSpeed = outputTokens / durationSec
} else {
tokenSpeed = 0

View File

@@ -7,6 +7,7 @@
*
* Supported Providers:
* - llamacpp: Local models via llama.cpp (requires running session)
* - mlx: Local models via MLX-Swift on Apple Silicon (requires running session)
* - anthropic: Claude models via Anthropic API (@ai-sdk/anthropic v2.0)
* - google/gemini: Gemini models via Google Generative AI API (@ai-sdk/google v2.0)
* - openai: OpenAI models via OpenAI API (@ai-sdk/openai)
@@ -30,6 +31,7 @@ import { createOpenAICompatible } from '@ai-sdk/openai-compatible'
import { createAnthropic } from '@ai-sdk/anthropic'
import { invoke } from '@tauri-apps/api/core'
import { SessionInfo } from '@janhq/core'
import { fetch } from '@tauri-apps/plugin-http'
/**
* Llama.cpp timings structure from the response
@@ -109,6 +111,9 @@ export class ModelFactory {
case 'llamacpp':
return this.createLlamaCppModel(modelId, provider)
case 'mlx':
return this.createMlxModel(modelId, provider)
case 'anthropic':
return this.createAnthropicModel(modelId, provider)
@@ -173,6 +178,59 @@ export class ModelFactory {
Origin: 'tauri://localhost',
},
includeUsage: true,
fetch: fetch,
})
return openAICompatible.languageModel(modelId, {
metadataExtractor: llamaCppMetadataExtractor,
})
}
/**
* Create an MLX model by starting the model and finding the running session.
* MLX uses the same OpenAI-compatible API pattern as llamacpp.
*/
private static async createMlxModel(
modelId: string,
provider?: ProviderObject
): Promise<LanguageModel> {
// Start the model first if provider is available
if (provider) {
try {
const { useServiceStore } = await import('@/hooks/useServiceHub')
const serviceHub = useServiceStore.getState().serviceHub
if (serviceHub) {
await serviceHub.models().startModel(provider, modelId)
}
} catch (error) {
console.error('Failed to start MLX model:', error)
throw new Error(
`Failed to start model: ${error instanceof Error ? error.message : JSON.stringify(error)}`
)
}
}
// Get session info which includes port and api_key
const sessionInfo = await invoke<SessionInfo | null>(
'plugin:mlx|find_mlx_session_by_model',
{ modelId }
)
if (!sessionInfo) {
throw new Error(`No running MLX session found for model: ${modelId}`)
}
// Create OpenAI-compatible client for MLX server
const openAICompatible = createOpenAICompatible({
name: 'mlx',
baseURL: `http://localhost:${sessionInfo.port}/v1`,
headers: {
Authorization: `Bearer ${sessionInfo.api_key}`,
Origin: 'tauri://localhost',
},
includeUsage: true,
fetch: fetch,
})
return openAICompatible.languageModel(modelId, {

View File

@@ -68,6 +68,8 @@ export function getProviderLogo(provider: string) {
return '/images/model-provider/jan.png'
case 'llamacpp':
return '/images/model-provider/llamacpp.svg'
case 'mlx':
return '/images/model-provider/mlx.png'
case 'anthropic':
return '/images/model-provider/anthropic.svg'
case 'huggingface':
@@ -97,6 +99,8 @@ export const getProviderTitle = (provider: string) => {
return 'Jan'
case 'llamacpp':
return 'Llama.cpp'
case 'mlx':
return 'MLX'
case 'openai':
return 'OpenAI'
case 'openrouter':

View File

@@ -4,17 +4,14 @@ import HeaderPage from '@/containers/HeaderPage'
import SettingsMenu from '@/containers/SettingsMenu'
import { useModelProvider } from '@/hooks/useModelProvider'
import { cn, getProviderTitle, getModelDisplayName } from '@/lib/utils'
import {
createFileRoute,
Link,
useParams,
} from '@tanstack/react-router'
import { createFileRoute, Link, useParams } from '@tanstack/react-router'
import { useTranslation } from '@/i18n/react-i18next-compat'
import Capabilities from '@/containers/Capabilities'
import { DynamicControllerSetting } from '@/containers/dynamicControllerSetting'
import { RenderMarkdown } from '@/containers/RenderMarkdown'
import { DialogEditModel } from '@/containers/dialogs/EditModel'
import { ImportVisionModelDialog } from '@/containers/dialogs/ImportVisionModelDialog'
import { ImportMlxModelDialog } from '@/containers/dialogs/ImportMlxModelDialog'
import { ModelSetting } from '@/containers/ModelSetting'
import { DialogDeleteModel } from '@/containers/dialogs/DeleteModel'
import { FavoriteModelAction } from '@/containers/FavoriteModelAction'
@@ -68,9 +65,9 @@ function ProviderDetail() {
const { getProviderByName, setProviders, updateProvider } = useModelProvider()
const provider = getProviderByName(providerName)
// Check if llamacpp provider needs backend configuration
// Check if llamacpp/mlx provider needs backend configuration
const needsBackendConfig =
provider?.provider === 'llamacpp' &&
(provider?.provider === 'llamacpp' || provider?.provider === 'mlx') &&
provider.settings?.some(
(setting) =>
setting.key === 'version_backend' &&
@@ -144,12 +141,14 @@ function ProviderDetail() {
}
useEffect(() => {
// Initial data fetch
serviceHub
.models()
.getActiveModels()
.then((models) => setActiveModels(models || []))
}, [serviceHub, setActiveModels])
// Initial data fetch - load active models for the current provider
if (provider?.provider) {
serviceHub
.models()
.getActiveModels(provider.provider)
.then((models) => setActiveModels(models || []))
}
}, [serviceHub, setActiveModels, provider?.provider])
// Clear importing state when model appears in the provider's model list
useEffect(() => {
@@ -270,10 +269,10 @@ function ProviderDetail() {
// Start the model with plan result
await serviceHub.models().startModel(provider, modelId)
// Refresh active models after starting
// Refresh active models after starting (pass provider to get correct engine's loaded models)
serviceHub
.models()
.getActiveModels()
.getActiveModels(provider.provider)
.then((models) => setActiveModels(models || []))
} catch (error) {
setModelLoadError(error as ErrorObject)
@@ -288,12 +287,12 @@ function ProviderDetail() {
// Original: stopModel(modelId).then(() => { setActiveModels((prevModels) => prevModels.filter((model) => model !== modelId)) })
serviceHub
.models()
.stopModel(modelId)
.stopModel(modelId, provider?.provider)
.then(() => {
// Refresh active models after stopping
// Refresh active models after stopping (pass provider to get correct engine's loaded models)
serviceHub
.models()
.getActiveModels()
.getActiveModels(provider?.provider)
.then((models) => setActiveModels(models || []))
})
.catch((error) => {
@@ -302,7 +301,8 @@ function ProviderDetail() {
}
const handleCheckForBackendUpdate = useCallback(async () => {
if (provider?.provider !== 'llamacpp') return
if (provider?.provider !== 'llamacpp' && provider?.provider !== 'mlx')
return
setIsCheckingBackendUpdate(true)
try {
@@ -320,7 +320,8 @@ function ProviderDetail() {
}, [provider, checkForBackendUpdate, t])
const handleInstallBackendFromFile = useCallback(async () => {
if (provider?.provider !== 'llamacpp') return
if (provider?.provider !== 'llamacpp' && provider?.provider !== 'mlx')
return
setIsInstallingBackend(true)
try {
@@ -345,8 +346,12 @@ function ProviderDetail() {
// Extract filename from the selected file path and replace spaces with dashes
const fileName = basenameNoExt(selectedFile).replace(/\s+/g, '-')
// Capitalize provider name for display
const providerDisplayName =
provider?.provider === 'llamacpp' ? 'Llamacpp' : 'MLX'
toast.success(t('settings:backendInstallSuccess'), {
description: `Llamacpp ${fileName} installed`,
description: `${providerDisplayName} ${fileName} installed`,
})
// Refresh settings to update backend configuration
@@ -367,7 +372,9 @@ function ProviderDetail() {
<div className="flex flex-col h-svh w-full">
<HeaderPage>
<div className="flex items-center gap-2 w-full">
<span className='font-medium text-base font-studio'>{t('common:settings')}</span>
<span className="font-medium text-base font-studio">
{t('common:settings')}
</span>
</div>
</HeaderPage>
<div className="flex h-[calc(100%-60px)]">
@@ -384,8 +391,9 @@ function ProviderDetail() {
className={cn(
'flex flex-col gap-3',
provider &&
provider.provider === 'llamacpp' &&
'flex-col-reverse'
(provider.provider === 'llamacpp' ||
provider.provider === 'mlx') &&
'flex-col-reverse'
)}
>
{/* Settings */}
@@ -395,7 +403,7 @@ function ProviderDetail() {
const actionComponent = (
<div className="mt-2">
{needsBackendConfig &&
setting.key === 'version_backend' ? (
setting.key === 'version_backend' ? (
<div className="flex items-center gap-1 text-sm">
<IconLoader size={16} className="animate-spin" />
<span>loading</span>
@@ -404,21 +412,18 @@ function ProviderDetail() {
<DynamicControllerSetting
controllerType={setting.controller_type}
controllerProps={setting.controller_props}
className={cn(
setting.key === 'device' && 'hidden'
)}
className={cn(setting.key === 'device' && 'hidden')}
onChange={(newValue) => {
if (provider) {
const newSettings = [...provider.settings]
// Handle different value types by forcing the type
// Use type assertion to bypass type checking
// Handle different value types by forcing the type
// Use type assertion to bypass type checking
; (
newSettings[settingIndex]
.controller_props as {
value: string | boolean | number
}
).value = newValue
;(
newSettings[settingIndex].controller_props as {
value: string | boolean | number
}
).value = newValue
// Create update object with updated settings
const updateObj: Partial<ModelProvider> = {
@@ -446,11 +451,11 @@ function ProviderDetail() {
)
if (deviceSettingIndex !== -1) {
(
;(
newSettings[deviceSettingIndex]
.controller_props as {
value: string
}
value: string
}
).value = ''
}
@@ -480,9 +485,7 @@ function ProviderDetail() {
serviceHub
.models()
.getActiveModels()
.then((models) =>
setActiveModels(models || [])
)
.then((models) => setActiveModels(models || []))
}
}}
/>
@@ -497,7 +500,7 @@ function ProviderDetail() {
className={cn(setting.key === 'device' && 'hidden')}
column={
setting.controller_type === 'input' &&
setting.controller_props.type !== 'number'
setting.controller_props.type !== 'number'
? true
: false
}
@@ -535,7 +538,8 @@ function ProviderDetail() {
</div>
)}
{setting.key === 'version_backend' &&
provider?.provider === 'llamacpp' && (
(provider?.provider === 'llamacpp' ||
provider?.provider === 'mlx') && (
<div className="mt-2 flex flex-wrap gap-2">
<Button
variant="outline"
@@ -543,27 +547,22 @@ function ProviderDetail() {
className={cn(
'p-0',
isCheckingBackendUpdate &&
'pointer-events-none'
'pointer-events-none'
)}
onClick={handleCheckForBackendUpdate}
>
<IconRefresh
size={12}
className={cn(
'text-muted-foreground',
isCheckingBackendUpdate &&
'animate-spin'
)}
/>
<span>
{isCheckingBackendUpdate
? t(
'settings:checkingForBackendUpdates'
)
: t(
'settings:checkForBackendUpdates'
)}
</span>
size={12}
className={cn(
'text-muted-foreground',
isCheckingBackendUpdate && 'animate-spin'
)}
/>
<span>
{isCheckingBackendUpdate
? t('settings:checkingForBackendUpdates')
: t('settings:checkForBackendUpdates')}
</span>
</Button>
<Button
variant="outline"
@@ -605,7 +604,7 @@ function ProviderDetail() {
{t('providers:models')}
</h1>
<div className="flex items-center gap-2">
{provider && provider.provider !== 'llamacpp' && (
{provider && provider.provider !== 'llamacpp' && provider.provider !== 'mlx' && (
<>
<Button
variant="secondary"
@@ -648,6 +647,21 @@ function ProviderDetail() {
}
/>
)}
{provider && provider.provider === 'mlx' && (
<ImportMlxModelDialog
provider={provider}
onSuccess={handleModelImportSuccess}
trigger={
<Button variant="secondary" size="sm">
<IconFolderPlus
size={18}
className="text-muted-foreground"
/>
<span>{t('providers:import')}</span>
</Button>
}
/>
)}
</div>
</div>
}
@@ -676,10 +690,7 @@ function ProviderDetail() {
modelId={model.id}
/>
{model.settings && (
<ModelSetting
provider={provider}
model={model}
/>
<ModelSetting provider={provider} model={model} />
)}
{((provider &&
!predefinedProviders.some(
@@ -690,8 +701,8 @@ function ProviderDetail() {
(p) => p.provider === provider.provider
) &&
Boolean(provider.api_key?.length))) && (
<FavoriteModelAction model={model} />
)}
<FavoriteModelAction model={model} />
)}
<DialogDeleteModel
provider={provider}
modelId={model.id}
@@ -711,9 +722,7 @@ function ProviderDetail() {
) : (
<Button
size="sm"
disabled={loadingModels.includes(
model.id
)}
disabled={loadingModels.includes(model.id)}
onClick={() => handleStartModel(model.id)}
>
{loadingModels.includes(model.id) ? (

View File

@@ -282,8 +282,8 @@ export class DefaultModelsService implements ModelsService {
}
}
async deleteModel(id: string): Promise<void> {
return this.getEngine()?.delete(id)
async deleteModel(id: string, provider?: string): Promise<void> {
return this.getEngine(provider)?.delete(id)
}
async getActiveModels(provider?: string): Promise<string[]> {

View File

@@ -125,7 +125,7 @@ export interface ModelsService {
skipVerification?: boolean
): Promise<void>
abortDownload(id: string): Promise<void>
deleteModel(id: string): Promise<void>
deleteModel(id: string, provider?: string): Promise<void>
getActiveModels(provider?: string): Promise<string[]>
stopModel(model: string, provider?: string): Promise<UnloadResult | undefined>
stopAllModels(): Promise<void>