parent
29104083cc
commit
53d3bd1269
|
@ -185,54 +185,6 @@ public abstract class SDBaseOps {
|
||||||
return argmin(null, in, keepDims, dimensions);
|
return argmin(null, in, keepDims, dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Assign/copy op: out = x.assign(y). Supports broadcasting
|
|
||||||
*
|
|
||||||
* @param x Input variable x
|
|
||||||
* @param y Input variable y
|
|
||||||
* @return Output variable
|
|
||||||
*/
|
|
||||||
public SDVariable assign(SDVariable x, SDVariable y) {
|
|
||||||
return assign(null, x, y);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Assign/copy op: out = x.assign(y). Supports broadcasting
|
|
||||||
*
|
|
||||||
* @param name Name of the output variable
|
|
||||||
* @param x Input variable x
|
|
||||||
* @param y Input variable y
|
|
||||||
* @return Output variable
|
|
||||||
*/
|
|
||||||
public SDVariable assign(String name, SDVariable x, SDVariable y) {
|
|
||||||
SDVariable ret = f().assign(x, y);
|
|
||||||
return updateVariableNameAndReference(ret, name);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Return an array with equal shape to the input, but all elements set to 'value'
|
|
||||||
*
|
|
||||||
* @param in Input variable
|
|
||||||
* @param value Value to set
|
|
||||||
* @return Output variable
|
|
||||||
*/
|
|
||||||
public SDVariable assign(SDVariable in, Number value) {
|
|
||||||
return assign(null, in, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Return an array with equal shape to the input, but all elements set to 'value'
|
|
||||||
*
|
|
||||||
* @param name Name of the output variable
|
|
||||||
* @param in Input variable
|
|
||||||
* @param value Value to set
|
|
||||||
* @return Output variable
|
|
||||||
*/
|
|
||||||
public SDVariable assign(String name, SDVariable in, Number value) {
|
|
||||||
SDVariable ret = f().assign(in, value);
|
|
||||||
return updateVariableNameAndReference(ret, name);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same
|
* Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same
|
||||||
* length and each pair taken from these sets has to have dimensions (M, N) and (N, K),
|
* length and each pair taken from these sets has to have dimensions (M, N) and (N, K),
|
||||||
|
|
|
@ -939,9 +939,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
tc.expectedOutput(t.name(), Transforms.min(ia, 0.5, true));
|
tc.expectedOutput(t.name(), Transforms.min(ia, 0.5, true));
|
||||||
break;
|
break;
|
||||||
case 65:
|
case 65:
|
||||||
t = sd.assign(in, 0.5);
|
continue; // assign op was removed.
|
||||||
tc.expectedOutput(t.name(), ia.dup().assign(0.5));
|
|
||||||
break;
|
|
||||||
case 66:
|
case 66:
|
||||||
t = sd.scalarFloorMod(in, 0.5);
|
t = sd.scalarFloorMod(in, 0.5);
|
||||||
tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new ScalarFMod(ia.dup(), 0.5)));
|
tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new ScalarFMod(ia.dup(), 0.5)));
|
||||||
|
@ -1181,9 +1179,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
tc.expectedOutput(t.name(), Transforms.xor(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false);
|
tc.expectedOutput(t.name(), Transforms.xor(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false);
|
||||||
break;
|
break;
|
||||||
case 18:
|
case 18:
|
||||||
t = sd.assign(in1, in2);
|
continue; //assign op was removed.
|
||||||
tc.expectedOutput(t.name(), ib);
|
|
||||||
break;
|
|
||||||
case 19:
|
case 19:
|
||||||
t = sd.math().atan2(in1, in2);
|
t = sd.math().atan2(in1, in2);
|
||||||
tc.expectedOutput(t.name(), Transforms.atan2(ib, ia)); //Note: y,x order for samediff; x,y order for transforms
|
tc.expectedOutput(t.name(), Transforms.atan2(ib, ia)); //Note: y,x order for samediff; x,y order for transforms
|
||||||
|
|
Loading…
Reference in New Issue