Small number of fixes (#98)
* Remove no longer functioning LegacyPooling2D class Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8066 First steps for reflection scanning - INDArray constructors Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small build fix for ND4S Signed-off-by: AlexDBlack <blacka101@gmail.com> * More nd4s fixes Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
b597fb942b
commit
923ab15583
|
@ -26,16 +26,17 @@ import org.deeplearning4j.nn.conf.ConvolutionMode;
|
|||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.Convolution3D;
|
||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.Op;
|
||||
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.LegacyPooling2D;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
|
@ -624,7 +625,7 @@ public class ConvolutionUtils {
|
|||
INDArray reshaped4d = in.reshape(in.size(0), 1, in.size(1), 1);
|
||||
|
||||
int[] outSize;
|
||||
int[] pad;
|
||||
int[] pad = null;
|
||||
int[] k = new int[]{kernel,1};
|
||||
int[] s = new int[]{stride, 1};
|
||||
int[] d = new int[]{dilation, 1};
|
||||
|
@ -638,8 +639,15 @@ public class ConvolutionUtils {
|
|||
|
||||
INDArray output = Nd4j.createUninitialized(new int[]{(int)in.size(0), 1, outH, 1}, 'c');
|
||||
|
||||
Op op = new LegacyPooling2D(reshaped4d, kernel, 1, stride, 1, padding, 0, dilation, 1,
|
||||
cm == ConvolutionMode.Same, LegacyPooling2D.Pooling2DType.MAX, 0.0, output);
|
||||
DynamicCustomOp op = new MaxPooling2D(in, output, Pooling2DConfig.builder()
|
||||
.kH(k[0]).kW(k[1])
|
||||
.sH(s[0]).sW(s[1])
|
||||
.pH(pad == null ? 0 : pad[0]).pW(pad == null ? 0 : pad[1])
|
||||
.dH(d[0]).dW(d[1])
|
||||
.isSameMode(cm== ConvolutionMode.Same)
|
||||
.isNHWC(false)
|
||||
.build());
|
||||
|
||||
Nd4j.getExecutioner().exec(op);
|
||||
return output.reshape('c', in.size(0), outH);
|
||||
}
|
||||
|
@ -717,10 +725,18 @@ public class ConvolutionUtils {
|
|||
}
|
||||
|
||||
long[] outArraySize = new long[]{inMask.size(0), inMask.size(1), outSize[0], outSize[1]};
|
||||
INDArray outMask = Nd4j.createUninitialized(outArraySize);
|
||||
Op op = new LegacyPooling2D(inMask, kernel[0], kernel[1], stride[0], stride[1], padding[0], padding[1], dilation[0], dilation[1],
|
||||
convolutionMode == ConvolutionMode.Same, LegacyPooling2D.Pooling2DType.MAX, 0.0, outMask);
|
||||
Nd4j.getExecutioner().exec(op);
|
||||
INDArray outMask = Nd4j.createUninitialized(inMask.dataType(), outArraySize);
|
||||
|
||||
DynamicCustomOp op = new MaxPooling2D(inMask, outMask, Pooling2DConfig.builder()
|
||||
.kH(k[0]).kW(k[1])
|
||||
.sH(s[0]).sW(s[1])
|
||||
.pH(p[0]).pW(p[1])
|
||||
.dH(d[0]).dW(d[1])
|
||||
.isSameMode(convolutionMode == ConvolutionMode.Same)
|
||||
.isNHWC(false)
|
||||
.build());
|
||||
|
||||
Nd4j.exec(op);
|
||||
return outMask;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -846,7 +846,6 @@ public class OpValidation {
|
|||
RestoreV2.class,
|
||||
SaveV2.class,
|
||||
ScalarSetValue.class, //Not used in SameDiff (it's a "set to X if less than X" type op, redundant given other ops)
|
||||
LegacyPooling2D.class, //Deprecated; not used in samediff
|
||||
BinomialDistributionEx.class, //Redundant?
|
||||
|
||||
//Exclude manual broadcast ops: SameDiff uses auto broadcasting
|
||||
|
|
|
@ -119,7 +119,6 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.convolution.Im2colBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.convolution.LegacyPooling2D.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalization.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalizationDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D.class,
|
||||
|
|
|
@ -1,116 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.BaseTransformFloatOp;
|
||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
||||
import org.nd4j.linalg.convolution.Convolution;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
/**
|
||||
* Legacy version of the Pooling2D operation
|
||||
* @deprecated Note: This operation will be removed in a future release
|
||||
*/
|
||||
@Deprecated
|
||||
@Slf4j
|
||||
public class LegacyPooling2D extends BaseTransformFloatOp {
|
||||
|
||||
public enum Pooling2DType {
|
||||
MAX, AVG, PNORM,
|
||||
}
|
||||
|
||||
private int kh, kw, sy, sx, ph, pw, dh, dw;
|
||||
private Pooling2DType type;
|
||||
boolean isSameMode;
|
||||
double extra;
|
||||
@Getter protected DataBuffer im2colShape;
|
||||
|
||||
public LegacyPooling2D() {}
|
||||
|
||||
public LegacyPooling2D(INDArray x, int kh, int kw, int sy, int sx, int ph, int pw, int dh, int dw, boolean isSameMode,
|
||||
Pooling2DType type, double extra, INDArray z) {
|
||||
super(x);
|
||||
|
||||
// FIXME: int csast
|
||||
int outHeight = Convolution.outputSize((int) x.size(2), kh, sy, ph, dh, isSameMode);
|
||||
int outWidth = Convolution.outputSize((int) x.size(3), kw, sx, pw, dw, isSameMode);
|
||||
|
||||
this.kh = kh;
|
||||
this.kw = kw;
|
||||
this.sy = sy;
|
||||
this.sx = sx;
|
||||
this.ph = ph;
|
||||
this.pw = pw;
|
||||
this.dh = dh;
|
||||
this.dw = dw;
|
||||
this.isSameMode = isSameMode;
|
||||
this.type = type;
|
||||
this.z = z;
|
||||
this.extra = extra;
|
||||
this.im2colShape = getIm2ColShape(x, kh, kw, outHeight, outWidth);
|
||||
extraArgs = this.extraArgs();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int opNum() {
|
||||
return 2;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName(){
|
||||
return "legacypooling2d";
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object[] extraArgs() {
|
||||
return new Object[] {kh, kw, sy, sx, ph, pw, dh, dw, isSameMode ? 1.0 : 0.0, type.ordinal(), extra};
|
||||
}
|
||||
|
||||
private static DataBuffer getIm2ColShape(INDArray img, int kernelHeight, int kernelWidth, int outHeight, int outWidth) {
|
||||
//number of images
|
||||
long n = img.size(0);
|
||||
//number of channels (depth)
|
||||
long c = img.size(1);
|
||||
|
||||
return Nd4j.getShapeInfoProvider().createShapeInformation(new long[] {n, c, kernelHeight, kernelWidth, outHeight, outWidth}, 'c', img.dataType()).getFirst();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("Not supported");
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
throw new NoOpNameFoundException("Not supported");
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
throw new UnsupportedOperationException("Not supported");
|
||||
}
|
||||
|
||||
}
|
|
@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
|||
|
||||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
|
@ -69,6 +70,14 @@ public class MaxPooling2D extends DynamicCustomOp {
|
|||
addArgs();
|
||||
}
|
||||
|
||||
public MaxPooling2D(INDArray input, INDArray output, @NonNull Pooling2DConfig config){
|
||||
super(null, new INDArray[]{input}, output == null ? null : new INDArray[]{output});
|
||||
config.setType(Pooling2D.Pooling2DType.MAX);
|
||||
|
||||
this.config = config;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isConfigProperties() {
|
||||
return true;
|
||||
|
|
|
@ -32,7 +32,6 @@ import org.nd4j.linalg.api.memory.enums.ResetPolicy;
|
|||
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.LegacyPooling2D;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
|
@ -488,52 +487,6 @@ public class SpecialTests extends BaseNd4jTest {
|
|||
System.out.println(out);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void legacyPooling2dTest_double(){
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{1,1,3,3});
|
||||
INDArray out = Nd4j.create(DataType.DOUBLE, 1,1,2,2).assign(-119);
|
||||
|
||||
Nd4j.getExecutioner().commit();
|
||||
|
||||
val op = new LegacyPooling2D(in, 2, 2, 1, 1, 0, 0, 1, 1, true, LegacyPooling2D.Pooling2DType.MAX, 0.0, out);
|
||||
Nd4j.getExecutioner().exec(op);
|
||||
Nd4j.getExecutioner().commit();
|
||||
System.out.println(in);
|
||||
System.out.println(out);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void legacyPooling2dTes_float(){
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray in = Nd4j.rand(DataType.FLOAT, new int[]{1,1,3,3});
|
||||
INDArray out = Nd4j.create(DataType.FLOAT, 1,1,2,2);
|
||||
|
||||
val op = new LegacyPooling2D(in, 2, 2, 1, 1, 0, 0, 1, 1, true, LegacyPooling2D.Pooling2DType.MAX, 0.0, out);
|
||||
Nd4j.getExecutioner().exec(op);
|
||||
Nd4j.getExecutioner().commit();
|
||||
System.out.println(in);
|
||||
System.out.println(out);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void legacyPooling2dTes_half(){
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray in = Nd4j.rand(DataType.HALF, new int[]{1,1,3,3});
|
||||
INDArray out = Nd4j.create(DataType.HALF, 1,1,2,2);
|
||||
|
||||
val op = new LegacyPooling2D(in, 2, 2, 1, 1, 0, 0, 1, 1, true, LegacyPooling2D.Pooling2DType.MAX, 0.0, out);
|
||||
Nd4j.getExecutioner().exec(op);
|
||||
Nd4j.getExecutioner().commit();
|
||||
System.out.println(in);
|
||||
System.out.println(out);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testYoloStyle(){
|
||||
WorkspaceConfiguration WS_ALL_LAYERS_ACT_CONFIG = WorkspaceConfiguration.builder()
|
||||
|
|
|
@ -1,87 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.crash;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.LegacyPooling2D;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
@Slf4j
|
||||
@RunWith(Parameterized.class)
|
||||
public class YuriiTests extends BaseNd4jTest {
|
||||
public YuriiTests(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void legacyPooling2dTest_double(){
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{1,1,3,3});
|
||||
INDArray out = Nd4j.create(DataType.DOUBLE, 1,1,2,2).assign(-119);
|
||||
|
||||
Nd4j.getExecutioner().commit();
|
||||
|
||||
val op = new LegacyPooling2D(in, 2, 2, 1, 1, 0, 0, 1, 1, true, LegacyPooling2D.Pooling2DType.MAX, 0.0, out);
|
||||
Nd4j.getExecutioner().exec(op);
|
||||
Nd4j.getExecutioner().commit();
|
||||
System.out.println(in);
|
||||
System.out.println(out);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void legacyPooling2dTes_float(){
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray in = Nd4j.rand(DataType.FLOAT, new int[]{1,1,3,3});
|
||||
INDArray out = Nd4j.create(DataType.FLOAT, 1,1,2,2);
|
||||
|
||||
val op = new LegacyPooling2D(in, 2, 2, 1, 1, 0, 0, 1, 1, true, LegacyPooling2D.Pooling2DType.MAX, 0.0, out);
|
||||
Nd4j.getExecutioner().exec(op);
|
||||
Nd4j.getExecutioner().commit();
|
||||
System.out.println(in);
|
||||
System.out.println(out);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void legacyPooling2dTes_half(){
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray in = Nd4j.rand(DataType.HALF, new int[]{1,1,3,3});
|
||||
INDArray out = Nd4j.create(DataType.HALF, 1,1,2,2);
|
||||
|
||||
val op = new LegacyPooling2D(in, 2, 2, 1, 1, 0, 0, 1, 1, true, LegacyPooling2D.Pooling2DType.MAX, 0.0, out);
|
||||
Nd4j.getExecutioner().exec(op);
|
||||
Nd4j.getExecutioner().commit();
|
||||
System.out.println(in);
|
||||
System.out.println(out);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
}
|
||||
}
|
|
@ -0,0 +1,73 @@
|
|||
package org.nd4j.linalg.ops;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
||||
import org.reflections.Reflections;
|
||||
import org.reflections.scanners.SubTypesScanner;
|
||||
import org.reflections.util.ClasspathHelper;
|
||||
import org.reflections.util.ConfigurationBuilder;
|
||||
import org.reflections.util.FilterBuilder;
|
||||
|
||||
import java.lang.reflect.Constructor;
|
||||
import java.lang.reflect.Modifier;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class OpConstructorTests extends BaseNd4jTest {
|
||||
|
||||
public OpConstructorTests(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void checkForINDArrayConstructors() throws Exception {
|
||||
/*
|
||||
Check that all op classes have at least one INDArray or INDArray[] constructor, so they can actually
|
||||
be used outside of SameDiff
|
||||
*/
|
||||
|
||||
Reflections f = new Reflections(new ConfigurationBuilder()
|
||||
.filterInputsBy(new FilterBuilder().include(FilterBuilder.prefix("org.nd4j.*")).exclude("^(?!.*\\.class$).*$"))
|
||||
.setUrls(ClasspathHelper.forPackage("org.nd4j")).setScanners(new SubTypesScanner()));
|
||||
|
||||
Set<Class<? extends DifferentialFunction>> classSet = f.getSubTypesOf(DifferentialFunction.class);
|
||||
|
||||
int count = 0;
|
||||
for(Class<?> c : classSet){
|
||||
if(Modifier.isAbstract(c.getModifiers()) || Modifier.isInterface(c.getModifiers()) || c == SDVariable.class || ILossFunction.class.isAssignableFrom(c))
|
||||
continue;
|
||||
|
||||
// System.out.println(c.getName());
|
||||
|
||||
Constructor<?>[] constructors = c.getConstructors();
|
||||
boolean foundINDArray = false;
|
||||
for( int i=0; i<constructors.length; i++ ){
|
||||
Constructor<?> co = constructors[i];
|
||||
String str = co.toGenericString(); //This is a convenience hack for checking - returns strings like "public org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2(org.nd4j.linalg.api.ndarray.INDArray,int...)"
|
||||
if(str.contains("INDArray") && !str.contains("SameDiff")){
|
||||
foundINDArray = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if(!foundINDArray){
|
||||
System.out.println("No INDArray constructor: " + c.getName());
|
||||
count++;
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(0, count);
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering(){
|
||||
return 'c';
|
||||
}
|
||||
|
||||
}
|
|
@ -80,8 +80,9 @@
|
|||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native-platform</artifactId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ch.qos.logback</groupId>
|
||||
|
|
|
@ -46,8 +46,8 @@ object Implicits {
|
|||
implicit def jfloatColl2INDArray(s: Seq[java.lang.Float]): FloatArray2INDArray =
|
||||
new FloatArray2INDArray(s.map(x => x: Float)(breakOut))
|
||||
class FloatArray2INDArray(val underlying: Array[Float]) extends AnyVal {
|
||||
def mkNDArray(shape: Array[Int], ord: NDOrdering = NDOrdering(Nd4j.order()), offset: Int = 0): INDArray =
|
||||
Nd4j.create(underlying, shape, ord.value, offset)
|
||||
def mkNDArray(shape: Array[Int], ord: NDOrdering = NDOrdering(Nd4j.order())): INDArray =
|
||||
Nd4j.create(underlying, shape, ord.value)
|
||||
|
||||
def asNDArray(shape: Int*): INDArray =
|
||||
Nd4j.create(underlying.toArray, shape.toArray: _*)
|
||||
|
|
|
@ -327,7 +327,7 @@ case object FloatNDArrayEvidence extends RealNDArrayEvidence[Float] {
|
|||
arr.asNDArray(shape: _*)
|
||||
|
||||
override def create(arr: Array[Float], shape: Array[Int], ordering: NDOrdering, offset: Int): INDArray =
|
||||
arr.mkNDArray(shape, ordering, offset)
|
||||
arr.mkNDArray(shape, ordering)
|
||||
|
||||
override def update(underlying: INDArray, ir: Array[IndexRange], num: Float): INDArray = {
|
||||
if (ir.length == 1 && !ir.head.hasNegative && ir.head
|
||||
|
|
Loading…
Reference in New Issue