Java で行列の演算 - nd4j, commons-math, la4j, ujmp, jblas, colt
Java で以下のような行列の演算を複数のライブラリで試しました。
- (a) 和
- (b) 積
- (c) 転置
とりあえず今回は、更新日が比較的新しめのライブラリを試してみました。
- ND4J 0.6.0
- Commons Math 3.6.1
- la4j 0.6.0
- UJMP 0.3.0
- jblas 1.2.4
- Colt Blazegraph 版 2.1.4
また、あくまでも個人的な印象ですが、手軽に使いたいなら la4j か Commons Math、性能重視なら ND4J か jblas、可視化や DB との連携を考慮するなら UJMP を使えば良さそうです。
ソースは http://github.com/fits/try_samples/tree/master/blog/20161031/
ND4J
Deeplearning4J で使用しているライブラリ。
build.gradle
apply plugin: 'application' mainClassName = 'SampleApp' repositories { jcenter() } dependencies { compile 'org.nd4j:nd4j-native-platform:0.6.0' runtime 'org.slf4j:slf4j-nop:1.7.21' }
src/main/java/SampleApp.java
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; public class SampleApp { public static void main(String... args) { INDArray x = Nd4j.create(new double[][] { {1, 2}, {3, 4} }); INDArray y = Nd4j.create(new double[][] { {5, 6}, {7, 8} }); // (a) System.out.println( x.add(y) ); System.out.println("-----"); // (b) System.out.println( x.mmul(y) ); System.out.println("-----"); // (c) System.out.println( x.transpose() ); } }
実行結果
> gradle -q run [[ 6.00, 8.00], [10.00, 12.00]] ----- [[19.00, 22.00], [43.00, 50.00]] ----- [[1.00, 3.00], [2.00, 4.00]]
Commons Math
Apache Commons のライブラリ。
build.gradle
・・・
dependencies {
compile 'org.apache.commons:commons-math3:3.6.1'
}
src/main/java/SampleApp.java
import org.apache.commons.math3.linear.MatrixUtils; import org.apache.commons.math3.linear.RealMatrix; public class SampleApp { public static void main(String... args) { RealMatrix x = MatrixUtils.createRealMatrix(new double[][] { {1, 2}, {3, 4} }); RealMatrix y = MatrixUtils.createRealMatrix(new double[][] { {5, 6}, {7, 8} }); // (a) System.out.println( x.add(y) ); System.out.println("-----"); // (b) System.out.println( x.multiply(y) ); System.out.println("-----"); // (c) System.out.println( x.transpose() ); } }
実行結果
> gradle -q run Array2DRowRealMatrix{{6.0,8.0},{10.0,12.0}} ----- Array2DRowRealMatrix{{19.0,22.0},{43.0,50.0}} ----- Array2DRowRealMatrix{{1.0,3.0},{2.0,4.0}}
la4j
Java のみで実装された軽量なライブラリ。
build.gradle
・・・
dependencies {
compile 'org.la4j:la4j:0.6.0'
}
src/main/java/SampleApp.java
import org.la4j.Matrix; public class SampleApp { public static void main(String... args) { Matrix x = Matrix.from2DArray(new double[][] { {1, 2}, {3, 4} }); Matrix y = Matrix.from2DArray(new double[][] { {5, 6}, {7, 8} }); // (a) System.out.println( x.add(y) ); System.out.println("-----"); // (b) System.out.println( x.multiply(y) ); System.out.println("-----"); // (c) System.out.println( x.transpose() ); } }
実行結果
> gradle -q run 6.000 8.000 10.000 12.000 ----- 19.000 22.000 43.000 50.000 ----- 1.000 3.000 2.000 4.000
UJMP
データ可視化や JDBC との連携など、機能豊富そうなライブラリ。 Colt や jblas 等ともプラグインモジュールで連携できる模様。
build.gradle
・・・
dependencies {
compile 'org.ujmp:ujmp-core:0.3.0'
}
src/main/java/SampleApp.java
import org.ujmp.core.DenseMatrix; import org.ujmp.core.Matrix; public class SampleApp { public static void main(String... args) { Matrix x = DenseMatrix.Factory.linkToArray( new double[] {1, 2}, new double[] {3, 4} ); Matrix y = DenseMatrix.Factory.linkToArray( new double[] {5, 6}, new double[] {7, 8} ); // (a) System.out.println( x.plus(y) ); System.out.println("-----"); // (b) System.out.println( x.mtimes(y) ); System.out.println("-----"); // (c) System.out.println( x.transpose() ); } }
実行結果
> gradle -q run 6.0000 8.0000 10.0000 12.0000 ----- 19.0000 22.0000 43.0000 50.0000 ----- 1.0000 3.0000 2.0000 4.0000
jblas
BLAS/LAPACK をベースとしたライブラリ。 ネイティブライブラリを使用する。
build.gradle
・・・
dependencies {
compile 'org.jblas:jblas:1.2.4'
}
src/main/java/SampleApp.java
import org.jblas.DoubleMatrix; public class SampleApp { public static void main(String... args) { DoubleMatrix x = new DoubleMatrix(new double[][] { {1, 2}, {3, 4} }); DoubleMatrix y = new DoubleMatrix(new double[][] { {5, 6}, {7, 8} }); // (a) System.out.println( x.add(y) ); System.out.println("-----"); // (b) System.out.println( x.mmul(y) ); System.out.println("-----"); // (c) System.out.println( x.transpose() ); } }
実行結果
> gradle -q run [6.000000, 8.000000; 10.000000, 12.000000] ----- [19.000000, 22.000000; 43.000000, 50.000000] ----- [1.000000, 3.000000; 2.000000, 4.000000] -- org.jblas INFO Starting temp DLL cleanup task. -- org.jblas INFO Deleted 4 unused temp DLL libraries from ・・・
Colt Blazegraph 版
Colt は長らく更新されていないようなので、今回は Blazegraph による fork 版? を使いました。
build.gradle
・・・
dependencies {
compile 'com.blazegraph:colt:2.1.4'
}
src/main/java/SampleApp.java
import cern.colt.matrix.DoubleFactory2D; import cern.colt.matrix.DoubleMatrix2D; import cern.colt.matrix.linalg.Algebra; import cern.jet.math.Functions; public class SampleApp { public static void main(String... args) { DoubleMatrix2D x = DoubleFactory2D.dense.make(new double[][] { {1, 2}, {3, 4} }); DoubleMatrix2D y = DoubleFactory2D.dense.make(new double[][] { {5, 6}, {7, 8} }); // (a) System.out.println( x.copy().assign(y, Functions.plus) ); System.out.println("-----"); Algebra algebra = new Algebra(); // (b) System.out.println( algebra.mult(x, y) ); System.out.println("-----"); // (c) System.out.println( x.viewDice() ); } }
assign
を使うと自身の値を更新するため copy
を使っています。
実行結果
> gradle -q run 2 x 2 matrix 6 8 10 12 ----- 2 x 2 matrix 19 22 43 50 ----- 2 x 2 matrix 1 3 2 4