From 458d141d8ea19ed76aa67933020e5441d144cb96 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Sat, 25 Jan 2020 20:13:23 +1100 Subject: [PATCH] Fix SDLoss null weights array issue (#185) Signed-off-by: AlexDBlack --- .../src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java | 2 +- .../src/test/java/org/nd4j/systeminfo/TestSystemInfo.java | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java index 96d76c52d..f0e94a4e5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java @@ -45,7 +45,7 @@ public class SDLoss extends SDOps { */ private SDVariable getWeights(SDVariable weights, String name, SDVariable predictions){ String weightName = (name == null) ? null : name + "/weight"; - return (weights == null) ? null : sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)); + return (weights == null) ? sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)) : weights; } /** diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/systeminfo/TestSystemInfo.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/systeminfo/TestSystemInfo.java index 764fcbd23..ddfffef17 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/systeminfo/TestSystemInfo.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/systeminfo/TestSystemInfo.java @@ -17,9 +17,10 @@ package org.nd4j.systeminfo; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.systeminfo.SystemInfo; -public class TestSystemInfo { +public class TestSystemInfo extends BaseND4JTest { @Test public void testSystemInfo(){ SystemInfo.printSystemInfo();