2022-09-20 15:40:53 +02:00

308 lines
9.6 KiB
JavaScript

/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
function extractHeaders(/*Uint8Array*/ bytes, offset){
var header1a = bytes.slice(offset+0,offset+4);
var header1b = bytes.slice(offset+4,offset+8);
var headerLength = byteArrayToInt(header1a);
var contentLength = byteArrayToInt(header1b);
return [headerLength, contentLength];
}
function byteArrayToInt(byteArray) {
var value = 0;
for ( var i = 0; i < byteArray.length; i++) {
value = (value * 256) + byteArray[i];
}
return value;
}
function decodeStaticInfo(headerContentBytes, bufferContentBytes){
var headerBuffer = new flatbuffers.ByteBuffer(headerContentBytes);
var contentBuffer = new flatbuffers.ByteBuffer(bufferContentBytes);
var header = nd4j.graph.UIStaticInfoRecord.getRootAsUIStaticInfoRecord(headerBuffer);
var infoType = header.infoType();
switch(infoType){
case nd4j.graph.UIInfoType.GRAPH_STRUCTURE:
var graphStructure = nd4j.graph.UIGraphStructure.getRootAsUIGraphStructure(contentBuffer);
return ["graph", graphStructure];
case nd4j.graph.UIInfoType.SYTEM_INFO:
var info = nd4j.graph.UISystemInfo.getRootAsUISystemInfo(contentBuffer);
return ["systeminfo", info];
case nd4j.graph.UIInfoType.START_EVENTS:
return ["startevents", null];
default:
console.log("Unknown static information type: " + infoType);
return null;
}
}
//Return graph inputs as a String[]
function uiGraphGetInputs(/*UIGraphStructure*/ graph){
var inLength = graph.inputsLength();
var inputs = [];
for( var i=0; i<inLength; i++ ){
inputs.push(graph.inputs(i));
}
return inputs;
}
//Return graph outputs as a String[]
function uiGraphGetOutputs(/*UIGraphStructure*/ graph){
var inLength = graph.outputsLength();
var outputs = [];
for( var i=0; i<inLength; i++ ){
outputs.push(graph.outputs(i));
}
return outputs;
}
//Return graph variables as nd4j.graph.UIVariable[]
function uiGraphGetVariables(/*UIGraphStructure*/ graph){
var varsLength = graph.variablesLength();
var vars = [];
for( var i=0; i<varsLength; i++ ){
vars.push(graph.variables(i));
}
return vars;
}
//Return graph variables as String[]
function uiGraphGetVariableNames(/*UIGraphStructure*/ graph){
var varsLength = graph.variablesLength();
var vars = [];
for( var i=0; i<varsLength; i++ ){
vars.push(graph.variables(i).name());
}
return vars;
}
//Returns nd4j.graph.UIOp[]
function uiGraphGetOps(/*UIGraphStructure*/ graph){
var opsLength = graph.opsLength();
var ops = [];
for( var i=0; i<opsLength; i++ ){
ops.push(graph.ops(i));
}
return ops;
}
function varTypeToString(varType){
switch (varType){
case nd4j.graph.VarType.CONSTANT:
return "constant";
case nd4j.graph.VarType.PLACEHOLDER:
return "placeholder";
case nd4j.graph.VarType.VARIABLE:
return "variable";
case nd4j.graph.VarType.ARRAY:
return "array";
default:
return "" + varType;
}
}
function varShapeToString(/*UIVariable*/ uivar){
var n = uivar.shapeLength();
if(n === 0)
return "";
var shape = [];
for( var i=0; i<n; i++ ){
var l = uivar.shape(i);
var s = l.toString();
var s2 = l.toFloat64().toString();
shape.push(s2);
}
return "[" + shape.toString() + "]";
}
function dataTypeToString(dataTypeByte){
switch (dataTypeByte){
case nd4j.graph.DataType.INHERIT:
return "INHERIT";
case nd4j.graph.DataType.BOOL:
return "BOOL";
case nd4j.graph.DataType.FLOAT8:
return "FLOAT8";
case nd4j.graph.DataType.HALF:
return "HALF";
case nd4j.graph.DataType.HALF2:
return "HALF2";
case nd4j.graph.DataType.FLOAT:
return "FLOAT";
case nd4j.graph.DataType.DOUBLE:
return "DOUBLE";
case nd4j.graph.DataType.INT8:
return "INT8";
case nd4j.graph.DataType.INT16:
return "INT16";
case nd4j.graph.DataType.INT32:
return "INT32";
case nd4j.graph.DataType.INT64:
return "INT64";
case nd4j.graph.DataType.UINT8:
return "UINT8";
case nd4j.graph.DataType.UINT16:
return "UINT16";
case nd4j.graph.DataType.UINT32:
return "UINT32";
case nd4j.graph.DataType.UINT64:
return "UINT64";
case nd4j.graph.DataType.QINT8:
return "QINT8";
case nd4j.graph.DataType.QINT16:
return "QINT16";
case nd4j.graph.DataType.BFLOAT16:
return "BFLOAT16";
case nd4j.graph.DataType.UTF8:
return "UTF8";
default:
return "" + dataTypeByte;
}
}
function dataTypeBytesPerElement(dataTypeByte){
switch (dataTypeByte){
case nd4j.graph.DataType.BOOL:
case nd4j.graph.DataType.FLOAT8:
case nd4j.graph.DataType.INT8:
case nd4j.graph.DataType.UINT8:
case nd4j.graph.DataType.QINT8:
return 1;
case nd4j.graph.DataType.HALF:
case nd4j.graph.DataType.HALF2:
case nd4j.graph.DataType.INT16:
case nd4j.graph.DataType.UINT16:
case nd4j.graph.DataType.QINT16:
case nd4j.graph.DataType.BFLOAT16:
return 2;
case nd4j.graph.DataType.FLOAT:
case nd4j.graph.DataType.INT32:
case nd4j.graph.DataType.UINT32:
return 4;
case nd4j.graph.DataType.DOUBLE:
case nd4j.graph.DataType.INT64:
case nd4j.graph.DataType.UINT64:
return 8;
case nd4j.graph.DataType.UTF8:
return 0; //TODO
default:
return "" + dataTypeByte;
}
}
function getScalar(/*FlatArray*/ flatArray, /*number[]*/ idxs){
//First: work out offset... assume C order here
var offset = 0;
// var prod = 1;
var rank = flatArray.shape(0).toFloat64(); //Note: shape is in nd4j format. So rank, shape. We're assuming C order here. Note also shape(i) returns flatbuffers long object
for( var i=0; i<rank; i++ ){
var size = flatArray.shape(i+1).toFloat64();
var stride = flatArray.shape(rank+i+1).toFloat64();
offset += stride * idxs[i];
// prod *= size.toFloat64();
}
return scalarFromFlatArrayIdx(flatArray, offset);
}
function scalarFromFlatArray(/*FlatArray*/ flatArray) {
return scalarFromFlatArrayIdx(flatArray, 0);
}
function scalarFromFlatArrayIdx(/*FlatArray*/ flatArray, idx){
//TODO OPTIMIZE THIS!
var dt = flatArray.dtype();
switch (dt){
//Skip hard to decode types for now
case nd4j.graph.DataType.FLOAT8:
case nd4j.graph.DataType.QINT8:
case nd4j.graph.DataType.HALF:
case nd4j.graph.DataType.HALF2:
case nd4j.graph.DataType.QINT16:
case nd4j.graph.DataType.BFLOAT16:
return null;
}
var bytesPerElem = dataTypeBytesPerElement(dt);
// var array = new Uint8Array(numBytes);
var dv = new DataView(new ArrayBuffer(bytesPerElem));
var byteOffset = idx * bytesPerElem;
var j=0;
for(var i=byteOffset; i<byteOffset + bytesPerElem; i++ ){
var signedByte = flatArray.buffer(i);
// array[i] = signedByte; //TODO do we need to convert here???
dv.setInt8(j++, signedByte);
}
//Note: "get(idx)" is byte offset
var out;
switch (dt){
case nd4j.graph.DataType.BOOL:
out = (dv.getUint8(0) === 0 ? "false" : "true");
break;
case nd4j.graph.DataType.INT8:
out = dv.getInt8(0);
break;
case nd4j.graph.DataType.UINT8:
out = dv.getUint8(0);
break;
case nd4j.graph.DataType.INT16:
out = dv.getInt16(0);
break;
case nd4j.graph.DataType.UINT16:
out = dv.getUint16(0);
break;
case nd4j.graph.DataType.INT32:
out = dv.getInt32(0);
break;
case nd4j.graph.DataType.UINT32:
out = dv.getUint32(0);
break;
case nd4j.graph.DataType.INT64:
//No getInt64 method... :/
//TODO Need a solution to this!
out = "<int64>";
break;
case nd4j.graph.DataType.UINT64:
out = "<uint64>";
break;
case nd4j.graph.DataType.FLOAT:
out = dv.getFloat32(0);
break;
case nd4j.graph.DataType.DOUBLE:
out = dv.getFloat64(0);
break;
case nd4j.graph.DataType.UTF8:
//TODO need to decode from bytes here
out = "<utf8>";
break;
default:
return "";
}
return out;
}