feat: support mlx plugin
# Conflicts: # Makefile # web-app/src/routes/settings/providers/$providerName.tsx
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -63,3 +63,5 @@ src-tauri/resources/
|
||||
test-data
|
||||
llm-docs
|
||||
.claude/agents
|
||||
mlx-server/.build
|
||||
mlx-server/.swiftpm
|
||||
|
||||
12
Makefile
12
Makefile
@@ -118,6 +118,18 @@ endif
|
||||
cargo test --manifest-path src-tauri/plugins/tauri-plugin-llamacpp/Cargo.toml
|
||||
cargo test --manifest-path src-tauri/utils/Cargo.toml
|
||||
|
||||
# Build MLX server (macOS Apple Silicon only)
|
||||
build-mlx-server:
|
||||
ifeq ($(shell uname -s),Darwin)
|
||||
@echo "Building MLX server for Apple Silicon..."
|
||||
# cd mlx-server && swift build -c release
|
||||
cd mlx-server && xcodebuild build -scheme mlx-server -destination 'platform=OS X'
|
||||
# -configuration Release
|
||||
@echo "MLX server built successfully"
|
||||
else
|
||||
@echo "Skipping MLX server build (macOS only)"
|
||||
endif
|
||||
|
||||
# Build
|
||||
build: install-and-build install-rust-targets
|
||||
yarn build
|
||||
|
||||
40
extensions/mlx-extension/package.json
Normal file
40
extensions/mlx-extension/package.json
Normal file
@@ -0,0 +1,40 @@
|
||||
{
|
||||
"name": "@janhq/mlx-extension",
|
||||
"productName": "MLX Inference Engine",
|
||||
"version": "1.0.0",
|
||||
"description": "This extension enables MLX-Swift inference on Apple Silicon Macs",
|
||||
"main": "dist/index.js",
|
||||
"module": "dist/module.js",
|
||||
"engine": "mlx",
|
||||
"author": "Jan <service@jan.ai>",
|
||||
"license": "AGPL-3.0",
|
||||
"scripts": {
|
||||
"build": "rolldown -c rolldown.config.mjs",
|
||||
"build:publish": "rimraf *.tgz --glob || true && yarn build && npm pack && cpx *.tgz ../../pre-install"
|
||||
},
|
||||
"devDependencies": {
|
||||
"cpx": "1.5.0",
|
||||
"rimraf": "3.0.2",
|
||||
"rolldown": "1.0.0-beta.1",
|
||||
"typescript": "5.9.2"
|
||||
},
|
||||
"dependencies": {
|
||||
"@janhq/core": "../../core/package.tgz",
|
||||
"@janhq/tauri-plugin-llamacpp-api": "link:../../src-tauri/plugins/tauri-plugin-llamacpp",
|
||||
"@janhq/tauri-plugin-mlx-api": "link:../../src-tauri/plugins/tauri-plugin-mlx",
|
||||
"@tauri-apps/api": "2.8.0",
|
||||
"@tauri-apps/plugin-http": "2.5.0",
|
||||
"@tauri-apps/plugin-log": "^2.6.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18.0.0"
|
||||
},
|
||||
"files": [
|
||||
"dist/*",
|
||||
"package.json"
|
||||
],
|
||||
"installConfig": {
|
||||
"hoistingLimits": "workspaces"
|
||||
},
|
||||
"packageManager": "yarn@4.5.3"
|
||||
}
|
||||
21
extensions/mlx-extension/rolldown.config.mjs
Normal file
21
extensions/mlx-extension/rolldown.config.mjs
Normal file
@@ -0,0 +1,21 @@
|
||||
|
||||
import { defineConfig } from 'rolldown'
|
||||
import pkgJson from './package.json' with { type: 'json' }
|
||||
import settingJson from './settings.json' with { type: 'json' }
|
||||
|
||||
export default defineConfig({
|
||||
input: 'src/index.ts',
|
||||
output: {
|
||||
format: 'esm',
|
||||
file: 'dist/index.js',
|
||||
},
|
||||
platform: 'browser',
|
||||
define: {
|
||||
SETTINGS: JSON.stringify(settingJson),
|
||||
ENGINE: JSON.stringify(pkgJson.engine),
|
||||
IS_MAC: JSON.stringify(process.platform === 'darwin'),
|
||||
},
|
||||
inject: process.env.IS_DEV ? {} : {
|
||||
fetch: ['@tauri-apps/plugin-http', 'fetch'],
|
||||
},
|
||||
})
|
||||
33
extensions/mlx-extension/settings.json
Normal file
33
extensions/mlx-extension/settings.json
Normal file
@@ -0,0 +1,33 @@
|
||||
[
|
||||
{
|
||||
"key": "ctx_size",
|
||||
"title": "Context Size",
|
||||
"description": "Context window size for MLX inference",
|
||||
"controllerType": "input",
|
||||
"controllerProps": {
|
||||
"value": 4096,
|
||||
"placeholder": "4096",
|
||||
"type": "number",
|
||||
"textAlign": "right"
|
||||
}
|
||||
},
|
||||
{
|
||||
"key": "auto_unload",
|
||||
"title": "Auto unload model",
|
||||
"description": "Automatically unload other models when loading a new one",
|
||||
"controllerType": "checkbox",
|
||||
"controllerProps": { "value": true }
|
||||
},
|
||||
{
|
||||
"key": "timeout",
|
||||
"title": "Timeout (seconds)",
|
||||
"description": "Maximum time to wait for model to load",
|
||||
"controllerType": "input",
|
||||
"controllerProps": {
|
||||
"value": 600,
|
||||
"placeholder": "600",
|
||||
"type": "number",
|
||||
"textAlign": "right"
|
||||
}
|
||||
}
|
||||
]
|
||||
5
extensions/mlx-extension/src/env.d.ts
vendored
Normal file
5
extensions/mlx-extension/src/env.d.ts
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
declare const SETTINGS: SettingComponentProps[]
|
||||
declare const ENGINE: string
|
||||
declare const IS_WINDOWS: boolean
|
||||
declare const IS_MAC: boolean
|
||||
declare const IS_LINUX: boolean
|
||||
825
extensions/mlx-extension/src/index.ts
Normal file
825
extensions/mlx-extension/src/index.ts
Normal file
@@ -0,0 +1,825 @@
|
||||
/**
|
||||
* MLX Extension - Inference engine for Apple Silicon Macs using MLX-Swift
|
||||
*
|
||||
* This extension provides an alternative to llama.cpp for running GGUF models
|
||||
* locally on Apple Silicon using the MLX framework with Metal GPU acceleration.
|
||||
*
|
||||
* It shares the same model directory as llamacpp-extension so users can
|
||||
* switch between engines without re-downloading models.
|
||||
*/
|
||||
|
||||
import {
|
||||
AIEngine,
|
||||
getJanDataFolderPath,
|
||||
fs,
|
||||
joinPath,
|
||||
modelInfo,
|
||||
SessionInfo,
|
||||
UnloadResult,
|
||||
chatCompletion,
|
||||
chatCompletionChunk,
|
||||
ImportOptions,
|
||||
chatCompletionRequest,
|
||||
events,
|
||||
AppEvent,
|
||||
DownloadEvent,
|
||||
} from '@janhq/core'
|
||||
|
||||
import { info, warn, error as logError } from '@tauri-apps/plugin-log'
|
||||
import { invoke } from '@tauri-apps/api/core'
|
||||
import {
|
||||
loadMlxModel,
|
||||
unloadMlxModel,
|
||||
MlxConfig,
|
||||
} from '@janhq/tauri-plugin-mlx-api'
|
||||
import { readGgufMetadata, ModelConfig } from '@janhq/tauri-plugin-llamacpp-api'
|
||||
|
||||
// Error message constant
|
||||
const OUT_OF_CONTEXT_SIZE = 'the request exceeds the available context size.'
|
||||
|
||||
const logger = {
|
||||
info: function (...args: any[]) {
|
||||
console.log(...args)
|
||||
info(args.map((arg) => ` ${arg}`).join(` `))
|
||||
},
|
||||
warn: function (...args: any[]) {
|
||||
console.warn(...args)
|
||||
warn(args.map((arg) => ` ${arg}`).join(` `))
|
||||
},
|
||||
error: function (...args: any[]) {
|
||||
console.error(...args)
|
||||
logError(args.map((arg) => ` ${arg}`).join(` `))
|
||||
},
|
||||
}
|
||||
|
||||
export default class mlx_extension extends AIEngine {
|
||||
provider: string = 'mlx'
|
||||
autoUnload: boolean = true
|
||||
timeout: number = 600
|
||||
readonly providerId: string = 'mlx'
|
||||
|
||||
private config: any = {}
|
||||
private providerPath!: string
|
||||
private apiSecret: string = 'JanMLX'
|
||||
private loadingModels = new Map<string, Promise<SessionInfo>>()
|
||||
|
||||
override async onLoad(): Promise<void> {
|
||||
super.onLoad()
|
||||
|
||||
let settings = structuredClone(SETTINGS)
|
||||
this.registerSettings(settings)
|
||||
|
||||
let loadedConfig: any = {}
|
||||
for (const item of settings) {
|
||||
const defaultValue = item.controllerProps.value
|
||||
loadedConfig[item.key] = await this.getSetting<typeof defaultValue>(
|
||||
item.key,
|
||||
defaultValue
|
||||
)
|
||||
}
|
||||
this.config = loadedConfig
|
||||
|
||||
this.autoUnload = this.config.auto_unload ?? true
|
||||
this.timeout = this.config.timeout ?? 600
|
||||
|
||||
this.getProviderPath()
|
||||
}
|
||||
|
||||
async getProviderPath(): Promise<string> {
|
||||
if (!this.providerPath) {
|
||||
// Use mlx folder for models
|
||||
this.providerPath = await joinPath([
|
||||
await getJanDataFolderPath(),
|
||||
'mlx',
|
||||
])
|
||||
}
|
||||
return this.providerPath
|
||||
}
|
||||
|
||||
override async onUnload(): Promise<void> {
|
||||
// Cleanup handled by Tauri plugin on app exit
|
||||
}
|
||||
|
||||
onSettingUpdate<T>(key: string, value: T): void {
|
||||
this.config[key] = value
|
||||
|
||||
if (key === 'auto_unload') {
|
||||
this.autoUnload = value as boolean
|
||||
} else if (key === 'timeout') {
|
||||
this.timeout = value as number
|
||||
}
|
||||
}
|
||||
|
||||
private async generateApiKey(
|
||||
modelId: string,
|
||||
port: string
|
||||
): Promise<string> {
|
||||
// Reuse the llamacpp plugin's API key generation
|
||||
const hash = await invoke<string>('plugin:llamacpp|generate_api_key', {
|
||||
modelId: modelId + port,
|
||||
apiSecret: this.apiSecret,
|
||||
})
|
||||
return hash
|
||||
}
|
||||
|
||||
override async get(modelId: string): Promise<modelInfo | undefined> {
|
||||
const modelPath = await joinPath([
|
||||
await this.getProviderPath(),
|
||||
'models',
|
||||
modelId,
|
||||
])
|
||||
const path = await joinPath([modelPath, 'model.yml'])
|
||||
|
||||
if (!(await fs.existsSync(path))) return undefined
|
||||
|
||||
const modelConfig = await invoke<ModelConfig>('read_yaml', { path })
|
||||
|
||||
return {
|
||||
id: modelId,
|
||||
name: modelConfig.name ?? modelId,
|
||||
providerId: this.provider,
|
||||
port: 0,
|
||||
sizeBytes: modelConfig.size_bytes ?? 0,
|
||||
embedding: modelConfig.embedding ?? false,
|
||||
} as modelInfo
|
||||
}
|
||||
|
||||
override async list(): Promise<modelInfo[]> {
|
||||
const modelsDir = await joinPath([await this.getProviderPath(), 'models'])
|
||||
if (!(await fs.existsSync(modelsDir))) {
|
||||
await fs.mkdir(modelsDir)
|
||||
}
|
||||
|
||||
let modelIds: string[] = []
|
||||
|
||||
// DFS to find all model.yml files
|
||||
let stack = [modelsDir]
|
||||
while (stack.length > 0) {
|
||||
const currentDir = stack.pop()
|
||||
|
||||
const modelConfigPath = await joinPath([currentDir, 'model.yml'])
|
||||
if (await fs.existsSync(modelConfigPath)) {
|
||||
modelIds.push(currentDir.slice(modelsDir.length + 1))
|
||||
continue
|
||||
}
|
||||
|
||||
const children = await fs.readdirSync(currentDir)
|
||||
for (const child of children) {
|
||||
const dirInfo = await fs.fileStat(child)
|
||||
if (!dirInfo.isDirectory) continue
|
||||
stack.push(child)
|
||||
}
|
||||
}
|
||||
|
||||
let modelInfos: modelInfo[] = []
|
||||
for (const modelId of modelIds) {
|
||||
const path = await joinPath([modelsDir, modelId, 'model.yml'])
|
||||
const modelConfig = await invoke<ModelConfig>('read_yaml', { path })
|
||||
|
||||
const capabilities: string[] = []
|
||||
if (modelConfig.mmproj_path) {
|
||||
capabilities.push('vision')
|
||||
}
|
||||
|
||||
// Check for tool support
|
||||
try {
|
||||
if (await this.isToolSupported(modelId)) {
|
||||
capabilities.push('tools')
|
||||
}
|
||||
} catch (e) {
|
||||
logger.warn(`Failed to check tool support for ${modelId}: ${e}`)
|
||||
}
|
||||
|
||||
modelInfos.push({
|
||||
id: modelId,
|
||||
name: modelConfig.name ?? modelId,
|
||||
providerId: this.provider,
|
||||
port: 0,
|
||||
sizeBytes: modelConfig.size_bytes ?? 0,
|
||||
embedding: modelConfig.embedding ?? false,
|
||||
capabilities: capabilities.length > 0 ? capabilities : undefined,
|
||||
} as modelInfo)
|
||||
}
|
||||
|
||||
return modelInfos
|
||||
}
|
||||
|
||||
private async getRandomPort(): Promise<number> {
|
||||
try {
|
||||
return await invoke<number>('plugin:mlx|get_mlx_random_port')
|
||||
} catch {
|
||||
logger.error('Unable to find a suitable port for MLX server')
|
||||
throw new Error('Unable to find a suitable port for MLX model')
|
||||
}
|
||||
}
|
||||
|
||||
override async load(
|
||||
modelId: string,
|
||||
overrideSettings?: any,
|
||||
isEmbedding: boolean = false
|
||||
): Promise<SessionInfo> {
|
||||
const sInfo = await this.findSessionByModel(modelId)
|
||||
if (sInfo) {
|
||||
throw new Error('Model already loaded!')
|
||||
}
|
||||
|
||||
if (this.loadingModels.has(modelId)) {
|
||||
return this.loadingModels.get(modelId)!
|
||||
}
|
||||
|
||||
const loadingPromise = this.performLoad(
|
||||
modelId,
|
||||
overrideSettings,
|
||||
isEmbedding
|
||||
)
|
||||
this.loadingModels.set(modelId, loadingPromise)
|
||||
|
||||
try {
|
||||
return await loadingPromise
|
||||
} finally {
|
||||
this.loadingModels.delete(modelId)
|
||||
}
|
||||
}
|
||||
|
||||
private async performLoad(
|
||||
modelId: string,
|
||||
overrideSettings?: any,
|
||||
isEmbedding: boolean = false
|
||||
): Promise<SessionInfo> {
|
||||
const loadedModels = await this.getLoadedModels()
|
||||
|
||||
// Auto-unload other models if needed
|
||||
const otherLoadingPromises = Array.from(this.loadingModels.entries())
|
||||
.filter(([id, _]) => id !== modelId)
|
||||
.map(([_, promise]) => promise)
|
||||
|
||||
if (
|
||||
this.autoUnload &&
|
||||
!isEmbedding &&
|
||||
(loadedModels.length > 0 || otherLoadingPromises.length > 0)
|
||||
) {
|
||||
if (otherLoadingPromises.length > 0) {
|
||||
await Promise.all(otherLoadingPromises)
|
||||
}
|
||||
|
||||
const allLoadedModels = await this.getLoadedModels()
|
||||
if (allLoadedModels.length > 0) {
|
||||
await Promise.all(allLoadedModels.map((id) => this.unload(id)))
|
||||
}
|
||||
}
|
||||
|
||||
const cfg = { ...this.config, ...(overrideSettings ?? {}) }
|
||||
|
||||
const janDataFolderPath = await getJanDataFolderPath()
|
||||
const modelConfigPath = await joinPath([
|
||||
this.providerPath,
|
||||
'models',
|
||||
modelId,
|
||||
'model.yml',
|
||||
])
|
||||
const modelConfig = await invoke<ModelConfig>('read_yaml', {
|
||||
path: modelConfigPath,
|
||||
})
|
||||
const port = await this.getRandomPort()
|
||||
|
||||
const api_key = await this.generateApiKey(modelId, String(port))
|
||||
const envs: Record<string, string> = {
|
||||
MLX_API_KEY: api_key,
|
||||
}
|
||||
|
||||
// Resolve model path - could be absolute or relative
|
||||
let modelPath: string
|
||||
if (modelConfig.model_path.startsWith('/') || modelConfig.model_path.includes(':')) {
|
||||
// Absolute path
|
||||
modelPath = modelConfig.model_path
|
||||
} else {
|
||||
// Relative path - resolve from Jan data folder
|
||||
modelPath = await joinPath([
|
||||
janDataFolderPath,
|
||||
modelConfig.model_path,
|
||||
])
|
||||
}
|
||||
|
||||
// Resolve the MLX server binary path
|
||||
const mlxServerPath = await this.getMlxServerBinaryPath()
|
||||
|
||||
const mlxConfig: MlxConfig = {
|
||||
ctx_size: cfg.ctx_size ?? 4096,
|
||||
n_predict: cfg.n_predict ?? 0,
|
||||
threads: cfg.threads ?? 0,
|
||||
chat_template: cfg.chat_template ?? '',
|
||||
}
|
||||
|
||||
logger.info(
|
||||
'Loading MLX model:',
|
||||
modelId,
|
||||
'with config:',
|
||||
JSON.stringify(mlxConfig)
|
||||
)
|
||||
|
||||
try {
|
||||
const sInfo = await loadMlxModel(
|
||||
mlxServerPath,
|
||||
modelId,
|
||||
modelPath,
|
||||
port,
|
||||
mlxConfig,
|
||||
envs,
|
||||
isEmbedding,
|
||||
Number(this.timeout)
|
||||
)
|
||||
return sInfo
|
||||
} catch (error) {
|
||||
logger.error('Error loading MLX model:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
private async getMlxServerBinaryPath(): Promise<string> {
|
||||
const janDataFolderPath = await getJanDataFolderPath()
|
||||
// Look for the MLX server binary in the Jan data folder
|
||||
const binaryPath = await joinPath([
|
||||
janDataFolderPath,
|
||||
'mlx',
|
||||
'mlx-server',
|
||||
])
|
||||
|
||||
if (await fs.existsSync(binaryPath)) {
|
||||
return binaryPath
|
||||
}
|
||||
|
||||
// Fallback: check in the app resources
|
||||
throw new Error(
|
||||
'MLX server binary not found. Please ensure mlx-server is installed at ' +
|
||||
binaryPath
|
||||
)
|
||||
}
|
||||
|
||||
override async unload(modelId: string): Promise<UnloadResult> {
|
||||
const sInfo = await this.findSessionByModel(modelId)
|
||||
if (!sInfo) {
|
||||
throw new Error(`No active MLX session found for model: ${modelId}`)
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await unloadMlxModel(sInfo.pid)
|
||||
if (result.success) {
|
||||
logger.info(`Successfully unloaded MLX model with PID ${sInfo.pid}`)
|
||||
} else {
|
||||
logger.warn(`Failed to unload MLX model: ${result.error}`)
|
||||
}
|
||||
return result
|
||||
} catch (error) {
|
||||
logger.error('Error unloading MLX model:', error)
|
||||
return {
|
||||
success: false,
|
||||
error: `Failed to unload model: ${error}`,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async findSessionByModel(modelId: string): Promise<SessionInfo> {
|
||||
try {
|
||||
return await invoke<SessionInfo>(
|
||||
'plugin:mlx|find_mlx_session_by_model',
|
||||
{ modelId }
|
||||
)
|
||||
} catch (e) {
|
||||
logger.error(e)
|
||||
throw new Error(String(e))
|
||||
}
|
||||
}
|
||||
|
||||
override async chat(
|
||||
opts: chatCompletionRequest,
|
||||
abortController?: AbortController
|
||||
): Promise<chatCompletion | AsyncIterable<chatCompletionChunk>> {
|
||||
const sessionInfo = await this.findSessionByModel(opts.model)
|
||||
if (!sessionInfo) {
|
||||
throw new Error(`No active MLX session found for model: ${opts.model}`)
|
||||
}
|
||||
|
||||
// Check if the process is alive
|
||||
const isAlive = await invoke<boolean>('plugin:mlx|is_mlx_process_running', {
|
||||
pid: sessionInfo.pid,
|
||||
})
|
||||
|
||||
if (isAlive) {
|
||||
try {
|
||||
await fetch(`http://localhost:${sessionInfo.port}/health`)
|
||||
} catch (e) {
|
||||
this.unload(sessionInfo.model_id)
|
||||
throw new Error('MLX model appears to have crashed! Please reload!')
|
||||
}
|
||||
} else {
|
||||
throw new Error('MLX model has crashed! Please reload!')
|
||||
}
|
||||
|
||||
const baseUrl = `http://localhost:${sessionInfo.port}/v1`
|
||||
const url = `${baseUrl}/chat/completions`
|
||||
const headers = {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${sessionInfo.api_key}`,
|
||||
}
|
||||
|
||||
const body = JSON.stringify(opts)
|
||||
|
||||
if (opts.stream) {
|
||||
return this.handleStreamingResponse(url, headers, body, abortController)
|
||||
}
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body,
|
||||
signal: abortController?.signal,
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json().catch(() => null)
|
||||
throw new Error(
|
||||
`MLX API request failed with status ${response.status}: ${JSON.stringify(errorData)}`
|
||||
)
|
||||
}
|
||||
|
||||
const completionResponse = (await response.json()) as chatCompletion
|
||||
|
||||
if (completionResponse.choices?.[0]?.finish_reason === 'length') {
|
||||
throw new Error(OUT_OF_CONTEXT_SIZE)
|
||||
}
|
||||
|
||||
return completionResponse
|
||||
}
|
||||
|
||||
private async *handleStreamingResponse(
|
||||
url: string,
|
||||
headers: HeadersInit,
|
||||
body: string,
|
||||
abortController?: AbortController
|
||||
): AsyncIterable<chatCompletionChunk> {
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body,
|
||||
signal: AbortSignal.any([
|
||||
AbortSignal.timeout(this.timeout * 1000),
|
||||
abortController?.signal,
|
||||
]),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json().catch(() => null)
|
||||
throw new Error(
|
||||
`MLX API request failed with status ${response.status}: ${JSON.stringify(errorData)}`
|
||||
)
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error('Response body is null')
|
||||
}
|
||||
|
||||
const reader = response.body.getReader()
|
||||
const decoder = new TextDecoder('utf-8')
|
||||
let buffer = ''
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
|
||||
buffer += decoder.decode(value, { stream: true })
|
||||
|
||||
const lines = buffer.split('\n')
|
||||
buffer = lines.pop() || ''
|
||||
|
||||
for (const line of lines) {
|
||||
const trimmedLine = line.trim()
|
||||
if (!trimmedLine || trimmedLine === 'data: [DONE]') continue
|
||||
|
||||
if (trimmedLine.startsWith('data: ')) {
|
||||
const jsonStr = trimmedLine.slice(6)
|
||||
try {
|
||||
const data = JSON.parse(jsonStr) as chatCompletionChunk
|
||||
|
||||
if (data.choices?.[0]?.finish_reason === 'length') {
|
||||
throw new Error(OUT_OF_CONTEXT_SIZE)
|
||||
}
|
||||
|
||||
yield data
|
||||
} catch (e) {
|
||||
logger.error('Error parsing MLX stream JSON:', e)
|
||||
throw e
|
||||
}
|
||||
} else if (trimmedLine.startsWith('error: ')) {
|
||||
const jsonStr = trimmedLine.slice(7)
|
||||
const error = JSON.parse(jsonStr)
|
||||
throw new Error(error.message)
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
}
|
||||
|
||||
override async delete(modelId: string): Promise<void> {
|
||||
const modelDir = await joinPath([
|
||||
await this.getProviderPath(),
|
||||
'models',
|
||||
modelId,
|
||||
])
|
||||
|
||||
const modelConfigPath = await joinPath([modelDir, 'model.yml'])
|
||||
if (!(await fs.existsSync(modelConfigPath))) {
|
||||
throw new Error(`Model ${modelId} does not exist`)
|
||||
}
|
||||
|
||||
const modelConfig = await invoke<ModelConfig>('read_yaml', {
|
||||
path: modelConfigPath,
|
||||
})
|
||||
|
||||
// Check if model_path is a relative path within mlx folder
|
||||
if (!modelConfig.model_path.startsWith('/') && !modelConfig.model_path.includes(':')) {
|
||||
// Model file is at {janDataFolder}/{model_path}
|
||||
// Delete the parent folder containing the actual model file
|
||||
const janDataFolderPath = await getJanDataFolderPath()
|
||||
const modelPath = await joinPath([janDataFolderPath, modelConfig.model_path])
|
||||
const parentDir = modelPath.substring(0, modelPath.lastIndexOf('/'))
|
||||
// Only delete if it's different from modelDir (i.e., not the same folder)
|
||||
if (parentDir !== modelDir) {
|
||||
await fs.rm(parentDir)
|
||||
}
|
||||
}
|
||||
|
||||
// Always delete the model.yml folder
|
||||
await fs.rm(modelDir)
|
||||
}
|
||||
|
||||
override async update(
|
||||
modelId: string,
|
||||
model: Partial<modelInfo>
|
||||
): Promise<void> {
|
||||
// Delegate to the same logic as llamacpp since they share the model dir
|
||||
const modelFolderPath = await joinPath([
|
||||
await this.getProviderPath(),
|
||||
'models',
|
||||
modelId,
|
||||
])
|
||||
const modelConfig = await invoke<ModelConfig>('read_yaml', {
|
||||
path: await joinPath([modelFolderPath, 'model.yml']),
|
||||
})
|
||||
const newFolderPath = await joinPath([
|
||||
await this.getProviderPath(),
|
||||
'models',
|
||||
model.id,
|
||||
])
|
||||
if (await fs.existsSync(newFolderPath)) {
|
||||
throw new Error(`Model with ID ${model.id} already exists`)
|
||||
}
|
||||
const newModelConfigPath = await joinPath([newFolderPath, 'model.yml'])
|
||||
await fs.mv(modelFolderPath, newFolderPath).then(() =>
|
||||
invoke('write_yaml', {
|
||||
data: {
|
||||
...modelConfig,
|
||||
model_path: modelConfig?.model_path?.replace(
|
||||
`mlx/models/${modelId}`,
|
||||
`mlx/models/${model.id}`
|
||||
),
|
||||
},
|
||||
savePath: newModelConfigPath,
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
override async import(modelId: string, opts: ImportOptions): Promise<void> {
|
||||
const isValidModelId = (id: string) => {
|
||||
// only allow alphanumeric, underscore, hyphen, and dot characters in modelId
|
||||
if (!/^[a-zA-Z0-9/_\-\.]+$/.test(id)) return false
|
||||
|
||||
// check for empty parts or path traversal
|
||||
const parts = id.split('/')
|
||||
return parts.every((s) => s !== '' && s !== '.' && s !== '..')
|
||||
}
|
||||
|
||||
if (!isValidModelId(modelId))
|
||||
throw new Error(
|
||||
`Invalid modelId: ${modelId}. Only alphanumeric and / _ - . characters are allowed.`
|
||||
)
|
||||
|
||||
const configPath = await joinPath([
|
||||
await this.getProviderPath(),
|
||||
'models',
|
||||
modelId,
|
||||
'model.yml',
|
||||
])
|
||||
if (await fs.existsSync(configPath))
|
||||
throw new Error(`Model ${modelId} already exists`)
|
||||
|
||||
const sourcePath = opts.modelPath
|
||||
|
||||
if (sourcePath.startsWith('https://')) {
|
||||
// Download from URL to mlx models folder
|
||||
const janDataFolderPath = await getJanDataFolderPath()
|
||||
const modelDir = await joinPath([janDataFolderPath, 'mlx', 'models', modelId])
|
||||
const localPath = await joinPath([modelDir, 'model.safetensors'])
|
||||
|
||||
const downloadManager = window.core.extensionManager.getByName(
|
||||
'@janhq/download-extension'
|
||||
)
|
||||
await downloadManager.downloadFiles(
|
||||
[
|
||||
{
|
||||
url: sourcePath,
|
||||
save_path: localPath,
|
||||
sha256: opts.modelSha256,
|
||||
size: opts.modelSize,
|
||||
model_id: modelId,
|
||||
},
|
||||
],
|
||||
`mlx/${modelId}`,
|
||||
(transferred: number, total: number) => {
|
||||
events.emit(DownloadEvent.onFileDownloadUpdate, {
|
||||
modelId,
|
||||
percent: transferred / total,
|
||||
size: { transferred, total },
|
||||
downloadType: 'Model',
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
// Create model.yml with relative path
|
||||
const modelConfig = {
|
||||
model_path: `mlx/models/${modelId}/model.safetensors`,
|
||||
name: modelId,
|
||||
size_bytes: opts.modelSize ?? 0,
|
||||
}
|
||||
|
||||
await fs.mkdir(modelDir)
|
||||
await invoke<void>('write_yaml', {
|
||||
data: modelConfig,
|
||||
savePath: configPath,
|
||||
})
|
||||
|
||||
events.emit(AppEvent.onModelImported, {
|
||||
modelId,
|
||||
modelPath: modelConfig.model_path,
|
||||
size_bytes: modelConfig.size_bytes,
|
||||
})
|
||||
} else {
|
||||
// Local file - use absolute path directly
|
||||
if (!(await fs.existsSync(sourcePath))) {
|
||||
throw new Error(`File not found: ${sourcePath}`)
|
||||
}
|
||||
|
||||
// Get file size
|
||||
const stat = await fs.fileStat(sourcePath)
|
||||
const size_bytes = stat.size
|
||||
|
||||
// Create model.yml with absolute path
|
||||
const modelConfig = {
|
||||
model_path: sourcePath,
|
||||
name: modelId,
|
||||
size_bytes,
|
||||
}
|
||||
|
||||
// Create model folder for model.yml only (no copying of safetensors)
|
||||
const modelDir = await joinPath([
|
||||
await this.getProviderPath(),
|
||||
'models',
|
||||
modelId,
|
||||
])
|
||||
await fs.mkdir(modelDir)
|
||||
|
||||
await invoke<void>('write_yaml', {
|
||||
data: modelConfig,
|
||||
savePath: configPath,
|
||||
})
|
||||
|
||||
events.emit(AppEvent.onModelImported, {
|
||||
modelId,
|
||||
modelPath: sourcePath,
|
||||
size_bytes,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
override async abortImport(modelId: string): Promise<void> {
|
||||
// Not applicable for MLX - imports go through llamacpp extension
|
||||
}
|
||||
|
||||
override async getLoadedModels(): Promise<string[]> {
|
||||
try {
|
||||
return await invoke<string[]>('plugin:mlx|get_mlx_loaded_models')
|
||||
} catch (e) {
|
||||
logger.error(e)
|
||||
throw new Error(e)
|
||||
}
|
||||
}
|
||||
|
||||
async isToolSupported(modelId: string): Promise<boolean> {
|
||||
// Check GGUF/safetensors metadata for tool support
|
||||
const modelConfigPath = await joinPath([
|
||||
this.providerPath,
|
||||
'models',
|
||||
modelId,
|
||||
'model.yml',
|
||||
])
|
||||
const modelConfig = await invoke<ModelConfig>('read_yaml', {
|
||||
path: modelConfigPath,
|
||||
})
|
||||
|
||||
// model_path could be absolute or relative
|
||||
let modelPath: string
|
||||
if (modelConfig.model_path.startsWith('/') || modelConfig.model_path.includes(':')) {
|
||||
// Absolute path
|
||||
modelPath = modelConfig.model_path
|
||||
} else {
|
||||
// Relative path - resolve from Jan data folder
|
||||
const janDataFolderPath = await getJanDataFolderPath()
|
||||
modelPath = await joinPath([janDataFolderPath, modelConfig.model_path])
|
||||
}
|
||||
|
||||
// Check if model is safetensors or GGUF
|
||||
const isSafetensors = modelPath.endsWith('.safetensors')
|
||||
const modelDir = modelPath.substring(0, modelPath.lastIndexOf('/'))
|
||||
|
||||
// For safetensors models, check multiple sources for tool support
|
||||
if (isSafetensors) {
|
||||
// Check 1: tokenizer_config.json (common for tool-capable models)
|
||||
const tokenizerConfigPath = await joinPath([modelDir, 'tokenizer_config.json'])
|
||||
if (await fs.existsSync(tokenizerConfigPath)) {
|
||||
try {
|
||||
const tokenizerConfigContent = await invoke<string>('read_file_sync', {
|
||||
args: [tokenizerConfigPath],
|
||||
})
|
||||
// Check for tool/function calling indicators
|
||||
const tcLower = tokenizerConfigContent.toLowerCase()
|
||||
if (tcLower.includes('function_call') ||
|
||||
tcLower.includes('tool_use') ||
|
||||
tcLower.includes('tools') ||
|
||||
tcLower.includes('assistant')) {
|
||||
logger.info(`Tool support detected from tokenizer_config.json for ${modelId}`)
|
||||
return true
|
||||
}
|
||||
} catch (e) {
|
||||
logger.warn(`Failed to read tokenizer_config.json: ${e}`)
|
||||
}
|
||||
}
|
||||
|
||||
// Check 2: chat_template.jinja for tool patterns
|
||||
const chatTemplatePath = await joinPath([modelDir, 'chat_template.jinja'])
|
||||
if (await fs.existsSync(chatTemplatePath)) {
|
||||
try {
|
||||
const chatTemplateContent = await invoke<string>('read_file_sync', {
|
||||
args: [chatTemplatePath],
|
||||
})
|
||||
// Common tool/function calling template patterns
|
||||
const ctLower = chatTemplateContent.toLowerCase()
|
||||
const toolPatterns = [
|
||||
/\{\%.*tool.*\%\}/, // {% tool ... %}
|
||||
/\{\%.*function.*\%\}/, // {% function ... %}
|
||||
/\{\%.*tool_call/,
|
||||
/\{\%.*tools\./,
|
||||
/\{[-]?#.*tool/,
|
||||
/\{[-]?%.*tool/,
|
||||
/"tool_calls"/, // "tool_calls" JSON key
|
||||
/'tool_calls'/, // 'tool_calls' JSON key
|
||||
/function_call/,
|
||||
/tool_use/,
|
||||
]
|
||||
for (const pattern of toolPatterns) {
|
||||
if (pattern.test(chatTemplateContent)) {
|
||||
logger.info(`Tool support detected from chat_template.jinja for ${modelId}`)
|
||||
return true
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
logger.warn(`Failed to read chat_template.jinja: ${e}`)
|
||||
}
|
||||
}
|
||||
|
||||
// Check 3: Look for tool-related files
|
||||
const toolFiles = ['tools.jinja', 'tool_use.jinja', 'function_calling.jinja']
|
||||
for (const toolFile of toolFiles) {
|
||||
const toolPath = await joinPath([modelDir, toolFile])
|
||||
if (await fs.existsSync(toolPath)) {
|
||||
logger.info(`Tool support detected from ${toolFile} for ${modelId}`)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(`No tool support detected for safetensors model ${modelId}`)
|
||||
return false
|
||||
} else {
|
||||
// For GGUF models, check metadata
|
||||
try {
|
||||
const metadata = await readGgufMetadata(modelPath)
|
||||
const chatTemplate = metadata.metadata?.['tokenizer.chat_template']
|
||||
return chatTemplate?.includes('tools') ?? false
|
||||
} catch (e) {
|
||||
logger.warn(`Failed to read GGUF metadata: ${e}`)
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
15
extensions/mlx-extension/tsconfig.json
Normal file
15
extensions/mlx-extension/tsconfig.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "es2018",
|
||||
"module": "ES6",
|
||||
"moduleResolution": "node",
|
||||
"outDir": "./dist",
|
||||
"esModuleInterop": true,
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"strict": false,
|
||||
"skipLibCheck": true,
|
||||
"rootDir": "./src"
|
||||
},
|
||||
"include": ["./src"],
|
||||
"exclude": ["**/*.test.ts"]
|
||||
}
|
||||
266
mlx-server/Package.resolved
Normal file
266
mlx-server/Package.resolved
Normal file
@@ -0,0 +1,266 @@
|
||||
{
|
||||
"pins" : [
|
||||
{
|
||||
"identity" : "async-http-client",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/swift-server/async-http-client.git",
|
||||
"state" : {
|
||||
"revision" : "4b99975677236d13f0754339864e5360142ff5a1",
|
||||
"version" : "1.30.3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "hummingbird",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/hummingbird-project/hummingbird",
|
||||
"state" : {
|
||||
"revision" : "daf66bfd4b46c1f3f080a1f3438d8fbecee7ace5",
|
||||
"version" : "2.19.0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "mlx-swift",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/ml-explore/mlx-swift",
|
||||
"state" : {
|
||||
"revision" : "4dccaeda1d83cf8697f235d2786c2d72ad4bb925",
|
||||
"version" : "0.30.3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "mlx-swift-lm",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/ml-explore/mlx-swift-lm",
|
||||
"state" : {
|
||||
"branch" : "main",
|
||||
"revision" : "2c700546340c37f275d23302163701b77c4dcbd9"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-algorithms",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-algorithms.git",
|
||||
"state" : {
|
||||
"revision" : "87e50f483c54e6efd60e885f7f5aa946cee68023",
|
||||
"version" : "1.2.1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-argument-parser",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-argument-parser",
|
||||
"state" : {
|
||||
"revision" : "c5d11a805e765f52ba34ec7284bd4fcd6ba68615",
|
||||
"version" : "1.7.0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-asn1",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-asn1.git",
|
||||
"state" : {
|
||||
"revision" : "810496cf121e525d660cd0ea89a758740476b85f",
|
||||
"version" : "1.5.1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-async-algorithms",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-async-algorithms.git",
|
||||
"state" : {
|
||||
"revision" : "6c050d5ef8e1aa6342528460db614e9770d7f804",
|
||||
"version" : "1.1.1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-atomics",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-atomics.git",
|
||||
"state" : {
|
||||
"revision" : "b601256eab081c0f92f059e12818ac1d4f178ff7",
|
||||
"version" : "1.3.0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-certificates",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-certificates.git",
|
||||
"state" : {
|
||||
"revision" : "7d5f6124c91a2d06fb63a811695a3400d15a100e",
|
||||
"version" : "1.17.1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-collections",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-collections.git",
|
||||
"state" : {
|
||||
"revision" : "7b847a3b7008b2dc2f47ca3110d8c782fb2e5c7e",
|
||||
"version" : "1.3.0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-configuration",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-configuration.git",
|
||||
"state" : {
|
||||
"revision" : "6ffef195ed4ba98ee98029970c94db7edc60d4c6",
|
||||
"version" : "1.0.1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-crypto",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-crypto.git",
|
||||
"state" : {
|
||||
"revision" : "6f70fa9eab24c1fd982af18c281c4525d05e3095",
|
||||
"version" : "4.2.0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-distributed-tracing",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-distributed-tracing.git",
|
||||
"state" : {
|
||||
"revision" : "baa932c1336f7894145cbaafcd34ce2dd0b77c97",
|
||||
"version" : "1.3.1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-http-structured-headers",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-http-structured-headers.git",
|
||||
"state" : {
|
||||
"revision" : "76d7627bd88b47bf5a0f8497dd244885960dde0b",
|
||||
"version" : "1.6.0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-http-types",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-http-types.git",
|
||||
"state" : {
|
||||
"revision" : "45eb0224913ea070ec4fba17291b9e7ecf4749ca",
|
||||
"version" : "1.5.1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-jinja",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/huggingface/swift-jinja.git",
|
||||
"state" : {
|
||||
"revision" : "d81197f35f41445bc10e94600795e68c6f5e94b0",
|
||||
"version" : "2.3.1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-log",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-log.git",
|
||||
"state" : {
|
||||
"revision" : "2778fd4e5a12a8aaa30a3ee8285f4ce54c5f3181",
|
||||
"version" : "1.9.1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-metrics",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-metrics.git",
|
||||
"state" : {
|
||||
"revision" : "0743a9364382629da3bf5677b46a2c4b1ce5d2a6",
|
||||
"version" : "2.7.1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-nio",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-nio.git",
|
||||
"state" : {
|
||||
"revision" : "5e72fc102906ebe75a3487595a653e6f43725552",
|
||||
"version" : "2.94.0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-nio-extras",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-nio-extras.git",
|
||||
"state" : {
|
||||
"revision" : "3df009d563dc9f21a5c85b33d8c2e34d2e4f8c3b",
|
||||
"version" : "1.32.1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-nio-http2",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-nio-http2.git",
|
||||
"state" : {
|
||||
"revision" : "c2ba4cfbb83f307c66f5a6df6bb43e3c88dfbf80",
|
||||
"version" : "1.39.0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-nio-ssl",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-nio-ssl.git",
|
||||
"state" : {
|
||||
"revision" : "173cc69a058623525a58ae6710e2f5727c663793",
|
||||
"version" : "2.36.0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-nio-transport-services",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-nio-transport-services.git",
|
||||
"state" : {
|
||||
"revision" : "60c3e187154421171721c1a38e800b390680fb5d",
|
||||
"version" : "1.26.0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-numerics",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-numerics",
|
||||
"state" : {
|
||||
"revision" : "0c0290ff6b24942dadb83a929ffaaa1481df04a2",
|
||||
"version" : "1.1.1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-service-context",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-service-context.git",
|
||||
"state" : {
|
||||
"revision" : "1983448fefc717a2bc2ebde5490fe99873c5b8a6",
|
||||
"version" : "1.2.1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-service-lifecycle",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/swift-server/swift-service-lifecycle.git",
|
||||
"state" : {
|
||||
"revision" : "1de37290c0ab3c5a96028e0f02911b672fd42348",
|
||||
"version" : "2.9.1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-system",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-system",
|
||||
"state" : {
|
||||
"revision" : "7c6ad0fc39d0763e0b699210e4124afd5041c5df",
|
||||
"version" : "1.6.4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-transformers",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/huggingface/swift-transformers",
|
||||
"state" : {
|
||||
"revision" : "573e5c9036c2f136b3a8a071da8e8907322403d0",
|
||||
"version" : "1.1.6"
|
||||
}
|
||||
}
|
||||
],
|
||||
"version" : 2
|
||||
}
|
||||
36
mlx-server/Package.swift
Normal file
36
mlx-server/Package.swift
Normal file
@@ -0,0 +1,36 @@
|
||||
// swift-tools-version: 5.9
|
||||
|
||||
import PackageDescription
|
||||
|
||||
let package = Package(
|
||||
name: "mlx-server",
|
||||
platforms: [
|
||||
.macOS(.v14)
|
||||
],
|
||||
products: [
|
||||
.executable(name: "mlx-server", targets: ["MLXServer"])
|
||||
],
|
||||
dependencies: [
|
||||
.package(url: "https://github.com/ml-explore/mlx-swift", from: "0.30.3"),
|
||||
.package(url: "https://github.com/ml-explore/mlx-swift-lm", branch: "main"),
|
||||
.package(url: "https://github.com/apple/swift-argument-parser", from: "1.7.0"),
|
||||
.package(url: "https://github.com/hummingbird-project/hummingbird", from: "2.19.0"),
|
||||
],
|
||||
targets: [
|
||||
.executableTarget(
|
||||
name: "MLXServer",
|
||||
dependencies: [
|
||||
.product(name: "MLX", package: "mlx-swift"),
|
||||
.product(name: "MLXNN", package: "mlx-swift"),
|
||||
.product(name: "MLXOptimizers", package: "mlx-swift"),
|
||||
.product(name: "MLXRandom", package: "mlx-swift"),
|
||||
.product(name: "MLXLLM", package: "mlx-swift-lm"),
|
||||
.product(name: "MLXVLM", package: "mlx-swift-lm"),
|
||||
.product(name: "MLXLMCommon", package: "mlx-swift-lm"),
|
||||
.product(name: "ArgumentParser", package: "swift-argument-parser"),
|
||||
.product(name: "Hummingbird", package: "hummingbird"),
|
||||
],
|
||||
path: "Sources/MLXServer"
|
||||
)
|
||||
]
|
||||
)
|
||||
67
mlx-server/Sources/MLXServer/MLXServerCommand.swift
Normal file
67
mlx-server/Sources/MLXServer/MLXServerCommand.swift
Normal file
@@ -0,0 +1,67 @@
|
||||
import ArgumentParser
|
||||
import Foundation
|
||||
import Hummingbird
|
||||
|
||||
@main
|
||||
struct MLXServerCommand: AsyncParsableCommand {
|
||||
static let configuration = CommandConfiguration(
|
||||
commandName: "mlx-server",
|
||||
abstract: "MLX-Swift inference server with OpenAI-compatible API"
|
||||
)
|
||||
|
||||
@Option(name: [.long, .short], help: "Path to the GGUF model file")
|
||||
var model: String
|
||||
|
||||
@Option(name: .long, help: "Port to listen on")
|
||||
var port: Int = 8080
|
||||
|
||||
@Option(name: .long, help: "Context window size")
|
||||
var ctxSize: Int = 4096
|
||||
|
||||
@Option(name: .long, help: "API key for authentication (optional)")
|
||||
var apiKey: String = ""
|
||||
|
||||
@Option(name: .long, help: "Chat template to use (optional)")
|
||||
var chatTemplate: String = ""
|
||||
|
||||
@Flag(name: .long, help: "Run in embedding mode")
|
||||
var embedding: Bool = false
|
||||
|
||||
func run() async throws {
|
||||
// Print startup info
|
||||
print("[mlx] MLX-Swift Server starting...")
|
||||
print("[mlx] Model path: \(model)")
|
||||
print("[mlx] Port: \(port)")
|
||||
print("[mlx] Context size: \(ctxSize)")
|
||||
|
||||
// Extract model ID from path
|
||||
let modelURL = URL(fileURLWithPath: model)
|
||||
let modelId = modelURL.deletingPathExtension().lastPathComponent
|
||||
|
||||
// Load the model
|
||||
let modelRunner = ModelRunner()
|
||||
|
||||
do {
|
||||
try await modelRunner.load(modelPath: model, modelId: modelId)
|
||||
} catch {
|
||||
print("[mlx] Failed to load model: \(error)")
|
||||
throw error
|
||||
}
|
||||
|
||||
// Set up the HTTP server
|
||||
let server = MLXHTTPServer(
|
||||
modelRunner: modelRunner,
|
||||
modelId: modelId,
|
||||
apiKey: apiKey
|
||||
)
|
||||
|
||||
let router = server.buildRouter()
|
||||
let app = Application(router: router, configuration: .init(address: .hostname("127.0.0.1", port: port)))
|
||||
|
||||
// Print readiness signal (monitored by Tauri plugin)
|
||||
print("[mlx] http server listening on http://127.0.0.1:\(port)")
|
||||
print("[mlx] server is listening on 127.0.0.1:\(port)")
|
||||
|
||||
try await app.run()
|
||||
}
|
||||
}
|
||||
364
mlx-server/Sources/MLXServer/ModelRunner.swift
Normal file
364
mlx-server/Sources/MLXServer/ModelRunner.swift
Normal file
@@ -0,0 +1,364 @@
|
||||
import Foundation
|
||||
import MLX
|
||||
import MLXLLM
|
||||
import MLXLMCommon
|
||||
import MLXRandom
|
||||
import MLXVLM
|
||||
|
||||
/// Manages loading and running inference with MLX models
|
||||
actor ModelRunner {
|
||||
private var container: ModelContainer?
|
||||
private var modelId: String = ""
|
||||
|
||||
var isLoaded: Bool {
|
||||
container != nil
|
||||
}
|
||||
|
||||
var currentModelId: String {
|
||||
modelId
|
||||
}
|
||||
|
||||
/// Load a model from the given path, trying LLM first then VLM
|
||||
func load(modelPath: String, modelId: String) async throws {
|
||||
print("[mlx] Loading model from: \(modelPath)")
|
||||
|
||||
let modelURL = URL(fileURLWithPath: modelPath)
|
||||
let modelDir = modelURL.deletingLastPathComponent()
|
||||
let configuration = ModelConfiguration(directory: modelDir, defaultPrompt: "")
|
||||
|
||||
// Try LLM factory first, fall back to VLM factory
|
||||
do {
|
||||
self.container = try await LLMModelFactory.shared.loadContainer(
|
||||
configuration: configuration
|
||||
) { progress in
|
||||
print("[mlx] Loading progress: \(Int(progress.fractionCompleted * 100))%")
|
||||
}
|
||||
print("[mlx] Model loaded as LLM: \(modelId)")
|
||||
} catch {
|
||||
print("[mlx] LLM loading failed (\(error.localizedDescription)), trying VLM factory...")
|
||||
do {
|
||||
self.container = try await VLMModelFactory.shared.loadContainer(
|
||||
configuration: configuration
|
||||
) { progress in
|
||||
print("[mlx] Loading progress: \(Int(progress.fractionCompleted * 100))%")
|
||||
}
|
||||
print("[mlx] Model loaded as VLM: \(modelId)")
|
||||
} catch {
|
||||
print("[mlx] Error: Failed to load model with both LLM and VLM factories: \(error.localizedDescription)")
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
self.modelId = modelId
|
||||
print("[mlx] Model ready: \(modelId)")
|
||||
}
|
||||
|
||||
/// Build Chat.Message array from ChatMessages, including images and videos
|
||||
private func buildChat(from messages: [ChatMessage]) -> [Chat.Message] {
|
||||
messages.map { message in
|
||||
let role: Chat.Message.Role =
|
||||
switch message.role {
|
||||
case "assistant":
|
||||
.assistant
|
||||
case "user":
|
||||
.user
|
||||
case "system":
|
||||
.system
|
||||
case "tool":
|
||||
.tool
|
||||
default:
|
||||
.user
|
||||
}
|
||||
|
||||
let images: [UserInput.Image] = (message.images ?? []).compactMap { urlString in
|
||||
guard let url = URL(string: urlString) else {
|
||||
print("[mlx] Warning: Invalid image URL: \(urlString)")
|
||||
return nil
|
||||
}
|
||||
return .url(url)
|
||||
}
|
||||
|
||||
let videos: [UserInput.Video] = (message.videos ?? []).compactMap { urlString in
|
||||
guard let url = URL(string: urlString) else {
|
||||
print("[mlx] Warning: Invalid video URL: \(urlString)")
|
||||
return nil
|
||||
}
|
||||
return .url(url)
|
||||
}
|
||||
|
||||
if !images.isEmpty {
|
||||
print("[mlx] Message has \(images.count) image(s)")
|
||||
}
|
||||
if !videos.isEmpty {
|
||||
print("[mlx] Message has \(videos.count) video(s)")
|
||||
}
|
||||
|
||||
return Chat.Message(
|
||||
role: role, content: message.content ?? "", images: images, videos: videos)
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert AnyCodable tools array to ToolSpec format
|
||||
private func buildToolSpecs(from tools: [AnyCodable]?) -> [[String: any Sendable]]? {
|
||||
guard let tools = tools, !tools.isEmpty else { return nil }
|
||||
let specs = tools.map { tool in
|
||||
tool.toSendable() as! [String: any Sendable]
|
||||
}
|
||||
print("[mlx] Tools provided: \(specs.count)")
|
||||
return specs
|
||||
}
|
||||
|
||||
/// Generate a chat completion (non-streaming)
|
||||
func generate(
|
||||
messages: [ChatMessage],
|
||||
temperature: Float = 0.7,
|
||||
topP: Float = 1.0,
|
||||
maxTokens: Int = 2048,
|
||||
repetitionPenalty: Float = 1.0,
|
||||
stop: [String] = [],
|
||||
tools: [AnyCodable]? = nil
|
||||
) async throws -> (String, [ToolCallInfo], UsageInfo) {
|
||||
guard let container = container else {
|
||||
print("[mlx] Error: generate() called but no model is loaded")
|
||||
throw MLXServerError.modelNotLoaded
|
||||
}
|
||||
|
||||
print("[mlx] Generate: \(messages.count) messages, temp=\(temperature), topP=\(topP), maxTokens=\(maxTokens)")
|
||||
|
||||
let chat = buildChat(from: messages)
|
||||
let toolSpecs = buildToolSpecs(from: tools)
|
||||
|
||||
let generateParameters = GenerateParameters(
|
||||
maxTokens: maxTokens,
|
||||
temperature: temperature,
|
||||
topP: topP,
|
||||
repetitionPenalty: repetitionPenalty
|
||||
)
|
||||
let userInput = UserInput(
|
||||
chat: chat,
|
||||
processing: .init(resize: .init(width: 1024, height: 1024)),
|
||||
tools: toolSpecs
|
||||
)
|
||||
|
||||
let result: (String, [ToolCallInfo], UsageInfo) = try await container.perform { context in
|
||||
let input = try await context.processor.prepare(input: userInput)
|
||||
let promptTokenCount = input.text.tokens.size
|
||||
|
||||
var output = ""
|
||||
var completionTokenCount = 0
|
||||
var collectedToolCalls: [ToolCallInfo] = []
|
||||
var completionInfo: GenerateCompletionInfo?
|
||||
|
||||
do {
|
||||
for await generation in try MLXLMCommon.generate(
|
||||
input: input, parameters: generateParameters, context: context
|
||||
) {
|
||||
switch generation {
|
||||
case .chunk(let chunk):
|
||||
output += chunk
|
||||
completionTokenCount += 1
|
||||
|
||||
// Check stop sequences
|
||||
var hitStop = false
|
||||
for s in stop where output.hasSuffix(s) {
|
||||
output = String(output.dropLast(s.count))
|
||||
hitStop = true
|
||||
print("[mlx] Hit stop sequence: \"\(s)\"")
|
||||
break
|
||||
}
|
||||
if hitStop { break }
|
||||
|
||||
case .info(let info):
|
||||
completionInfo = info
|
||||
print("[mlx] Generation info: \(info.promptTokenCount) prompt tokens, \(info.generationTokenCount) generated tokens")
|
||||
print("[mlx] Prompt: \(String(format: "%.1f", info.promptTokensPerSecond)) tokens/sec")
|
||||
print("[mlx] Generation: \(String(format: "%.1f", info.tokensPerSecond)) tokens/sec")
|
||||
|
||||
case .toolCall(let toolCall):
|
||||
let argsData = try JSONSerialization.data(
|
||||
withJSONObject: toolCall.function.arguments.mapValues { $0.anyValue },
|
||||
options: [.sortedKeys]
|
||||
)
|
||||
let argsString = String(data: argsData, encoding: .utf8) ?? "{}"
|
||||
let info = ToolCallInfo(
|
||||
id: generateToolCallId(),
|
||||
type: "function",
|
||||
function: FunctionCall(
|
||||
name: toolCall.function.name,
|
||||
arguments: argsString
|
||||
)
|
||||
)
|
||||
collectedToolCalls.append(info)
|
||||
print("[mlx] Tool call: \(toolCall.function.name)(\(argsString))")
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
print("[mlx] Error during generation: \(error.localizedDescription)")
|
||||
throw error
|
||||
}
|
||||
|
||||
let usage: UsageInfo
|
||||
if let info = completionInfo {
|
||||
usage = UsageInfo(
|
||||
prompt_tokens: info.promptTokenCount,
|
||||
completion_tokens: info.generationTokenCount,
|
||||
total_tokens: info.promptTokenCount + info.generationTokenCount
|
||||
)
|
||||
} else {
|
||||
usage = UsageInfo(
|
||||
prompt_tokens: promptTokenCount,
|
||||
completion_tokens: completionTokenCount,
|
||||
total_tokens: promptTokenCount + completionTokenCount
|
||||
)
|
||||
}
|
||||
|
||||
print("[mlx] Generate complete: \(output.count) chars, \(collectedToolCalls.count) tool call(s)")
|
||||
return (output, collectedToolCalls, usage)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
/// Generate a streaming chat completion
|
||||
func generateStream(
|
||||
messages: [ChatMessage],
|
||||
temperature: Float = 0.7,
|
||||
topP: Float = 1.0,
|
||||
maxTokens: Int = 2048,
|
||||
repetitionPenalty: Float = 1.0,
|
||||
stop: [String] = [],
|
||||
tools: [AnyCodable]? = nil
|
||||
) -> AsyncThrowingStream<StreamEvent, Error> {
|
||||
AsyncThrowingStream { continuation in
|
||||
Task {
|
||||
guard let container = self.container else {
|
||||
print("[mlx] Error: generateStream() called but no model is loaded")
|
||||
continuation.finish(throwing: MLXServerError.modelNotLoaded)
|
||||
return
|
||||
}
|
||||
|
||||
print("[mlx] Stream generate: \(messages.count) messages, temp=\(temperature), topP=\(topP), maxTokens=\(maxTokens)")
|
||||
|
||||
let chat = self.buildChat(from: messages)
|
||||
let toolSpecs = self.buildToolSpecs(from: tools)
|
||||
|
||||
let userInput = UserInput(
|
||||
chat: chat,
|
||||
processing: .init(resize: .init(width: 1024, height: 1024)),
|
||||
tools: toolSpecs
|
||||
)
|
||||
|
||||
do {
|
||||
try await container.perform { context in
|
||||
let generateParameters = GenerateParameters(
|
||||
maxTokens: maxTokens,
|
||||
temperature: temperature,
|
||||
topP: topP,
|
||||
repetitionPenalty: repetitionPenalty
|
||||
)
|
||||
|
||||
let input = try await context.processor.prepare(input: userInput)
|
||||
|
||||
var completionTokenCount = 0
|
||||
var accumulated = ""
|
||||
var hasToolCalls = false
|
||||
|
||||
do {
|
||||
for await generation in try MLXLMCommon.generate(
|
||||
input: input, parameters: generateParameters, context: context
|
||||
) {
|
||||
switch generation {
|
||||
case .chunk(let chunk):
|
||||
accumulated += chunk
|
||||
completionTokenCount += 1
|
||||
|
||||
continuation.yield(.chunk(chunk))
|
||||
|
||||
// Check stop sequences
|
||||
var hitStop = false
|
||||
for s in stop where accumulated.hasSuffix(s) {
|
||||
hitStop = true
|
||||
print("[mlx] Hit stop sequence: \"\(s)\"")
|
||||
break
|
||||
}
|
||||
if hitStop { break }
|
||||
|
||||
case .info(let info):
|
||||
print("[mlx] Stream generation info: \(info.promptTokenCount) prompt tokens, \(info.generationTokenCount) generated tokens")
|
||||
print("[mlx] Prompt: \(String(format: "%.1f", info.promptTokensPerSecond)) tokens/sec")
|
||||
print("[mlx] Generation: \(String(format: "%.1f", info.tokensPerSecond)) tokens/sec")
|
||||
|
||||
let usage = UsageInfo(
|
||||
prompt_tokens: info.promptTokenCount,
|
||||
completion_tokens: info.generationTokenCount,
|
||||
total_tokens: info.promptTokenCount + info.generationTokenCount
|
||||
)
|
||||
let timings = TimingsInfo(
|
||||
prompt_n: info.promptTokenCount,
|
||||
predicted_n: info.generationTokenCount,
|
||||
predicted_per_second: info.tokensPerSecond,
|
||||
prompt_per_second: info.promptTokensPerSecond
|
||||
)
|
||||
continuation.yield(.done(usage: usage, timings: timings, hasToolCalls: hasToolCalls))
|
||||
|
||||
case .toolCall(let toolCall):
|
||||
hasToolCalls = true
|
||||
let argsData = try JSONSerialization.data(
|
||||
withJSONObject: toolCall.function.arguments.mapValues { $0.anyValue },
|
||||
options: [.sortedKeys]
|
||||
)
|
||||
let argsString = String(data: argsData, encoding: .utf8) ?? "{}"
|
||||
let info = ToolCallInfo(
|
||||
id: generateToolCallId(),
|
||||
type: "function",
|
||||
function: FunctionCall(
|
||||
name: toolCall.function.name,
|
||||
arguments: argsString
|
||||
)
|
||||
)
|
||||
print("[mlx] Stream tool call: \(toolCall.function.name)(\(argsString))")
|
||||
continuation.yield(.toolCall(info))
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
print("[mlx] Error during stream generation: \(error.localizedDescription)")
|
||||
throw error
|
||||
}
|
||||
|
||||
// If no .info was received, send done with fallback usage
|
||||
// The .info case already yields .done, so only send if we haven't
|
||||
print("[mlx] Stream complete: \(accumulated.count) chars")
|
||||
continuation.finish()
|
||||
}
|
||||
} catch {
|
||||
print("[mlx] Error in stream: \(error.localizedDescription)")
|
||||
continuation.finish(throwing: error)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Events emitted during streaming generation
|
||||
enum StreamEvent {
|
||||
/// A text chunk
|
||||
case chunk(String)
|
||||
/// A tool call from the model
|
||||
case toolCall(ToolCallInfo)
|
||||
/// Generation complete with usage and timing info
|
||||
case done(usage: UsageInfo, timings: TimingsInfo?, hasToolCalls: Bool)
|
||||
}
|
||||
|
||||
enum MLXServerError: Error, LocalizedError {
|
||||
case modelNotLoaded
|
||||
case invalidRequest(String)
|
||||
|
||||
var errorDescription: String? {
|
||||
switch self {
|
||||
case .modelNotLoaded:
|
||||
return "No model is currently loaded"
|
||||
case .invalidRequest(let msg):
|
||||
return "Invalid request: \(msg)"
|
||||
}
|
||||
}
|
||||
}
|
||||
225
mlx-server/Sources/MLXServer/OpenAITypes.swift
Normal file
225
mlx-server/Sources/MLXServer/OpenAITypes.swift
Normal file
@@ -0,0 +1,225 @@
|
||||
import Foundation
|
||||
|
||||
// MARK: - Chat Completion Request
|
||||
|
||||
struct ChatCompletionRequest: Codable {
|
||||
let model: String
|
||||
let messages: [ChatMessage]
|
||||
var temperature: Float?
|
||||
var top_p: Float?
|
||||
var max_tokens: Int?
|
||||
var stream: Bool?
|
||||
var stop: [String]?
|
||||
var n_predict: Int?
|
||||
var repetition_penalty: Float?
|
||||
var tools: [AnyCodable]?
|
||||
}
|
||||
|
||||
struct AnyCodable: Codable, @unchecked Sendable {
|
||||
let value: Any
|
||||
|
||||
init(_ value: Any) {
|
||||
self.value = value
|
||||
}
|
||||
|
||||
func encode(to encoder: Encoder) throws {
|
||||
var container = encoder.singleValueContainer()
|
||||
|
||||
if let string = value as? String {
|
||||
try container.encode(string)
|
||||
} else if let int = value as? Int {
|
||||
try container.encode(int)
|
||||
} else if let double = value as? Double {
|
||||
try container.encode(double)
|
||||
} else if let bool = value as? Bool {
|
||||
try container.encode(bool)
|
||||
} else if let array = value as? [Any] {
|
||||
try container.encode(array.map { AnyCodable($0) })
|
||||
} else if let dict = value as? [String: Any] {
|
||||
try container.encode(dict.mapValues { AnyCodable($0) })
|
||||
} else {
|
||||
throw EncodingError.invalidValue(value, EncodingError.Context(codingPath: encoder.codingPath, debugDescription: "Unsupported type"))
|
||||
}
|
||||
}
|
||||
|
||||
init(from decoder: Decoder) throws {
|
||||
let container = try decoder.singleValueContainer()
|
||||
|
||||
if let string = try? container.decode(String.self) {
|
||||
value = string
|
||||
} else if let int = try? container.decode(Int.self) {
|
||||
value = int
|
||||
} else if let double = try? container.decode(Double.self) {
|
||||
value = double
|
||||
} else if let bool = try? container.decode(Bool.self) {
|
||||
value = bool
|
||||
} else if let array = try? container.decode([AnyCodable].self) {
|
||||
value = array.map { $0.value }
|
||||
} else if let dict = try? container.decode([String: AnyCodable].self) {
|
||||
value = dict.mapValues { $0.value }
|
||||
} else {
|
||||
throw DecodingError.dataCorrupted(DecodingError.Context(codingPath: decoder.codingPath, debugDescription: "Unsupported type"))
|
||||
}
|
||||
}
|
||||
|
||||
/// Recursively convert the underlying value to `[String: any Sendable]` or primitive Sendable types
|
||||
func toSendable() -> any Sendable {
|
||||
switch value {
|
||||
case let string as String:
|
||||
return string
|
||||
case let int as Int:
|
||||
return int
|
||||
case let double as Double:
|
||||
return double
|
||||
case let bool as Bool:
|
||||
return bool
|
||||
case let array as [Any]:
|
||||
return array.map { AnyCodable($0).toSendable() }
|
||||
case let dict as [String: Any]:
|
||||
return dict.mapValues { AnyCodable($0).toSendable() }
|
||||
default:
|
||||
return String(describing: value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ChatMessage: Codable {
|
||||
let role: String
|
||||
var content: String?
|
||||
var images: [String]?
|
||||
var videos: [String]?
|
||||
var tool_calls: [ToolCallInfo]?
|
||||
var tool_call_id: String?
|
||||
var name: String?
|
||||
}
|
||||
|
||||
// MARK: - Tool Call Types (OpenAI-compatible)
|
||||
|
||||
struct ToolCallInfo: Codable {
|
||||
let id: String
|
||||
let type: String
|
||||
let function: FunctionCall
|
||||
}
|
||||
|
||||
struct FunctionCall: Codable {
|
||||
let name: String
|
||||
let arguments: String
|
||||
}
|
||||
|
||||
struct ToolCallDelta: Codable {
|
||||
let index: Int
|
||||
var id: String?
|
||||
var type: String?
|
||||
var function: FunctionCallDelta?
|
||||
}
|
||||
|
||||
struct FunctionCallDelta: Codable {
|
||||
var name: String?
|
||||
var arguments: String?
|
||||
}
|
||||
|
||||
// MARK: - Chat Completion Response (non-streaming)
|
||||
|
||||
struct ChatCompletionResponse: Codable {
|
||||
let id: String
|
||||
let object: String
|
||||
let created: Int
|
||||
let model: String
|
||||
let choices: [ChatChoice]
|
||||
var usage: UsageInfo?
|
||||
}
|
||||
|
||||
struct ChatChoice: Codable {
|
||||
let index: Int
|
||||
let message: ChatMessage
|
||||
let finish_reason: String?
|
||||
}
|
||||
|
||||
struct UsageInfo: Codable {
|
||||
let prompt_tokens: Int
|
||||
let completion_tokens: Int
|
||||
let total_tokens: Int
|
||||
}
|
||||
|
||||
// MARK: - Chat Completion Chunk (streaming)
|
||||
|
||||
struct ChatCompletionChunk: Codable {
|
||||
let id: String
|
||||
let object: String
|
||||
let created: Int
|
||||
let model: String
|
||||
let choices: [ChatChunkChoice]
|
||||
var usage: UsageInfo?
|
||||
var timings: TimingsInfo?
|
||||
}
|
||||
|
||||
struct ChatChunkChoice: Codable {
|
||||
let index: Int
|
||||
let delta: ChatDelta
|
||||
let finish_reason: String?
|
||||
}
|
||||
|
||||
struct ChatDelta: Codable {
|
||||
var role: String?
|
||||
var content: String?
|
||||
var tool_calls: [ToolCallDelta]?
|
||||
}
|
||||
|
||||
struct TimingsInfo: Codable {
|
||||
var prompt_n: Int?
|
||||
var predicted_n: Int?
|
||||
var predicted_per_second: Double?
|
||||
var prompt_per_second: Double?
|
||||
}
|
||||
|
||||
// MARK: - Models List Response
|
||||
|
||||
struct ModelsResponse: Codable {
|
||||
let object: String
|
||||
let data: [ModelInfo]
|
||||
}
|
||||
|
||||
struct ModelInfo: Codable {
|
||||
let id: String
|
||||
let object: String
|
||||
let created: Int
|
||||
let owned_by: String
|
||||
}
|
||||
|
||||
// MARK: - Health Response
|
||||
|
||||
struct HealthResponse: Codable {
|
||||
let status: String
|
||||
}
|
||||
|
||||
// MARK: - Error Response
|
||||
|
||||
struct ErrorResponse: Codable {
|
||||
let error: ErrorDetail
|
||||
}
|
||||
|
||||
struct ErrorDetail: Codable {
|
||||
let message: String
|
||||
let type_name: String
|
||||
let code: String?
|
||||
|
||||
enum CodingKeys: String, CodingKey {
|
||||
case message
|
||||
case type_name = "type"
|
||||
case code
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Helpers
|
||||
|
||||
func generateResponseId() -> String {
|
||||
"chatcmpl-\(UUID().uuidString.prefix(12))"
|
||||
}
|
||||
|
||||
func generateToolCallId() -> String {
|
||||
"call_\(UUID().uuidString.replacingOccurrences(of: "-", with: "").prefix(24))"
|
||||
}
|
||||
|
||||
func currentTimestamp() -> Int {
|
||||
Int(Date().timeIntervalSince1970)
|
||||
}
|
||||
340
mlx-server/Sources/MLXServer/Server.swift
Normal file
340
mlx-server/Sources/MLXServer/Server.swift
Normal file
@@ -0,0 +1,340 @@
|
||||
import Foundation
|
||||
import Hummingbird
|
||||
|
||||
/// HTTP server that exposes an OpenAI-compatible API backed by MLX
|
||||
struct MLXHTTPServer {
|
||||
let modelRunner: ModelRunner
|
||||
let modelId: String
|
||||
let apiKey: String
|
||||
|
||||
func buildRouter() -> Router<BasicRequestContext> {
|
||||
let router = Router()
|
||||
|
||||
// Health check
|
||||
router.get("/health") { _, _ in
|
||||
let response = HealthResponse(status: "ok")
|
||||
return try encodeJSON(response)
|
||||
}
|
||||
|
||||
// List models
|
||||
router.get("/v1/models") { _, _ in
|
||||
let response = ModelsResponse(
|
||||
object: "list",
|
||||
data: [
|
||||
ModelInfo(
|
||||
id: self.modelId,
|
||||
object: "model",
|
||||
created: currentTimestamp(),
|
||||
owned_by: "mlx"
|
||||
)
|
||||
]
|
||||
)
|
||||
return try encodeJSON(response)
|
||||
}
|
||||
|
||||
// Chat completions
|
||||
router.post("/v1/chat/completions") { request, context in
|
||||
// Validate API key if set
|
||||
if !self.apiKey.isEmpty {
|
||||
let authHeader =
|
||||
request.headers[.authorization]
|
||||
let expectedAuth = "Bearer \(self.apiKey)"
|
||||
if authHeader != expectedAuth {
|
||||
let error = ErrorResponse(
|
||||
error: ErrorDetail(
|
||||
message: "Unauthorized",
|
||||
type_name: "authentication_error",
|
||||
code: "unauthorized"
|
||||
)
|
||||
)
|
||||
let response = try Response(
|
||||
status: .unauthorized,
|
||||
headers: [.contentType: "application/json"],
|
||||
body: .init(byteBuffer: encodeJSONBuffer(error))
|
||||
)
|
||||
return response
|
||||
}
|
||||
}
|
||||
|
||||
// Parse request body
|
||||
let body = try await request.body.collect(upTo: 10 * 1024 * 1024) // 10MB max
|
||||
let chatRequest = try JSONDecoder().decode(ChatCompletionRequest.self, from: body)
|
||||
|
||||
let temperature = chatRequest.temperature ?? 0.7
|
||||
let topP = chatRequest.top_p ?? 1.0
|
||||
let maxTokens = chatRequest.max_tokens ?? chatRequest.n_predict ?? 2048
|
||||
let repetitionPenalty = chatRequest.repetition_penalty ?? 1.0
|
||||
let stop = chatRequest.stop ?? []
|
||||
let isStreaming = chatRequest.stream ?? false
|
||||
let tools = chatRequest.tools
|
||||
|
||||
print("[mlx] Request: model=\(chatRequest.model), messages=\(chatRequest.messages.count), stream=\(isStreaming), tools=\(tools?.count ?? 0)")
|
||||
|
||||
if isStreaming {
|
||||
return try await self.handleStreamingRequest(
|
||||
chatRequest: chatRequest,
|
||||
temperature: temperature,
|
||||
topP: topP,
|
||||
maxTokens: maxTokens,
|
||||
repetitionPenalty: repetitionPenalty,
|
||||
stop: stop,
|
||||
tools: tools
|
||||
)
|
||||
} else {
|
||||
return try await self.handleNonStreamingRequest(
|
||||
chatRequest: chatRequest,
|
||||
temperature: temperature,
|
||||
topP: topP,
|
||||
maxTokens: maxTokens,
|
||||
repetitionPenalty: repetitionPenalty,
|
||||
stop: stop,
|
||||
tools: tools
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
private func handleNonStreamingRequest(
|
||||
chatRequest: ChatCompletionRequest,
|
||||
temperature: Float,
|
||||
topP: Float,
|
||||
maxTokens: Int,
|
||||
repetitionPenalty: Float,
|
||||
stop: [String],
|
||||
tools: [AnyCodable]? = nil
|
||||
) async throws -> Response {
|
||||
let (text, toolCalls, usage) = try await modelRunner.generate(
|
||||
messages: chatRequest.messages,
|
||||
temperature: temperature,
|
||||
topP: topP,
|
||||
maxTokens: maxTokens,
|
||||
repetitionPenalty: repetitionPenalty,
|
||||
stop: stop,
|
||||
tools: tools
|
||||
)
|
||||
|
||||
let finishReason = toolCalls.isEmpty ? "stop" : "tool_calls"
|
||||
let message = ChatMessage(
|
||||
role: "assistant",
|
||||
content: text.isEmpty && !toolCalls.isEmpty ? nil : text,
|
||||
tool_calls: toolCalls.isEmpty ? nil : toolCalls
|
||||
)
|
||||
|
||||
let response = ChatCompletionResponse(
|
||||
id: generateResponseId(),
|
||||
object: "chat.completion",
|
||||
created: currentTimestamp(),
|
||||
model: chatRequest.model,
|
||||
choices: [
|
||||
ChatChoice(
|
||||
index: 0,
|
||||
message: message,
|
||||
finish_reason: finishReason
|
||||
)
|
||||
],
|
||||
usage: usage
|
||||
)
|
||||
|
||||
print("[mlx] Response: \(text.count) chars, \(toolCalls.count) tool call(s), finish=\(finishReason)")
|
||||
|
||||
return try Response(
|
||||
status: .ok,
|
||||
headers: [.contentType: "application/json"],
|
||||
body: .init(byteBuffer: encodeJSONBuffer(response))
|
||||
)
|
||||
}
|
||||
|
||||
private func handleStreamingRequest(
|
||||
chatRequest: ChatCompletionRequest,
|
||||
temperature: Float,
|
||||
topP: Float,
|
||||
maxTokens: Int,
|
||||
repetitionPenalty: Float,
|
||||
stop: [String],
|
||||
tools: [AnyCodable]? = nil
|
||||
) async throws -> Response {
|
||||
let responseId = generateResponseId()
|
||||
let created = currentTimestamp()
|
||||
let model = chatRequest.model
|
||||
|
||||
let stream = await modelRunner.generateStream(
|
||||
messages: chatRequest.messages,
|
||||
temperature: temperature,
|
||||
topP: topP,
|
||||
maxTokens: maxTokens,
|
||||
repetitionPenalty: repetitionPenalty,
|
||||
stop: stop,
|
||||
tools: tools
|
||||
)
|
||||
|
||||
// Build SSE response body
|
||||
let responseStream = AsyncStream<ByteBuffer> { continuation in
|
||||
Task {
|
||||
// Send initial role chunk
|
||||
let initialChunk = ChatCompletionChunk(
|
||||
id: responseId,
|
||||
object: "chat.completion.chunk",
|
||||
created: created,
|
||||
model: model,
|
||||
choices: [
|
||||
ChatChunkChoice(
|
||||
index: 0,
|
||||
delta: ChatDelta(role: "assistant", content: nil),
|
||||
finish_reason: nil
|
||||
)
|
||||
]
|
||||
)
|
||||
if let data = try? encodeJSONData(initialChunk) {
|
||||
var buffer = ByteBufferAllocator().buffer(capacity: data.count + 8)
|
||||
buffer.writeString("data: ")
|
||||
buffer.writeBytes(data)
|
||||
buffer.writeString("\n\n")
|
||||
continuation.yield(buffer)
|
||||
}
|
||||
|
||||
do {
|
||||
for try await event in stream {
|
||||
switch event {
|
||||
case .chunk(let token):
|
||||
let chunk = ChatCompletionChunk(
|
||||
id: responseId,
|
||||
object: "chat.completion.chunk",
|
||||
created: created,
|
||||
model: model,
|
||||
choices: [
|
||||
ChatChunkChoice(
|
||||
index: 0,
|
||||
delta: ChatDelta(role: nil, content: token),
|
||||
finish_reason: nil
|
||||
)
|
||||
]
|
||||
)
|
||||
if let data = try? encodeJSONData(chunk) {
|
||||
var buffer = ByteBufferAllocator().buffer(
|
||||
capacity: data.count + 8)
|
||||
buffer.writeString("data: ")
|
||||
buffer.writeBytes(data)
|
||||
buffer.writeString("\n\n")
|
||||
continuation.yield(buffer)
|
||||
}
|
||||
|
||||
case .toolCall(let toolCallInfo):
|
||||
let chunk = ChatCompletionChunk(
|
||||
id: responseId,
|
||||
object: "chat.completion.chunk",
|
||||
created: created,
|
||||
model: model,
|
||||
choices: [
|
||||
ChatChunkChoice(
|
||||
index: 0,
|
||||
delta: ChatDelta(
|
||||
role: nil,
|
||||
content: nil,
|
||||
tool_calls: [
|
||||
ToolCallDelta(
|
||||
index: 0,
|
||||
id: toolCallInfo.id,
|
||||
type: toolCallInfo.type,
|
||||
function: FunctionCallDelta(
|
||||
name: toolCallInfo.function.name,
|
||||
arguments: toolCallInfo.function.arguments
|
||||
)
|
||||
)
|
||||
]
|
||||
),
|
||||
finish_reason: nil
|
||||
)
|
||||
]
|
||||
)
|
||||
if let data = try? encodeJSONData(chunk) {
|
||||
var buffer = ByteBufferAllocator().buffer(
|
||||
capacity: data.count + 8)
|
||||
buffer.writeString("data: ")
|
||||
buffer.writeBytes(data)
|
||||
buffer.writeString("\n\n")
|
||||
continuation.yield(buffer)
|
||||
}
|
||||
|
||||
case .done(let usage, let timings, let hasToolCalls):
|
||||
let finishReason = hasToolCalls ? "tool_calls" : "stop"
|
||||
// Final chunk with finish_reason
|
||||
let finalChunk = ChatCompletionChunk(
|
||||
id: responseId,
|
||||
object: "chat.completion.chunk",
|
||||
created: created,
|
||||
model: model,
|
||||
choices: [
|
||||
ChatChunkChoice(
|
||||
index: 0,
|
||||
delta: ChatDelta(role: nil, content: nil),
|
||||
finish_reason: finishReason
|
||||
)
|
||||
],
|
||||
usage: usage,
|
||||
timings: timings
|
||||
)
|
||||
if let data = try? encodeJSONData(finalChunk) {
|
||||
var buffer = ByteBufferAllocator().buffer(
|
||||
capacity: data.count + 8)
|
||||
buffer.writeString("data: ")
|
||||
buffer.writeBytes(data)
|
||||
buffer.writeString("\n\n")
|
||||
continuation.yield(buffer)
|
||||
}
|
||||
|
||||
// Send [DONE]
|
||||
var doneBuffer = ByteBufferAllocator().buffer(capacity: 16)
|
||||
doneBuffer.writeString("data: [DONE]\n\n")
|
||||
continuation.yield(doneBuffer)
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
print("[mlx] Error in SSE stream: \(error.localizedDescription)")
|
||||
// Send error as SSE event
|
||||
var buffer = ByteBufferAllocator().buffer(capacity: 256)
|
||||
buffer.writeString(
|
||||
"error: {\"message\":\"\(error.localizedDescription)\"}\n\n")
|
||||
continuation.yield(buffer)
|
||||
}
|
||||
|
||||
continuation.finish()
|
||||
}
|
||||
}
|
||||
|
||||
return Response(
|
||||
status: .ok,
|
||||
headers: [
|
||||
.contentType: "text/event-stream",
|
||||
.init("Cache-Control")!: "no-cache",
|
||||
.init("Connection")!: "keep-alive",
|
||||
],
|
||||
body: .init(asyncSequence: responseStream)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - JSON Encoding Helpers
|
||||
|
||||
private func encodeJSON<T: Encodable>(_ value: T) throws -> Response {
|
||||
let data = try JSONEncoder().encode(value)
|
||||
var buffer = ByteBufferAllocator().buffer(capacity: data.count)
|
||||
buffer.writeBytes(data)
|
||||
return Response(
|
||||
status: .ok,
|
||||
headers: [.contentType: "application/json"],
|
||||
body: .init(byteBuffer: buffer)
|
||||
)
|
||||
}
|
||||
|
||||
private func encodeJSONBuffer<T: Encodable>(_ value: T) throws -> ByteBuffer {
|
||||
let data = try JSONEncoder().encode(value)
|
||||
var buffer = ByteBufferAllocator().buffer(capacity: data.count)
|
||||
buffer.writeBytes(data)
|
||||
return buffer
|
||||
}
|
||||
|
||||
private func encodeJSONData<T: Encodable>(_ value: T) throws -> Data {
|
||||
try JSONEncoder().encode(value)
|
||||
}
|
||||
16
src-tauri/Cargo.lock
generated
16
src-tauri/Cargo.lock
generated
@@ -41,6 +41,7 @@ dependencies = [
|
||||
"tauri-plugin-http",
|
||||
"tauri-plugin-llamacpp",
|
||||
"tauri-plugin-log",
|
||||
"tauri-plugin-mlx",
|
||||
"tauri-plugin-opener",
|
||||
"tauri-plugin-os",
|
||||
"tauri-plugin-rag",
|
||||
@@ -6311,6 +6312,21 @@ dependencies = [
|
||||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tauri-plugin-mlx"
|
||||
version = "0.6.599"
|
||||
dependencies = [
|
||||
"jan-utils",
|
||||
"log",
|
||||
"nix",
|
||||
"serde",
|
||||
"sysinfo",
|
||||
"tauri",
|
||||
"tauri-plugin",
|
||||
"thiserror 2.0.17",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tauri-plugin-opener"
|
||||
version = "2.5.0"
|
||||
|
||||
@@ -27,9 +27,11 @@ default = [
|
||||
]
|
||||
hardware = ["dep:tauri-plugin-hardware"]
|
||||
deep-link = ["dep:tauri-plugin-deep-link"]
|
||||
mlx = ["dep:tauri-plugin-mlx"]
|
||||
desktop = [
|
||||
"deep-link",
|
||||
"hardware"
|
||||
"hardware",
|
||||
"mlx"
|
||||
]
|
||||
mobile = [
|
||||
"tauri/protocol-asset",
|
||||
@@ -82,6 +84,7 @@ zip = "0.6"
|
||||
tauri-plugin-deep-link = { version = "2", optional = true }
|
||||
tauri-plugin-hardware = { path = "./plugins/tauri-plugin-hardware", optional = true }
|
||||
tauri-plugin-llamacpp = { path = "./plugins/tauri-plugin-llamacpp" }
|
||||
tauri-plugin-mlx = { path = "./plugins/tauri-plugin-mlx", optional = true }
|
||||
tauri-plugin-vector-db = { path = "./plugins/tauri-plugin-vector-db" }
|
||||
tauri-plugin-rag = { path = "./plugins/tauri-plugin-rag" }
|
||||
tauri-plugin-http = { version = "2", features = ["unsafe-headers"] }
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
"hardware:default",
|
||||
"deep-link:default",
|
||||
"llamacpp:default",
|
||||
"mlx:default",
|
||||
"updater:default",
|
||||
"updater:allow-check",
|
||||
{
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
"vector-db:default",
|
||||
"rag:default",
|
||||
"llamacpp:default",
|
||||
"mlx:default",
|
||||
"deep-link:default",
|
||||
"hardware:default",
|
||||
{
|
||||
|
||||
@@ -79,6 +79,7 @@ export interface DownloadItem {
|
||||
proxy?: Record<string, string | string[] | boolean>
|
||||
sha256?: string
|
||||
size?: number
|
||||
model_id?: string
|
||||
}
|
||||
|
||||
export interface ModelConfig {
|
||||
@@ -90,6 +91,7 @@ export interface ModelConfig {
|
||||
sha256?: string
|
||||
mmproj_sha256?: string
|
||||
mmproj_size_bytes?: number
|
||||
embedding?: boolean
|
||||
}
|
||||
|
||||
export interface EmbeddingResponse {
|
||||
|
||||
27
src-tauri/plugins/tauri-plugin-mlx/Cargo.toml
Normal file
27
src-tauri/plugins/tauri-plugin-mlx/Cargo.toml
Normal file
@@ -0,0 +1,27 @@
|
||||
[package]
|
||||
name = "tauri-plugin-mlx"
|
||||
version = "0.6.599"
|
||||
authors = ["Jan <service@jan.ai>"]
|
||||
description = "Tauri plugin for managing MLX-Swift server processes and model loading on Apple Silicon"
|
||||
license = "MIT"
|
||||
repository = "https://github.com/janhq/jan"
|
||||
edition = "2021"
|
||||
rust-version = "1.77.2"
|
||||
exclude = ["/examples", "/dist-js", "/guest-js", "/node_modules"]
|
||||
links = "tauri-plugin-mlx"
|
||||
|
||||
[dependencies]
|
||||
log = "0.4"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
sysinfo = "0.34.2"
|
||||
tauri = { version = "2.5.0", default-features = false, features = [] }
|
||||
thiserror = "2.0.12"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
jan-utils = { path = "../../utils" }
|
||||
|
||||
# Unix-specific dependencies (macOS)
|
||||
[target.'cfg(unix)'.dependencies]
|
||||
nix = { version = "=0.30.1", features = ["signal", "process"] }
|
||||
|
||||
[build-dependencies]
|
||||
tauri-plugin = { version = "2.3.1", features = ["build"] }
|
||||
14
src-tauri/plugins/tauri-plugin-mlx/build.rs
Normal file
14
src-tauri/plugins/tauri-plugin-mlx/build.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
const COMMANDS: &[&str] = &[
|
||||
"cleanup_mlx_processes",
|
||||
"load_mlx_model",
|
||||
"unload_mlx_model",
|
||||
"is_mlx_process_running",
|
||||
"get_mlx_random_port",
|
||||
"find_mlx_session_by_model",
|
||||
"get_mlx_loaded_models",
|
||||
"get_mlx_all_sessions",
|
||||
];
|
||||
|
||||
fn main() {
|
||||
tauri_plugin::Builder::new(COMMANDS).build();
|
||||
}
|
||||
63
src-tauri/plugins/tauri-plugin-mlx/dist-js/index.cjs
Normal file
63
src-tauri/plugins/tauri-plugin-mlx/dist-js/index.cjs
Normal file
@@ -0,0 +1,63 @@
|
||||
'use strict';
|
||||
|
||||
var core = require('@tauri-apps/api/core');
|
||||
|
||||
function asNumber(v, defaultValue = 0) {
|
||||
if (v === '' || v === null || v === undefined)
|
||||
return defaultValue;
|
||||
const n = Number(v);
|
||||
return isFinite(n) ? n : defaultValue;
|
||||
}
|
||||
function asString(v, defaultValue = '') {
|
||||
if (v === '' || v === null || v === undefined)
|
||||
return defaultValue;
|
||||
return String(v);
|
||||
}
|
||||
function normalizeMlxConfig(config) {
|
||||
return {
|
||||
ctx_size: asNumber(config.ctx_size),
|
||||
n_predict: asNumber(config.n_predict),
|
||||
threads: asNumber(config.threads),
|
||||
chat_template: asString(config.chat_template),
|
||||
};
|
||||
}
|
||||
async function loadMlxModel(binaryPath, modelId, modelPath, port, cfg, envs, isEmbedding = false, timeout = 600) {
|
||||
const config = normalizeMlxConfig(cfg);
|
||||
return await core.invoke('plugin:mlx|load_mlx_model', {
|
||||
binaryPath,
|
||||
modelId,
|
||||
modelPath,
|
||||
port,
|
||||
config,
|
||||
envs,
|
||||
isEmbedding,
|
||||
timeout,
|
||||
});
|
||||
}
|
||||
async function unloadMlxModel(pid) {
|
||||
return await core.invoke('plugin:mlx|unload_mlx_model', { pid });
|
||||
}
|
||||
async function isMlxProcessRunning(pid) {
|
||||
return await core.invoke('plugin:mlx|is_mlx_process_running', { pid });
|
||||
}
|
||||
async function getMlxRandomPort() {
|
||||
return await core.invoke('plugin:mlx|get_mlx_random_port');
|
||||
}
|
||||
async function findMlxSessionByModel(modelId) {
|
||||
return await core.invoke('plugin:mlx|find_mlx_session_by_model', { modelId });
|
||||
}
|
||||
async function getMlxLoadedModels() {
|
||||
return await core.invoke('plugin:mlx|get_mlx_loaded_models');
|
||||
}
|
||||
async function getMlxAllSessions() {
|
||||
return await core.invoke('plugin:mlx|get_mlx_all_sessions');
|
||||
}
|
||||
|
||||
exports.findMlxSessionByModel = findMlxSessionByModel;
|
||||
exports.getMlxAllSessions = getMlxAllSessions;
|
||||
exports.getMlxLoadedModels = getMlxLoadedModels;
|
||||
exports.getMlxRandomPort = getMlxRandomPort;
|
||||
exports.isMlxProcessRunning = isMlxProcessRunning;
|
||||
exports.loadMlxModel = loadMlxModel;
|
||||
exports.normalizeMlxConfig = normalizeMlxConfig;
|
||||
exports.unloadMlxModel = unloadMlxModel;
|
||||
10
src-tauri/plugins/tauri-plugin-mlx/dist-js/index.d.ts
vendored
Normal file
10
src-tauri/plugins/tauri-plugin-mlx/dist-js/index.d.ts
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
import { SessionInfo, UnloadResult, MlxConfig } from './types';
|
||||
export { SessionInfo, UnloadResult, MlxConfig } from './types';
|
||||
export declare function normalizeMlxConfig(config: any): MlxConfig;
|
||||
export declare function loadMlxModel(binaryPath: string, modelId: string, modelPath: string, port: number, cfg: MlxConfig, envs: Record<string, string>, isEmbedding?: boolean, timeout?: number): Promise<SessionInfo>;
|
||||
export declare function unloadMlxModel(pid: number): Promise<UnloadResult>;
|
||||
export declare function isMlxProcessRunning(pid: number): Promise<boolean>;
|
||||
export declare function getMlxRandomPort(): Promise<number>;
|
||||
export declare function findMlxSessionByModel(modelId: string): Promise<SessionInfo | null>;
|
||||
export declare function getMlxLoadedModels(): Promise<string[]>;
|
||||
export declare function getMlxAllSessions(): Promise<SessionInfo[]>;
|
||||
54
src-tauri/plugins/tauri-plugin-mlx/dist-js/index.js
Normal file
54
src-tauri/plugins/tauri-plugin-mlx/dist-js/index.js
Normal file
@@ -0,0 +1,54 @@
|
||||
import { invoke } from '@tauri-apps/api/core';
|
||||
|
||||
function asNumber(v, defaultValue = 0) {
|
||||
if (v === '' || v === null || v === undefined)
|
||||
return defaultValue;
|
||||
const n = Number(v);
|
||||
return isFinite(n) ? n : defaultValue;
|
||||
}
|
||||
function asString(v, defaultValue = '') {
|
||||
if (v === '' || v === null || v === undefined)
|
||||
return defaultValue;
|
||||
return String(v);
|
||||
}
|
||||
function normalizeMlxConfig(config) {
|
||||
return {
|
||||
ctx_size: asNumber(config.ctx_size),
|
||||
n_predict: asNumber(config.n_predict),
|
||||
threads: asNumber(config.threads),
|
||||
chat_template: asString(config.chat_template),
|
||||
};
|
||||
}
|
||||
async function loadMlxModel(binaryPath, modelId, modelPath, port, cfg, envs, isEmbedding = false, timeout = 600) {
|
||||
const config = normalizeMlxConfig(cfg);
|
||||
return await invoke('plugin:mlx|load_mlx_model', {
|
||||
binaryPath,
|
||||
modelId,
|
||||
modelPath,
|
||||
port,
|
||||
config,
|
||||
envs,
|
||||
isEmbedding,
|
||||
timeout,
|
||||
});
|
||||
}
|
||||
async function unloadMlxModel(pid) {
|
||||
return await invoke('plugin:mlx|unload_mlx_model', { pid });
|
||||
}
|
||||
async function isMlxProcessRunning(pid) {
|
||||
return await invoke('plugin:mlx|is_mlx_process_running', { pid });
|
||||
}
|
||||
async function getMlxRandomPort() {
|
||||
return await invoke('plugin:mlx|get_mlx_random_port');
|
||||
}
|
||||
async function findMlxSessionByModel(modelId) {
|
||||
return await invoke('plugin:mlx|find_mlx_session_by_model', { modelId });
|
||||
}
|
||||
async function getMlxLoadedModels() {
|
||||
return await invoke('plugin:mlx|get_mlx_loaded_models');
|
||||
}
|
||||
async function getMlxAllSessions() {
|
||||
return await invoke('plugin:mlx|get_mlx_all_sessions');
|
||||
}
|
||||
|
||||
export { findMlxSessionByModel, getMlxAllSessions, getMlxLoadedModels, getMlxRandomPort, isMlxProcessRunning, loadMlxModel, normalizeMlxConfig, unloadMlxModel };
|
||||
18
src-tauri/plugins/tauri-plugin-mlx/dist-js/types.d.ts
vendored
Normal file
18
src-tauri/plugins/tauri-plugin-mlx/dist-js/types.d.ts
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
export interface SessionInfo {
|
||||
pid: number;
|
||||
port: number;
|
||||
model_id: string;
|
||||
model_path: string;
|
||||
is_embedding: boolean;
|
||||
api_key: string;
|
||||
}
|
||||
export interface UnloadResult {
|
||||
success: boolean;
|
||||
error?: string;
|
||||
}
|
||||
export type MlxConfig = {
|
||||
ctx_size: number;
|
||||
n_predict: number;
|
||||
threads: number;
|
||||
chat_template: string;
|
||||
};
|
||||
73
src-tauri/plugins/tauri-plugin-mlx/guest-js/index.ts
Normal file
73
src-tauri/plugins/tauri-plugin-mlx/guest-js/index.ts
Normal file
@@ -0,0 +1,73 @@
|
||||
import { invoke } from '@tauri-apps/api/core'
|
||||
import { SessionInfo, UnloadResult, MlxConfig } from './types'
|
||||
|
||||
export { SessionInfo, UnloadResult, MlxConfig } from './types'
|
||||
|
||||
function asNumber(v: any, defaultValue = 0): number {
|
||||
if (v === '' || v === null || v === undefined) return defaultValue
|
||||
const n = Number(v)
|
||||
return isFinite(n) ? n : defaultValue
|
||||
}
|
||||
|
||||
function asString(v: any, defaultValue = ''): string {
|
||||
if (v === '' || v === null || v === undefined) return defaultValue
|
||||
return String(v)
|
||||
}
|
||||
|
||||
export function normalizeMlxConfig(config: any): MlxConfig {
|
||||
return {
|
||||
ctx_size: asNumber(config.ctx_size),
|
||||
n_predict: asNumber(config.n_predict),
|
||||
threads: asNumber(config.threads),
|
||||
chat_template: asString(config.chat_template),
|
||||
}
|
||||
}
|
||||
|
||||
export async function loadMlxModel(
|
||||
binaryPath: string,
|
||||
modelId: string,
|
||||
modelPath: string,
|
||||
port: number,
|
||||
cfg: MlxConfig,
|
||||
envs: Record<string, string>,
|
||||
isEmbedding: boolean = false,
|
||||
timeout: number = 600
|
||||
): Promise<SessionInfo> {
|
||||
const config = normalizeMlxConfig(cfg)
|
||||
return await invoke('plugin:mlx|load_mlx_model', {
|
||||
binaryPath,
|
||||
modelId,
|
||||
modelPath,
|
||||
port,
|
||||
config,
|
||||
envs,
|
||||
isEmbedding,
|
||||
timeout,
|
||||
})
|
||||
}
|
||||
|
||||
export async function unloadMlxModel(pid: number): Promise<UnloadResult> {
|
||||
return await invoke('plugin:mlx|unload_mlx_model', { pid })
|
||||
}
|
||||
|
||||
export async function isMlxProcessRunning(pid: number): Promise<boolean> {
|
||||
return await invoke('plugin:mlx|is_mlx_process_running', { pid })
|
||||
}
|
||||
|
||||
export async function getMlxRandomPort(): Promise<number> {
|
||||
return await invoke('plugin:mlx|get_mlx_random_port')
|
||||
}
|
||||
|
||||
export async function findMlxSessionByModel(
|
||||
modelId: string
|
||||
): Promise<SessionInfo | null> {
|
||||
return await invoke('plugin:mlx|find_mlx_session_by_model', { modelId })
|
||||
}
|
||||
|
||||
export async function getMlxLoadedModels(): Promise<string[]> {
|
||||
return await invoke('plugin:mlx|get_mlx_loaded_models')
|
||||
}
|
||||
|
||||
export async function getMlxAllSessions(): Promise<SessionInfo[]> {
|
||||
return await invoke('plugin:mlx|get_mlx_all_sessions')
|
||||
}
|
||||
20
src-tauri/plugins/tauri-plugin-mlx/guest-js/types.ts
Normal file
20
src-tauri/plugins/tauri-plugin-mlx/guest-js/types.ts
Normal file
@@ -0,0 +1,20 @@
|
||||
export interface SessionInfo {
|
||||
pid: number
|
||||
port: number
|
||||
model_id: string
|
||||
model_path: string
|
||||
is_embedding: boolean
|
||||
api_key: string
|
||||
}
|
||||
|
||||
export interface UnloadResult {
|
||||
success: boolean
|
||||
error?: string
|
||||
}
|
||||
|
||||
export type MlxConfig = {
|
||||
ctx_size: number
|
||||
n_predict: number
|
||||
threads: number
|
||||
chat_template: string
|
||||
}
|
||||
33
src-tauri/plugins/tauri-plugin-mlx/package.json
Normal file
33
src-tauri/plugins/tauri-plugin-mlx/package.json
Normal file
@@ -0,0 +1,33 @@
|
||||
{
|
||||
"name": "@janhq/tauri-plugin-mlx-api",
|
||||
"version": "0.6.6",
|
||||
"private": true,
|
||||
"description": "Tauri plugin API for MLX-Swift inference on Apple Silicon",
|
||||
"type": "module",
|
||||
"types": "./dist-js/index.d.ts",
|
||||
"main": "./dist-js/index.cjs",
|
||||
"module": "./dist-js/index.js",
|
||||
"exports": {
|
||||
"types": "./dist-js/index.d.ts",
|
||||
"import": "./dist-js/index.js",
|
||||
"require": "./dist-js/index.cjs"
|
||||
},
|
||||
"files": [
|
||||
"dist-js",
|
||||
"README.md"
|
||||
],
|
||||
"scripts": {
|
||||
"build": "rollup -c",
|
||||
"prepublishOnly": "yarn build",
|
||||
"pretest": "yarn build"
|
||||
},
|
||||
"dependencies": {
|
||||
"@tauri-apps/api": ">=2.0.0-beta.6"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@rollup/plugin-typescript": "^12.0.0",
|
||||
"rollup": "^4.9.6",
|
||||
"tslib": "^2.6.2",
|
||||
"typescript": "^5.3.3"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
# Automatically generated - DO NOT EDIT!
|
||||
|
||||
"$schema" = "../../schemas/schema.json"
|
||||
|
||||
[[permission]]
|
||||
identifier = "allow-cleanup-mlx-processes"
|
||||
description = "Enables the cleanup_mlx_processes command without any pre-configured scope."
|
||||
commands.allow = ["cleanup_mlx_processes"]
|
||||
|
||||
[[permission]]
|
||||
identifier = "deny-cleanup-mlx-processes"
|
||||
description = "Denies the cleanup_mlx_processes command without any pre-configured scope."
|
||||
commands.deny = ["cleanup_mlx_processes"]
|
||||
@@ -0,0 +1,13 @@
|
||||
# Automatically generated - DO NOT EDIT!
|
||||
|
||||
"$schema" = "../../schemas/schema.json"
|
||||
|
||||
[[permission]]
|
||||
identifier = "allow-find-mlx-session-by-model"
|
||||
description = "Enables the find_mlx_session_by_model command without any pre-configured scope."
|
||||
commands.allow = ["find_mlx_session_by_model"]
|
||||
|
||||
[[permission]]
|
||||
identifier = "deny-find-mlx-session-by-model"
|
||||
description = "Denies the find_mlx_session_by_model command without any pre-configured scope."
|
||||
commands.deny = ["find_mlx_session_by_model"]
|
||||
@@ -0,0 +1,13 @@
|
||||
# Automatically generated - DO NOT EDIT!
|
||||
|
||||
"$schema" = "../../schemas/schema.json"
|
||||
|
||||
[[permission]]
|
||||
identifier = "allow-get-mlx-all-sessions"
|
||||
description = "Enables the get_mlx_all_sessions command without any pre-configured scope."
|
||||
commands.allow = ["get_mlx_all_sessions"]
|
||||
|
||||
[[permission]]
|
||||
identifier = "deny-get-mlx-all-sessions"
|
||||
description = "Denies the get_mlx_all_sessions command without any pre-configured scope."
|
||||
commands.deny = ["get_mlx_all_sessions"]
|
||||
@@ -0,0 +1,13 @@
|
||||
# Automatically generated - DO NOT EDIT!
|
||||
|
||||
"$schema" = "../../schemas/schema.json"
|
||||
|
||||
[[permission]]
|
||||
identifier = "allow-get-mlx-loaded-models"
|
||||
description = "Enables the get_mlx_loaded_models command without any pre-configured scope."
|
||||
commands.allow = ["get_mlx_loaded_models"]
|
||||
|
||||
[[permission]]
|
||||
identifier = "deny-get-mlx-loaded-models"
|
||||
description = "Denies the get_mlx_loaded_models command without any pre-configured scope."
|
||||
commands.deny = ["get_mlx_loaded_models"]
|
||||
@@ -0,0 +1,13 @@
|
||||
# Automatically generated - DO NOT EDIT!
|
||||
|
||||
"$schema" = "../../schemas/schema.json"
|
||||
|
||||
[[permission]]
|
||||
identifier = "allow-get-mlx-random-port"
|
||||
description = "Enables the get_mlx_random_port command without any pre-configured scope."
|
||||
commands.allow = ["get_mlx_random_port"]
|
||||
|
||||
[[permission]]
|
||||
identifier = "deny-get-mlx-random-port"
|
||||
description = "Denies the get_mlx_random_port command without any pre-configured scope."
|
||||
commands.deny = ["get_mlx_random_port"]
|
||||
@@ -0,0 +1,13 @@
|
||||
# Automatically generated - DO NOT EDIT!
|
||||
|
||||
"$schema" = "../../schemas/schema.json"
|
||||
|
||||
[[permission]]
|
||||
identifier = "allow-is-mlx-process-running"
|
||||
description = "Enables the is_mlx_process_running command without any pre-configured scope."
|
||||
commands.allow = ["is_mlx_process_running"]
|
||||
|
||||
[[permission]]
|
||||
identifier = "deny-is-mlx-process-running"
|
||||
description = "Denies the is_mlx_process_running command without any pre-configured scope."
|
||||
commands.deny = ["is_mlx_process_running"]
|
||||
@@ -0,0 +1,13 @@
|
||||
# Automatically generated - DO NOT EDIT!
|
||||
|
||||
"$schema" = "../../schemas/schema.json"
|
||||
|
||||
[[permission]]
|
||||
identifier = "allow-load-mlx-model"
|
||||
description = "Enables the load_mlx_model command without any pre-configured scope."
|
||||
commands.allow = ["load_mlx_model"]
|
||||
|
||||
[[permission]]
|
||||
identifier = "deny-load-mlx-model"
|
||||
description = "Denies the load_mlx_model command without any pre-configured scope."
|
||||
commands.deny = ["load_mlx_model"]
|
||||
@@ -0,0 +1,13 @@
|
||||
# Automatically generated - DO NOT EDIT!
|
||||
|
||||
"$schema" = "../../schemas/schema.json"
|
||||
|
||||
[[permission]]
|
||||
identifier = "allow-unload-mlx-model"
|
||||
description = "Enables the unload_mlx_model command without any pre-configured scope."
|
||||
commands.allow = ["unload_mlx_model"]
|
||||
|
||||
[[permission]]
|
||||
identifier = "deny-unload-mlx-model"
|
||||
description = "Denies the unload_mlx_model command without any pre-configured scope."
|
||||
commands.deny = ["unload_mlx_model"]
|
||||
@@ -0,0 +1,232 @@
|
||||
## Default Permission
|
||||
|
||||
Default permissions for the MLX plugin
|
||||
|
||||
#### This default permission set includes the following:
|
||||
|
||||
- `allow-cleanup-mlx-processes`
|
||||
- `allow-load-mlx-model`
|
||||
- `allow-unload-mlx-model`
|
||||
- `allow-is-mlx-process-running`
|
||||
- `allow-get-mlx-random-port`
|
||||
- `allow-find-mlx-session-by-model`
|
||||
- `allow-get-mlx-loaded-models`
|
||||
- `allow-get-mlx-all-sessions`
|
||||
|
||||
## Permission Table
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th>Identifier</th>
|
||||
<th>Description</th>
|
||||
</tr>
|
||||
|
||||
|
||||
<tr>
|
||||
<td>
|
||||
|
||||
`mlx:allow-cleanup-mlx-processes`
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
Enables the cleanup_mlx_processes command without any pre-configured scope.
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>
|
||||
|
||||
`mlx:deny-cleanup-mlx-processes`
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
Denies the cleanup_mlx_processes command without any pre-configured scope.
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>
|
||||
|
||||
`mlx:allow-find-mlx-session-by-model`
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
Enables the find_mlx_session_by_model command without any pre-configured scope.
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>
|
||||
|
||||
`mlx:deny-find-mlx-session-by-model`
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
Denies the find_mlx_session_by_model command without any pre-configured scope.
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>
|
||||
|
||||
`mlx:allow-get-mlx-all-sessions`
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
Enables the get_mlx_all_sessions command without any pre-configured scope.
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>
|
||||
|
||||
`mlx:deny-get-mlx-all-sessions`
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
Denies the get_mlx_all_sessions command without any pre-configured scope.
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>
|
||||
|
||||
`mlx:allow-get-mlx-loaded-models`
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
Enables the get_mlx_loaded_models command without any pre-configured scope.
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>
|
||||
|
||||
`mlx:deny-get-mlx-loaded-models`
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
Denies the get_mlx_loaded_models command without any pre-configured scope.
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>
|
||||
|
||||
`mlx:allow-get-mlx-random-port`
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
Enables the get_mlx_random_port command without any pre-configured scope.
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>
|
||||
|
||||
`mlx:deny-get-mlx-random-port`
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
Denies the get_mlx_random_port command without any pre-configured scope.
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>
|
||||
|
||||
`mlx:allow-is-mlx-process-running`
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
Enables the is_mlx_process_running command without any pre-configured scope.
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>
|
||||
|
||||
`mlx:deny-is-mlx-process-running`
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
Denies the is_mlx_process_running command without any pre-configured scope.
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>
|
||||
|
||||
`mlx:allow-load-mlx-model`
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
Enables the load_mlx_model command without any pre-configured scope.
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>
|
||||
|
||||
`mlx:deny-load-mlx-model`
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
Denies the load_mlx_model command without any pre-configured scope.
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>
|
||||
|
||||
`mlx:allow-unload-mlx-model`
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
Enables the unload_mlx_model command without any pre-configured scope.
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td>
|
||||
|
||||
`mlx:deny-unload-mlx-model`
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
Denies the unload_mlx_model command without any pre-configured scope.
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
12
src-tauri/plugins/tauri-plugin-mlx/permissions/default.toml
Normal file
12
src-tauri/plugins/tauri-plugin-mlx/permissions/default.toml
Normal file
@@ -0,0 +1,12 @@
|
||||
[default]
|
||||
description = "Default permissions for the MLX plugin"
|
||||
permissions = [
|
||||
"allow-cleanup-mlx-processes",
|
||||
"allow-load-mlx-model",
|
||||
"allow-unload-mlx-model",
|
||||
"allow-is-mlx-process-running",
|
||||
"allow-get-mlx-random-port",
|
||||
"allow-find-mlx-session-by-model",
|
||||
"allow-get-mlx-loaded-models",
|
||||
"allow-get-mlx-all-sessions",
|
||||
]
|
||||
@@ -0,0 +1,402 @@
|
||||
{
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"title": "PermissionFile",
|
||||
"description": "Permission file that can define a default permission, a set of permissions or a list of inlined permissions.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"default": {
|
||||
"description": "The default permission set for the plugin",
|
||||
"anyOf": [
|
||||
{
|
||||
"$ref": "#/definitions/DefaultPermission"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
]
|
||||
},
|
||||
"set": {
|
||||
"description": "A list of permissions sets defined",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/PermissionSet"
|
||||
}
|
||||
},
|
||||
"permission": {
|
||||
"description": "A list of inlined permissions",
|
||||
"default": [],
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/Permission"
|
||||
}
|
||||
}
|
||||
},
|
||||
"definitions": {
|
||||
"DefaultPermission": {
|
||||
"description": "The default permission set of the plugin.\n\nWorks similarly to a permission with the \"default\" identifier.",
|
||||
"type": "object",
|
||||
"required": [
|
||||
"permissions"
|
||||
],
|
||||
"properties": {
|
||||
"version": {
|
||||
"description": "The version of the permission.",
|
||||
"type": [
|
||||
"integer",
|
||||
"null"
|
||||
],
|
||||
"format": "uint64",
|
||||
"minimum": 1.0
|
||||
},
|
||||
"description": {
|
||||
"description": "Human-readable description of what the permission does. Tauri convention is to use `<h4>` headings in markdown content for Tauri documentation generation purposes.",
|
||||
"type": [
|
||||
"string",
|
||||
"null"
|
||||
]
|
||||
},
|
||||
"permissions": {
|
||||
"description": "All permissions this set contains.",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"PermissionSet": {
|
||||
"description": "A set of direct permissions grouped together under a new name.",
|
||||
"type": "object",
|
||||
"required": [
|
||||
"description",
|
||||
"identifier",
|
||||
"permissions"
|
||||
],
|
||||
"properties": {
|
||||
"identifier": {
|
||||
"description": "A unique identifier for the permission.",
|
||||
"type": "string"
|
||||
},
|
||||
"description": {
|
||||
"description": "Human-readable description of what the permission does.",
|
||||
"type": "string"
|
||||
},
|
||||
"permissions": {
|
||||
"description": "All permissions this set contains.",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/PermissionKind"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"Permission": {
|
||||
"description": "Descriptions of explicit privileges of commands.\n\nIt can enable commands to be accessible in the frontend of the application.\n\nIf the scope is defined it can be used to fine grain control the access of individual or multiple commands.",
|
||||
"type": "object",
|
||||
"required": [
|
||||
"identifier"
|
||||
],
|
||||
"properties": {
|
||||
"version": {
|
||||
"description": "The version of the permission.",
|
||||
"type": [
|
||||
"integer",
|
||||
"null"
|
||||
],
|
||||
"format": "uint64",
|
||||
"minimum": 1.0
|
||||
},
|
||||
"identifier": {
|
||||
"description": "A unique identifier for the permission.",
|
||||
"type": "string"
|
||||
},
|
||||
"description": {
|
||||
"description": "Human-readable description of what the permission does. Tauri internal convention is to use `<h4>` headings in markdown content for Tauri documentation generation purposes.",
|
||||
"type": [
|
||||
"string",
|
||||
"null"
|
||||
]
|
||||
},
|
||||
"commands": {
|
||||
"description": "Allowed or denied commands when using this permission.",
|
||||
"default": {
|
||||
"allow": [],
|
||||
"deny": []
|
||||
},
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/Commands"
|
||||
}
|
||||
]
|
||||
},
|
||||
"scope": {
|
||||
"description": "Allowed or denied scoped when using this permission.",
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/Scopes"
|
||||
}
|
||||
]
|
||||
},
|
||||
"platforms": {
|
||||
"description": "Target platforms this permission applies. By default all platforms are affected by this permission.",
|
||||
"type": [
|
||||
"array",
|
||||
"null"
|
||||
],
|
||||
"items": {
|
||||
"$ref": "#/definitions/Target"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"Commands": {
|
||||
"description": "Allowed and denied commands inside a permission.\n\nIf two commands clash inside of `allow` and `deny`, it should be denied by default.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"allow": {
|
||||
"description": "Allowed command.",
|
||||
"default": [],
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"deny": {
|
||||
"description": "Denied command, which takes priority.",
|
||||
"default": [],
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"Scopes": {
|
||||
"description": "An argument for fine grained behavior control of Tauri commands.\n\nIt can be of any serde serializable type and is used to allow or prevent certain actions inside a Tauri command. The configured scope is passed to the command and will be enforced by the command implementation.\n\n## Example\n\n```json { \"allow\": [{ \"path\": \"$HOME/**\" }], \"deny\": [{ \"path\": \"$HOME/secret.txt\" }] } ```",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"allow": {
|
||||
"description": "Data that defines what is allowed by the scope.",
|
||||
"type": [
|
||||
"array",
|
||||
"null"
|
||||
],
|
||||
"items": {
|
||||
"$ref": "#/definitions/Value"
|
||||
}
|
||||
},
|
||||
"deny": {
|
||||
"description": "Data that defines what is denied by the scope. This should be prioritized by validation logic.",
|
||||
"type": [
|
||||
"array",
|
||||
"null"
|
||||
],
|
||||
"items": {
|
||||
"$ref": "#/definitions/Value"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"Value": {
|
||||
"description": "All supported ACL values.",
|
||||
"anyOf": [
|
||||
{
|
||||
"description": "Represents a null JSON value.",
|
||||
"type": "null"
|
||||
},
|
||||
{
|
||||
"description": "Represents a [`bool`].",
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"description": "Represents a valid ACL [`Number`].",
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/Number"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "Represents a [`String`].",
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"description": "Represents a list of other [`Value`]s.",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/Value"
|
||||
}
|
||||
},
|
||||
{
|
||||
"description": "Represents a map of [`String`] keys to [`Value`]s.",
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"$ref": "#/definitions/Value"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"Number": {
|
||||
"description": "A valid ACL number.",
|
||||
"anyOf": [
|
||||
{
|
||||
"description": "Represents an [`i64`].",
|
||||
"type": "integer",
|
||||
"format": "int64"
|
||||
},
|
||||
{
|
||||
"description": "Represents a [`f64`].",
|
||||
"type": "number",
|
||||
"format": "double"
|
||||
}
|
||||
]
|
||||
},
|
||||
"Target": {
|
||||
"description": "Platform target.",
|
||||
"oneOf": [
|
||||
{
|
||||
"description": "MacOS.",
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"macOS"
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "Windows.",
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"windows"
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "Linux.",
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "Android.",
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"android"
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "iOS.",
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"iOS"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"PermissionKind": {
|
||||
"type": "string",
|
||||
"oneOf": [
|
||||
{
|
||||
"description": "Enables the cleanup_mlx_processes command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "allow-cleanup-mlx-processes",
|
||||
"markdownDescription": "Enables the cleanup_mlx_processes command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the cleanup_mlx_processes command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "deny-cleanup-mlx-processes",
|
||||
"markdownDescription": "Denies the cleanup_mlx_processes command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the find_mlx_session_by_model command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "allow-find-mlx-session-by-model",
|
||||
"markdownDescription": "Enables the find_mlx_session_by_model command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the find_mlx_session_by_model command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "deny-find-mlx-session-by-model",
|
||||
"markdownDescription": "Denies the find_mlx_session_by_model command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the get_mlx_all_sessions command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "allow-get-mlx-all-sessions",
|
||||
"markdownDescription": "Enables the get_mlx_all_sessions command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the get_mlx_all_sessions command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "deny-get-mlx-all-sessions",
|
||||
"markdownDescription": "Denies the get_mlx_all_sessions command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the get_mlx_loaded_models command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "allow-get-mlx-loaded-models",
|
||||
"markdownDescription": "Enables the get_mlx_loaded_models command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the get_mlx_loaded_models command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "deny-get-mlx-loaded-models",
|
||||
"markdownDescription": "Denies the get_mlx_loaded_models command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the get_mlx_random_port command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "allow-get-mlx-random-port",
|
||||
"markdownDescription": "Enables the get_mlx_random_port command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the get_mlx_random_port command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "deny-get-mlx-random-port",
|
||||
"markdownDescription": "Denies the get_mlx_random_port command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the is_mlx_process_running command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "allow-is-mlx-process-running",
|
||||
"markdownDescription": "Enables the is_mlx_process_running command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the is_mlx_process_running command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "deny-is-mlx-process-running",
|
||||
"markdownDescription": "Denies the is_mlx_process_running command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the load_mlx_model command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "allow-load-mlx-model",
|
||||
"markdownDescription": "Enables the load_mlx_model command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the load_mlx_model command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "deny-load-mlx-model",
|
||||
"markdownDescription": "Denies the load_mlx_model command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Enables the unload_mlx_model command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "allow-unload-mlx-model",
|
||||
"markdownDescription": "Enables the unload_mlx_model command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Denies the unload_mlx_model command without any pre-configured scope.",
|
||||
"type": "string",
|
||||
"const": "deny-unload-mlx-model",
|
||||
"markdownDescription": "Denies the unload_mlx_model command without any pre-configured scope."
|
||||
},
|
||||
{
|
||||
"description": "Default permissions for the MLX plugin\n#### This default permission set includes:\n\n- `allow-cleanup-mlx-processes`\n- `allow-load-mlx-model`\n- `allow-unload-mlx-model`\n- `allow-is-mlx-process-running`\n- `allow-get-mlx-random-port`\n- `allow-find-mlx-session-by-model`\n- `allow-get-mlx-loaded-models`\n- `allow-get-mlx-all-sessions`",
|
||||
"type": "string",
|
||||
"const": "default",
|
||||
"markdownDescription": "Default permissions for the MLX plugin\n#### This default permission set includes:\n\n- `allow-cleanup-mlx-processes`\n- `allow-load-mlx-model`\n- `allow-unload-mlx-model`\n- `allow-is-mlx-process-running`\n- `allow-get-mlx-random-port`\n- `allow-find-mlx-session-by-model`\n- `allow-get-mlx-loaded-models`\n- `allow-get-mlx-all-sessions`"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
31
src-tauri/plugins/tauri-plugin-mlx/rollup.config.js
Normal file
31
src-tauri/plugins/tauri-plugin-mlx/rollup.config.js
Normal file
@@ -0,0 +1,31 @@
|
||||
import { readFileSync } from 'node:fs'
|
||||
import { dirname, join } from 'node:path'
|
||||
import { cwd } from 'node:process'
|
||||
import typescript from '@rollup/plugin-typescript'
|
||||
|
||||
const pkg = JSON.parse(readFileSync(join(cwd(), 'package.json'), 'utf8'))
|
||||
|
||||
export default {
|
||||
input: 'guest-js/index.ts',
|
||||
output: [
|
||||
{
|
||||
file: pkg.exports.import,
|
||||
format: 'esm'
|
||||
},
|
||||
{
|
||||
file: pkg.exports.require,
|
||||
format: 'cjs'
|
||||
}
|
||||
],
|
||||
plugins: [
|
||||
typescript({
|
||||
declaration: true,
|
||||
declarationDir: dirname(pkg.exports.import)
|
||||
})
|
||||
],
|
||||
external: [
|
||||
/^@tauri-apps\/api/,
|
||||
...Object.keys(pkg.dependencies || {}),
|
||||
...Object.keys(pkg.peerDependencies || {})
|
||||
]
|
||||
}
|
||||
59
src-tauri/plugins/tauri-plugin-mlx/src/cleanup.rs
Normal file
59
src-tauri/plugins/tauri-plugin-mlx/src/cleanup.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
use tauri::{Manager, Runtime};
|
||||
|
||||
pub async fn cleanup_processes<R: Runtime>(app_handle: &tauri::AppHandle<R>) {
|
||||
let app_state = match app_handle.try_state::<crate::state::MlxState>() {
|
||||
Some(state) => state,
|
||||
None => {
|
||||
log::warn!("MlxState not found in app_handle");
|
||||
return;
|
||||
}
|
||||
};
|
||||
let mut map = app_state.mlx_server_process.lock().await;
|
||||
let pids: Vec<i32> = map.keys().cloned().collect();
|
||||
for pid in pids {
|
||||
if let Some(session) = map.remove(&pid) {
|
||||
let mut child = session.child;
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use nix::sys::signal::{kill, Signal};
|
||||
use nix::unistd::Pid;
|
||||
use tokio::time::{timeout, Duration};
|
||||
|
||||
if let Some(raw_pid) = child.id() {
|
||||
let raw_pid = raw_pid as i32;
|
||||
log::info!("Sending SIGTERM to MLX PID {} during shutdown", raw_pid);
|
||||
let _ = kill(Pid::from_raw(raw_pid), Signal::SIGTERM);
|
||||
|
||||
match timeout(Duration::from_secs(2), child.wait()).await {
|
||||
Ok(Ok(status)) => {
|
||||
log::info!("MLX process {} exited gracefully: {}", raw_pid, status)
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
log::error!(
|
||||
"Error waiting after SIGTERM for MLX process {}: {}",
|
||||
raw_pid,
|
||||
e
|
||||
)
|
||||
}
|
||||
Err(_) => {
|
||||
log::warn!(
|
||||
"SIGTERM timed out for MLX PID {}; sending SIGKILL",
|
||||
raw_pid
|
||||
);
|
||||
let _ = kill(Pid::from_raw(raw_pid), Signal::SIGKILL);
|
||||
let _ = child.wait().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn cleanup_mlx_processes<R: Runtime>(
|
||||
app_handle: tauri::AppHandle<R>,
|
||||
) -> Result<(), String> {
|
||||
cleanup_processes(&app_handle).await;
|
||||
Ok(())
|
||||
}
|
||||
360
src-tauri/plugins/tauri-plugin-mlx/src/commands.rs
Normal file
360
src-tauri/plugins/tauri-plugin-mlx/src/commands.rs
Normal file
@@ -0,0 +1,360 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Stdio;
|
||||
use std::time::Duration;
|
||||
use tauri::{Manager, Runtime, State};
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
use tokio::process::Command;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::Instant;
|
||||
|
||||
use crate::error::{ErrorCode, MlxError, ServerError, ServerResult};
|
||||
use crate::process::{
|
||||
find_session_by_model_id, get_all_active_sessions, get_all_loaded_model_ids,
|
||||
get_random_available_port, is_process_running_by_pid,
|
||||
};
|
||||
use crate::state::{MlxBackendSession, MlxState, SessionInfo};
|
||||
|
||||
#[cfg(unix)]
|
||||
use crate::process::graceful_terminate_process;
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
pub struct UnloadResult {
|
||||
success: bool,
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
/// MLX server configuration passed from the frontend
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct MlxConfig {
|
||||
#[serde(default)]
|
||||
pub ctx_size: i32,
|
||||
#[serde(default)]
|
||||
pub n_predict: i32,
|
||||
#[serde(default)]
|
||||
pub threads: i32,
|
||||
#[serde(default)]
|
||||
pub chat_template: String,
|
||||
}
|
||||
|
||||
/// Load a model using the MLX server binary
|
||||
#[tauri::command]
|
||||
pub async fn load_mlx_model<R: Runtime>(
|
||||
app_handle: tauri::AppHandle<R>,
|
||||
binary_path: String,
|
||||
model_id: String,
|
||||
model_path: String,
|
||||
port: u16,
|
||||
config: MlxConfig,
|
||||
envs: HashMap<String, String>,
|
||||
is_embedding: bool,
|
||||
timeout: u64,
|
||||
) -> ServerResult<SessionInfo> {
|
||||
let state: State<MlxState> = app_handle.state();
|
||||
let mut process_map = state.mlx_server_process.lock().await;
|
||||
|
||||
log::info!("Attempting to launch MLX server at path: {:?}", binary_path);
|
||||
log::info!("Using MLX configuration: {:?}", config);
|
||||
|
||||
// Validate binary path
|
||||
let bin_path = PathBuf::from(&binary_path);
|
||||
if !bin_path.exists() {
|
||||
return Err(MlxError::new(
|
||||
ErrorCode::BinaryNotFound,
|
||||
format!("MLX server binary not found at: {}", binary_path),
|
||||
None,
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
// Validate model path
|
||||
let model_path_pb = PathBuf::from(&model_path);
|
||||
if !model_path_pb.exists() {
|
||||
return Err(MlxError::new(
|
||||
ErrorCode::ModelFileNotFound,
|
||||
format!("Model file not found at: {}", model_path),
|
||||
None,
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
let api_key: String = envs
|
||||
.get("MLX_API_KEY")
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| {
|
||||
log::warn!("API key not provided for MLX server");
|
||||
String::new()
|
||||
});
|
||||
|
||||
// Build command arguments
|
||||
let mut args: Vec<String> = vec![
|
||||
"--model".to_string(),
|
||||
model_path.clone(),
|
||||
"--port".to_string(),
|
||||
port.to_string(),
|
||||
];
|
||||
|
||||
if config.ctx_size > 0 {
|
||||
args.push("--ctx-size".to_string());
|
||||
args.push(config.ctx_size.to_string());
|
||||
}
|
||||
|
||||
if !api_key.is_empty() {
|
||||
args.push("--api-key".to_string());
|
||||
args.push(api_key.clone());
|
||||
}
|
||||
|
||||
if !config.chat_template.is_empty() {
|
||||
args.push("--chat-template".to_string());
|
||||
args.push(config.chat_template.clone());
|
||||
}
|
||||
|
||||
if is_embedding {
|
||||
args.push("--embedding".to_string());
|
||||
}
|
||||
|
||||
log::info!("MLX server arguments: {:?}", args);
|
||||
|
||||
// Configure the command
|
||||
let mut command = Command::new(&bin_path);
|
||||
command.args(&args);
|
||||
command.envs(envs);
|
||||
command.stdout(Stdio::piped());
|
||||
command.stderr(Stdio::piped());
|
||||
|
||||
// Spawn the child process
|
||||
let mut child = command.spawn().map_err(ServerError::Io)?;
|
||||
|
||||
let stderr = child.stderr.take().expect("stderr was piped");
|
||||
let stdout = child.stdout.take().expect("stdout was piped");
|
||||
|
||||
// Create channels for communication between tasks
|
||||
let (ready_tx, mut ready_rx) = mpsc::channel::<bool>(1);
|
||||
|
||||
// Spawn task to monitor stdout for readiness
|
||||
let stdout_ready_tx = ready_tx.clone();
|
||||
let _stdout_task = tokio::spawn(async move {
|
||||
let mut reader = BufReader::new(stdout);
|
||||
let mut byte_buffer = Vec::new();
|
||||
|
||||
loop {
|
||||
byte_buffer.clear();
|
||||
match reader.read_until(b'\n', &mut byte_buffer).await {
|
||||
Ok(0) => break,
|
||||
Ok(_) => {
|
||||
let line = String::from_utf8_lossy(&byte_buffer);
|
||||
let line = line.trim_end();
|
||||
if !line.is_empty() {
|
||||
log::info!("[mlx stdout] {}", line);
|
||||
}
|
||||
|
||||
let line_lower = line.to_lowercase();
|
||||
if line_lower.contains("http server listening")
|
||||
|| line_lower.contains("server is listening")
|
||||
|| line_lower.contains("server started")
|
||||
|| line_lower.contains("ready to accept")
|
||||
|| line_lower.contains("server started and listening on")
|
||||
{
|
||||
log::info!(
|
||||
"MLX server appears to be ready based on stdout: '{}'",
|
||||
line
|
||||
);
|
||||
let _ = stdout_ready_tx.send(true).await;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Error reading MLX stdout: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Spawn task to capture stderr and monitor for errors
|
||||
let stderr_task = tokio::spawn(async move {
|
||||
let mut reader = BufReader::new(stderr);
|
||||
let mut byte_buffer = Vec::new();
|
||||
let mut stderr_buffer = String::new();
|
||||
|
||||
loop {
|
||||
byte_buffer.clear();
|
||||
match reader.read_until(b'\n', &mut byte_buffer).await {
|
||||
Ok(0) => break,
|
||||
Ok(_) => {
|
||||
let line = String::from_utf8_lossy(&byte_buffer);
|
||||
let line = line.trim_end();
|
||||
|
||||
if !line.is_empty() {
|
||||
stderr_buffer.push_str(line);
|
||||
stderr_buffer.push('\n');
|
||||
log::info!("[mlx] {}", line);
|
||||
|
||||
let line_lower = line.to_lowercase();
|
||||
if line_lower.contains("server is listening")
|
||||
|| line_lower.contains("server listening on")
|
||||
|| line_lower.contains("server started and listening on")
|
||||
{
|
||||
log::info!(
|
||||
"MLX model appears to be ready based on logs: '{}'",
|
||||
line
|
||||
);
|
||||
let _ = ready_tx.send(true).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Error reading MLX logs: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stderr_buffer
|
||||
});
|
||||
|
||||
// Check if process exited early
|
||||
if let Some(status) = child.try_wait()? {
|
||||
if !status.success() {
|
||||
let stderr_output = stderr_task.await.unwrap_or_default();
|
||||
log::error!("MLX server failed early with code {:?}", status);
|
||||
log::error!("{}", stderr_output);
|
||||
return Err(MlxError::from_stderr(&stderr_output).into());
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for server to be ready or timeout
|
||||
let timeout_duration = Duration::from_secs(timeout);
|
||||
let start_time = Instant::now();
|
||||
log::info!("Waiting for MLX model session to be ready...");
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
Some(true) = ready_rx.recv() => {
|
||||
log::info!("MLX model is ready to accept requests!");
|
||||
break;
|
||||
}
|
||||
_ = tokio::time::sleep(Duration::from_millis(50)) => {
|
||||
if let Some(status) = child.try_wait()? {
|
||||
let stderr_output = stderr_task.await.unwrap_or_default();
|
||||
if !status.success() {
|
||||
log::error!("MLX server exited with error code {:?}", status);
|
||||
return Err(MlxError::from_stderr(&stderr_output).into());
|
||||
} else {
|
||||
log::error!("MLX server exited successfully but without ready signal");
|
||||
return Err(MlxError::from_stderr(&stderr_output).into());
|
||||
}
|
||||
}
|
||||
|
||||
if start_time.elapsed() > timeout_duration {
|
||||
log::error!("Timeout waiting for MLX server to be ready");
|
||||
let _ = child.kill().await;
|
||||
let stderr_output = stderr_task.await.unwrap_or_default();
|
||||
return Err(MlxError::new(
|
||||
ErrorCode::ModelLoadTimedOut,
|
||||
"The MLX model took too long to load and timed out.".into(),
|
||||
Some(format!(
|
||||
"Timeout: {}s\n\nStderr:\n{}",
|
||||
timeout_duration.as_secs(),
|
||||
stderr_output
|
||||
)),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let pid = child.id().map(|id| id as i32).unwrap_or(-1);
|
||||
|
||||
log::info!("MLX server process started with PID: {} and is ready", pid);
|
||||
let session_info = SessionInfo {
|
||||
pid,
|
||||
port: port.into(),
|
||||
model_id,
|
||||
model_path: model_path_pb.display().to_string(),
|
||||
is_embedding,
|
||||
api_key,
|
||||
};
|
||||
|
||||
process_map.insert(
|
||||
pid,
|
||||
MlxBackendSession {
|
||||
child,
|
||||
info: session_info.clone(),
|
||||
},
|
||||
);
|
||||
|
||||
Ok(session_info)
|
||||
}
|
||||
|
||||
/// Unload an MLX model by terminating its process
|
||||
#[tauri::command]
|
||||
pub async fn unload_mlx_model<R: Runtime>(
|
||||
app_handle: tauri::AppHandle<R>,
|
||||
pid: i32,
|
||||
) -> ServerResult<UnloadResult> {
|
||||
let state: State<MlxState> = app_handle.state();
|
||||
let mut map = state.mlx_server_process.lock().await;
|
||||
|
||||
if let Some(session) = map.remove(&pid) {
|
||||
let mut child = session.child;
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
graceful_terminate_process(&mut child).await;
|
||||
}
|
||||
|
||||
Ok(UnloadResult {
|
||||
success: true,
|
||||
error: None,
|
||||
})
|
||||
} else {
|
||||
log::warn!("No MLX server with PID '{}' found", pid);
|
||||
Ok(UnloadResult {
|
||||
success: true,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a process is still running
|
||||
#[tauri::command]
|
||||
pub async fn is_mlx_process_running<R: Runtime>(
|
||||
app_handle: tauri::AppHandle<R>,
|
||||
pid: i32,
|
||||
) -> Result<bool, String> {
|
||||
is_process_running_by_pid(app_handle, pid).await
|
||||
}
|
||||
|
||||
/// Get a random available port
|
||||
#[tauri::command]
|
||||
pub async fn get_mlx_random_port<R: Runtime>(
|
||||
app_handle: tauri::AppHandle<R>,
|
||||
) -> Result<u16, String> {
|
||||
get_random_available_port(app_handle).await
|
||||
}
|
||||
|
||||
/// Find session information by model ID
|
||||
#[tauri::command]
|
||||
pub async fn find_mlx_session_by_model<R: Runtime>(
|
||||
app_handle: tauri::AppHandle<R>,
|
||||
model_id: String,
|
||||
) -> Result<Option<SessionInfo>, String> {
|
||||
find_session_by_model_id(app_handle, &model_id).await
|
||||
}
|
||||
|
||||
/// Get all loaded model IDs
|
||||
#[tauri::command]
|
||||
pub async fn get_mlx_loaded_models<R: Runtime>(
|
||||
app_handle: tauri::AppHandle<R>,
|
||||
) -> Result<Vec<String>, String> {
|
||||
get_all_loaded_model_ids(app_handle).await
|
||||
}
|
||||
|
||||
/// Get all active sessions
|
||||
#[tauri::command]
|
||||
pub async fn get_mlx_all_sessions<R: Runtime>(
|
||||
app_handle: tauri::AppHandle<R>,
|
||||
) -> Result<Vec<SessionInfo>, String> {
|
||||
get_all_active_sessions(app_handle).await
|
||||
}
|
||||
91
src-tauri/plugins/tauri-plugin-mlx/src/error.rs
Normal file
91
src-tauri/plugins/tauri-plugin-mlx/src/error.rs
Normal file
@@ -0,0 +1,91 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
||||
pub enum ErrorCode {
|
||||
BinaryNotFound,
|
||||
ModelFileNotFound,
|
||||
ModelLoadFailed,
|
||||
ModelLoadTimedOut,
|
||||
OutOfMemory,
|
||||
MlxProcessError,
|
||||
IoError,
|
||||
InternalError,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, thiserror::Error)]
|
||||
#[error("MlxError {{ code: {code:?}, message: \"{message}\" }}")]
|
||||
pub struct MlxError {
|
||||
pub code: ErrorCode,
|
||||
pub message: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub details: Option<String>,
|
||||
}
|
||||
|
||||
impl MlxError {
|
||||
pub fn new(code: ErrorCode, message: String, details: Option<String>) -> Self {
|
||||
Self {
|
||||
code,
|
||||
message,
|
||||
details,
|
||||
}
|
||||
}
|
||||
|
||||
/// Parses stderr from the MLX server and creates a specific MlxError.
|
||||
pub fn from_stderr(stderr: &str) -> Self {
|
||||
let lower_stderr = stderr.to_lowercase();
|
||||
|
||||
if lower_stderr.contains("out of memory")
|
||||
|| lower_stderr.contains("failed to allocate")
|
||||
|| lower_stderr.contains("insufficient memory")
|
||||
{
|
||||
return Self::new(
|
||||
ErrorCode::OutOfMemory,
|
||||
"Out of memory. The model requires more RAM than available.".into(),
|
||||
Some(stderr.into()),
|
||||
);
|
||||
}
|
||||
|
||||
Self::new(
|
||||
ErrorCode::MlxProcessError,
|
||||
"The MLX model process encountered an unexpected error.".into(),
|
||||
Some(stderr.into()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ServerError {
|
||||
#[error(transparent)]
|
||||
Mlx(#[from] MlxError),
|
||||
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("Tauri error: {0}")]
|
||||
Tauri(#[from] tauri::Error),
|
||||
}
|
||||
|
||||
impl serde::Serialize for ServerError {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let error_to_serialize: MlxError = match self {
|
||||
ServerError::Mlx(err) => err.clone(),
|
||||
ServerError::Io(e) => MlxError::new(
|
||||
ErrorCode::IoError,
|
||||
"An input/output error occurred.".into(),
|
||||
Some(e.to_string()),
|
||||
),
|
||||
ServerError::Tauri(e) => MlxError::new(
|
||||
ErrorCode::InternalError,
|
||||
"An internal application error occurred.".into(),
|
||||
Some(e.to_string()),
|
||||
),
|
||||
};
|
||||
error_to_serialize.serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
pub type ServerResult<T> = Result<T, ServerError>;
|
||||
32
src-tauri/plugins/tauri-plugin-mlx/src/lib.rs
Normal file
32
src-tauri/plugins/tauri-plugin-mlx/src/lib.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use tauri::{
|
||||
plugin::{Builder, TauriPlugin},
|
||||
Manager, Runtime,
|
||||
};
|
||||
|
||||
pub mod cleanup;
|
||||
mod commands;
|
||||
mod error;
|
||||
mod process;
|
||||
pub mod state;
|
||||
|
||||
pub use cleanup::cleanup_mlx_processes;
|
||||
|
||||
/// Initializes the MLX plugin.
|
||||
pub fn init<R: Runtime>() -> TauriPlugin<R> {
|
||||
Builder::new("mlx")
|
||||
.invoke_handler(tauri::generate_handler![
|
||||
cleanup::cleanup_mlx_processes,
|
||||
commands::load_mlx_model,
|
||||
commands::unload_mlx_model,
|
||||
commands::is_mlx_process_running,
|
||||
commands::get_mlx_random_port,
|
||||
commands::find_mlx_session_by_model,
|
||||
commands::get_mlx_loaded_models,
|
||||
commands::get_mlx_all_sessions,
|
||||
])
|
||||
.setup(|app, _api| {
|
||||
app.manage(state::MlxState::new());
|
||||
Ok(())
|
||||
})
|
||||
.build()
|
||||
}
|
||||
120
src-tauri/plugins/tauri-plugin-mlx/src/process.rs
Normal file
120
src-tauri/plugins/tauri-plugin-mlx/src/process.rs
Normal file
@@ -0,0 +1,120 @@
|
||||
use std::collections::HashSet;
|
||||
use sysinfo::{Pid, System};
|
||||
use tauri::{Manager, Runtime, State};
|
||||
|
||||
use crate::state::{MlxState, SessionInfo};
|
||||
use jan_utils::generate_random_port;
|
||||
|
||||
/// Check if a process is running by PID
|
||||
pub async fn is_process_running_by_pid<R: Runtime>(
|
||||
app_handle: tauri::AppHandle<R>,
|
||||
pid: i32,
|
||||
) -> Result<bool, String> {
|
||||
let mut system = System::new();
|
||||
system.refresh_processes(sysinfo::ProcessesToUpdate::All, true);
|
||||
let process_pid = Pid::from(pid as usize);
|
||||
let alive = system.process(process_pid).is_some();
|
||||
|
||||
if !alive {
|
||||
let state: State<MlxState> = app_handle.state();
|
||||
let mut map = state.mlx_server_process.lock().await;
|
||||
map.remove(&pid);
|
||||
}
|
||||
|
||||
Ok(alive)
|
||||
}
|
||||
|
||||
/// Get a random available port, avoiding ports used by existing sessions
|
||||
pub async fn get_random_available_port<R: Runtime>(
|
||||
app_handle: tauri::AppHandle<R>,
|
||||
) -> Result<u16, String> {
|
||||
let state: State<MlxState> = app_handle.state();
|
||||
let map = state.mlx_server_process.lock().await;
|
||||
|
||||
let used_ports: HashSet<u16> = map
|
||||
.values()
|
||||
.filter_map(|session| {
|
||||
if session.info.port > 0 && session.info.port <= u16::MAX as i32 {
|
||||
Some(session.info.port as u16)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
drop(map);
|
||||
|
||||
generate_random_port(&used_ports)
|
||||
}
|
||||
|
||||
/// Gracefully terminate a process on Unix systems (macOS)
|
||||
#[cfg(unix)]
|
||||
pub async fn graceful_terminate_process(child: &mut tokio::process::Child) {
|
||||
use nix::sys::signal::{kill, Signal};
|
||||
use nix::unistd::Pid;
|
||||
use std::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
|
||||
if let Some(raw_pid) = child.id() {
|
||||
let raw_pid = raw_pid as i32;
|
||||
log::info!("Sending SIGTERM to MLX process PID {}", raw_pid);
|
||||
let _ = kill(Pid::from_raw(raw_pid), Signal::SIGTERM);
|
||||
|
||||
match timeout(Duration::from_secs(5), child.wait()).await {
|
||||
Ok(Ok(status)) => log::info!("MLX process exited gracefully: {}", status),
|
||||
Ok(Err(e)) => log::error!("Error waiting after SIGTERM for MLX process: {}", e),
|
||||
Err(_) => {
|
||||
log::warn!(
|
||||
"SIGTERM timed out for MLX PID {}; sending SIGKILL",
|
||||
raw_pid
|
||||
);
|
||||
let _ = kill(Pid::from_raw(raw_pid), Signal::SIGKILL);
|
||||
match child.wait().await {
|
||||
Ok(s) => log::info!("Force-killed MLX process exited: {}", s),
|
||||
Err(e) => log::error!("Error waiting after SIGKILL for MLX process: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Find a session by model ID
|
||||
pub async fn find_session_by_model_id<R: Runtime>(
|
||||
app_handle: tauri::AppHandle<R>,
|
||||
model_id: &str,
|
||||
) -> Result<Option<SessionInfo>, String> {
|
||||
let state: State<MlxState> = app_handle.state();
|
||||
let map = state.mlx_server_process.lock().await;
|
||||
|
||||
let session_info = map
|
||||
.values()
|
||||
.find(|backend_session| backend_session.info.model_id == model_id)
|
||||
.map(|backend_session| backend_session.info.clone());
|
||||
|
||||
Ok(session_info)
|
||||
}
|
||||
|
||||
/// Get all loaded model IDs
|
||||
pub async fn get_all_loaded_model_ids<R: Runtime>(
|
||||
app_handle: tauri::AppHandle<R>,
|
||||
) -> Result<Vec<String>, String> {
|
||||
let state: State<MlxState> = app_handle.state();
|
||||
let map = state.mlx_server_process.lock().await;
|
||||
|
||||
let model_ids = map
|
||||
.values()
|
||||
.map(|backend_session| backend_session.info.model_id.clone())
|
||||
.collect();
|
||||
|
||||
Ok(model_ids)
|
||||
}
|
||||
|
||||
/// Get all active sessions
|
||||
pub async fn get_all_active_sessions<R: Runtime>(
|
||||
app_handle: tauri::AppHandle<R>,
|
||||
) -> Result<Vec<SessionInfo>, String> {
|
||||
let state: State<MlxState> = app_handle.state();
|
||||
let map = state.mlx_server_process.lock().await;
|
||||
let sessions: Vec<SessionInfo> = map.values().map(|s| s.info.clone()).collect();
|
||||
Ok(sessions)
|
||||
}
|
||||
39
src-tauri/plugins/tauri-plugin-mlx/src/state.rs
Normal file
39
src-tauri/plugins/tauri-plugin-mlx/src/state.rs
Normal file
@@ -0,0 +1,39 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::process::Child;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SessionInfo {
|
||||
pub pid: i32,
|
||||
pub port: i32,
|
||||
pub model_id: String,
|
||||
pub model_path: String,
|
||||
pub is_embedding: bool,
|
||||
pub api_key: String,
|
||||
}
|
||||
|
||||
pub struct MlxBackendSession {
|
||||
pub child: Child,
|
||||
pub info: SessionInfo,
|
||||
}
|
||||
|
||||
/// MLX plugin state
|
||||
pub struct MlxState {
|
||||
pub mlx_server_process: Arc<Mutex<HashMap<i32, MlxBackendSession>>>,
|
||||
}
|
||||
|
||||
impl Default for MlxState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
mlx_server_process: Arc::new(Mutex::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MlxState {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
14
src-tauri/plugins/tauri-plugin-mlx/tsconfig.json
Normal file
14
src-tauri/plugins/tauri-plugin-mlx/tsconfig.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "es2021",
|
||||
"module": "esnext",
|
||||
"moduleResolution": "bundler",
|
||||
"skipLibCheck": true,
|
||||
"strict": true,
|
||||
"noUnusedLocals": true,
|
||||
"noImplicitAny": true,
|
||||
"noEmit": true
|
||||
},
|
||||
"include": ["guest-js/*.ts"],
|
||||
"exclude": ["dist-js", "node_modules"]
|
||||
}
|
||||
@@ -52,6 +52,18 @@ __metadata:
|
||||
languageName: unknown
|
||||
linkType: soft
|
||||
|
||||
"@janhq/tauri-plugin-mlx-api@workspace:tauri-plugin-mlx":
|
||||
version: 0.0.0-use.local
|
||||
resolution: "@janhq/tauri-plugin-mlx-api@workspace:tauri-plugin-mlx"
|
||||
dependencies:
|
||||
"@rollup/plugin-typescript": "npm:^12.0.0"
|
||||
"@tauri-apps/api": "npm:>=2.0.0-beta.6"
|
||||
rollup: "npm:^4.9.6"
|
||||
tslib: "npm:^2.6.2"
|
||||
typescript: "npm:^5.3.3"
|
||||
languageName: unknown
|
||||
linkType: soft
|
||||
|
||||
"@janhq/tauri-plugin-rag-api@workspace:tauri-plugin-rag":
|
||||
version: 0.0.0-use.local
|
||||
resolution: "@janhq/tauri-plugin-rag-api@workspace:tauri-plugin-rag"
|
||||
|
||||
@@ -41,6 +41,11 @@ pub fn run() {
|
||||
app_builder = app_builder.plugin(tauri_plugin_deep_link::init());
|
||||
}
|
||||
|
||||
#[cfg(feature = "mlx")]
|
||||
{
|
||||
app_builder = app_builder.plugin(tauri_plugin_mlx::init());
|
||||
}
|
||||
|
||||
#[cfg(not(any(target_os = "android", target_os = "ios")))]
|
||||
{
|
||||
app_builder = app_builder.plugin(tauri_plugin_hardware::init());
|
||||
@@ -328,6 +333,17 @@ pub fn run() {
|
||||
} else {
|
||||
log::info!("Llama processes cleaned up successfully");
|
||||
}
|
||||
|
||||
#[cfg(feature = "mlx")]
|
||||
{
|
||||
use tauri_plugin_mlx::cleanup_mlx_processes;
|
||||
if let Err(e) = cleanup_mlx_processes(app_handle.clone()).await {
|
||||
log::warn!("Failed to cleanup MLX processes: {}", e);
|
||||
} else {
|
||||
log::info!("MLX processes cleaned up successfully");
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("App cleanup completed");
|
||||
});
|
||||
});
|
||||
|
||||
BIN
web-app/public/images/model-provider/mlx.png
Normal file
BIN
web-app/public/images/model-provider/mlx.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 63 KiB |
@@ -41,7 +41,7 @@ export const DialogDeleteModel = ({
|
||||
deleteModelCache(selectedModelId)
|
||||
serviceHub
|
||||
.models()
|
||||
.deleteModel(selectedModelId)
|
||||
.deleteModel(selectedModelId, provider.provider)
|
||||
.then(() => {
|
||||
serviceHub
|
||||
.providers()
|
||||
|
||||
264
web-app/src/containers/dialogs/ImportMlxModelDialog.tsx
Normal file
264
web-app/src/containers/dialogs/ImportMlxModelDialog.tsx
Normal file
@@ -0,0 +1,264 @@
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogTrigger,
|
||||
} from '@/components/ui/dialog'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { useServiceHub } from '@/hooks/useServiceHub'
|
||||
import { useState } from 'react'
|
||||
import { toast } from 'sonner'
|
||||
import {
|
||||
IconLoader2,
|
||||
IconCheck,
|
||||
} from '@tabler/icons-react'
|
||||
import { ExtensionManager } from '@/lib/extension'
|
||||
|
||||
type ImportMlxModelDialogProps = {
|
||||
provider: ModelProvider
|
||||
trigger?: React.ReactNode
|
||||
onSuccess?: (importedModelName?: string) => void
|
||||
}
|
||||
|
||||
export const ImportMlxModelDialog = ({
|
||||
provider,
|
||||
trigger,
|
||||
onSuccess,
|
||||
}: ImportMlxModelDialogProps) => {
|
||||
const serviceHub = useServiceHub()
|
||||
const [open, setOpen] = useState(false)
|
||||
const [importing, setImporting] = useState(false)
|
||||
const [selectedPath, setSelectedPath] = useState<string | null>(null)
|
||||
const [modelName, setModelName] = useState('')
|
||||
|
||||
const handleFileSelect = async () => {
|
||||
const result = await serviceHub.dialog().open({
|
||||
multiple: false,
|
||||
directory: false,
|
||||
filters: [
|
||||
{
|
||||
name: 'Safetensor Files',
|
||||
extensions: ['safetensors'],
|
||||
},
|
||||
{
|
||||
name: 'All Files',
|
||||
extensions: ['*'],
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
if (result && typeof result === 'string') {
|
||||
setSelectedPath(result)
|
||||
|
||||
// Extract model name from path
|
||||
const pathParts = result.split(/[\\/]/)
|
||||
const nameFromPath = pathParts[pathParts.length - 1] || 'mlx-model'
|
||||
const sanitizedName = nameFromPath
|
||||
.replace(/\s/g, '-')
|
||||
.replace(/[^a-zA-Z0-9/_.\-]/g, '')
|
||||
setModelName(sanitizedName)
|
||||
}
|
||||
}
|
||||
|
||||
const handleImport = async () => {
|
||||
if (!selectedPath) {
|
||||
toast.error('Please select a safetensor file or folder')
|
||||
return
|
||||
}
|
||||
|
||||
if (!modelName) {
|
||||
toast.error('Please enter a model name')
|
||||
return
|
||||
}
|
||||
|
||||
// Validate model name - only allow alphanumeric, underscore, hyphen, and dot
|
||||
if (!/^[a-zA-Z0-9/_.\-]+$/.test(modelName)) {
|
||||
toast.error('Invalid model name. Only alphanumeric and _ - . characters are allowed.')
|
||||
return
|
||||
}
|
||||
|
||||
// Check if model already exists
|
||||
const modelExists = provider.models.some(
|
||||
(model) => model.id === modelName
|
||||
)
|
||||
|
||||
if (modelExists) {
|
||||
toast.error('Model already exists', {
|
||||
description: `${modelName} already imported`,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
setImporting(true)
|
||||
|
||||
try {
|
||||
console.log('[MLX Import] Starting import:', { modelName, selectedPath })
|
||||
|
||||
// Get the MLX engine and call its import method
|
||||
const engine = ExtensionManager.getInstance().getEngine('mlx')
|
||||
if (!engine) {
|
||||
throw new Error('MLX engine not found')
|
||||
}
|
||||
|
||||
console.log('[MLX Import] Calling engine.import()...')
|
||||
await engine.import(modelName, {
|
||||
modelPath: selectedPath,
|
||||
})
|
||||
console.log('[MLX Import] Import completed')
|
||||
|
||||
toast.success('Model imported successfully', {
|
||||
description: `${modelName} has been imported`,
|
||||
})
|
||||
|
||||
// Reset form and close dialog
|
||||
setSelectedPath(null)
|
||||
setModelName('')
|
||||
setOpen(false)
|
||||
onSuccess?.(modelName)
|
||||
} catch (error) {
|
||||
console.error('[MLX Import] Import model error:', error)
|
||||
toast.error('Failed to import model', {
|
||||
description:
|
||||
error instanceof Error ? error.message : String(error),
|
||||
})
|
||||
} finally {
|
||||
setImporting(false)
|
||||
}
|
||||
}
|
||||
|
||||
const resetForm = () => {
|
||||
setSelectedPath(null)
|
||||
setModelName('')
|
||||
}
|
||||
|
||||
const handleOpenChange = (newOpen: boolean) => {
|
||||
if (!importing) {
|
||||
setOpen(newOpen)
|
||||
if (!newOpen) {
|
||||
resetForm()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const displayPath = selectedPath
|
||||
? selectedPath.split(/[\\/]/).pop() || selectedPath
|
||||
: null
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={handleOpenChange}>
|
||||
<DialogTrigger asChild>{trigger}</DialogTrigger>
|
||||
<DialogContent
|
||||
onInteractOutside={(e) => {
|
||||
e.preventDefault()
|
||||
}}
|
||||
>
|
||||
<DialogHeader>
|
||||
<DialogTitle className="flex items-center gap-2">
|
||||
Import MLX Model
|
||||
</DialogTitle>
|
||||
<DialogDescription>
|
||||
Import a safetensor model file or folder for use with MLX. MLX models
|
||||
are typically downloaded from HuggingFace and use the safetensors format.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<div className="space-y-6">
|
||||
{/* Model Name Input */}
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">
|
||||
Model Name
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={modelName}
|
||||
onChange={(e) => setModelName(e.target.value)}
|
||||
placeholder="my-mlx-model"
|
||||
className="w-full px-3 py-2 bg-background border rounded-lg focus:outline-none focus:ring-2 focus:ring-accent/50"
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Only alphanumeric and _ - . characters are allowed
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* File Selection Area */}
|
||||
<div className="border rounded-lg p-4 space-y-3">
|
||||
<div className="flex items-center gap-2">
|
||||
<h3 className="font-medium">
|
||||
Safetensor File or Folder
|
||||
</h3>
|
||||
<span className="text-xs bg-secondary px-2 py-1 rounded-sm">
|
||||
Required
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{displayPath ? (
|
||||
<div className="bg-accent/10 border rounded-lg p-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-2">
|
||||
<IconCheck size={16} className="text-accent" />
|
||||
<span className="text-sm font-medium">
|
||||
{displayPath}
|
||||
</span>
|
||||
</div>
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
onClick={handleFileSelect}
|
||||
disabled={importing}
|
||||
>
|
||||
Change
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<Button
|
||||
type="button"
|
||||
variant="link"
|
||||
onClick={handleFileSelect}
|
||||
disabled={importing}
|
||||
className="w-full h-12 border border-dashed text-muted-foreground"
|
||||
>
|
||||
Select Safetensor File or Folder
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Preview */}
|
||||
{modelName && (
|
||||
<div className="rounded-lg p-3">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-sm font-medium text-muted-foreground">
|
||||
Model will be saved as:
|
||||
</span>
|
||||
</div>
|
||||
<p className="text-sm font-mono mt-1">
|
||||
mlx/models/{modelName}/
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex gap-2 pt-4 justify-end">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={() => handleOpenChange(false)}
|
||||
disabled={importing}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleImport}
|
||||
size="sm"
|
||||
disabled={importing || !selectedPath || !modelName}
|
||||
>
|
||||
{importing && <IconLoader2 className="mr-2 size-4 animate-spin" />}
|
||||
{importing ? 'Importing...' : 'Import Model'}
|
||||
</Button>
|
||||
</div>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
)
|
||||
}
|
||||
@@ -55,10 +55,7 @@ export class CustomChatTransport implements ChatTransport<UIMessage> {
|
||||
private serviceHub: ServiceHub | null
|
||||
private threadId?: string
|
||||
|
||||
constructor(
|
||||
systemMessage?: string,
|
||||
threadId?: string
|
||||
) {
|
||||
constructor(systemMessage?: string, threadId?: string) {
|
||||
this.systemMessage = systemMessage
|
||||
this.threadId = threadId
|
||||
this.serviceHub = useServiceStore.getState().serviceHub
|
||||
@@ -248,17 +245,17 @@ export class CustomChatTransport implements ChatTransport<UIMessage> {
|
||||
|
||||
return result.toUIMessageStream({
|
||||
messageMetadata: ({ part }) => {
|
||||
if (!streamStartTime) {
|
||||
// Track stream start time on start
|
||||
if (part.type === 'start' && !streamStartTime) {
|
||||
streamStartTime = Date.now()
|
||||
}
|
||||
// Track stream start time on first text delta
|
||||
if (part.type === 'text-delta') {
|
||||
// Count text deltas as a rough token approximation
|
||||
// Each delta typically represents one token in streaming
|
||||
textDeltaCount++
|
||||
|
||||
// Report streaming token speed in real-time
|
||||
if (this.onStreamingTokenSpeed) {
|
||||
if (this.onStreamingTokenSpeed && streamStartTime) {
|
||||
const elapsedMs = Date.now() - streamStartTime
|
||||
this.onStreamingTokenSpeed(textDeltaCount, elapsedMs)
|
||||
}
|
||||
@@ -279,22 +276,18 @@ export class CustomChatTransport implements ChatTransport<UIMessage> {
|
||||
}
|
||||
}
|
||||
const usage = finishPart.totalUsage
|
||||
const llamacppMeta = finishPart.providerMetadata?.llamacpp
|
||||
const durationMs = streamStartTime ? Date.now() - streamStartTime : 0
|
||||
const durationSec = durationMs / 1000
|
||||
|
||||
// Use provider's outputTokens, or llama.cpp completionTokens, or fall back to text delta count
|
||||
const outputTokens =
|
||||
usage?.outputTokens ??
|
||||
llamacppMeta?.completionTokens ??
|
||||
textDeltaCount
|
||||
const inputTokens = usage?.inputTokens ?? llamacppMeta?.promptTokens
|
||||
const inputTokens = usage?.inputTokens
|
||||
|
||||
// Use llama.cpp's tokens per second if available, otherwise calculate from duration
|
||||
let tokenSpeed: number
|
||||
if (llamacppMeta?.tokensPerSecond != null) {
|
||||
tokenSpeed = llamacppMeta.tokensPerSecond
|
||||
} else if (durationSec > 0 && outputTokens > 0) {
|
||||
if (durationSec > 0 && outputTokens > 0) {
|
||||
tokenSpeed = outputTokens / durationSec
|
||||
} else {
|
||||
tokenSpeed = 0
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
*
|
||||
* Supported Providers:
|
||||
* - llamacpp: Local models via llama.cpp (requires running session)
|
||||
* - mlx: Local models via MLX-Swift on Apple Silicon (requires running session)
|
||||
* - anthropic: Claude models via Anthropic API (@ai-sdk/anthropic v2.0)
|
||||
* - google/gemini: Gemini models via Google Generative AI API (@ai-sdk/google v2.0)
|
||||
* - openai: OpenAI models via OpenAI API (@ai-sdk/openai)
|
||||
@@ -30,6 +31,7 @@ import { createOpenAICompatible } from '@ai-sdk/openai-compatible'
|
||||
import { createAnthropic } from '@ai-sdk/anthropic'
|
||||
import { invoke } from '@tauri-apps/api/core'
|
||||
import { SessionInfo } from '@janhq/core'
|
||||
import { fetch } from '@tauri-apps/plugin-http'
|
||||
|
||||
/**
|
||||
* Llama.cpp timings structure from the response
|
||||
@@ -109,6 +111,9 @@ export class ModelFactory {
|
||||
case 'llamacpp':
|
||||
return this.createLlamaCppModel(modelId, provider)
|
||||
|
||||
case 'mlx':
|
||||
return this.createMlxModel(modelId, provider)
|
||||
|
||||
case 'anthropic':
|
||||
return this.createAnthropicModel(modelId, provider)
|
||||
|
||||
@@ -173,6 +178,59 @@ export class ModelFactory {
|
||||
Origin: 'tauri://localhost',
|
||||
},
|
||||
includeUsage: true,
|
||||
fetch: fetch,
|
||||
})
|
||||
|
||||
return openAICompatible.languageModel(modelId, {
|
||||
metadataExtractor: llamaCppMetadataExtractor,
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an MLX model by starting the model and finding the running session.
|
||||
* MLX uses the same OpenAI-compatible API pattern as llamacpp.
|
||||
*/
|
||||
private static async createMlxModel(
|
||||
modelId: string,
|
||||
provider?: ProviderObject
|
||||
): Promise<LanguageModel> {
|
||||
// Start the model first if provider is available
|
||||
if (provider) {
|
||||
try {
|
||||
const { useServiceStore } = await import('@/hooks/useServiceHub')
|
||||
const serviceHub = useServiceStore.getState().serviceHub
|
||||
|
||||
if (serviceHub) {
|
||||
await serviceHub.models().startModel(provider, modelId)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to start MLX model:', error)
|
||||
throw new Error(
|
||||
`Failed to start model: ${error instanceof Error ? error.message : JSON.stringify(error)}`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Get session info which includes port and api_key
|
||||
const sessionInfo = await invoke<SessionInfo | null>(
|
||||
'plugin:mlx|find_mlx_session_by_model',
|
||||
{ modelId }
|
||||
)
|
||||
|
||||
if (!sessionInfo) {
|
||||
throw new Error(`No running MLX session found for model: ${modelId}`)
|
||||
}
|
||||
|
||||
// Create OpenAI-compatible client for MLX server
|
||||
const openAICompatible = createOpenAICompatible({
|
||||
name: 'mlx',
|
||||
baseURL: `http://localhost:${sessionInfo.port}/v1`,
|
||||
headers: {
|
||||
Authorization: `Bearer ${sessionInfo.api_key}`,
|
||||
Origin: 'tauri://localhost',
|
||||
},
|
||||
includeUsage: true,
|
||||
fetch: fetch,
|
||||
})
|
||||
|
||||
return openAICompatible.languageModel(modelId, {
|
||||
|
||||
@@ -68,6 +68,8 @@ export function getProviderLogo(provider: string) {
|
||||
return '/images/model-provider/jan.png'
|
||||
case 'llamacpp':
|
||||
return '/images/model-provider/llamacpp.svg'
|
||||
case 'mlx':
|
||||
return '/images/model-provider/mlx.png'
|
||||
case 'anthropic':
|
||||
return '/images/model-provider/anthropic.svg'
|
||||
case 'huggingface':
|
||||
@@ -97,6 +99,8 @@ export const getProviderTitle = (provider: string) => {
|
||||
return 'Jan'
|
||||
case 'llamacpp':
|
||||
return 'Llama.cpp'
|
||||
case 'mlx':
|
||||
return 'MLX'
|
||||
case 'openai':
|
||||
return 'OpenAI'
|
||||
case 'openrouter':
|
||||
|
||||
@@ -4,17 +4,14 @@ import HeaderPage from '@/containers/HeaderPage'
|
||||
import SettingsMenu from '@/containers/SettingsMenu'
|
||||
import { useModelProvider } from '@/hooks/useModelProvider'
|
||||
import { cn, getProviderTitle, getModelDisplayName } from '@/lib/utils'
|
||||
import {
|
||||
createFileRoute,
|
||||
Link,
|
||||
useParams,
|
||||
} from '@tanstack/react-router'
|
||||
import { createFileRoute, Link, useParams } from '@tanstack/react-router'
|
||||
import { useTranslation } from '@/i18n/react-i18next-compat'
|
||||
import Capabilities from '@/containers/Capabilities'
|
||||
import { DynamicControllerSetting } from '@/containers/dynamicControllerSetting'
|
||||
import { RenderMarkdown } from '@/containers/RenderMarkdown'
|
||||
import { DialogEditModel } from '@/containers/dialogs/EditModel'
|
||||
import { ImportVisionModelDialog } from '@/containers/dialogs/ImportVisionModelDialog'
|
||||
import { ImportMlxModelDialog } from '@/containers/dialogs/ImportMlxModelDialog'
|
||||
import { ModelSetting } from '@/containers/ModelSetting'
|
||||
import { DialogDeleteModel } from '@/containers/dialogs/DeleteModel'
|
||||
import { FavoriteModelAction } from '@/containers/FavoriteModelAction'
|
||||
@@ -68,9 +65,9 @@ function ProviderDetail() {
|
||||
const { getProviderByName, setProviders, updateProvider } = useModelProvider()
|
||||
const provider = getProviderByName(providerName)
|
||||
|
||||
// Check if llamacpp provider needs backend configuration
|
||||
// Check if llamacpp/mlx provider needs backend configuration
|
||||
const needsBackendConfig =
|
||||
provider?.provider === 'llamacpp' &&
|
||||
(provider?.provider === 'llamacpp' || provider?.provider === 'mlx') &&
|
||||
provider.settings?.some(
|
||||
(setting) =>
|
||||
setting.key === 'version_backend' &&
|
||||
@@ -144,12 +141,14 @@ function ProviderDetail() {
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
// Initial data fetch
|
||||
serviceHub
|
||||
.models()
|
||||
.getActiveModels()
|
||||
.then((models) => setActiveModels(models || []))
|
||||
}, [serviceHub, setActiveModels])
|
||||
// Initial data fetch - load active models for the current provider
|
||||
if (provider?.provider) {
|
||||
serviceHub
|
||||
.models()
|
||||
.getActiveModels(provider.provider)
|
||||
.then((models) => setActiveModels(models || []))
|
||||
}
|
||||
}, [serviceHub, setActiveModels, provider?.provider])
|
||||
|
||||
// Clear importing state when model appears in the provider's model list
|
||||
useEffect(() => {
|
||||
@@ -270,10 +269,10 @@ function ProviderDetail() {
|
||||
// Start the model with plan result
|
||||
await serviceHub.models().startModel(provider, modelId)
|
||||
|
||||
// Refresh active models after starting
|
||||
// Refresh active models after starting (pass provider to get correct engine's loaded models)
|
||||
serviceHub
|
||||
.models()
|
||||
.getActiveModels()
|
||||
.getActiveModels(provider.provider)
|
||||
.then((models) => setActiveModels(models || []))
|
||||
} catch (error) {
|
||||
setModelLoadError(error as ErrorObject)
|
||||
@@ -288,12 +287,12 @@ function ProviderDetail() {
|
||||
// Original: stopModel(modelId).then(() => { setActiveModels((prevModels) => prevModels.filter((model) => model !== modelId)) })
|
||||
serviceHub
|
||||
.models()
|
||||
.stopModel(modelId)
|
||||
.stopModel(modelId, provider?.provider)
|
||||
.then(() => {
|
||||
// Refresh active models after stopping
|
||||
// Refresh active models after stopping (pass provider to get correct engine's loaded models)
|
||||
serviceHub
|
||||
.models()
|
||||
.getActiveModels()
|
||||
.getActiveModels(provider?.provider)
|
||||
.then((models) => setActiveModels(models || []))
|
||||
})
|
||||
.catch((error) => {
|
||||
@@ -302,7 +301,8 @@ function ProviderDetail() {
|
||||
}
|
||||
|
||||
const handleCheckForBackendUpdate = useCallback(async () => {
|
||||
if (provider?.provider !== 'llamacpp') return
|
||||
if (provider?.provider !== 'llamacpp' && provider?.provider !== 'mlx')
|
||||
return
|
||||
|
||||
setIsCheckingBackendUpdate(true)
|
||||
try {
|
||||
@@ -320,7 +320,8 @@ function ProviderDetail() {
|
||||
}, [provider, checkForBackendUpdate, t])
|
||||
|
||||
const handleInstallBackendFromFile = useCallback(async () => {
|
||||
if (provider?.provider !== 'llamacpp') return
|
||||
if (provider?.provider !== 'llamacpp' && provider?.provider !== 'mlx')
|
||||
return
|
||||
|
||||
setIsInstallingBackend(true)
|
||||
try {
|
||||
@@ -345,8 +346,12 @@ function ProviderDetail() {
|
||||
// Extract filename from the selected file path and replace spaces with dashes
|
||||
const fileName = basenameNoExt(selectedFile).replace(/\s+/g, '-')
|
||||
|
||||
// Capitalize provider name for display
|
||||
const providerDisplayName =
|
||||
provider?.provider === 'llamacpp' ? 'Llamacpp' : 'MLX'
|
||||
|
||||
toast.success(t('settings:backendInstallSuccess'), {
|
||||
description: `Llamacpp ${fileName} installed`,
|
||||
description: `${providerDisplayName} ${fileName} installed`,
|
||||
})
|
||||
|
||||
// Refresh settings to update backend configuration
|
||||
@@ -367,7 +372,9 @@ function ProviderDetail() {
|
||||
<div className="flex flex-col h-svh w-full">
|
||||
<HeaderPage>
|
||||
<div className="flex items-center gap-2 w-full">
|
||||
<span className='font-medium text-base font-studio'>{t('common:settings')}</span>
|
||||
<span className="font-medium text-base font-studio">
|
||||
{t('common:settings')}
|
||||
</span>
|
||||
</div>
|
||||
</HeaderPage>
|
||||
<div className="flex h-[calc(100%-60px)]">
|
||||
@@ -384,8 +391,9 @@ function ProviderDetail() {
|
||||
className={cn(
|
||||
'flex flex-col gap-3',
|
||||
provider &&
|
||||
provider.provider === 'llamacpp' &&
|
||||
'flex-col-reverse'
|
||||
(provider.provider === 'llamacpp' ||
|
||||
provider.provider === 'mlx') &&
|
||||
'flex-col-reverse'
|
||||
)}
|
||||
>
|
||||
{/* Settings */}
|
||||
@@ -395,7 +403,7 @@ function ProviderDetail() {
|
||||
const actionComponent = (
|
||||
<div className="mt-2">
|
||||
{needsBackendConfig &&
|
||||
setting.key === 'version_backend' ? (
|
||||
setting.key === 'version_backend' ? (
|
||||
<div className="flex items-center gap-1 text-sm">
|
||||
<IconLoader size={16} className="animate-spin" />
|
||||
<span>loading</span>
|
||||
@@ -404,21 +412,18 @@ function ProviderDetail() {
|
||||
<DynamicControllerSetting
|
||||
controllerType={setting.controller_type}
|
||||
controllerProps={setting.controller_props}
|
||||
className={cn(
|
||||
setting.key === 'device' && 'hidden'
|
||||
)}
|
||||
className={cn(setting.key === 'device' && 'hidden')}
|
||||
onChange={(newValue) => {
|
||||
if (provider) {
|
||||
const newSettings = [...provider.settings]
|
||||
// Handle different value types by forcing the type
|
||||
// Use type assertion to bypass type checking
|
||||
// Handle different value types by forcing the type
|
||||
// Use type assertion to bypass type checking
|
||||
|
||||
; (
|
||||
newSettings[settingIndex]
|
||||
.controller_props as {
|
||||
value: string | boolean | number
|
||||
}
|
||||
).value = newValue
|
||||
;(
|
||||
newSettings[settingIndex].controller_props as {
|
||||
value: string | boolean | number
|
||||
}
|
||||
).value = newValue
|
||||
|
||||
// Create update object with updated settings
|
||||
const updateObj: Partial<ModelProvider> = {
|
||||
@@ -446,11 +451,11 @@ function ProviderDetail() {
|
||||
)
|
||||
|
||||
if (deviceSettingIndex !== -1) {
|
||||
(
|
||||
;(
|
||||
newSettings[deviceSettingIndex]
|
||||
.controller_props as {
|
||||
value: string
|
||||
}
|
||||
value: string
|
||||
}
|
||||
).value = ''
|
||||
}
|
||||
|
||||
@@ -480,9 +485,7 @@ function ProviderDetail() {
|
||||
serviceHub
|
||||
.models()
|
||||
.getActiveModels()
|
||||
.then((models) =>
|
||||
setActiveModels(models || [])
|
||||
)
|
||||
.then((models) => setActiveModels(models || []))
|
||||
}
|
||||
}}
|
||||
/>
|
||||
@@ -497,7 +500,7 @@ function ProviderDetail() {
|
||||
className={cn(setting.key === 'device' && 'hidden')}
|
||||
column={
|
||||
setting.controller_type === 'input' &&
|
||||
setting.controller_props.type !== 'number'
|
||||
setting.controller_props.type !== 'number'
|
||||
? true
|
||||
: false
|
||||
}
|
||||
@@ -535,7 +538,8 @@ function ProviderDetail() {
|
||||
</div>
|
||||
)}
|
||||
{setting.key === 'version_backend' &&
|
||||
provider?.provider === 'llamacpp' && (
|
||||
(provider?.provider === 'llamacpp' ||
|
||||
provider?.provider === 'mlx') && (
|
||||
<div className="mt-2 flex flex-wrap gap-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
@@ -543,27 +547,22 @@ function ProviderDetail() {
|
||||
className={cn(
|
||||
'p-0',
|
||||
isCheckingBackendUpdate &&
|
||||
'pointer-events-none'
|
||||
'pointer-events-none'
|
||||
)}
|
||||
onClick={handleCheckForBackendUpdate}
|
||||
>
|
||||
<IconRefresh
|
||||
size={12}
|
||||
className={cn(
|
||||
'text-muted-foreground',
|
||||
isCheckingBackendUpdate &&
|
||||
'animate-spin'
|
||||
)}
|
||||
/>
|
||||
<span>
|
||||
{isCheckingBackendUpdate
|
||||
? t(
|
||||
'settings:checkingForBackendUpdates'
|
||||
)
|
||||
: t(
|
||||
'settings:checkForBackendUpdates'
|
||||
)}
|
||||
</span>
|
||||
size={12}
|
||||
className={cn(
|
||||
'text-muted-foreground',
|
||||
isCheckingBackendUpdate && 'animate-spin'
|
||||
)}
|
||||
/>
|
||||
<span>
|
||||
{isCheckingBackendUpdate
|
||||
? t('settings:checkingForBackendUpdates')
|
||||
: t('settings:checkForBackendUpdates')}
|
||||
</span>
|
||||
</Button>
|
||||
<Button
|
||||
variant="outline"
|
||||
@@ -605,7 +604,7 @@ function ProviderDetail() {
|
||||
{t('providers:models')}
|
||||
</h1>
|
||||
<div className="flex items-center gap-2">
|
||||
{provider && provider.provider !== 'llamacpp' && (
|
||||
{provider && provider.provider !== 'llamacpp' && provider.provider !== 'mlx' && (
|
||||
<>
|
||||
<Button
|
||||
variant="secondary"
|
||||
@@ -648,6 +647,21 @@ function ProviderDetail() {
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{provider && provider.provider === 'mlx' && (
|
||||
<ImportMlxModelDialog
|
||||
provider={provider}
|
||||
onSuccess={handleModelImportSuccess}
|
||||
trigger={
|
||||
<Button variant="secondary" size="sm">
|
||||
<IconFolderPlus
|
||||
size={18}
|
||||
className="text-muted-foreground"
|
||||
/>
|
||||
<span>{t('providers:import')}</span>
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
@@ -676,10 +690,7 @@ function ProviderDetail() {
|
||||
modelId={model.id}
|
||||
/>
|
||||
{model.settings && (
|
||||
<ModelSetting
|
||||
provider={provider}
|
||||
model={model}
|
||||
/>
|
||||
<ModelSetting provider={provider} model={model} />
|
||||
)}
|
||||
{((provider &&
|
||||
!predefinedProviders.some(
|
||||
@@ -690,8 +701,8 @@ function ProviderDetail() {
|
||||
(p) => p.provider === provider.provider
|
||||
) &&
|
||||
Boolean(provider.api_key?.length))) && (
|
||||
<FavoriteModelAction model={model} />
|
||||
)}
|
||||
<FavoriteModelAction model={model} />
|
||||
)}
|
||||
<DialogDeleteModel
|
||||
provider={provider}
|
||||
modelId={model.id}
|
||||
@@ -711,9 +722,7 @@ function ProviderDetail() {
|
||||
) : (
|
||||
<Button
|
||||
size="sm"
|
||||
disabled={loadingModels.includes(
|
||||
model.id
|
||||
)}
|
||||
disabled={loadingModels.includes(model.id)}
|
||||
onClick={() => handleStartModel(model.id)}
|
||||
>
|
||||
{loadingModels.includes(model.id) ? (
|
||||
|
||||
@@ -282,8 +282,8 @@ export class DefaultModelsService implements ModelsService {
|
||||
}
|
||||
}
|
||||
|
||||
async deleteModel(id: string): Promise<void> {
|
||||
return this.getEngine()?.delete(id)
|
||||
async deleteModel(id: string, provider?: string): Promise<void> {
|
||||
return this.getEngine(provider)?.delete(id)
|
||||
}
|
||||
|
||||
async getActiveModels(provider?: string): Promise<string[]> {
|
||||
|
||||
@@ -125,7 +125,7 @@ export interface ModelsService {
|
||||
skipVerification?: boolean
|
||||
): Promise<void>
|
||||
abortDownload(id: string): Promise<void>
|
||||
deleteModel(id: string): Promise<void>
|
||||
deleteModel(id: string, provider?: string): Promise<void>
|
||||
getActiveModels(provider?: string): Promise<string[]>
|
||||
stopModel(model: string, provider?: string): Promise<UnloadResult | undefined>
|
||||
stopAllModels(): Promise<void>
|
||||
|
||||
Reference in New Issue
Block a user