From c8fa27a5f16d85c811818393dc2997beb0403ebc Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Fri, 22 Aug 2025 08:41:48 +0200 Subject: [PATCH 1/4] Add mlx-vlm Signed-off-by: Ettore Di Giacinto --- Makefile | 4 + backend/index.yaml | 22 ++ backend/python/mlx-vlm/Makefile | 23 ++ backend/python/mlx-vlm/backend.py | 477 ++++++++++++++++++++++++ backend/python/mlx-vlm/install.sh | 14 + backend/python/mlx-vlm/requirements.txt | 4 + backend/python/mlx-vlm/run.sh | 11 + backend/python/mlx-vlm/test.py | 146 ++++++++ backend/python/mlx-vlm/test.sh | 12 + 9 files changed, 713 insertions(+) create mode 100644 backend/python/mlx-vlm/Makefile create mode 100644 backend/python/mlx-vlm/backend.py create mode 100755 backend/python/mlx-vlm/install.sh create mode 100644 backend/python/mlx-vlm/requirements.txt create mode 100755 backend/python/mlx-vlm/run.sh create mode 100644 backend/python/mlx-vlm/test.py create mode 100755 backend/python/mlx-vlm/test.sh diff --git a/Makefile b/Makefile index e9197fd3b0d3..80738e78e5c1 100644 --- a/Makefile +++ b/Makefile @@ -373,6 +373,10 @@ backends/diffuser-darwin: USE_PIP=true BACKEND=diffusers BUILD_TYPE=mps $(MAKE) build-darwin-python-backend ./local-ai backends install "ocifile://$(abspath ./backend-images/diffusers.tar)" +backends/mlx-vlm: build + BACKEND=mlx-vlm BUILD_TYPE=mps bash ./scripts/build/python-darwin.sh + ./local-ai backends install "ocifile://$(abspath ./backend-images/mlx-vlm.tar)" + backend-images: mkdir -p backend-images diff --git a/backend/index.yaml b/backend/index.yaml index dc8036e1cc93..110dc6fbb198 100644 --- a/backend/index.yaml +++ b/backend/index.yaml @@ -142,6 +142,23 @@ - text-to-text - LLM - MLX +- &mlx-vlm + name: "mlx-vlm" + uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx-vlm" + icon: https://avatars.githubusercontent.com/u/102832242?s=200&v=4 + urls: + - https://github.com/ml-explore/mlx-vlm + mirrors: + - localai/localai-backends:latest-metal-darwin-arm64-mlx-vlm + license: MIT + description: | + Run Vision-Language Models with MLX + tags: + - text-to-text + - multimodal + - vision-language + - LLM + - MLX - &rerankers name: "rerankers" alias: "rerankers" @@ -392,6 +409,11 @@ uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx" mirrors: - localai/localai-backends:master-metal-darwin-arm64-mlx +- !!merge <<: *mlx-vlm + name: "mlx-vlm-development" + uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx-vlm" + mirrors: + - localai/localai-backends:master-metal-darwin-arm64-mlx-vlm - !!merge <<: *kitten-tts name: "kitten-tts-development" uri: "quay.io/go-skynet/local-ai-backends:master-kitten-tts" diff --git a/backend/python/mlx-vlm/Makefile b/backend/python/mlx-vlm/Makefile new file mode 100644 index 000000000000..804031aa970d --- /dev/null +++ b/backend/python/mlx-vlm/Makefile @@ -0,0 +1,23 @@ +.PHONY: mlx-vlm +mlx-vlm: + bash install.sh + +.PHONY: run +run: mlx-vlm + @echo "Running mlx-vlm..." + bash run.sh + @echo "mlx run." + +.PHONY: test +test: mlx-vlm + @echo "Testing mlx-vlm..." + bash test.sh + @echo "mlx tested." + +.PHONY: protogen-clean +protogen-clean: + $(RM) backend_pb2_grpc.py backend_pb2.py + +.PHONY: clean +clean: protogen-clean + rm -rf venv __pycache__ \ No newline at end of file diff --git a/backend/python/mlx-vlm/backend.py b/backend/python/mlx-vlm/backend.py new file mode 100644 index 000000000000..02730c814965 --- /dev/null +++ b/backend/python/mlx-vlm/backend.py @@ -0,0 +1,477 @@ +#!/usr/bin/env python3 +import asyncio +from concurrent import futures +import argparse +import signal +import sys +import os +from typing import List +import time + +import backend_pb2 +import backend_pb2_grpc + +import grpc +from mlx_vlm import load, generate, stream_generate +from mlx_vlm.prompt_utils import apply_chat_template +from mlx_vlm.utils import load_config, load_image +import mlx.core as mx +import base64 +import io +from PIL import Image +import tempfile + +_ONE_DAY_IN_SECONDS = 60 * 60 * 24 + +# If MAX_WORKERS are specified in the environment use it, otherwise default to 1 +MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) + +# Implement the BackendServicer class with the service methods +class BackendServicer(backend_pb2_grpc.BackendServicer): + """ + A gRPC servicer that implements the Backend service defined in backend.proto. + """ + + def _is_float(self, s): + """Check if a string can be converted to float.""" + try: + float(s) + return True + except ValueError: + return False + + def _is_int(self, s): + """Check if a string can be converted to int.""" + try: + int(s) + return True + except ValueError: + return False + + def Health(self, request, context): + """ + Returns a health check message. + + Args: + request: The health check request. + context: The gRPC context. + + Returns: + backend_pb2.Reply: The health check reply. + """ + return backend_pb2.Reply(message=bytes("OK", 'utf-8')) + + async def LoadModel(self, request, context): + """ + Loads a multimodal vision-language model using MLX-VLM. + + Args: + request: The load model request. + context: The gRPC context. + + Returns: + backend_pb2.Result: The load model result. + """ + try: + print(f"Loading MLX-VLM model: {request.Model}", file=sys.stderr) + print(f"Request: {request}", file=sys.stderr) + + # Parse options like in the diffusers backend + options = request.Options + self.options = {} + + # The options are a list of strings in this form optname:optvalue + # We store all the options in a dict for later use + for opt in options: + if ":" not in opt: + continue + key, value = opt.split(":", 1) # Split only on first colon to handle values with colons + + # Convert numeric values to appropriate types + if self._is_float(value): + value = float(value) + elif self._is_int(value): + value = int(value) + elif value.lower() in ["true", "false"]: + value = value.lower() == "true" + + self.options[key] = value + + print(f"Options: {self.options}", file=sys.stderr) + + # Load model and processor using MLX-VLM + # mlx-vlm load function returns (model, processor) instead of (model, tokenizer) + self.model, self.processor = load(request.Model) + + # Load model config for chat template support + self.config = load_config(request.Model) + + except Exception as err: + print(f"Error loading MLX-VLM model {err=}, {type(err)=}", file=sys.stderr) + return backend_pb2.Result(success=False, message=f"Error loading MLX-VLM model: {err}") + + print("MLX-VLM model loaded successfully", file=sys.stderr) + return backend_pb2.Result(message="MLX-VLM model loaded successfully", success=True) + + async def Predict(self, request, context): + """ + Generates text based on the given prompt and sampling parameters using MLX-VLM with multimodal support. + + Args: + request: The predict request. + context: The gRPC context. + + Returns: + backend_pb2.Reply: The predict result. + """ + temp_files = [] + try: + # Process images and audios from request + image_paths = [] + audio_paths = [] + + # Process images + if request.Images: + for img_data in request.Images: + img_path = self.load_image_from_base64(img_data) + if img_path: + image_paths.append(img_path) + temp_files.append(img_path) + + # Process audios + if request.Audios: + for audio_data in request.Audios: + audio_path = self.load_audio_from_base64(audio_data) + if audio_path: + audio_paths.append(audio_path) + temp_files.append(audio_path) + + # Prepare the prompt with multimodal information + prompt = self._prepare_prompt(request, num_images=len(image_paths), num_audios=len(audio_paths)) + + # Build generation parameters using request attributes and options + max_tokens, generation_params = self._build_generation_params(request) + + print(f"Generating text with MLX-VLM - max_tokens: {max_tokens}, params: {generation_params}", file=sys.stderr) + print(f"Images: {len(image_paths)}, Audios: {len(audio_paths)}", file=sys.stderr) + + # Generate text using MLX-VLM with multimodal inputs + response = generate( + model=self.model, + processor=self.processor, + prompt=prompt, + image=image_paths if image_paths else None, + audio=audio_paths if audio_paths else None, + max_tokens=max_tokens, + temperature=generation_params.get('temp', 0.6), + top_p=generation_params.get('top_p', 1.0), + verbose=False + ) + + return backend_pb2.Reply(message=bytes(response, encoding='utf-8')) + + except Exception as e: + print(f"Error in MLX-VLM Predict: {e}", file=sys.stderr) + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(f"Generation failed: {str(e)}") + return backend_pb2.Reply(message=bytes("", encoding='utf-8')) + finally: + # Clean up temporary files + self.cleanup_temp_files(temp_files) + + def Embedding(self, request, context): + """ + A gRPC method that calculates embeddings for a given sentence. + + Note: MLX-VLM doesn't support embeddings directly. This method returns an error. + + Args: + request: An EmbeddingRequest object that contains the request parameters. + context: A grpc.ServicerContext object that provides information about the RPC. + + Returns: + An EmbeddingResult object that contains the calculated embeddings. + """ + print("Embeddings not supported in MLX-VLM backend", file=sys.stderr) + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Embeddings are not supported in the MLX-VLM backend.") + return backend_pb2.EmbeddingResult() + + async def PredictStream(self, request, context): + """ + Generates text based on the given prompt and sampling parameters, and streams the results using MLX-VLM with multimodal support. + + Args: + request: The predict stream request. + context: The gRPC context. + + Yields: + backend_pb2.Reply: Streaming predict results. + """ + temp_files = [] + try: + # Process images and audios from request + image_paths = [] + audio_paths = [] + + # Process images + if request.Images: + for img_data in request.Images: + img_path = self.load_image_from_base64(img_data) + if img_path: + image_paths.append(img_path) + temp_files.append(img_path) + + # Process audios + if request.Audios: + for audio_data in request.Audios: + audio_path = self.load_audio_from_base64(audio_data) + if audio_path: + audio_paths.append(audio_path) + temp_files.append(audio_path) + + # Prepare the prompt with multimodal information + prompt = self._prepare_prompt(request, num_images=len(image_paths), num_audios=len(audio_paths)) + + # Build generation parameters using request attributes and options + max_tokens, generation_params = self._build_generation_params(request, default_max_tokens=512) + + print(f"Streaming text with MLX-VLM - max_tokens: {max_tokens}, params: {generation_params}", file=sys.stderr) + print(f"Images: {len(image_paths)}, Audios: {len(audio_paths)}", file=sys.stderr) + + # Stream text generation using MLX-VLM with multimodal inputs + for response in stream_generate( + model=self.model, + processor=self.processor, + prompt=prompt, + image=image_paths if image_paths else None, + audio=audio_paths if audio_paths else None, + max_tokens=max_tokens, + temperature=generation_params.get('temp', 0.6), + top_p=generation_params.get('top_p', 1.0), + ): + yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8')) + + except Exception as e: + print(f"Error in MLX-VLM PredictStream: {e}", file=sys.stderr) + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(f"Streaming generation failed: {str(e)}") + yield backend_pb2.Reply(message=bytes("", encoding='utf-8')) + finally: + # Clean up temporary files + self.cleanup_temp_files(temp_files) + + def _prepare_prompt(self, request, num_images=0, num_audios=0): + """ + Prepare the prompt for MLX-VLM generation, handling chat templates and multimodal inputs. + + Args: + request: The gRPC request containing prompt and message information. + num_images: Number of images in the request. + num_audios: Number of audio files in the request. + + Returns: + str: The prepared prompt. + """ + # If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template + if not request.Prompt and request.UseTokenizerTemplate and request.Messages: + # Convert gRPC messages to the format expected by apply_chat_template + messages = [] + for msg in request.Messages: + messages.append({"role": msg.role, "content": msg.content}) + + # Use mlx-vlm's apply_chat_template which handles multimodal inputs + prompt = apply_chat_template( + self.processor, + self.config, + messages, + num_images=num_images, + num_audios=num_audios + ) + return prompt + elif request.Prompt: + # If we have a direct prompt but also have images/audio, we need to format it properly + if num_images > 0 or num_audios > 0: + # Create a simple message structure for multimodal prompt + messages = [{"role": "user", "content": request.Prompt}] + prompt = apply_chat_template( + self.processor, + self.config, + messages, + num_images=num_images, + num_audios=num_audios + ) + return prompt + else: + return request.Prompt + else: + # Fallback to empty prompt with multimodal template if we have media + if num_images > 0 or num_audios > 0: + messages = [{"role": "user", "content": ""}] + prompt = apply_chat_template( + self.processor, + self.config, + messages, + num_images=num_images, + num_audios=num_audios + ) + return prompt + else: + return "" + + + + + + def _build_generation_params(self, request, default_max_tokens=200): + """ + Build generation parameters from request attributes and options for MLX-VLM. + + Args: + request: The gRPC request. + default_max_tokens: Default max_tokens if not specified. + + Returns: + tuple: (max_tokens, generation_params dict) + """ + # Extract max_tokens + max_tokens = getattr(request, 'Tokens', default_max_tokens) + if max_tokens == 0: + max_tokens = default_max_tokens + + # Extract generation parameters from request attributes + temp = getattr(request, 'Temperature', 0.0) + if temp == 0.0: + temp = 0.6 # Default temperature + + top_p = getattr(request, 'TopP', 0.0) + if top_p == 0.0: + top_p = 1.0 # Default top_p + + # Initialize generation parameters for MLX-VLM + generation_params = { + 'temp': temp, + 'top_p': top_p, + } + + # Add seed if specified + seed = getattr(request, 'Seed', 0) + if seed != 0: + mx.random.seed(seed) + + # Override with options if available + if hasattr(self, 'options'): + # Max tokens from options + if 'max_tokens' in self.options: + max_tokens = self.options['max_tokens'] + + # Generation parameters from options + param_option_mapping = { + 'temp': 'temp', + 'temperature': 'temp', # alias + 'top_p': 'top_p', + } + + for option_key, param_key in param_option_mapping.items(): + if option_key in self.options: + generation_params[param_key] = self.options[option_key] + + # Handle seed from options + if 'seed' in self.options: + mx.random.seed(self.options['seed']) + + return max_tokens, generation_params + + def load_image_from_base64(self, image_data: str): + """ + Load an image from base64 encoded data. + + Args: + image_data (str): Base64 encoded image data. + + Returns: + PIL.Image or str: The loaded image or path to the image. + """ + try: + decoded_data = base64.b64decode(image_data) + image = Image.open(io.BytesIO(decoded_data)) + + # Save to temporary file for mlx-vlm + with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file: + image.save(tmp_file.name, format='JPEG') + return tmp_file.name + + except Exception as e: + print(f"Error loading image from base64: {e}", file=sys.stderr) + return None + + def load_audio_from_base64(self, audio_data: str): + """ + Load audio from base64 encoded data. + + Args: + audio_data (str): Base64 encoded audio data. + + Returns: + str: Path to the loaded audio file. + """ + try: + decoded_data = base64.b64decode(audio_data) + + # Save to temporary file for mlx-vlm + with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file: + tmp_file.write(decoded_data) + return tmp_file.name + + except Exception as e: + print(f"Error loading audio from base64: {e}", file=sys.stderr) + return None + + def cleanup_temp_files(self, file_paths: List[str]): + """ + Clean up temporary files. + + Args: + file_paths (List[str]): List of file paths to clean up. + """ + for file_path in file_paths: + try: + if file_path and os.path.exists(file_path): + os.remove(file_path) + except Exception as e: + print(f"Error removing temporary file {file_path}: {e}", file=sys.stderr) + +async def serve(address): + # Start asyncio gRPC server + server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), + options=[ + ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB + ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB + ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB + ]) + # Add the servicer to the server + backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) + # Bind the server to the address + server.add_insecure_port(address) + + # Gracefully shutdown the server on SIGTERM or SIGINT + loop = asyncio.get_event_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler( + sig, lambda: asyncio.ensure_future(server.stop(5)) + ) + + # Start the server + await server.start() + print("Server started. Listening on: " + address, file=sys.stderr) + # Wait for the server to be terminated + await server.wait_for_termination() + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run the gRPC server.") + parser.add_argument( + "--addr", default="localhost:50051", help="The address to bind the server to." + ) + args = parser.parse_args() + + asyncio.run(serve(args.addr)) diff --git a/backend/python/mlx-vlm/install.sh b/backend/python/mlx-vlm/install.sh new file mode 100755 index 000000000000..b8ee48552490 --- /dev/null +++ b/backend/python/mlx-vlm/install.sh @@ -0,0 +1,14 @@ +#!/bin/bash +set -e + +USE_PIP=true + +backend_dir=$(dirname $0) + +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +installRequirements diff --git a/backend/python/mlx-vlm/requirements.txt b/backend/python/mlx-vlm/requirements.txt new file mode 100644 index 000000000000..f1771cc4adb4 --- /dev/null +++ b/backend/python/mlx-vlm/requirements.txt @@ -0,0 +1,4 @@ +grpcio==1.71.0 +protobuf +certifi +setuptools \ No newline at end of file diff --git a/backend/python/mlx-vlm/run.sh b/backend/python/mlx-vlm/run.sh new file mode 100755 index 000000000000..fc88f97da712 --- /dev/null +++ b/backend/python/mlx-vlm/run.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +backend_dir=$(dirname $0) + +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +startBackend $@ \ No newline at end of file diff --git a/backend/python/mlx-vlm/test.py b/backend/python/mlx-vlm/test.py new file mode 100644 index 000000000000..827aa71a3e33 --- /dev/null +++ b/backend/python/mlx-vlm/test.py @@ -0,0 +1,146 @@ +import unittest +import subprocess +import time +import backend_pb2 +import backend_pb2_grpc + +import grpc + +import unittest +import subprocess +import time +import grpc +import backend_pb2_grpc +import backend_pb2 + +class TestBackendServicer(unittest.TestCase): + """ + TestBackendServicer is the class that tests the gRPC service. + + This class contains methods to test the startup and shutdown of the gRPC service. + """ + def setUp(self): + self.service = subprocess.Popen(["python", "backend.py", "--addr", "localhost:50051"]) + time.sleep(10) + + def tearDown(self) -> None: + self.service.terminate() + self.service.wait() + + def test_server_startup(self): + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.Health(backend_pb2.HealthMessage()) + self.assertEqual(response.message, b'OK') + except Exception as err: + print(err) + self.fail("Server failed to start") + finally: + self.tearDown() + def test_load_model(self): + """ + This method tests if the model is loaded successfully + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) + self.assertTrue(response.success) + self.assertEqual(response.message, "Model loaded successfully") + except Exception as err: + print(err) + self.fail("LoadModel service failed") + finally: + self.tearDown() + + def test_text(self): + """ + This method tests if the embeddings are generated successfully + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) + self.assertTrue(response.success) + req = backend_pb2.PredictOptions(Prompt="The capital of France is") + resp = stub.Predict(req) + self.assertIsNotNone(resp.message) + except Exception as err: + print(err) + self.fail("text service failed") + finally: + self.tearDown() + + def test_sampling_params(self): + """ + This method tests if all sampling parameters are correctly processed + NOTE: this does NOT test for correctness, just that we received a compatible response + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) + self.assertTrue(response.success) + + req = backend_pb2.PredictOptions( + Prompt="The capital of France is", + TopP=0.8, + Tokens=50, + Temperature=0.7, + TopK=40, + PresencePenalty=0.1, + FrequencyPenalty=0.2, + RepetitionPenalty=1.1, + MinP=0.05, + Seed=42, + StopPrompts=["\n"], + StopTokenIds=[50256], + BadWords=["badword"], + IncludeStopStrInOutput=True, + IgnoreEOS=True, + MinTokens=5, + Logprobs=5, + PromptLogprobs=5, + SkipSpecialTokens=True, + SpacesBetweenSpecialTokens=True, + TruncatePromptTokens=10, + GuidedDecoding=True, + N=2, + ) + resp = stub.Predict(req) + self.assertIsNotNone(resp.message) + self.assertIsNotNone(resp.logprobs) + except Exception as err: + print(err) + self.fail("sampling params service failed") + finally: + self.tearDown() + + + def test_embedding(self): + """ + This method tests if the embeddings are generated successfully + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="intfloat/e5-mistral-7b-instruct")) + self.assertTrue(response.success) + embedding_request = backend_pb2.PredictOptions(Embeddings="This is a test sentence.") + embedding_response = stub.Embedding(embedding_request) + self.assertIsNotNone(embedding_response.embeddings) + # assert that is a list of floats + self.assertIsInstance(embedding_response.embeddings, list) + # assert that the list is not empty + self.assertTrue(len(embedding_response.embeddings) > 0) + except Exception as err: + print(err) + self.fail("Embedding service failed") + finally: + self.tearDown() \ No newline at end of file diff --git a/backend/python/mlx-vlm/test.sh b/backend/python/mlx-vlm/test.sh new file mode 100755 index 000000000000..f31ae54e47dc --- /dev/null +++ b/backend/python/mlx-vlm/test.sh @@ -0,0 +1,12 @@ +#!/bin/bash +set -e + +backend_dir=$(dirname $0) + +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +runUnittests From 271a179a7ae28dce9fb4f2f2058599a2bec28ee8 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 23 Aug 2025 22:40:05 +0200 Subject: [PATCH 2/4] Add to CI workflows Signed-off-by: Ettore Di Giacinto --- .github/workflows/backend.yml | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/.github/workflows/backend.yml b/.github/workflows/backend.yml index cae602f21ff1..4293f460561a 100644 --- a/.github/workflows/backend.yml +++ b/.github/workflows/backend.yml @@ -972,6 +972,19 @@ jobs: dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }} quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }} quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }} + mlx-vlm-darwin: + uses: ./.github/workflows/backend_build_darwin.yml + with: + backend: "mlx-vlm" + build-type: "mps" + go-version: "1.24.x" + tag-suffix: "-metal-darwin-arm64-mlx-vlm" + runs-on: "macOS-14" + secrets: + dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }} + dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }} + quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }} + quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }} llama-cpp-darwin: runs-on: macOS-14 strategy: From ab215c345f9f482b8eaeef5c2616fd5c10e0fa06 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 23 Aug 2025 22:41:26 +0200 Subject: [PATCH 3/4] Add requirements-mps.txt Signed-off-by: Ettore Di Giacinto --- backend/python/mlx-vlm/requirements-mps.txt | 1 + 1 file changed, 1 insertion(+) create mode 100644 backend/python/mlx-vlm/requirements-mps.txt diff --git a/backend/python/mlx-vlm/requirements-mps.txt b/backend/python/mlx-vlm/requirements-mps.txt new file mode 100644 index 000000000000..8737f6091c70 --- /dev/null +++ b/backend/python/mlx-vlm/requirements-mps.txt @@ -0,0 +1 @@ +git+https://github.com/Blaizzy/mlx-vlm \ No newline at end of file From dd3cee30b1e5dc6a022825d6bbb7b89f457c37a8 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 23 Aug 2025 22:43:53 +0200 Subject: [PATCH 4/4] Simplify Signed-off-by: Ettore Di Giacinto --- Makefile | 8 ++++---- scripts/build/python-darwin.sh | 2 ++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/Makefile b/Makefile index 80738e78e5c1..48bf7c9e12a9 100644 --- a/Makefile +++ b/Makefile @@ -366,15 +366,15 @@ build-darwin-python-backend: build bash ./scripts/build/python-darwin.sh backends/mlx: - BACKEND=mlx BUILD_TYPE=mps $(MAKE) build-darwin-python-backend + BACKEND=mlx $(MAKE) build-darwin-python-backend ./local-ai backends install "ocifile://$(abspath ./backend-images/mlx.tar)" backends/diffuser-darwin: - USE_PIP=true BACKEND=diffusers BUILD_TYPE=mps $(MAKE) build-darwin-python-backend + BACKEND=diffusers $(MAKE) build-darwin-python-backend ./local-ai backends install "ocifile://$(abspath ./backend-images/diffusers.tar)" -backends/mlx-vlm: build - BACKEND=mlx-vlm BUILD_TYPE=mps bash ./scripts/build/python-darwin.sh +backends/mlx-vlm: + BACKEND=mlx-vlm $(MAKE) build-darwin-python-backend ./local-ai backends install "ocifile://$(abspath ./backend-images/mlx-vlm.tar)" backend-images: diff --git a/scripts/build/python-darwin.sh b/scripts/build/python-darwin.sh index 4b0a373e0e8a..513de2ea5f4d 100644 --- a/scripts/build/python-darwin.sh +++ b/scripts/build/python-darwin.sh @@ -3,6 +3,8 @@ set -ex export PORTABLE_PYTHON=true +export BUILD_TYPE=mps +export USE_PIP=true IMAGE_NAME="${IMAGE_NAME:-localai/llama-cpp-darwin}" mkdir -p backend-images make -C backend/python/${BACKEND}