diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java index 96d76c52d..f0e94a4e5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java @@ -45,7 +45,7 @@ public class SDLoss extends SDOps { */ private SDVariable getWeights(SDVariable weights, String name, SDVariable predictions){ String weightName = (name == null) ? null : name + "/weight"; - return (weights == null) ? null : sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)); + return (weights == null) ? sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)) : weights; } /** diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/systeminfo/TestSystemInfo.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/systeminfo/TestSystemInfo.java index 764fcbd23..ddfffef17 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/systeminfo/TestSystemInfo.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/systeminfo/TestSystemInfo.java @@ -17,9 +17,10 @@ package org.nd4j.systeminfo; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.systeminfo.SystemInfo; -public class TestSystemInfo { +public class TestSystemInfo extends BaseND4JTest { @Test public void testSystemInfo(){ SystemInfo.printSystemInfo();