diff --git a/.gitignore b/.gitignore index d12c881..b1eb4f3 100644 --- a/.gitignore +++ b/.gitignore @@ -7,7 +7,9 @@ *.tgz yarn.lock +bun.lock package-lock.json +.cache/ npm-debug.log yarn-error.log /node_modules/ diff --git a/.gitmodules b/.gitmodules index 40a4a0f..d3fbe2b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ [submodule "deps/mlx"] path = deps/mlx - url = https://github.com/ml-explore/mlx + url = https://github.com/robert-johansson/mlx [submodule "deps/kizunapi"] path = deps/kizunapi - url = https://github.com/photoionization/kizunapi + url = https://github.com/robert-johansson/kizunapi diff --git a/CMakeLists.txt b/CMakeLists.txt index 2cb0999..c1b0b69 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,5 +38,6 @@ target_include_directories(${PROJECT_NAME} PRIVATE "deps/kizunapi") option(MLX_BUILD_TESTS "Build tests for mlx" OFF) option(MLX_BUILD_EXAMPLES "Build examples for mlx" OFF) option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" ON) +option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" ON) add_subdirectory(deps/mlx) target_link_libraries(${PROJECT_NAME} mlx) diff --git a/deps/kizunapi b/deps/kizunapi index b8d0622..05b7857 160000 --- a/deps/kizunapi +++ b/deps/kizunapi @@ -1 +1 @@ -Subproject commit b8d06226897a0cfe42a6efab39c413efd35b2276 +Subproject commit 05b7857c9166af966aa8da061dee993d9cc8c6f6 diff --git a/deps/mlx b/deps/mlx index b529515..e793c9a 160000 --- a/deps/mlx +++ b/deps/mlx @@ -1 +1 @@ -Subproject commit b529515eb158edd0919746ce4e545fe0879d6437 +Subproject commit e793c9ac6a4c3e6ae7ba717ffd19e3121b79519b diff --git a/node_mlx.node.d.ts b/node_mlx.node.d.ts index 2000c8a..3f859f9 100644 --- a/node_mlx.node.d.ts +++ b/node_mlx.node.d.ts @@ -195,6 +195,7 @@ declare module '*node_mlx.node' { function argmin(array: ScalarOrArray, axis?: number, keepdims?: boolean, s?: StreamOrDevice): array; function argpartition(array: ScalarOrArray, kth: number, axis?: number, s?: StreamOrDevice): array; function argsort(array: ScalarOrArray, s?: StreamOrDevice): array; + function searchsorted(a: ScalarOrArray, v: ScalarOrArray, right?: boolean, s?: StreamOrDevice): array; function arrayEqual(a: ScalarOrArray, b: ScalarOrArray, equalNan?: boolean, s?: StreamOrDevice): array; function asStrided(array: ScalarOrArray, shape?: number[], strides?: number[], offset?: number, s?: StreamOrDevice): array; function atleast1d(...arrays: array[]): array; @@ -240,6 +241,10 @@ declare module '*node_mlx.node' { function notEqual(a: ScalarOrArray, b: ScalarOrArray, s?: StreamOrDevice): array; function erf(array: ScalarOrArray, s?: StreamOrDevice): array; function erfinv(array: ScalarOrArray, s?: StreamOrDevice): array; + function lgamma(array: ScalarOrArray, s?: StreamOrDevice): array; + function digamma(array: ScalarOrArray, s?: StreamOrDevice): array; + function besselI0e(array: ScalarOrArray, s?: StreamOrDevice): array; + function besselI1e(array: ScalarOrArray, s?: StreamOrDevice): array; function exp(array: ScalarOrArray, s?: StreamOrDevice): array; function expm1(array: ScalarOrArray, s?: StreamOrDevice): array; function expandDims(array: ScalarOrArray, dims: number | number[], s?: StreamOrDevice): array; @@ -378,6 +383,7 @@ declare module '*node_mlx.node' { function tidy(func: () => U): U; function dispose(...args: unknown[]): void; function getWrappersCount(): number; + function sweepDeadArrays(): number; // Metal. namespace metal { diff --git a/src/array.cc b/src/array.cc index 8edbb55..f9485fd 100644 --- a/src/array.cc +++ b/src/array.cc @@ -349,6 +349,7 @@ napi_value Item(mx::array* a, napi_env env) { return nullptr; } a->eval(); + a->detach(); return VisitArrayData([env](auto* data) { return ki::ToNodeValue(env, *data); }, a); @@ -383,6 +384,7 @@ napi_value ToList(mx::array* a, napi_env env) { if (a->ndim() == 0) return Item(a, env); a->eval(); + a->detach(); return VisitArrayData([env, a](auto* data) { return MxArrayToJsArray(env, *a, data); }, a); @@ -409,6 +411,7 @@ napi_value ToTypedArray(mx::array* a, napi_env env) { return nullptr; } a->eval(); + a->detach(); // Create a ArrayBuffer that stores a reference to array's data. using DataType = std::shared_ptr; napi_value buffer; @@ -462,38 +465,58 @@ std::stack> g_tidy_arrays; // Release all array pointers allocated during the call. napi_value Tidy(napi_env env, std::function func) { - // Push a new set to stack. + // Push a new set to stack. TypeBridge::Wrap inserts arrays here during func(). g_tidy_arrays.push(std::set()); - auto& top = g_tidy_arrays.top(); + // Shared flag: tracks whether cpp_then already popped the stack. + // Prevents double-pop in nested tidy (inner finally must not pop outer set). + auto popped = std::make_shared(false); return AwaitFunction( env, std::move(func), - [&top](napi_env env, napi_value result) { - // Exclude the arrays in result from the stack. + [popped](napi_env env, napi_value result) { + // Move the set out of the stack so it's safe from concurrent modification. + auto top = std::move(g_tidy_arrays.top()); + g_tidy_arrays.pop(); + *popped = true; + // Exclude the arrays in result from the set. TreeVisit(env, result, [&top](napi_env env, napi_value value) { if (auto a = ki::FromNodeTo(env, value); a) top.erase(*a); return napi_value(); }); - // Clear the arrays in the stack. + // Clear the arrays in the set. ki::InstanceData* instance_data = ki::InstanceData::Get(env); for (mx::array* a : top) { - // The arary might be in 3 states: + // The array might be in 3 states: // 1. Its JS object is well alive. - // 2. The JS object has been fully GCed. + // 2. The JS object has been fully GCed (finalizer ran, ptr freed). // 3. The JS object is marked as dead, but the finalizer has not run. - // We have to unbind the JS object in 1, and only delete array in 1 - // and 3. + // We must check wrapper validity BEFORE dereferencing the pointer, + // because in state 2 the pointer is dangling (already deleted by + // TypeBridge::Finalize during GC). napi_value value; - if (instance_data->GetWrapper(a, &value)) + bool has_wrapper = instance_data->GetWrapper(a, &value); + if (has_wrapper) { + // State 1: JS object alive — unbind it napi_remove_wrap(env, value, nullptr); - if (instance_data->DeleteWrapper(a)) + } + // Try to claim ownership (returns true for states 1 and 3) + if (instance_data->DeleteWrapper(a)) { + // Safe to dereference: pointer is still valid (not yet finalized) + int64_t ext = ki::internal::ExternalMemorySize::Get(a); + if (ext > 0) { + int64_t adjusted; + napi_adjust_external_memory(env, -ext, &adjusted); + } delete a; + } + // State 2: fully GC'd — skip, pointer is dangling } return result; }, - [](napi_env env) { - // Always pop even when error happened. - g_tidy_arrays.pop(); + [popped](napi_env env) { + // Only pop if cpp_then didn't already handle it. + if (!*popped) + g_tidy_arrays.pop(); }); } @@ -504,8 +527,13 @@ void Dispose(const ki::Arguments& args) { TreeVisit(args.Env(), args[i], [instance_data](napi_env env, napi_value value) { if (auto a = ki::FromNodeTo(env, value); a) { + int64_t ext = ki::internal::ExternalMemorySize::Get(a.value()); napi_remove_wrap(env, value, nullptr); instance_data->DeleteWrapper(a.value()); + if (ext > 0) { + int64_t adjusted; + napi_adjust_external_memory(env, -ext, &adjusted); + } delete a.value(); } return napi_value(); @@ -520,6 +548,69 @@ size_t GetWrappersCount(napi_env env) { } // namespace +// Pending set for double-check sweep. Arrays whose weak references reported +// null on the first scan are held here and re-checked on the next sweep. +// This works around a Bun/JSC bug where napi_get_reference_value temporarily +// returns null for weak references to objects that are still alive. +#include +static std::unordered_set g_pending_sweep; + +// Synchronously sweep dead array wrappers using double-check. +// +// Bun's JSC N-API implementation has a bug where weak references temporarily +// report null (napi_get_reference_value returns nullptr) for objects that are +// still reachable from JS. A single-check sweep that deletes immediately on +// null would cause use-after-free when SCI later accesses those arrays. +// +// The double-check protocol: +// Sweep N: scan finds null → add to pending set (don't delete) +// Sweep N+1: re-check pending → still null = confirmed dead → delete +// alive again = false positive → keep +// +// This adds one sweep cycle of delay before deletion (~50 ops). The pending +// set is small relative to the wrappers map and the scan is O(pending). +size_t SweepDeadArrays(napi_env env) { + ki::InstanceData* instance_data = ki::InstanceData::Get(env); + + // Phase 1: Re-check pending items from previous sweep + size_t deleted = 0; + for (auto it = g_pending_sweep.begin(); it != g_pending_sweep.end(); ) { + void* ptr = *it; + napi_value value; + bool in_map = instance_data->GetWrapper(ptr, &value); + if (!in_map) { + // Finalizer already ran between sweeps + it = g_pending_sweep.erase(it); + continue; + } + if (value == nullptr) { + // Still dead on second check — confirmed, delete + if (instance_data->DeleteWrapper(ptr)) { + mx::array* a = static_cast(ptr); + int64_t ext = ki::internal::ExternalMemorySize::Get(a); + if (ext > 0) { + int64_t adjusted; + napi_adjust_external_memory(env, -ext, &adjusted); + } + delete a; + deleted++; + } + it = g_pending_sweep.erase(it); + } else { + // Resurrected — was null, now alive. Bun/JSC false positive. + it = g_pending_sweep.erase(it); + } + } + + // Phase 2: Scan for newly dead wrappers — mark as pending + auto newly_dead = instance_data->ScanDeadWrappers(); + for (void* ptr : newly_dead) { + g_pending_sweep.insert(ptr); + } + + return deleted; +} + namespace ki { // Allow passing Dtype to JS directly, no memory management involved as they are @@ -808,5 +899,6 @@ void InitArray(napi_env env, napi_value exports) { ki::Set(env, exports, "tidy", &Tidy, "dispose", &Dispose, - "getWrappersCount", &GetWrappersCount); + "getWrappersCount", &GetWrappersCount, + "sweepDeadArrays", &SweepDeadArrays); } diff --git a/src/array.h b/src/array.h index 901e2a5..9c3346e 100644 --- a/src/array.h +++ b/src/array.h @@ -65,6 +65,32 @@ struct Type : public AllowPassByValue { napi_value value); }; +namespace internal { + +// Report external memory for mx::array to enable GC pressure signaling. +// MLX arrays hold Metal GPU buffers that are invisible to the JS GC. +// Without this, the GC doesn't know about GPU memory pressure and doesn't +// collect array wrappers fast enough, causing Metal resource exhaustion. +template<> +struct ExternalMemorySize { + static int64_t Get(mx::array* a) { + // Metal has a hard limit of 499K buffer allocations. We must create + // enough external memory pressure to force the GC to collect array + // wrappers before hitting it. Report 1MB per array as the minimum + // external cost — this is much larger than the actual data size but + // necessary to trigger sufficiently aggressive GC for GPU resources. + size_t n = a->nbytes(); + constexpr int64_t min_cost = 1024 * 1024; // 1MB + return static_cast(n) > min_cost ? static_cast(n) : min_cost; + } +}; + +} // namespace internal + } // namespace ki +// Synchronously sweep dead array wrappers whose JS finalizers haven't run yet. +// Called automatically from Eval() to prevent Metal resource accumulation. +size_t SweepDeadArrays(napi_env env); + #endif // SRC_ARRAY_H_ diff --git a/src/fast.cc b/src/fast.cc index 6cfe1e0..24aa539 100644 --- a/src/fast.cc +++ b/src/fast.cc @@ -42,16 +42,16 @@ mx::array ScaledDotProductAttention( throw std::invalid_argument(msg.str()); } return mx::fast::scaled_dot_product_attention( - queries, keys, values, scale, mask_str, {}, s); + queries, keys, values, scale, mask_str, {}, {}, s); } else { auto mask_arr = std::get(mask); return mx::fast::scaled_dot_product_attention( - queries, keys, values, scale, "", {mask_arr}, s); + queries, keys, values, scale, "", {mask_arr}, {}, s); } } else { return mx::fast::scaled_dot_product_attention( - queries, keys, values, scale, "", {}, s); + queries, keys, values, scale, "", {}, {}, s); } } diff --git a/src/fft.cc b/src/fft.cc index 808889a..1255986 100644 --- a/src/fft.cc +++ b/src/fft.cc @@ -1,27 +1,33 @@ #include "src/array.h" #include "src/stream.h" +// Type aliases for the three FFT function signatures (with FFTNorm parameter). +using FFTNFunc1 = mx::array(*)(const mx::array&, const mx::Shape&, + const std::vector&, + mx::fft::FFTNorm, mx::StreamOrDevice); +using FFTNFunc2 = mx::array(*)(const mx::array&, const std::vector&, + mx::fft::FFTNorm, mx::StreamOrDevice); +using FFTNFunc3 = mx::array(*)(const mx::array&, + mx::fft::FFTNorm, mx::StreamOrDevice); + // A template converter for ops that accept |n| and |axis|. inline std::function n, std::optional axis, mx::StreamOrDevice s)> -FFTOpWrapper(mx::array(*func1)(const mx::array&, - int, - int, - mx::StreamOrDevice), - mx::array(*func2)(const mx::array&, - int, - mx::StreamOrDevice)) { +FFTOpWrapper(mx::array(*func1)(const mx::array&, int, int, + mx::fft::FFTNorm, mx::StreamOrDevice), + mx::array(*func2)(const mx::array&, int, + mx::fft::FFTNorm, mx::StreamOrDevice)) { return [func1, func2](const mx::array& a, std::optional n, std::optional axis, mx::StreamOrDevice s) { if (n) - return func1(a, *n, axis.value_or(-1), s); + return func1(a, *n, axis.value_or(-1), mx::fft::FFTNorm::Backward, s); else - return func2(a, axis.value_or(-1), s); + return func2(a, axis.value_or(-1), mx::fft::FFTNorm::Backward, s); }; } @@ -31,30 +37,23 @@ std::function> axes, mx::StreamOrDevice s)> FFTNOpWrapper(const char* name, - mx::array(*func1)(const mx::array&, - const std::vector&, - const std::vector&, - mx::StreamOrDevice), - mx::array(*func2)(const mx::array&, - const std::vector&, - mx::StreamOrDevice), - mx::array(*func3)(const mx::array&, - mx::StreamOrDevice)) { + FFTNFunc1 func1, FFTNFunc2 func2, FFTNFunc3 func3) { return [name, func1, func2, func3](const mx::array& a, std::optional> n, std::optional> axes, mx::StreamOrDevice s) { if (n && axes) { - return mx::fft::fftn(a, std::move(*n), std::move(*axes), s); + mx::Shape shape_n(n->begin(), n->end()); + return func1(a, shape_n, std::move(*axes), mx::fft::FFTNorm::Backward, s); } else if (axes) { - return mx::fft::fftn(a, std::move(*axes), s); + return func2(a, std::move(*axes), mx::fft::FFTNorm::Backward, s); } else if (n) { std::ostringstream msg; msg << "[" << name << "] " << "`axes` should not be `None` if `s` is not `None`."; throw std::invalid_argument(msg.str()); } else { - return mx::fft::fftn(a, s); + return func3(a, mx::fft::FFTNorm::Backward, s); } }; } @@ -65,15 +64,7 @@ std::function> axes, mx::StreamOrDevice s)> FFT2OpWrapper(const char* name, - mx::array(*func1)(const mx::array&, - const std::vector&, - const std::vector&, - mx::StreamOrDevice), - mx::array(*func2)(const mx::array&, - const std::vector&, - mx::StreamOrDevice), - mx::array(*func3)(const mx::array&, - mx::StreamOrDevice)) { + FFTNFunc1 func1, FFTNFunc2 func2, FFTNFunc3 func3) { return [name, func1, func2, func3](const mx::array& a, std::optional> n, std::optional> axes, @@ -88,42 +79,40 @@ void InitFFT(napi_env env, napi_value exports) { ki::Set(env, exports, "fft", fft); ki::Set(env, fft, - "fft", FFTOpWrapper(&mx::fft::fft, - &mx::fft::fft), - "ifft", FFTOpWrapper(&mx::fft::ifft, - &mx::fft::ifft), + "fft", FFTOpWrapper(&mx::fft::fft, &mx::fft::fft), + "ifft", FFTOpWrapper(&mx::fft::ifft, &mx::fft::ifft), "fft2", FFT2OpWrapper("fft2", - &mx::fft::fftn, - &mx::fft::fftn, - &mx::fft::fftn), + static_cast(&mx::fft::fftn), + static_cast(&mx::fft::fftn), + static_cast(&mx::fft::fftn)), "ifft2", FFT2OpWrapper("ifft2", - &mx::fft::ifftn, - &mx::fft::ifftn, - &mx::fft::ifftn), + static_cast(&mx::fft::ifftn), + static_cast(&mx::fft::ifftn), + static_cast(&mx::fft::ifftn)), "fftn", FFTNOpWrapper("fftn", - &mx::fft::fftn, - &mx::fft::fftn, - &mx::fft::fftn), + static_cast(&mx::fft::fftn), + static_cast(&mx::fft::fftn), + static_cast(&mx::fft::fftn)), "ifftn", FFTNOpWrapper("ifftn", - &mx::fft::ifftn, - &mx::fft::ifftn, - &mx::fft::ifftn), + static_cast(&mx::fft::ifftn), + static_cast(&mx::fft::ifftn), + static_cast(&mx::fft::ifftn)), "rfft", FFTOpWrapper(&mx::fft::rfft, &mx::fft::rfft), "irfft", FFTOpWrapper(&mx::fft::irfft, &mx::fft::irfft), "rfft2", FFT2OpWrapper("rfft2", - &mx::fft::rfftn, - &mx::fft::rfftn, - &mx::fft::rfftn), + static_cast(&mx::fft::rfftn), + static_cast(&mx::fft::rfftn), + static_cast(&mx::fft::rfftn)), "irfft2", FFT2OpWrapper("irfft2", - &mx::fft::irfftn, - &mx::fft::irfftn, - &mx::fft::irfftn), + static_cast(&mx::fft::irfftn), + static_cast(&mx::fft::irfftn), + static_cast(&mx::fft::irfftn)), "rfftn", FFTNOpWrapper("rfftn", - &mx::fft::rfftn, - &mx::fft::rfftn, - &mx::fft::rfftn), + static_cast(&mx::fft::rfftn), + static_cast(&mx::fft::rfftn), + static_cast(&mx::fft::rfftn)), "irfftn", FFTNOpWrapper("irfftn", - &mx::fft::irfftn, - &mx::fft::irfftn, - &mx::fft::irfftn)); + static_cast(&mx::fft::irfftn), + static_cast(&mx::fft::irfftn), + static_cast(&mx::fft::irfftn))); } diff --git a/src/indexing.cc b/src/indexing.cc index e5d7286..b61dc12 100644 --- a/src/indexing.cc +++ b/src/indexing.cc @@ -563,7 +563,7 @@ ScatterResult ScatterArgsNDimentional(const mx::array* a, a->shape().begin() + non_none_indices, a->shape().end()); up = mx::reshape(std::move(up), std::move(up_reshape)); - mx::Shape axes(arr_indices.size(), 0); + std::vector axes(arr_indices.size(), 0); std::iota(axes.begin(), axes.end(), 0); return {std::move(arr_indices), std::move(up), std::move(axes)}; } diff --git a/src/memory.cc b/src/memory.cc index 64f789f..af641b6 100644 --- a/src/memory.cc +++ b/src/memory.cc @@ -9,5 +9,7 @@ void InitMemory(napi_env env, napi_value exports) { "setMemoryLimit", &mx::set_memory_limit, "setWiredLimit", &mx::set_wired_limit, "setCacheLimit", &mx::set_cache_limit, - "clearCache", &mx::clear_cache); + "clearCache", &mx::clear_cache, + "getNumResources", &mx::get_num_resources, + "getResourceLimit", &mx::get_resource_limit); } diff --git a/src/metal.cc b/src/metal.cc index 91f79ae..777f1b4 100644 --- a/src/metal.cc +++ b/src/metal.cc @@ -1,4 +1,14 @@ #include "src/bindings.h" +#include "mlx/backend/gpu/device_info.h" + +namespace metal_ops { + +const std::unordered_map>& +DeviceInfo() { + return mx::gpu::device_info(0); +} + +} // namespace metal_ops void InitMetal(napi_env env, napi_value exports) { napi_value metal = ki::CreateObject(env); @@ -8,5 +18,5 @@ void InitMetal(napi_env env, napi_value exports) { "isAvailable", &mx::metal::is_available, "startCapture", &mx::metal::start_capture, "stopCapture", &mx::metal::stop_capture, - "deviceInfo", &mx::metal::device_info); + "deviceInfo", &metal_ops::DeviceInfo); } diff --git a/src/ops.cc b/src/ops.cc index aad9d05..bdbd461 100644 --- a/src/ops.cc +++ b/src/ops.cc @@ -191,7 +191,7 @@ mx::array Full(std::variant shape, ScalarOrArray vals, std::optional dtype, mx::StreamOrDevice s) { - return mx::full(PutIntoVector(std::move(shape)), + return mx::full(PutIntoShape(std::move(shape)), ToArray(std::move(vals), std::move(dtype)), s); } @@ -199,13 +199,13 @@ mx::array Full(std::variant shape, mx::array Zeros(std::variant shape, std::optional dtype, mx::StreamOrDevice s) { - return mx::zeros(PutIntoVector(std::move(shape)), dtype.value_or(mx::float32), s); + return mx::zeros(PutIntoShape(std::move(shape)), dtype.value_or(mx::float32), s); } mx::array Ones(std::variant shape, std::optional dtype, mx::StreamOrDevice s) { - return mx::ones(PutIntoVector(std::move(shape)), dtype.value_or(mx::float32), s); + return mx::ones(PutIntoShape(std::move(shape)), dtype.value_or(mx::float32), s); } mx::array Eye(int n, @@ -303,8 +303,9 @@ std::vector Split(const mx::array& a, if (auto i = std::get_if(&indices); i) { return mx::split(a, *i, axis.value_or(0), s); } else { - return mx::split(a, std::move(std::get>(indices)), - axis.value_or(0), s); + auto& v = std::get>(indices); + mx::Shape shape_indices(v.begin(), v.end()); + return mx::split(a, std::move(shape_indices), axis.value_or(0), s); } } @@ -346,6 +347,13 @@ mx::array ArgSort(const mx::array& a, return mx::argsort(a, s); } +mx::array SearchSorted(const mx::array& a, + const mx::array& v, + std::optional right, + mx::StreamOrDevice s) { + return mx::searchsorted(a, v, right.value_or(false), s); +} + mx::array Softmax(const mx::array& a, OptionalAxes axis, std::optional precise, @@ -544,7 +552,7 @@ mx::array ConvTranspose1d( mx::StreamOrDevice s) { return mx::conv_transpose1d(input, weight, stride.value_or(1), padding.value_or(0), dilation.value_or(1), - groups.value_or(1), s); + /*output_padding=*/0, groups.value_or(1), s); } mx::array ConvTranspose2d( @@ -574,7 +582,7 @@ mx::array ConvTranspose2d( dilation_pair = std::move(*p); return mx::conv_transpose2d(input, weight, stride_pair, padding_pair, - dilation_pair, groups.value_or(1), s); + dilation_pair, {0, 0}, groups.value_or(1), s); } mx::array ConvTranspose3d( @@ -604,7 +612,7 @@ mx::array ConvTranspose3d( dilation_tuple = std::move(*p); return mx::conv_transpose3d(input, weight, stride_tuple, padding_tuple, - dilation_tuple, groups.value_or(1), s); + dilation_tuple, {0, 0, 0}, groups.value_or(1), s); } mx::array ConvGeneral( @@ -789,6 +797,10 @@ void InitOps(napi_env env, napi_value exports) { "expm1", &mx::expm1, "erf", &mx::erf, "erfinv", &mx::erfinv, + "lgamma", &mx::lgamma, + "digamma", &mx::digamma, + "besselI0e", &mx::bessel_i0e, + "besselI1e", &mx::bessel_i1e, "sin", &mx::sin, "cos", &mx::cos, "tan", &mx::tan, @@ -811,7 +823,9 @@ void InitOps(napi_env env, napi_value exports) { "stopGradient", &mx::stop_gradient, "sigmoid", &mx::sigmoid, "power", BinOpWrapper(&mx::power), - "arange", &ops::ARange, + "arange", &ops::ARange); + + ki::Set(env, exports, "linspace", &ops::Linspace, "kron", &mx::kron, "take", &ops::Take, @@ -848,15 +862,20 @@ void InitOps(napi_env env, napi_value exports) { "min", DimOpWrapper(&mx::min), "max", DimOpWrapper(&mx::max), "logcumsumexp", CumOpWrapper(&mx::logcumsumexp), - "logsumexp", DimOpWrapper(&mx::logsumexp), + "logsumexp", DimOpWrapper(&mx::logsumexp)); + + ki::Set(env, exports, "mean", DimOpWrapper(&mx::mean), "variance", &ops::Var, "std", &ops::Std, "split", &ops::Split, - "argmin", &ops::ArgMin, + "argmin", &ops::ArgMin); + + ki::Set(env, exports, "argmax", &ops::ArgMax, "sort", &ops::Sort, "argsort", &ops::ArgSort, + "searchsorted", &ops::SearchSorted, "partition", KthOpWrapper(&mx::partition, &mx::partition), "argpartition", KthOpWrapper(&mx::argpartition, &mx::argpartition), "topk", KthOpWrapper(&mx::topk, &mx::topk), @@ -864,7 +883,9 @@ void InitOps(napi_env env, napi_value exports) { "blockMaskedMM", &mx::block_masked_mm, "gatherMM", &mx::gather_mm, "gatherQMM", &mx::gather_qmm, - "softmax", &ops::Softmax, + "softmax", &ops::Softmax); + + ki::Set(env, exports, "concatenate", &ops::Concatenate, "concat", &ops::Concatenate, "stack", &ops::Stack, @@ -876,7 +897,9 @@ void InitOps(napi_env env, napi_value exports) { "cumsum", CumOpWrapper(&mx::cumsum), "cumprod", CumOpWrapper(&mx::cumprod), "cummax", CumOpWrapper(&mx::cummax), - "cummin", CumOpWrapper(&mx::cummin), + "cummin", CumOpWrapper(&mx::cummin)); + + ki::Set(env, exports, "conj", &mx::conjugate, "conjugate", &mx::conjugate, "convolve", &ops::Convolve, @@ -912,7 +935,9 @@ void InitOps(napi_env env, napi_value exports) { "bitwiseXor", BinOpWrapper(&mx::bitwise_xor), "leftShift", BinOpWrapper(&mx::left_shift), "rightShift", BinOpWrapper(&mx::right_shift), - "view", &mx::view, + "view", &mx::view); + + ki::Set(env, exports, "hadamardTransform", &mx::hadamard_transform, "einsumPath", &mx::einsum_path, "einsum", &mx::einsum, diff --git a/src/transforms.cc b/src/transforms.cc index c2be8c0..f03ad14 100644 --- a/src/transforms.cc +++ b/src/transforms.cc @@ -166,11 +166,14 @@ ValueAndGradImpl(const char* error_tag, std::iota(gradient_indices.begin(), gradient_indices.end(), 0); // The result of |js_func| execution. napi_value result = nullptr; + // Flag set when the JS callback fails during tracing. + bool callback_failed = false; // Call value_and_grad with the JS function. napi_env env = js_func.Env(); auto value_and_grad_func = mx::value_and_grad( [error_tag, scalar_func_only, - &js_func, &args, &argnums, &arrays, &strides, &result, &env]( + &js_func, &args, &argnums, &arrays, &strides, &result, + &callback_failed, &env]( const std::vector& primals) -> std::vector { // Read the args into |js_args| vector, and replace the arrays in it // with the traced |primals|. @@ -191,6 +194,7 @@ ValueAndGradImpl(const char* error_tag, js_args.size(), js_args.empty() ? nullptr : &js_args.front(), &result) != napi_ok) { + callback_failed = true; return {}; } // Validate the return value. @@ -240,6 +244,18 @@ ValueAndGradImpl(const char* error_tag, // Call the function immediately, because this C++ lambda is actually the // result of value_and_grad. const auto& [values, gradients] = value_and_grad_func(arrays); + // If the JS callback threw during tracing, propagate the error instead + // of continuing with garbage results (stale tracer Symbol objects). + if (callback_failed) { + // Re-throw if there's a pending exception, otherwise create one. + bool has_exception = false; + napi_is_exception_pending(env, &has_exception); + if (!has_exception) { + ki::ThrowError(env, error_tag, + " The function threw an error during tracing."); + } + return {nullptr, nullptr}; + } // Convert gradients to JS value. For array inputs the gradients will be // returned, for Array and Object inputs the original arg will be returned // with their array properties replaced with corresponding gradients. @@ -265,7 +281,18 @@ ValueAndGradImpl(const char* error_tag, namespace transforms_ops { void Eval(ki::Arguments* args) { - mx::eval(TreeFlatten(args)); + auto arrays = TreeFlatten(args); + mx::eval(arrays); + // Detach evaluated arrays from the computation graph. + // After eval, each array still holds shared_ptr references to its inputs + // (for potential re-evaluation or gradient computation). In long-running + // processes, these graph chains prevent Metal buffers from being freed, + // causing num_resources to grow monotonically until crash. + // Since node-mlx manages gradients explicitly via valueAndGrad/grad + // (which trace their own graphs), the forward graph is not needed after eval. + for (auto& a : arrays) { + a.detach(); + } } napi_value AsyncEval(ki::Arguments* args) { @@ -477,6 +504,7 @@ void InitTransforms(napi_env env, napi_value exports) { "grad", &transforms_ops::Grad, "vmap", &transforms_ops::VMap, "compile", &transforms_ops::Compile, + "compileClearCache", &mx::detail::compile_clear_cache, "disableCompile", &mx::disable_compile, "enableCompile", &mx::enable_compile); } diff --git a/src/utils.cc b/src/utils.cc index b827c53..7195c18 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -1,7 +1,7 @@ #include "src/array.h" #include "src/utils.h" -mx::Shape PutIntoVector(std::variant shape) { +mx::Shape PutIntoShape(std::variant shape) { if (auto i = std::get_if(&shape); i) return {*i}; return std::move(std::get(shape)); @@ -39,6 +39,16 @@ napi_value AwaitFunction( std::function cpp_then, std::function cpp_finally) { napi_value result = func(); + // If func() threw a JS exception, skip cpp_then and go straight to cleanup. + // Check both the standard exception-pending flag AND null result, because + // Bun's N-API implementation does not always set the pending exception flag + // when a JS exception occurs inside a native callback. + bool has_exception = false; + napi_is_exception_pending(env, &has_exception); + if (has_exception || result == nullptr) { + cpp_finally(env); + return nullptr; + } // Return immediately if the result is not promise. bool is_promise = false; napi_is_promise(env, result, &is_promise); diff --git a/src/utils.h b/src/utils.h index 8e9fd41..3cf2fb2 100644 --- a/src/utils.h +++ b/src/utils.h @@ -8,12 +8,45 @@ namespace mx = mlx::core; +// Teach kizunapi how to serialize/deserialize SmallVector (used for Shape +// and other types in MLX >= 0.26). Mirrors the std::vector specialization. +namespace ki { + +template +struct Type> { + static constexpr const char* name = "Array"; + static napi_status ToNode(napi_env env, + const mlx::core::SmallVector& vec, + napi_value* result) { + napi_status s = napi_create_array_with_length(env, vec.size(), result); + if (s != napi_ok) return s; + for (size_t i = 0; i < vec.size(); ++i) { + napi_value el; + s = ConvertToNode(env, vec[i], &el); + if (s != napi_ok) return s; + s = napi_set_element(env, *result, i, el); + if (s != napi_ok) return s; + } + return napi_ok; + } + static std::optional> FromNode( + napi_env env, napi_value value) { + // Read as std::vector then convert to SmallVector. + auto vec = Type>::FromNode(env, value); + if (!vec) return std::nullopt; + return mlx::core::SmallVector(vec->begin(), vec->end()); + } +}; + +} // namespace ki + using OptionalAxes = std::variant>; using ScalarOrArray = std::variant; -// Read args into a vector of types. -template -bool ReadArgs(ki::Arguments* args, std::vector* results) { +// Read args into a container of types (vector or SmallVector). +template +bool ReadArgs(ki::Arguments* args, Container* results) { + using T = typename Container::value_type; while (args->RemainingsLength() > 0) { std::optional a = args->GetNext(); if (!a) { @@ -45,8 +78,15 @@ void DefineToString(napi_env env, napi_value prototype) { symbol, ki::MemberFunction(&ToString)); } +// If input is one int, put it into a Shape, otherwise just return the Shape. +mx::Shape PutIntoShape(std::variant shape); + // If input is one int, put it into a vector, otherwise just return the vector. -std::vector PutIntoVector(std::variant> shape); +inline std::vector PutIntoVector(std::variant> v) { + if (auto i = std::get_if(&v); i) + return {*i}; + return std::move(std::get>(v)); +} // Get axis arg from js value. std::vector GetReduceAxes(OptionalAxes value, int dims);