MLX では、導関数を求めてから、微分係数を計算することができます。 これは、JAX由来のものだそうです。 MLX とは何かというと、Apple シリコン搭載の Mac コンピュータ向けの機械学習フレームワークです。
まずは、\(f(x)\) の計算から見てみましょう。
$$f(x) = \sin(x)$$について、\(a = 0\) のとき、
$$f(a) = 0$$です。これを、コードで表現すると次のようになります。
>> import mlx.core as mx
>> f = mx.sin
>> a = mx.array(0.0) # 数値は mx.array オブジェクトに変換
>> f(a)
array(0, dtype=float32)
1階微分
さて、1階微分を計算してみます。
$$f'(x) = \frac{d}{dx} f(x)$$より、 \(a = 0\) のときの微分係数を求めると、\(f'(x) = \cos(x)\) ですので、
$$f'(a) = 1$$となります。数式表現に対応づけてコードでも
>> fd = mx.grad(f)
>> fd(a)
array(1, dtype=float32)
と書くことができます。導関数を求めてから、微分係数を計算するというステップになっていますね。
2階微分
2階微分でも同様に行うことができます。2次導関数
$$f''(x) = \frac{d^2}{dx^2} f(x)$$を求めてから、\(a = 0\) を代入して微分係数を求めます。ここで、\(f''(x) = -\sin(x)\) ですので、
$$f''(a) = 0$$となります。コードでも数式を書くように、
>> fdd = mx.grad(mx.grad(f))
>> fdd(a)
array(-0, dtype=float32)
と表現できます。
**
このように、MLX を使えば1階微分から高階微分まで、まるで数式を書き下すように直感的に扱うことができます。