/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.cp;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Queue;
import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.DMLScriptException;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObjectFactory;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.lineage.Lineage;
import org.apache.sysds.runtime.lineage.LineageCache;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;
import org.apache.sysds.runtime.lineage.LineageCacheStatistics;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.utils.Statistics;

public class FunctionCallCPInstruction
extends CPInstruction {
    private static final Log LOG = LogFactory.getLog((String)FunctionCallCPInstruction.class.getName());
    private final String _functionName;
    private final String _namespace;
    private final boolean _opt;
    private final CPOperand[] _boundInputs;
    private final LineageItem[] _lineageInputs;
    private final List<String> _boundInputNames;
    private final List<String> _funArgNames;
    private final List<String> _boundOutputNames;

    public FunctionCallCPInstruction(String namespace, String functName, boolean opt, CPOperand[] boundInputs, LineageItem[] lineageInputs, List<String> funArgNames, List<String> boundOutputNames, String istr) {
        super(CPInstruction.CPType.FCall, null, functName, istr);
        this._functionName = functName;
        this._namespace = namespace;
        this._opt = opt;
        this._boundInputs = boundInputs;
        this._lineageInputs = lineageInputs;
        this._boundInputNames = Arrays.stream(boundInputs).map(i -> i.getName()).collect(Collectors.toCollection(ArrayList::new));
        this._funArgNames = funArgNames;
        this._boundOutputNames = boundOutputNames;
    }

    public FunctionCallCPInstruction(String namespace, String functName, boolean opt, CPOperand[] boundInputs, List<String> funArgNames, List<String> boundOutputNames, String istr) {
        this(namespace, functName, opt, boundInputs, null, funArgNames, boundOutputNames, istr);
    }

    public String getFunctionName() {
        return this._functionName;
    }

    public String getNamespace() {
        return this._namespace;
    }

    public static FunctionCallCPInstruction parseInstruction(String str) {
        int i;
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String namespace = parts[1];
        String functionName = parts[2];
        boolean opt = Boolean.parseBoolean(parts[3]);
        int numInputs = Integer.valueOf(parts[4]);
        int numOutputs = Integer.valueOf(parts[5]);
        CPOperand[] boundInputs = new CPOperand[numInputs];
        ArrayList<String> funArgNames = new ArrayList<String>();
        ArrayList<String> boundOutputNames = new ArrayList<String>();
        for (i = 0; i < numInputs; ++i) {
            String[] nameValue = IOUtilFunctions.splitByFirst(parts[6 + i], "=");
            boundInputs[i] = new CPOperand(nameValue[1]);
            funArgNames.add(nameValue[0]);
        }
        for (i = 0; i < numOutputs; ++i) {
            boundOutputNames.add(parts[6 + numInputs + i]);
        }
        return new FunctionCallCPInstruction(namespace, functionName, opt, boundInputs, funArgNames, boundOutputNames, str);
    }

    @Override
    public Instruction preprocessInstruction(ExecutionContext ec) {
        return super.preprocessInstruction(ec);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        FunctionProgramBlock fpb;
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("Executing instruction : " + this.toString()));
        }
        if (this._boundInputs.length < (fpb = ec.getProgram().getFunctionProgramBlock(this._namespace, this._functionName, this._opt)).getInputParams().size()) {
            throw new DMLRuntimeException("fcall " + this._functionName + ": Number of bound input parameters does not match the function signature (" + this._boundInputs.length + ", but " + fpb.getInputParams().size() + " expected)");
        }
        LineageItem[] liInputs = this._lineageInputs;
        if (this._lineageInputs == null) {
            LineageItem[] lineageItemArray = liInputs = LineageCacheConfig.isMultiLevelReuse() || DMLScript.LINEAGE_ESTIMATE ? LineageItemUtils.getLineage(ec, this._boundInputs) : null;
        }
        if (!fpb.isNondeterministic() && this.reuseFunctionOutputs(liInputs, fpb, ec)) {
            return;
        }
        LocalVariableMap functionVariables = new LocalVariableMap();
        Lineage lineage = DMLScript.LINEAGE ? new Lineage() : null;
        for (int i = 0; i < this._boundInputs.length; ++i) {
            CPOperand input = this._boundInputs[i];
            if (!input.isLiteral() && !ec.containsVariable(input.getName())) {
                throw new DMLRuntimeException("Input variable '" + input.getName() + "' not existing on call of " + DMLProgram.constructFunctionKey(this._namespace, this._functionName) + " (line " + this.getLineNum() + ").");
            }
            String argName = this._funArgNames.get(i);
            DataIdentifier currFormalParam = fpb.getInputParam(argName);
            if (currFormalParam == null) {
                throw new DMLRuntimeException("fcall " + this._functionName + ": Non-existing named function argument: '" + argName + "' (line " + this.getLineNum() + ").");
            }
            Data value = ec.getVariable(input);
            if (value.getDataType() == Types.DataType.SCALAR && value.getValueType() != currFormalParam.getValueType()) {
                value = ScalarObjectFactory.createScalarObject(currFormalParam.getValueType(), (ScalarObject)value);
            }
            functionVariables.put(currFormalParam.getName(), value);
            if (lineage == null) continue;
            LineageItem inLitem = this._lineageInputs == null ? ec.getLineageItem(input) : this._lineageInputs[i];
            LineageItem lineageItem = inLitem = inLitem != null ? inLitem : ec.getLineage().getOrCreate(input);
            if (LineageItemUtils.isFunctionDebugging()) {
                String funcOp = this._functionName + "_INP_" + currFormalParam.getName();
                LineageItem funcItem = new LineageItem(funcOp, new LineageItem[]{inLitem});
                lineage.set(currFormalParam.getName(), funcItem);
                continue;
            }
            lineage.set(currFormalParam.getName(), inLitem);
        }
        Queue<Boolean> pinStatus = ec.pinVariables(this._boundInputNames);
        ExecutionContext fn_ec = ExecutionContextFactory.createContext(false, false, ec.getProgram());
        if (DMLScript.USE_ACCELERATOR) {
            fn_ec.setGPUContexts(ec.getGPUContexts());
            fn_ec.getGPUContext(0).initializeThread();
        }
        fn_ec.setVariables(functionVariables);
        fn_ec.setLineage(lineage);
        long t0 = !LineageCacheConfig.ReuseCacheType.isNone() || DMLScript.LINEAGE_ESTIMATE ? System.nanoTime() : 0L;
        try {
            fpb._functionName = this._functionName;
            fpb._namespace = this._namespace;
            fpb.execute(fn_ec);
        }
        catch (DMLScriptException e) {
            throw e;
        }
        catch (Exception e) {
            String fname = DMLProgram.constructFunctionKey(this._namespace, this._functionName);
            throw new DMLRuntimeException("error executing function " + fname, e);
        }
        long t1 = !LineageCacheConfig.ReuseCacheType.isNone() || DMLScript.LINEAGE_ESTIMATE ? System.nanoTime() : 0L;
        HashSet<String> expectRetVars = new HashSet<String>();
        for (DataIdentifier dataIdentifier : fpb.getOutputParams()) {
            expectRetVars.add(dataIdentifier.getName());
        }
        LocalVariableMap retVars = fn_ec.getVariables();
        for (String varName : new ArrayList<String>(retVars.keySet())) {
            if (expectRetVars.contains(varName)) continue;
            fn_ec.cleanupDataObject(fn_ec.removeVariable(varName));
        }
        ec.unpinVariables(this._boundInputNames, pinStatus);
        int n = Math.min(this._boundOutputNames.size(), fpb.getOutputParams().size());
        ArrayList<Data> toBeCleanedUp = new ArrayList<Data>();
        for (int i = 0; i < n; ++i) {
            String boundVarName = this._boundOutputNames.get(i);
            String retVarName = fpb.getOutputParams().get(i).getName();
            Data boundValue = retVars.get(retVarName);
            if (boundValue == null) {
                throw new DMLRuntimeException("fcall " + this._functionName + ": " + boundVarName + " was not assigned a return value");
            }
            Data exdata = ec.removeVariable(boundVarName);
            if (exdata != boundValue && !retVars.hasReferences(exdata)) {
                toBeCleanedUp.add(exdata);
            }
            ec.setVariable(boundVarName, boundValue);
            if (lineage == null) continue;
            LineageItem outLitem = lineage.get(retVarName);
            if (LineageItemUtils.isFunctionDebugging()) {
                String funcOp = this._functionName + "_RET_" + boundVarName;
                LineageItem funcItem = new LineageItem(funcOp, new LineageItem[]{outLitem});
                ec.getLineage().set(boundVarName, funcItem);
                continue;
            }
            ec.getLineage().set(boundVarName, outLitem);
        }
        for (Data dat : toBeCleanedUp) {
            ec.cleanupDataObject(dat);
        }
        if (DMLScript.LINEAGE && LineageCacheConfig.isMultiLevelReuse() && !fpb.isNondeterministic() || LineageCacheConfig.isEstimator() && !fpb.isNondeterministic()) {
            LineageCache.putValue(fpb.getOutputParams(), liInputs, this.getCacheFunctionName(this._functionName, fpb), fn_ec, t1 - t0);
        }
    }

    @Override
    public void postprocessInstruction(ExecutionContext ec) {
        super.postprocessInstruction(ec);
    }

    @Override
    public void printMe() {
        LOG.debug((Object)("ExternalBuiltInFunction: " + this.toString()));
    }

    public List<String> getBoundOutputParamNames() {
        return this._boundOutputNames;
    }

    public List<String> getFunArgNames() {
        return this._funArgNames;
    }

    public String updateInstStringFunctionName(String pattern, String replace) {
        String[] parts = this.instString.split("\u00b0");
        if (parts[3].equals(pattern)) {
            parts[3] = replace;
        }
        StringBuilder sb = new StringBuilder();
        for (String part : parts) {
            sb.append(part);
            sb.append("\u00b0");
        }
        return sb.substring(0, sb.length() - "\u00b0".length());
    }

    public CPOperand[] getInputs() {
        return this._boundInputs;
    }

    private boolean reuseFunctionOutputs(LineageItem[] liInputs, FunctionProgramBlock fpb, ExecutionContext ec) {
        String funcName = this.getCacheFunctionName(this._functionName, fpb);
        int numOutputs = Math.min(this._boundOutputNames.size(), fpb.getOutputParams().size());
        boolean reuse = LineageCache.reuse(this._boundOutputNames, fpb.getOutputParams(), numOutputs, liInputs, funcName, ec);
        if (reuse && DMLScript.STATISTICS) {
            Statistics.maintainCPFuncCallStats(this.getExtendedOpcode());
            LineageCacheStatistics.incrementFuncHits();
        }
        return reuse;
    }

    private String getCacheFunctionName(String fname, FunctionProgramBlock fpb) {
        String tmpFname = !fpb.hasThreadID() ? fname : fname.substring(0, fname.lastIndexOf("_t" + fpb.getThreadID()));
        return DMLProgram.constructFunctionKey(this._namespace, tmpFname);
    }
}

