Explore the power of machine learning within apps. Discuss integrating machine learning features, share best practices, and explore the possibilities for your app.

Posts under General subtopic

Post

Replies

Boosts

Views

Activity

RDMA API Documentation
With the release of the newest version of tahoe and MLX supporting RDMA. Is there a documentation link to how to utilizes the libdrma dylib as well as what functions are available? I am currently assuming it mostly follows the standard linux infiniband library but I would like the apple specific details.
0
0
166
1w
recent JAX versions fail on Metal
Hi, I'm not sure whether this is the appropriate forum for this topic. I just followed a link from the JAX Metal plugin page https://developer.apple.com/metal/jax/ I'm writing a Python app with JAX, and recent JAX versions fail on Metal. E.g. v0.8.2 I have to downgrade JAX pretty hard to make it work: pip install jax==0.4.35 jaxlib==0.4.35 jax-metal==0.1.1 Can we get an updated release of jax-metal that would fix this issue? Here is the error I get with JAX v0.8.2: WARNING:2025-12-26 09:55:28,117:jax._src.xla_bridge:881: Platform 'METAL' is experimental and not all JAX functionality may be correctly supported! WARNING: All log messages before absl::InitializeLog() is called are written to STDERR W0000 00:00:1766771728.118004 207582 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported! Metal device set to: Apple M3 Max systemMemory: 36.00 GB maxCacheSize: 13.50 GB I0000 00:00:1766771728.129886 207582 service.cc:145] XLA service 0x600001fad300 initialized for platform METAL (this does not guarantee that XLA will be used). Devices: I0000 00:00:1766771728.129893 207582 service.cc:153] StreamExecutor device (0): Metal, <undefined> I0000 00:00:1766771728.130856 207582 mps_client.cc:406] Using Simple allocator. I0000 00:00:1766771728.130864 207582 mps_client.cc:384] XLA backend will use up to 28990554112 bytes on device 0 for SimpleAllocator. Traceback (most recent call last): File "<string>", line 1, in <module> import jax; print(jax.numpy.arange(10)) ~~~~~~~~~~~~~~~~^^^^ File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/numpy/lax_numpy.py", line 5951, in arange return _arange(start, stop=stop, step=step, dtype=dtype, out_sharding=sharding) File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/numpy/lax_numpy.py", line 6012, in _arange return lax.broadcasted_iota(dtype, (size,), 0, out_sharding=out_sharding) ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/lax/lax.py", line 3415, in broadcasted_iota return iota_p.bind(dtype=dtype, shape=shape, ~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^ dimension=dimension, sharding=out_sharding) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/core.py", line 633, in bind return self._true_bind(*args, **params) ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/core.py", line 649, in _true_bind return self.bind_with_trace(prev_trace, args, params) ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/core.py", line 661, in bind_with_trace return trace.process_primitive(self, args, params) ~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^ File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/core.py", line 1210, in process_primitive return primitive.impl(*args, **params) ~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/Users/florin/git/FlorinAndrei/star-cluster-simulator/.venv/lib/python3.13/site-packages/jax/_src/dispatch.py", line 91, in apply_primitive outs = fun(*args) jax.errors.JaxRuntimeError: UNKNOWN: -:0:0: error: unknown attribute code: 22 -:0:0: note: in bytecode version 6 produced by: StableHLO_v1.13.0 -------------------- For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. I0000 00:00:1766771728.149951 207582 mps_client.h:209] MetalClient destroyed.
0
0
274
2d