Automatic differentiation of Java code using Code Reflection

原文はこちら。
The original article was written by Paul Sandoz (Architect, Java at Oracle).
https://openjdk.org/projects/babylon/articles/auto-diff

この記事では、自動微分(automatic differentiation)とは何か、なぜ便利なのか、そしてJavaメソッドの自動微分を実装するのに役立つCode Reflectionの使い方について説明します。

Code Reflectionは、OpenJDKのProject Babylonの下で研究開発されているJavaプラットフォームの機能です。

Project Babylon
https://openjdk.org/projects/babylon

課題を説明し、解決策を提示しながら、Code Reflectionの概念とAPIを紹介します。説明は網羅的でも非常に詳細でもなく、読者がCode Reflectionとその機能を直感的に感じ、理解できるようにデザインしています。

Machine learning

機械学習の分野を考えてみましょう。機械学習とは、数字の画像にどのような数字が含まれているかを予測したり、人の話し声を録音した音声から、その人がどのような文章を話しているかを予測したりするというモデルを作り出す科学です。

機械学習モデルとは、数学的な関数fです。多くの入力、出力、定数やパラメータを持つ複雑なものです。

モデルを効果的にするためには学習が必要で、多くのモデルは勾配降下アルゴリズムを使って学習されています。このアルゴリズムは、fに既知の入力を繰り返し適用し、fの計算された出力と期待される出力を比較し、計算された出力と期待される出力の差、つまり誤差が十分に小さくなるまでfのパラメータを調整します。このアルゴリズムは、誤差をfの微分f'に適用した結果である勾配ベクトルによってモデルを調整します。

機械学習開発者は、このようなモデルをコンピュータ・コードを使って実装します。そのためには、fとその微分であるf'の実装が実行可能である必要があります。

ある数学関数を実装する次のメソッドfを考えてみましょう。

f(x, y) = x * (-sin(x * y) + y) * 4

入力(または独立変数)xに関するfの偏導関数を、手動で計算して以下のように記述できます。

df_dx(x, y) = (-sin(x * y) + y - x * cos(x * y) * y) * 4

yに関するfの偏導関数は、後でJavaコードとして示します)

偏導関数メソッドの呼び出し結果を組み合わせると、偏導関数から勾配ベクトルを生成できます。

Automatic differentiation

手作業による微分は非常にミスの多いプロセスです。微分は機械的なプロセスですが、上に示したような単純な数学関数であっても、デバッグしづらいミスを犯しやすく、数学関数が複雑になればなるほど、そのプロセスはすぐに困難なものになります。特に機械学習モデルの場合はそうです。コンピュータは、この機械的な作業に理想的に適しています。

Javaメソッドfとして書かれた数学関数を自動的に微分するには、微分のルールを実装し、そのルールをfの記号表現に適用して微分メソッドf'を生成するJavaプログラムDを書く必要があります。

プログラムDでは、Code Reflectionを使用して、コード・モデルと呼ばれるメソッドfの記号表現が得られます。そしてDは、fのコードモデル内の記号情報であるオペレーションを走査し、そうしたオペレーションに微分の規則を適用できます。例えば、オペレーションは加算や乗算を表す数学演算であったり、超越関数を実装する Java メソッド(java.lang.Math::sinメソッドなど)への呼び出しを表す呼び出しオペレーションだったりします。

Dは、微分を計算するオペレーションを含むf'を表す新しいコードモデルを生成します。それをバイトコードにコンパイルしてJavaプログラムとして呼び出し可能です。

自動微分には、forward-mode(順モード)とreverse-mode(逆モード)の2つのアプローチがあります。N個の独立変数に対して、順モードの自動微分はN個の偏微分メソッドを生成する必要があり、したがって勾配を生成するためにN回のメソッド呼び出しが必要です。逆モードの自動微分にはこのような制限はありませんが、独立変数が少ない場合は効率が悪くなります。

プログラムDをJavaライブラリにカプセル化できます。理想的には次のように使うことになるでしょう。

@CodeReflection
double f(double x, double y) {
    return x * (-sin(x * y) + y) * 4;
}

Function<double[], double[]> g_f = AD.gradientFunction(this::f);
double[] g = g_f.apply(new double[]{x, y});

微分対象の関数であるメソッドf@CodeReflectionと注釈を付けます。これにより、fで利用可能なコード・モデルが存在し、呼び出しと同様のアクセス制御ルールでアクセスできることが保証されます。次に、メソッド参照としてfを渡してメソッドAD.gradientFunctionを呼び出します。このメソッド参照は、そのインスタンスがfのコードモデルへのアクセスを与えるcode reflection型を対象としています。

gradientFunctionメソッドのライブラリ作成者は、どのようにメソッドfを微分できるのでしょうか?

Implementing forward-mode automatic differentiation

以後の章では、Code Reflectionを使用してgradientFunctionメソッドを実装する方法を説明します。理解しやすいので、順モードでの自動微分に焦点を当てますが、同じ一般原則が逆モードの自動微分にも適用できます。

概念実証の実装は、Babylonリポジトリにあるテストコードにあります。この実装は完全には程遠く、この問題にアプローチする多くの可能な方法の一つに過ぎません。

PoCコード
https://github.com/openjdk/babylon/tree/code-reflection/test/jdk/java/lang/reflect/code/ad

Differentiating simple functions

先ほどの簡単な数学関数に注目してみましょう。この関数にはxyという独立変数が2個あります。

@CodeReflection
static double f(double x, double y) {
    return x * (-Math.sin(x * y) + y) * 4.0d;
}

これは@CodeReflectionでアノテーションされています。コンパイル時にコードモデルが生成され、リフレクションによって実行時にアクセスできるようになります。

また、xyに関するfの偏導関数を手作業で導出できるので、生成されたものに対してテストできます。

static double df_dx(double x, double y) {
    return (-Math.sin(x * y) + y - x * Math.cos(x * y) * y) * 4.0d;
}

static double df_dy(double x, double y) {
    return x * (1 - Math.cos(x * y) * x) * 4.0d;
}

Obtaining a code model

基本的には、コードモデルと微分対象の独立変数への参照を受け取り、入力の偏導関数である新しいコードモデルを生成するメソッドを実装する必要があります。先ほどの例(プログラムDを使います。ここで作者はメソッドの注釈にCode Reflection APIを最低限しか使用していないことがわかっています)ほどユーザーフレンドリーではありませんが、この側面に焦点を当てます。

まず、fがクラスTのstaticメソッドとして宣言されているとします。以下のようにして、リフレクションを使用してそのコードモデルを取得します。

Method fm = T.class.getDeclaredMethod("f", double.class, double.class);
Optional<CoreOps.FuncOp> o = fm.getCodeModel();
CoreOps.FuncOp fcm = o.orElseThrow();

リフレクションAPIを使用して、java.lang.reflect.Methodインスタンスであるfを見つけ、メソッドgetCodeModelを呼び出してそのコード・モデルを求めます。@CodeReflectionでアノテーションされたメソッドのみがコードモデルを持つため、このメソッドはpartialです。

fのコード・モデルはCoreOps.FuncOpのインスタンスとして表され、Javaメソッドの宣言をモデル化する関数宣言操作に対応します。

Explaining the code model

コード・モデルは、オペレーション(operation)、ボディ(body)、およびブロック(block)を含むツリーです。オペレーションには、0 個以上のボディが含まれます。ボディには、1 つ以上のブロックが含まれます。ブロックには、1 つ以上のオペレーションのシーケンスが含まれます。ブロックには、0 個以上のブロック・パラメーター(値)を宣言できます。オペレーションは、オペレーションの結果である値を宣言します。オペレーションでは、オペランドとして値を使用できます。

この単純なツリー構造を使用すると、多くのJava言語コンストラクトをモデル化する操作を定義でき、したがって、多くのJavaプログラムをモデル化するコード・モデルを構築できます。これは最初は意外に見えるかもしれません。読者は、算術演算のような従来の意味でのオペレーション(operation)という用語に馴染みがあるかもしれません。しかし、上述した構造を考えれば、このような従来の意味に限定する必要はありません。関数を宣言するオペレーション(CoreOps.FuncOpのインスタンス)、Javaのラムダ式をモデル化するオペレーション(CoreOps.LambdaOpのインスタンス)、またはJavaのtry文をモデル化するオペレーション(ExtendedOps.JavaTryOpのインスタンス)といったセマンティクスを持つオペレーションを自由に定義できます。

fのコード・モデルはどのように見えるでしょうか。メモリ内の形式(CoreOps.FuncOpのインスタンス)をテキスト形式にシリアライズできます。

System.out.println(fcm.toText());

以下のように出力されます。

func @"f" (%0 : double, %1 : double)double -> {
    %2 : Var<double> = var %0 @"x";
    %3 : Var<double> = var %1 @"y";
    %4 : double = var.load %2;
    %5 : double = var.load %2;
    %6 : double = var.load %3;
    %7 : double = mul %5 %6;
    %8 : double = invoke %7 @"java.lang.Math::sin(double)double";
    %9 : double = neg %8;
    %10 : double = var.load %3;
    %11 : double = add %9 %10;
    %12 : double = mul %4 %11;
    %13 : double = constant @"4.0";
    %14 : double = mul %12 %13;
    return %14;
};

テキスト形式から、コード・モデルのルートが関数宣言(func)のオペレーションであることがわかります。関数宣言のオペレーションは、他のすべてのオペレーションと同様にオペレーション結果を持ちますが、ツリーのルートであるため、それを提示する必要はありません。

ラムダのような式は、関数宣言オペレーションの単一のボディと、エントリブロックと呼ばれるボディの最初で唯一のブロックの融合を表します。そして、エントリーブロックには一連のオペレーションがあります。各オペレーションには、対応するクラスのインスタンスがインメモリ形式で存在し、これらはすべて抽象クラスjava.lang.reflect.code.Opから拡張されています。

エントリ・ブロックには、(xyに対応する)%0%1という2つのブロック・パラメータがあり、それぞれをdouble型で記述していますが、これらはfのメソッドパラメータをモデル化しています。これらのパラメータを、さまざまなオペレーションのオペランドとして使用します。多くのオペレーションはオペレーション結果、例えば乗算演算の結果%12を生成し、それを後続のオペレーションのオペランドとして使用する、などです。returnオペレーションは、他のすべてのオペレーションと同様に結果を持ちますが、その結果を意味あるように使用できないため、提示しません。

コードモデルにはSSA(Static Single-Assignment、静的単一代入)という性質があります。例えば、値%12は決して変更できません。変数宣言は、値(ボックス)を保持する値を生成する操作としてモデル化され、アクセス操作はそのボックスにロードまたはストアします。

(読者の中には、これがMLIRに非常に似ていると思う人もいるかもしれませんが、それは意図的なものです)。

MLIR (Multi-Level Intermediate Representation) Overview
https://mlir.llvm.org/

メソッド宣言、変数(メソッドパラメータやローカル変数)や変数へのアクセス、バイナリや単項の数学演算、メソッド呼び出し(例えばjava.lang.Math::sinメソッド)のようなJava言語の構成要素を、オペレーションがどのようにモデル化しているかを確認できます。

Analyzing the model

変数の宣言とアクセスを取り除いたモデルに変換することで、このモデルを単純化することができます。これを純粋なSSA変換と呼びます。

fcm = SSA.transform(fcm);

結果として得られるコード・モデルのテキスト形式は次のようになります。

func @"f" (%0 : double, %1 : double)double -> {
    %2 : double = mul %0 %1;
    %3 : double = invoke %2 @"java.lang.Math::sin(double)double";
    %4 : double = neg %3;
    %5 : double = add %4 %1;
    %6 : double = mul %0 %5;
    %7 : double = constant @"4.0";
    %8 : double = mul %6 %7;
    return %8;
};

これはプログラムの意味が保たれた、より単純なモデルです。モデルが単純なので、自動微分の準備のための分析も単純になります。

独立変数に関して微分する場合、その変数に関してモデルを分析し、独立変数に推移的に依存するアクティブな一連の値を計算します。

例えば、(独立変数xを表す)ブロック・パラメータ%0を、演算結果%2(乗算の結果)を生成するオペレーションでオペランドとして使用します。

%0のアクティブセットは

 {%0, %2, %3, %4, %5, %6, %8, %9} 

です。値%9はreturnオペレーションの結果を表しますが、その型はvoidであり、その値は明示的にテキスト形式には現れません。

この場合、(独立変数yを表す)%1のアクティブセットは同じです。

このセットは、コード・モデルの値の用途を走査することで計算できます。コード・モデルのインメモリ表現は、簡単に利用できるように構築されます。

単純で素朴な実装は以下のようになります。

static Set<Value> activeSet(Value root) {
    Set<Value> s = new LinkedHashSet<>();
    activeSet(s, root);
    return s;
}

static void activeSet(Set<Value> s, Value v) {
    s.add(v);
    // Iterate over the uses of v
    for (Op.Result use : v.uses()) {
        activeSet(s, use);
    }
}

テストコードで実装され、いくつかの制御フローに数式が埋め込まれた例で後で使用されるように、実際には、アクティブセットの計算はもう少し複雑にする必要があります。

Reporting programming errors

一般的に、すべてのJavaプログラムが微分可能なわけではありません。これは制約のあるプログラミング・モデルだから微分可能なのです。微分対象のJavaメソッドの作者は制約を意識する必要があり、自動微分プログラムはプログラミング・エラーを報告する必要があります。

アクティブセットを計算した後(あるいは計算中)、値が微分できない場合や独立値が結果に寄与しない場合などのチェックを行い、エラーを報告できます。さらに、モデルに対して追加のチェックを行い、サポートされていない言語構文が存在する場合、例えばtry文が存在する場合、あるいは他の構文を無視する場合などにエラーにしたいかもしれません。テストコードでは、そのような大規模なチェックは行っていません。

このようなエラーの報告は、実行時ではなく、メソッドがソース・コンパイラによってコンパイルされるときに行われるのが理想的です。Code Reflectionは、このような目的のためにコンパイル時に同じコード・モデルを利用できるようにすることもできますが、この記事ではこの機能については説明しません。

Differentiating a code model

アクティブセットができたら、それを使ってコードモデルの微分を行います。アクティブセットの演算結果は、微分する必要があるオペレーションを参照しています。

必要なオペレーションに微分のルールを適用することで、遭遇するコードモデルを変換できます。

コード・モデルは不変(immutable)です。コード・モデルは、構築もしくは既存のコード・モデルを変換することによって生成できます。変換は、入力コード・モデルを受け取り、出力コード・モデルを構築します。入力コード・モデルで遭遇する各入力演算に対して、その演算を出力コード・モデルのビルダーに追加するか(コピー)、追加しないか(削除)、新しい出力操作を追加するか(置換または追加)を選択します。

まず、コンストラクタがアクティブ・セットを計算するクラスForwardDifferentiationを宣言します。

import static java.lang.reflect.code.op.CoreOps.*;

public final class ForwardDifferentiation {
    // The function to differentiate
    final FuncOp fcm;
    // The independent variable
    final Block.Parameter ind;
    // The active set for the independent variable
    final Set<Value> activeSet;
    // The map of input value to it's (output) differentiated value
    final Map<Value, Value> diffValueMapping;

    // The constant value 0.0d
    // Declared in the (output) function's entry block
    Value zero;

    private ForwardDifferentiation(FuncOp fcm, Block.Parameter ind) {
        int indI = fcm.body().entryBlock().parameters().indexOf(ind);
        if (indI == -1) {
            throw new IllegalArgumentException("Independent argument not defined by function");
        }
        this.fcm = fcm;
        this.ind = ind;

        // Calculate the active set of dependent values for the independent value
        this.activeSet = ActiveSet.activeSet(fcm, ind);
        // A mapping of input values to their (output) differentiated values
        this.diffValueMapping = new HashMap<>();
    }
}

入力演算結果から(出力される)微分値へのマップであるdiffValueMappingも作成されます。これは、依存する計算で使用される可能性のある微分値を追跡するために使われます。コード・モデルは不変なので、入力コードモデルで参照される値が古くなったり、変更されたりする心配はありません。

次に、偏微分を計算するForwardDifferentiationのstatic factoryメソッドを宣言します。

public static FuncOp partialDiff(FuncOp fcm, Block.Parameter ind) {
    return new ForwardDifferentiation(fcm, ind).partialDiff();
}

このstaticメソッドはForwardDifferentiationのインスタンスを作成し、変換を実施するpartialDiffメソッドを呼び出します。

FuncOp partialDiff() {
    int indI = fcm.body().entryBlock().parameters().indexOf(ind);

    // Transform f to f' w.r.t ind
    AtomicBoolean first = new AtomicBoolean(true);
    FuncOp dfcm = fcm.transform(STR."d\{fcm.funcName()}_darg\{indI}",
            (block, op) -> {
                // Initialization
                if (first.getAndSet(false)) {
                    // Declare the zero value constant
                    zero = block.op(constant(ind.type(), 0.0d));
                    // Declare the one value constant
                    Value one = block.op(constant(ind.type(), 1.0d));
                    // The differential of ind is one
                    // For all other parameters it is zero (absence from the map)
                    diffValueMapping.put(ind, one);
                }

                // If the result of the operation is in the active set,
                // then differentiate it, otherwise copy it
                if (activeSet.contains(op.result())) {
                    Value dor = diffOp(block, op);
                    // Map the input result to its (output) differentiated result
                    // so that it can be used when differentiating subsequent operations
                    diffValueMapping.put(op.result(), dor);
                } else {
                    // Block is not part of the active set, just copy it
                    block.op(op);
                }
                return block;
            });

    return dfcm;
}

FuncOpオペレーションのtransformメソッドを使用してコード・モデルを変換します。このメソッドには、出力コードモデルの関数の名前と、(出力)ブロック・ビルダーと入力演算を受け付けるラムダ式を渡します。transformメソッドは入力コードモデル内のすべての操作を走査し、遭遇した入力操作をラムダ式に報告します。

最初の遭遇時に、出力モデルに定数演算(ConstantOpのインスタンス)を追加することで、出力モデルにいくつかの定数値を宣言します。具体的には、ゼロ定数値0.0dを宣言します。これは、出力モデルに追加する後続の演算のオペランドとして使用されます。

遭遇の都度、演算の結果がアクティブセットのメンバーであれば、それを微分し、(入力)結果をその(出力)微分値にマッピングします。内部から外部へ連鎖規則を適用します。

diffOpメソッドは、微分のルールを適用します。このメソッドは、微分するための演算のインスタンスをターゲットとする、パターンマッチングによるswitch式で構成されます。以下は、このメソッドから得られる興味深いswitch caseのサブセットです。

Value diffOp(Block.Builder block, Op op) {
    // Switch on the op, using pattern matching
    return switch (op) {
        case ... -> {
        }
        case CoreOps.MulOp _ -> {
            // Copy input operation
            block.op(op);

            // Product rule
            // diff(l) * r + l * diff(r)
            Value lhs = op.operands().get(0);
            Value rhs = op.operands().get(1);
            Value dlhs = diffValueMapping.getOrDefault(lhs, zero);
            Value drhs = diffValueMapping.getOrDefault(rhs, zero);
            Value outputLhs = block.context().getValue(lhs);
            Value outputRhs = block.context().getValue(rhs);
            yield block.op(add(
                    block.op(mul(dlhs, outputRhs)),
                    block.op(mul(outputLhs, drhs))));
        }
        case CoreOps.InvokeOp c -> {
            MethodDesc md = c.invokeDescriptor();
            String operationName = null;
            if (md.refType().equals(J_L_MATH)) {
                operationName = md.name();
            }
            // Differentiate sin(x)
            if ("sin".equals(operationName)) {
                // Copy input operation
                block.op(op);

                // Chain rule
                // cos(expr) * diff(expr)
                Value a = op.operands().get(0);
                Value da = diffValueMapping.getOrDefault(a, zero);
                Value outputA = block.context().getValue(a);
                Op.Result cosx = block.op(invoke(J_L_MATH_COS, outputA));
                yield block.op(mul(cosx, da));
            } else {
                throw new UnsupportedOperationException("Operation not supported: " + op.opName());
            }
        }
    };
}

最初のケースは、積の微分法則 (product rule) を適用してMulOpのインスタンスである乗算演算の差分を計算する方法を示しています。

入力の乗算演算をビルダーに追加します。このビルダーでは、出力モデルに追加される更なる計算で使用される可能性があるため、演算結果を入力モデルから出力モデルにコピーします。

乗算演算の第1オペランド(左)と第2オペランド(右)の入力値が与えられたら、それらの微分値を求めますが、これは事前に計算済みであるか、計算できていない場合にはゼロである必要があります。次に、積の微分法則に対応する2つの新しい乗算演算を追加します。そのためには、第1オペランドと第2オペランドの入力値に対応する出力値も求める必要があります(これもまた、前の入力演算をコピーするときに事前に計算済みである必要があります)。最後に、乗算の結果を加算演算で合計し、結果を得ます。

2つ目のケースは、sin(x)の微分(cos(x)x')を計算する方法です。メソッドjava.lang.Math::sinを呼び出すInvokeOpのインスタンスに対応します。呼び出し操作をコピーし、java.lang.Math::cosを呼び出す呼び出し操作を追加し、乗算操作を追加します。

私たちは、パターンとswitchが多くの種類の変換に使用されることを期待しており、これを念頭に置いてCode Reflection APIを設計しています。将来の言語機能では、独自のパターンを書くことができ、より洗練されたツリーベースのオペレーションのマッチング(用途やオペランドを含む)が可能になることを期待しています。

すべてをまとめて、微分されたコード・モデルのテキスト形式を出力してみましょう。

import ad.ForwardDifferentiation;

Method fm = T.class.getDeclaredMethod("f", double.class, double.class);
Optional<CoreOps.FuncOp> o = fm.getCodeModel();
CoreOps.FuncOp fcm = SSA.transform(o.orElseThrow());
Block.Parameter x = fcm.body().entryBlock().parameters().get(0);
// Code model in, code model out
CoreOps.FuncOp dfcm_x = ForwardDifferentiation.partialDiff(fcm, x);
func @"df_darg0" (%0 : double, %1 : double)double -> {
    %2 : double = constant @"0.0";
    %3 : double = constant @"1.0";
    %4 : double = mul %0 %1;
    %5 : double = mul %3 %1;
    %6 : double = mul %0 %2;
    %7 : double = add %5 %6;
    %8 : double = invoke %4 @"java.lang.Math::sin(double)double";
    %9 : double = invoke %4 @"java.lang.Math::cos(double)double";
    %10 : double = mul %9 %7;
    %11 : double = neg %8;
    %12 : double = neg %10;
    %13 : double = add %11 %1;
    %14 : double = add %12 %2;
    %15 : double = mul %0 %13;
    %16 : double = mul %3 %13;
    %17 : double = mul %0 %14;
    %18 : double = add %16 %17;
    %19 : double = constant @"4.0";
    %20 : double = mul %15 %19;
    %21 : double = mul %18 %19;
    %22 : double = mul %15 %2;
    %23 : double = add %21 %22;
    return %23;
};

ハンドコーディングしたJavaコードと比較してみます。

static double df_dx(double x, double y) {
    return (-Math.sin(x * y) + y - x * Math.cos(x * y) * y) * 4.0d;
}

微分されたコード・モデルには数学演算が多く、その多くが冗長であることがすぐにわかります。例えば、0や1による乗算があり、1つの演算の結果(%20)は使用されません。

コードモデルをさらに変換して、冗長な演算を削除する必要があります(一般に式の削除と呼ばれます)。式を削除したコードモデルを以下に示します。

(式の除去がどのように実装されているかについては、ここでは触れません。上で示したのと同じCode Reflection APIと同様のテクニックを使用しています。興味関心のある読者はコードを見たいと思うかもしれません)

func @"df_darg0" (%0 : double, %1 : double)double -> {
    %2 : double = constant @"0.0";
    %3 : double = mul %0 %1;
    %4 : double = add %1 %2;
    %5 : double = invoke %3 @"java.lang.Math::sin(double)double";
    %6 : double = invoke %3 @"java.lang.Math::cos(double)double";
    %7 : double = mul %6 %4;
    %8 : double = sub %1 %5;
    %9 : double = sub %2 %7;
    %10 : double = mul %0 %9;
    %11 : double = add %8 %10;
    %12 : double = constant @"4.0";
    %13 : double = mul %11 %12;
    %14 : double = add %13 %2;
    return %14;
};

微分変換でいくつかの式をも削除できますが、そうすると変換が複雑になります。多くの場合、変換を集中させ、より多くの作業を犠牲にしてでも、2つ以上の変換ステージに分けた方が良いでしょう。

このコードモデルから、それを解釈したり、バイトコードに変換したり、メソッドハンドルを呼び出して実行したりできます。

同じことを独立変数yに適用すると、2つのコードモデルが得られ、そこから勾配を計算し、gradientFunctionメソッドを実装できます。

Differentiating models with control flow

前の例では、簡単な数式を実装したJavaメソッドを微分しました。このセクションでは、制御フロー文に埋め込まれた数式を含む、より複雑なメソッドについて考えます。

@CodeReflection
static double f(/* independent */ double x, int y) {
    /* dependent */
    double o = 1.0;
    for (int i = 0; i < y; i = i + 1) {
        if (i > 1) {
            if (i < 5) {
                o = o * x;
            }
        }
    }
    return o;
}

メソッドfには1個の独立変数xがあります。パラメータyxを使った計算に間接的に影響しますが、xyの間には直接的なデータ依存関係はありません。乗算演算がループの中に埋め込まれており、条件付きの制御フローがあります。このメソッドは、これまで紹介したのと同じテクニックを使って微分できます。

以下はハンドコーディングしたものです。

Here is the hand-coded version.

static double df_dx(/* independent */ double x, int y) {
    double d_o = 0.0;
    double o = 1.0;
    for (int i = 0; i < y; i = i + 1) {
        if (i > 1) {
            if (i < 5) {
                d_o = d_o * x + o * 1.0;
                o = o * x;
            }
        }
    }
    return d_o;
}

o * xの乗算に適用される積の微分法則に注意してください。ソースではoが更新される前にこの計算が適用されることに注意しなければなりません(賢明な読者は、od_oxと1がdual number(二重数もしくは双対数)の2要素に対応することに気づいてらっしゃることでしょう)。

Dual Number
https://en.wikipedia.org/wiki/Dual_number#Differentiation

fのコードモデルには、forループとif文をモデル化したオペレーションが含まれており、コードの構造は保持されています。

func @"f" (%0 : double, %1 : int)double -> {
    %2 : Var<double> = var %0 @"x";
    %3 : Var<int> = var %1 @"y";
    %4 : double = constant @"1.0";
    %5 : Var<double> = var %4 @"o";
    java.for
        ()Var<int> -> {
            %6 : int = constant @"0";
            %7 : Var<int> = var %6 @"i";
            yield %7;
        }
        (%8 : Var<int>)boolean -> {
            %9 : int = var.load %8;
            %10 : int = var.load %3;
            %11 : boolean = lt %9 %10;
            yield %11;
        }
        (%12 : Var<int>)void -> {
            %13 : int = var.load %12;
            %14 : int = constant @"1";
            %15 : int = add %13 %14;
            var.store %12 %15;
            yield;
        }
        (%16 : Var<int>)void -> {
            java.if
                ()boolean -> {
                    %17 : int = var.load %16;
                    %18 : int = constant @"1";
                    %19 : boolean = gt %17 %18;
                    yield %19;
                }
                ()void -> {
                    java.if
                        ()boolean -> {
                            %20 : int = var.load %16;
                            %21 : int = constant @"5";
                            %22 : boolean = lt %20 %21;
                            yield %22;
                        }
                        ()void -> {
                            %23 : double = var.load %5;
                            %24 : double = var.load %2;
                            %25 : double = mul %23 %24;
                            var.store %5 %25;
                            yield;
                        }
                        ()void -> {
                            yield;
                        };
                    yield;
                }
                ()void -> {
                    yield;
                };
            java.continue;
        };
    %26 : double = var.load %5;
    return %26;
};

オペレーションの中には、多くのレベルのボディを含むものがあることがよくわかります。この場合、前の例と同じように、各ボディは1つの(エントリ)ブロックを持ちます。for操作の4個のボディは、Java言語仕様で定められている入れ子になった式と文(nested expressions and statements)に対応しています。

diffOpメソッドを修正してこのようなオペレーションとその振る舞いを知る代わりに、解決策を単純化して一般化できます。Code Reflectionでは、2つの操作セットを定義しています。幅広いJavaプログラムのモデリングに使用できるコアオペレーションと、拡張(または補助)オペレーションです。拡張オペレーションは、for文、if文、try文などのJava言語構文をモデル化します。拡張オペレーションは、ボディ内の連結されたブロック・セット内の一連のコアオペレーションに落とし込むことができます。

fのコード・モデルをlowering transformationを使用して低次元化し、純粋な SSA に変換します。

fcm = fcm.transform((block, op) -> {
    if (op instanceof Op.Lowerable lop) {
        return lop.lower(block);
    } else {
        block.op(op);
        return block;
    }
});
fcm = SSA.transform(fcm);
func @"fcf" (%0 : double, %1 : int)double -> {
    %2 : double = constant @"1.0";
    %3 : int = constant @"0";
    branch ^block_0(%2, %3);
  
  ^block_0(%4 : double, %5 : int):
    %6 : boolean = lt %5 %1;
    cbranch %6 ^block_1 ^block_2;
  
  ^block_1:
    %7 : int = constant @"1";
    %8 : boolean = gt %5 %7;
    cbranch %8 ^block_3 ^block_4;
  
  ^block_3:
    %9 : int = constant @"5";
    %10 : boolean = lt %5 %9;
    cbranch %10 ^block_5 ^block_6;
  
  ^block_5:
    %11 : double = mul %4 %0;
    branch ^block_7(%11);
  
  ^block_6:
    branch ^block_7(%4);
  
  ^block_7(%12 : double):
    branch ^block_8(%12);
  
  ^block_4:
    branch ^block_8(%4);
  
  ^block_8(%13 : double):
    branch ^block_9;
  
  ^block_9:
    %14 : int = constant @"1";
    %15 : int = add %5 %14;
    branch ^block_0(%13, %15);
  
  ^block_2:
    return %4;
};

結果のコード・モデル(プログラムの意味は保持されています)には、funcオペレーションのボディ内に複数のブロックがあることがわかります。もはやボディの入れ子はありません。

プログラムの制御フローのモデル化では、ブロックと、それらがどのように接続されて制御フロー・グラフを形成するかによって決まります。block_9では、次のループ反復のためのブロック引数としてoiの値を渡すblock_0への分岐を確認できます。(ブロック引数とパラメータは、コードをモデル化する他のアプローチにおけるphiノードに類似しています)

このようなブロック間のつながりを理解するために、アクティブセットと微分演算を計算する我々の実装を拡張する必要があります。

ささやかな改良で、xのアクティブセットを計算し、このJavaメソッドを微分できます。アクティブセットは、(oxに依存するため)現在のループ反復の値oを表すパラメータ%4のようなブロック・パラメータをメンバとして持ちます。微分されたモデルを以下に示します。

func @"dfcf_darg0" (%0 : double, %1 : int)double -> {
    %2 : double = constant @"0.0";
    %3 : double = constant @"1.0";
    %4 : double = constant @"1.0";
    %5 : int = constant @"0";
    branch ^block_0(%4, %5, %2);
  
  ^block_0(%6 : double, %7 : int, %8 : double):
    %9 : boolean = lt %7 %1;
    cbranch %9 ^block_1 ^block_2;
  
  ^block_1:
    %10 : int = constant @"1";
    %11 : boolean = gt %7 %10;
    cbranch %11 ^block_3 ^block_4;
  
  ^block_3:
    %12 : int = constant @"5";
    %13 : boolean = lt %7 %12;
    cbranch %13 ^block_5 ^block_6;
  
  ^block_5:
    %14 : double = mul %6 %0;
    %15 : double = mul %8 %0;
    %16 : double = mul %6 %3;
    %17 : double = add %15 %16;
    branch ^block_7(%14, %17);
  
  ^block_6:
    branch ^block_7(%6, %8);
  
  ^block_7(%18 : double, %19 : double):
    branch ^block_8(%18, %19);
  
  ^block_4:
    branch ^block_8(%6, %8);
  
  ^block_8(%20 : double, %21 : double):
    branch ^block_9;
  
  ^block_9:
    %22 : int = constant @"1";
    %23 : int = add %7 %22;
    branch ^block_0(%20, %23, %21);
  
  ^block_2:
    return %8;
};

block_5で積の微分法則が適用されていることがわかります。さらに、oの微分d_oを渡すために必要な追加のブロック引数やパラメータがあります。block_9では、ブロック引数%21が追加された分岐があり、これがblock_0d_oの次の値、パラメータ%8になります。

コメントを残す

このサイトはスパムを低減するために Akismet を使っています。コメントデータの処理方法の詳細はこちらをご覧ください