70 lines
2.8 KiB
Python
70 lines
2.8 KiB
Python
################################################################################
|
|
# Copyright (c) 2015-2018 Skymind, Inc.
|
|
#
|
|
# This program and the accompanying materials are made available under the
|
|
# terms of the Apache License, Version 2.0 which is available at
|
|
# https://www.apache.org/licenses/LICENSE-2.0.
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
# License for the specific language governing permissions and limitations
|
|
# under the License.
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
################################################################################
|
|
|
|
from .java_classes import TFGraphMapper, Nd4j, NDArrayIndex
|
|
from .ndarray import array
|
|
|
|
|
|
class TFModel(object):
|
|
def __init__(self, filepath, inputs=None, outputs=None):
|
|
self.sd = TFGraphMapper.getInstance().importGraph(filepath)
|
|
self.inputs = inputs
|
|
self.outputs = outputs
|
|
if inputs is None:
|
|
input_vars = [self.sd.variables().get(0)]
|
|
elif type(inputs) in [list, tuple]:
|
|
input_vars = []
|
|
for x in inputs:
|
|
var = self.sd.getVariable(x)
|
|
if var is None:
|
|
raise ValueError('Variable not found in samediff graph: ' + x)
|
|
input_vars.append(var)
|
|
else:
|
|
input_vars = [self.sd.getVariable(inputs)]
|
|
if input_vars[0] is None:
|
|
raise ValueError('Variable not found in samediff graph: ' + inputs)
|
|
if outputs is None:
|
|
nvars = self.sd.variables().size()
|
|
output_vars = [self.sd.variables().get(nvars - 1)]
|
|
elif type(outputs) in [list, tuple]:
|
|
output_vars = []
|
|
for x in outputs:
|
|
var = self.sd.getVariable(x)
|
|
if var is None:
|
|
raise ValueError('Variable not found in samediff graph: ' + x)
|
|
output_vars.append(var)
|
|
else:
|
|
output_vars = [self.sd.getVariable(outputs)]
|
|
if output_vars[0] is None:
|
|
raise ValueError('Variable not found in samediff graph: ' + outputs)
|
|
self.input_vars = input_vars
|
|
self.output_vars = output_vars
|
|
|
|
def __call__(self, input):
|
|
if type(input) in (list, tuple):
|
|
input_arrays = [array(x).array for x in input]
|
|
else:
|
|
input_arrays = [array(input).array]
|
|
for arr, var in zip(input_arrays, self.input_vars):
|
|
self.sd.associateArrayWithVariable(arr, var)
|
|
output_arrays = []
|
|
getattr(self.sd, 'exec')()
|
|
for var in self.output_vars:
|
|
output_arrays.append(array(var.getArr()))
|
|
if len(output_arrays) == 1:
|
|
return output_arrays[0]
|
|
return output_arrays
|