MLXと微分計算

MLX では、導関数を求めてから、微分係数を計算することができます。 これは、JAX由来のものだそうです。 MLX とは何かというと、Apple シリコン搭載の Mac コンピュータ向けの機械学習フレームワークです。

まずは、\(f(x)\) の計算から見てみましょう。

$$f(x) = \sin(x)$$

について、\(a = 0\) のとき、

$$f(a) = 0$$

です。これを、コードで表現すると次のようになります。

>> import mlx.core as mx
>> f = mx.sin
>> a = mx.array(0.0)  # 数値は mx.array オブジェクトに変換
>> f(a)
array(0, dtype=float32)

1階微分

さて、1階微分を計算してみます。

$$f'(x) = \frac{d}{dx} f(x)$$

より、 \(a = 0\) のときの微分係数を求めると、\(f'(x) = \cos(x)\) ですので、

$$f'(a) = 1$$

となります。数式表現に対応づけてコードでも

>> fd = mx.grad(f)
>> fd(a)
array(1, dtype=float32)

と書くことができます。導関数を求めてから、微分係数を計算するというステップになっていますね。


2階微分

2階微分でも同様に行うことができます。2次導関数

$$f''(x) = \frac{d^2}{dx^2} f(x)$$

を求めてから、\(a = 0\) を代入して微分係数を求めます。ここで、\(f''(x) = -\sin(x)\) ですので、

$$f''(a) = 0$$

となります。コードでも数式を書くように、

>> fdd = mx.grad(mx.grad(f))
>> fdd(a)
array(-0, dtype=float32)

と表現できます。

**

このように、MLX を使えば1階微分から高階微分まで、まるで数式を書き下すように直感的に扱うことができます。

by lwgena