Update Nd4jTestsC.java

master
agibsonccc 2021-03-23 20:21:56 +09:00
parent 95f3067010
commit b1f8819bde
1 changed files with 1 additions and 1 deletions

View File

@ -8591,7 +8591,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
}
private static INDArray fwd(INDArray input, INDArray W, INDArray b){
INDArray ret = Nd4j.createUninitialized(input.size(0), W.size(1));
INDArray ret = Nd4j.createUninitialized(input.size(0), W.size(1)).castTo(DataType.DOUBLE);
input.mmuli(W, ret);
ret.addiRowVector(b);
return ret;