Fix SDLoss null weights array issue (#185)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
This commit is contained in:
parent
4db28a9300
commit
458d141d8e
@ -45,7 +45,7 @@ public class SDLoss extends SDOps {
|
|||||||
*/
|
*/
|
||||||
private SDVariable getWeights(SDVariable weights, String name, SDVariable predictions){
|
private SDVariable getWeights(SDVariable weights, String name, SDVariable predictions){
|
||||||
String weightName = (name == null) ? null : name + "/weight";
|
String weightName = (name == null) ? null : name + "/weight";
|
||||||
return (weights == null) ? null : sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
|
return (weights == null) ? sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)) : weights;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -17,9 +17,10 @@
|
|||||||
package org.nd4j.systeminfo;
|
package org.nd4j.systeminfo;
|
||||||
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.BaseND4JTest;
|
||||||
import org.nd4j.systeminfo.SystemInfo;
|
import org.nd4j.systeminfo.SystemInfo;
|
||||||
|
|
||||||
public class TestSystemInfo {
|
public class TestSystemInfo extends BaseND4JTest {
|
||||||
@Test
|
@Test
|
||||||
public void testSystemInfo(){
|
public void testSystemInfo(){
|
||||||
SystemInfo.printSystemInfo();
|
SystemInfo.printSystemInfo();
|
||||||
|
Loading…
x
Reference in New Issue
Block a user