Update parameterized python tests, move python tests to proper package

master
agibsonccc 2021-03-18 12:18:39 +09:00
parent 13cae7fb60
commit 224f18a586
6 changed files with 157 additions and 53 deletions

View File

@ -1,31 +1,31 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
* * ******************************************************************************
* * *
* * *
* * * This program and the accompanying materials are made available under the
* * * terms of the Apache License, Version 2.0 which is available at
* * * https://www.apache.org/licenses/LICENSE-2.0.
* * *
* * * See the NOTICE file distributed with this work for additional
* * * information regarding copyright ownership.
* * * Unless required by applicable law or agreed to in writing, software
* * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * * License for the specific language governing permissions and limitations
* * * under the License.
* * *
* * * SPDX-License-Identifier: Apache-2.0
* * *****************************************************************************
*
*
*/
package org.nd4j.python4j;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.python4j.*;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -35,7 +35,6 @@ import javax.annotation.concurrent.NotThreadSafe;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.stream.Stream;
@ -43,6 +42,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
@NotThreadSafe
public class PythonNumpyBasicTest {
public static Stream<Arguments> params() {
DataType[] types = new DataType[] {
DataType.BOOL,
@ -61,9 +61,9 @@ public class PythonNumpyBasicTest {
};
long[][] shapes = new long[][]{
new long[]{2, 3},
new long[]{3},
new long[]{1},
new long[]{2, 3},
new long[]{3},
new long[]{1},
new long[]{} // scalar
};
@ -78,23 +78,23 @@ public class PythonNumpyBasicTest {
}
@ParameterizedTest
@MethodSource("#params")
public void testConversion(DataType dataType,long[] shape){
try(PythonGIL pythonGIL = PythonGIL.lock()) {
INDArray arr = Nd4j.zeros(dataType, shape);
PythonObject npArr = PythonTypes.convert(arr);
INDArray arr2 = PythonTypes.<INDArray>getPythonTypeForPythonObject(npArr).toJava(npArr);
if (dataType == DataType.BFLOAT16){
arr = arr.castTo(DataType.FLOAT);
}
assertEquals(arr,arr2);
}
@MethodSource("org.nd4j.python4j.PythonNumpyBasicTest#params")
public void testConversion(DataType dataType,long[] shape) {
try(PythonGIL pythonGIL = PythonGIL.lock()) {
INDArray arr = Nd4j.zeros(dataType, shape);
PythonObject npArr = PythonTypes.convert(arr);
INDArray arr2 = PythonTypes.<INDArray>getPythonTypeForPythonObject(npArr).toJava(npArr);
if (dataType == DataType.BFLOAT16){
arr = arr.castTo(DataType.FLOAT);
}
assertEquals(arr,arr2);
}
}
@ParameterizedTest
@MethodSource("#params")
@MethodSource("org.nd4j.python4j.PythonNumpyBasicTest#params")
public void testExecution(DataType dataType,long[] shape) {
try(PythonGIL pythonGIL = PythonGIL.lock()) {
List<PythonVariable> inputs = new ArrayList<>();
@ -124,7 +124,7 @@ public class PythonNumpyBasicTest {
@ParameterizedTest
@MethodSource("#params")
@MethodSource("org.nd4j.python4j.PythonNumpyBasicTest#params")
public void testInplaceExecution(DataType dataType,long[] shape) {
try(PythonGIL pythonGIL = PythonGIL.lock()) {
if (dataType == DataType.BOOL || dataType == DataType.BFLOAT16)return;

View File

@ -1,4 +1,27 @@
/*
*
* * ******************************************************************************
* * *
* * *
* * * This program and the accompanying materials are made available under the
* * * terms of the Apache License, Version 2.0 which is available at
* * * https://www.apache.org/licenses/LICENSE-2.0.
* * *
* * * See the NOTICE file distributed with this work for additional
* * * information regarding copyright ownership.
* * * Unless required by applicable law or agreed to in writing, software
* * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * * License for the specific language governing permissions and limitations
* * * under the License.
* * *
* * * SPDX-License-Identifier: Apache-2.0
* * *****************************************************************************
*
*
*/
package org.nd4j.python4j;/*
* ******************************************************************************
* *
* *
@ -61,8 +84,7 @@ public class PythonNumpyCollectionsTest {
}).stream().map(Arguments::of);
}
@Test
@MethodSource("#params")
@MethodSource("org.nd4j.python4j.PythonNumpyCollectionsTest#params")
@ParameterizedTest
public void testPythonDictFromMap(DataType dataType) throws PythonException {
try(PythonGIL pythonGIL = PythonGIL.lock()) {
@ -84,8 +106,7 @@ public class PythonNumpyCollectionsTest {
}
@Test
@MethodSource("#params")
@MethodSource("org.nd4j.python4j.PythonNumpyCollectionsTest#params")
@ParameterizedTest
public void testPythonListFromList(DataType dataType) throws PythonException {
try(PythonGIL pythonGIL = PythonGIL.lock()) {

View File

@ -1,4 +1,27 @@
/*
*
* * ******************************************************************************
* * *
* * *
* * * This program and the accompanying materials are made available under the
* * * terms of the Apache License, Version 2.0 which is available at
* * * https://www.apache.org/licenses/LICENSE-2.0.
* * *
* * * See the NOTICE file distributed with this work for additional
* * * information regarding copyright ownership.
* * * Unless required by applicable law or agreed to in writing, software
* * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * * License for the specific language governing permissions and limitations
* * * under the License.
* * *
* * * SPDX-License-Identifier: Apache-2.0
* * *****************************************************************************
*
*
*/
package org.nd4j.python4j;/*
* ******************************************************************************
* *
* *
@ -18,11 +41,6 @@
* *****************************************************************************
*/
import org.nd4j.python4j.Python;
import org.nd4j.python4j.PythonGC;
import org.nd4j.python4j.PythonGIL;
import org.nd4j.python4j.PythonObject;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.factory.Nd4j;

View File

@ -1,4 +1,27 @@
/*
*
* * ******************************************************************************
* * *
* * *
* * * This program and the accompanying materials are made available under the
* * * terms of the Apache License, Version 2.0 which is available at
* * * https://www.apache.org/licenses/LICENSE-2.0.
* * *
* * * See the NOTICE file distributed with this work for additional
* * * information regarding copyright ownership.
* * * Unless required by applicable law or agreed to in writing, software
* * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * * License for the specific language governing permissions and limitations
* * * under the License.
* * *
* * * SPDX-License-Identifier: Apache-2.0
* * *****************************************************************************
*
*
*/
package org.nd4j.python4j;/*
* ******************************************************************************
* *
* *
@ -18,8 +41,6 @@
* *****************************************************************************
*/
import org.nd4j.python4j.*;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;

View File

@ -1,4 +1,27 @@
/*
*
* * ******************************************************************************
* * *
* * *
* * * This program and the accompanying materials are made available under the
* * * terms of the Apache License, Version 2.0 which is available at
* * * https://www.apache.org/licenses/LICENSE-2.0.
* * *
* * * See the NOTICE file distributed with this work for additional
* * * information regarding copyright ownership.
* * * Unless required by applicable law or agreed to in writing, software
* * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * * License for the specific language governing permissions and limitations
* * * under the License.
* * *
* * * SPDX-License-Identifier: Apache-2.0
* * *****************************************************************************
*
*
*/
package org.nd4j.python4j;/*
* ******************************************************************************
* *
* *
@ -61,8 +84,7 @@ public class PythonNumpyMultiThreadTest {
}
@Test
@MethodSource("#params")
@MethodSource("org.nd4j.python4j.PythonNumpyMultiThreadTest#params")
@ParameterizedTest
public void testMultiThreading1(DataType dataType) throws Throwable {
final List<Throwable> exceptions = Collections.synchronizedList(new ArrayList<Throwable>());
@ -100,8 +122,7 @@ public class PythonNumpyMultiThreadTest {
}
@Test
@MethodSource("#params")
@MethodSource("org.nd4j.python4j.PythonNumpyMultiThreadTest#params")
@ParameterizedTest
public void testMultiThreading2(DataType dataType) throws Throwable {
final List<Throwable> exceptions = Collections.synchronizedList(new ArrayList<>());

View File

@ -1,4 +1,27 @@
/*
*
* * ******************************************************************************
* * *
* * *
* * * This program and the accompanying materials are made available under the
* * * terms of the Apache License, Version 2.0 which is available at
* * * https://www.apache.org/licenses/LICENSE-2.0.
* * *
* * * See the NOTICE file distributed with this work for additional
* * * information regarding copyright ownership.
* * * Unless required by applicable law or agreed to in writing, software
* * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * * License for the specific language governing permissions and limitations
* * * under the License.
* * *
* * * SPDX-License-Identifier: Apache-2.0
* * *****************************************************************************
*
*
*/
package org.nd4j.python4j;/*
* ******************************************************************************
* *
* *