/*
* This file is part of OpenModelica.
*
* Copyright (c) 1998-2020, Open Source Modelica Consortium (OSMC),
* c/o Linköpings universitet, Department of Computer and Information Science,
* SE-58183 Linköping, Sweden.
*
* All rights reserved.
*
* THIS PROGRAM IS PROVIDED UNDER THE TERMS OF GPL VERSION 3 LICENSE OR
* THIS OSMC PUBLIC LICENSE (OSMC-PL) VERSION 1.2.
* ANY USE, REPRODUCTION OR DISTRIBUTION OF THIS PROGRAM CONSTITUTES
* RECIPIENT'S ACCEPTANCE OF THE OSMC PUBLIC LICENSE OR THE GPL VERSION 3,
* ACCORDING TO RECIPIENTS CHOICE.
*
* The OpenModelica software and the Open Source Modelica
* Consortium (OSMC) Public License (OSMC-PL) are obtained
* from OSMC, either from the above address,
* from the URLs: http://www.ida.liu.se/projects/OpenModelica or
* http://www.openmodelica.org, and in the OpenModelica distribution.
* GNU version 3 is obtained from: http://www.gnu.org/copyleft/gpl.html.
*
* This program is distributed WITHOUT ANY WARRANTY; without
* even the implied warranty of  MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE, EXCEPT AS EXPRESSLY SET FORTH
* IN THE BY RECIPIENT SELECTED SUBSIDIARY LICENSE CONDITIONS OF OSMC-PL.
*
* See the full OSMC Public License conditions for more details.
*
*/
encapsulated package NBDifferentiate
"file:        NBDifferentiate.mo
 package:     NBDifferentiate
 description: This file contains the functions to differentiate equations and
              expressions symbolically.
"
public
  // OF imports
  import Absyn.Path;
  import AbsynUtil;
  import DAE;

  // NF imports
  import Algorithm = NFAlgorithm;
  import Binding = NFBinding;
  import BuiltinFuncs = NFBuiltinFuncs;
  import Call = NFCall;
  import Class = NFClass;
  import NFClassTree.ClassTree;
  import Component = NFComponent;
  import ComponentRef = NFComponentRef;
  import Dimension = NFDimension;
  import Expression = NFExpression;
  import InstContext = NFInstContext;
  import NFInstNode.{InstNode, CachedData};
  import NFFunction.{Function, Slot};
  import FunctionDerivative = NFFunctionDerivative;
  import Operator = NFOperator;
  import Prefixes = NFPrefixes;
  import Sections = NFSections;
  import SimplifyExp = NFSimplifyExp;
  import Statement = NFStatement;
  import Subscript = NFSubscript;
  import Type = NFType;
  import NFPrefixes.Variability;
  import Variable = NFVariable;

  // Backend imports
  import NBEquation.{Equation, EquationAttributes, EquationPointer, EquationPointers, IfEquationBody, WhenEquationBody, WhenStatement};
  import NBVariable.{VariablePointer};
  import BVariable = NBVariable;
  import Replacements = NBReplacements;
  import StrongComponent = NBStrongComponent;
  import Tearing = NBTearing;

  // Util imports
  import Array;
  import BackendUtil = NBBackendUtil;
  import Error;
  import UnorderedMap;
  import Slice = NBSlice;

  // ================================
  //        TYPES AND UNIONTYPES
  // ================================
  type DifferentiationType = enumeration(TIME, SIMPLE, FUNCTION, JACOBIAN);

  uniontype DifferentiationArguments
    record DIFFERENTIATION_ARGUMENTS
      ComponentRef diffCref                                     "The input will be differentiated w.r.t. this cref (only SIMPLE).";
      list<Pointer<Variable>> new_vars                          "contains all new variables that need to be added to the system";
      Option<UnorderedMap<ComponentRef, ComponentRef>> diff_map "seed and temporary cref map x --> $SEED.MATRIX.x, y --> $pDer.MATRIX.y. Can be used for any differentiation rules";
      DifferentiationType diffType                              "Differentiation use case (time, simple, function, jacobian)";
      UnorderedMap<Path, Function> funcMap                      "Function tree containing all functions and their known derivatives";
      Boolean scalarized                                        "true if the variables are scalarized";
      Option<UnorderedMap<ComponentRef, list<Expression>>> adjoint_map  "map for accumulating adjoint gradients for component refs";
      Expression current_grad                                   "current gradient expression, used in reverse mode";
      Boolean collectAdjoints                                   "If false, skip writing into adjoint_map (used for LHS traversal in reverse/Jacobian).";
    end DIFFERENTIATION_ARGUMENTS;

    function default
      input DifferentiationType ty = DifferentiationType.TIME;
      input UnorderedMap<Path, Function> funcMap = UnorderedMap.new<Function>(AbsynUtil.pathHash, AbsynUtil.pathEqual);
      output DifferentiationArguments diffArgs = DIFFERENTIATION_ARGUMENTS(
        diffCref    = ComponentRef.EMPTY(),
        new_vars    = {},
        diff_map    = NONE(),
        diffType    = ty,
        funcMap     = funcMap,
        scalarized  = false,
        adjoint_map = NONE(),
        current_grad= Expression.EMPTY(Type.REAL()),
        collectAdjoints = false
      );
    end default;

    function simpleCref "Differentiate w.r.t. cref"
      input ComponentRef cref;
      input UnorderedMap<Path, Function> funcMap = UnorderedMap.new<Function>(AbsynUtil.pathHash, AbsynUtil.pathEqual);
      output DifferentiationArguments diffArgs = DIFFERENTIATION_ARGUMENTS(
        diffCref    = cref,
        new_vars    = {},
        diff_map    = NONE(),
        diffType    = DifferentiationType.SIMPLE,
        funcMap     = funcMap,
        scalarized  = false,
        adjoint_map = NONE(),
        current_grad = Expression.EMPTY(Type.REAL()),
        collectAdjoints = false
      );
    end simpleCref;

    function toString
      input DifferentiationArguments diffArgs;
      output String str = "[" + diffTypeStr(diffArgs.diffType) + "]";
    algorithm
      if diffArgs.diffType == DifferentiationType.SIMPLE then
        str := str + " " + ComponentRef.toString(diffArgs.diffCref);
      end if;
    end toString;

    function diffTypeStr
      input DifferentiationType diffType;
      output String str;
    algorithm
      str := match diffType
        case DifferentiationType.TIME       then "TIME";
        case DifferentiationType.SIMPLE     then "SIMPLE";
        case DifferentiationType.FUNCTION   then "FUNCTION";
        case DifferentiationType.JACOBIAN   then "JACOBIAN";
        else "FAIL";
      end match;
    end diffTypeStr;
  end DifferentiationArguments;

  // ================================
  //             FUNCTIONS
  // ================================

  function differentiateStrongComponentList
    "author: kabdelhak
    Differentiates a list of strong components."
    input output list<StrongComponent> comps;
    input output DifferentiationArguments diffArguments;
    input Pointer<Integer> idx;
    input String context;
    input String name;
  protected
    Pointer<DifferentiationArguments> diffArguments_ptr = Pointer.create(diffArguments);
  algorithm
    comps := List.map(comps, function differentiateStrongComponent(diffArguments_ptr = diffArguments_ptr, idx = idx, context = context, name = name));
    diffArguments := Pointer.access(diffArguments_ptr);
  end differentiateStrongComponentList;


  function differentiateStrongComponentListAdjoint
    "author: fbrandt
    Differentiates a list of strong components.
    Extended: Before differentiating each component, set current_grad to the
    mapped seed/pDER variable (diff_map value) for the component's LHS cref, if present.
    Many rules for reverse mode differentiation can be found e.g. here: https://fkoehler.site/autodiff-table/"
    input output list<StrongComponent> comps;
    input output DifferentiationArguments diffArguments;
    input Pointer<Integer> idx;
    input String context;
    input String name;
  protected
    Pointer<DifferentiationArguments> diffArguments_ptr = Pointer.create(diffArguments);
    list<StrongComponent> newComps = {};
    UnorderedMap<ComponentRef,ComponentRef> diff_map;
    ComponentRef lhsCref;
    ComponentRef gradCref;
    list<VariablePointer> compVars;
    DifferentiationArguments da;
  algorithm
    diff_map := Util.getOption(diffArguments.diff_map);
    for comp in comps loop
      // Determine LHS cref of this component
      dbg("Component: " + StrongComponent.toString(comp));
      // compVars := match comp
      //   case StrongComponent.ALGEBRAIC_LOOP() then StrongComponent.getLoopIterationVars(comp);
      //   else StrongComponent.getVariables(comp);
      // end match;
      compVars := StrongComponent.getVariables(comp);
      for var in compVars loop
        lhsCref := BVariable.getVarName(var);
        // Update current_grad if we have a mapping for lhsCref
        if (not ComponentRef.isEmpty(lhsCref)) then
          gradCref := UnorderedMap.getOrFail(lhsCref, diff_map);
          // this is currently not needed, but in case we have subscripts on LHS later, we need to copy them to the seed
          gradCref := match comp
            case StrongComponent.RESIZABLE_COMPONENT() then ComponentRef.copySubscripts(StrongComponent.getVarCref(comp), gradCref); // put subscript on the seed;
            case StrongComponent.SLICED_COMPONENT() then ComponentRef.copySubscripts(StrongComponent.getVarCref(comp), gradCref); // put subscript on the seed;
            else gradCref;
          end match;
          // and update in diffArguments
          da := Pointer.access(diffArguments_ptr);
          da.current_grad := Expression.fromCref(gradCref);
          Pointer.update(diffArguments_ptr, da);
        else
          dbg("  No seed mapping for: " + ComponentRef.toString(lhsCref));
        end if;
        // Differentiate this component
        dbg("  Differentiating component...");
        comp := differentiateStrongComponent(comp, diffArguments_ptr, idx, context, name);
        newComps := comp :: newComps;
        dbg("  Done differentiating component.");
      end for;
    end for;

    comps := listReverse(newComps);
    diffArguments := Pointer.access(diffArguments_ptr);
  end differentiateStrongComponentListAdjoint;

  function differentiateStrongComponent
    input output StrongComponent comp;
    input Pointer<DifferentiationArguments> diffArguments_ptr;
    input Pointer<Integer> idx;
    input String context;
    input String name;
  algorithm
    comp := match comp
      local
        Pointer<Variable> new_var;
        Pointer<Equation> new_eqn;
        list<Slice<VariablePointer>> new_var_slices;
        list<Pointer<Equation>> new_eqns;
        ComponentRef new_cref;
        Slice<VariablePointer> new_var_slice;
        Slice<EquationPointer> new_eqn_slice;
        DifferentiationArguments diffArguments;
        Tearing strict;
        Option<Tearing> casual;
        Boolean linear;

      case StrongComponent.SINGLE_COMPONENT() algorithm
        new_var := differentiateVariablePointer(comp.var, diffArguments_ptr);
        new_eqn := differentiateEquationPointer(comp.eqn, diffArguments_ptr, name);
        Equation.createName(new_eqn, idx, context);
      then StrongComponent.SINGLE_COMPONENT(new_var, new_eqn, comp.status);

      case StrongComponent.MULTI_COMPONENT() algorithm
        new_var_slices := list(Slice.apply(var, function differentiateVariablePointer(diffArguments_ptr = diffArguments_ptr)) for var in comp.vars);
        new_eqn_slice := Slice.apply(comp.eqn, function differentiateEquationPointer(diffArguments_ptr = diffArguments_ptr, name = name));
        Equation.createName(Slice.getT(new_eqn_slice), idx = idx, context = context);
      then StrongComponent.MULTI_COMPONENT(new_var_slices, new_eqn_slice, comp.status);

      case StrongComponent.SLICED_COMPONENT() algorithm
        // Map the subscripted LHS cref without collecting into the adjoint_map if one exists
        (Expression.CREF(cref = new_cref), diffArguments) := differentiateComponentRefNoCollect(Expression.fromCref(comp.var_cref), Pointer.access(diffArguments_ptr));
        Pointer.update(diffArguments_ptr, diffArguments);
        new_var_slice := Slice.apply(comp.var, function differentiateVariablePointer(diffArguments_ptr = diffArguments_ptr));
        new_eqn_slice := Slice.apply(comp.eqn, function differentiateEquationPointer(diffArguments_ptr = diffArguments_ptr, name = name));
        Slice.applyMutable(new_eqn_slice, function Equation.createName(idx = idx, context = context));
      then StrongComponent.SLICED_COMPONENT(new_cref, new_var_slice, new_eqn_slice, comp.status);

      case StrongComponent.RESIZABLE_COMPONENT() algorithm
        (Expression.CREF(cref = new_cref), diffArguments) := differentiateComponentRef(Expression.fromCref(comp.var_cref), Pointer.access(diffArguments_ptr));
        Pointer.update(diffArguments_ptr, diffArguments);
        new_var_slice := Slice.apply(comp.var, function differentiateVariablePointer(diffArguments_ptr = diffArguments_ptr));
        new_eqn_slice := Slice.apply(comp.eqn, function differentiateEquationPointer(diffArguments_ptr = diffArguments_ptr, name = name));
        Slice.applyMutable(new_eqn_slice, function Equation.createName(idx = idx, context = context));
      then StrongComponent.RESIZABLE_COMPONENT(new_cref, new_var_slice, new_eqn_slice, comp.order, comp.status);

      case StrongComponent.GENERIC_COMPONENT() algorithm
        (Expression.CREF(cref = new_cref), diffArguments) := differentiateComponentRef(Expression.fromCref(comp.var_cref), Pointer.access(diffArguments_ptr));
        Pointer.update(diffArguments_ptr, diffArguments);
        new_var_slice := Slice.apply(comp.var, function differentiateVariablePointer(diffArguments_ptr = diffArguments_ptr));
        new_eqn_slice := Slice.apply(comp.eqn, function differentiateEquationPointer(diffArguments_ptr = diffArguments_ptr, name = name));
        Slice.applyMutable(new_eqn_slice, function Equation.createName(idx = idx, context = context));
      then StrongComponent.GENERIC_COMPONENT(new_cref, new_var_slice, new_eqn_slice);

      case StrongComponent.ALGEBRAIC_LOOP() algorithm
        strict := differentiateTearing(comp.strict, diffArguments_ptr, idx, context, name);
        casual := Util.applyOption(comp.casual, function differentiateTearing(diffArguments_ptr=diffArguments_ptr, idx=idx, context=context, name=name));
        // if we differentiate for jacobian, the algebraic loops will always be linear
        linear := match Pointer.access(diffArguments_ptr) case DIFFERENTIATION_ARGUMENTS(diffType = NBDifferentiate.DifferentiationType.JACOBIAN) then true; else comp.linear; end match;
      then StrongComponent.ALGEBRAIC_LOOP(-1, strict, casual, linear, false, comp.homotopy, comp.status);

      case StrongComponent.ENTWINED_COMPONENT() algorithm
        Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " not implemented for entwined equation:\n" + StrongComponent.toString(comp)});
      then fail();

      case StrongComponent.ALIAS() then differentiateStrongComponent(comp.original, diffArguments_ptr, idx, context, name);

      else algorithm
        Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " not implemented for unknown strong component:\n" + StrongComponent.toString(comp)});
      then fail();
    end match;
  end differentiateStrongComponent;

  function differentiateTearing
    input Tearing tearing;
    input Pointer<DifferentiationArguments> diffArguments_ptr;
    input Pointer<Integer> idx;
    input String context;
    input String name;
    output Tearing diff_tearing;
  protected
    list<Slice<VariablePointer>> ite_vars;
    list<Slice<EquationPointer>> res_eqns;
    array<StrongComponent> inner_eqns;
  algorithm
    ite_vars := list(Slice.apply(var, function differentiateVariablePointer(diffArguments_ptr = diffArguments_ptr)) for var in tearing.iteration_vars);
    res_eqns := list(Slice.apply(eqn, function differentiateEquationPointer(diffArguments_ptr = diffArguments_ptr, name = name)) for eqn in tearing.residual_eqns);
    // filter discretes?
    inner_eqns := listArray(list(differentiateStrongComponent(ie, diffArguments_ptr, idx, context, name) for ie in tearing.innerEquations));

    // diff jac?
    diff_tearing := Tearing.TEARING_SET(ite_vars, res_eqns, inner_eqns, NONE());
  end differentiateTearing;

  function differentiateEquationPointerList
    "author: kabdelhak
    Differentiates a list of equations wrapped in pointers."
    input output list<Pointer<Equation>> equations;
    input output DifferentiationArguments diffArguments;
    input Pointer<Integer> idx;
    input String context;
    input String name;
  protected
    Pointer<DifferentiationArguments> diffArguments_ptr = Pointer.create(diffArguments);
  algorithm
    equations := List.map(equations, function differentiateEquationPointer(diffArguments_ptr = diffArguments_ptr, name = name));
    for eqn in equations loop
      Equation.createName(eqn, idx, context);
    end for;
    diffArguments := Pointer.access(diffArguments_ptr);
  end differentiateEquationPointerList;

  function differentiateEquationPointer
    input Pointer<Equation> eq_ptr;
    input Pointer<DifferentiationArguments> diffArguments_ptr;
    input String name = "";
    output Pointer<Equation> derivative_ptr;
  protected
    Equation eq, diffedEq;
    DifferentiationArguments old_diffArguments, new_diffArguments;
  algorithm
    eq := Pointer.access(eq_ptr);
    old_diffArguments := Pointer.access(diffArguments_ptr);

    derivative_ptr := match Equation.getAttributes(eq)

      // we differentiate w.r.t time and there already is a derivative saved
      case EquationAttributes.EQUATION_ATTRIBUTES(derivative = SOME(derivative_ptr))
        guard(old_diffArguments.diffType == DifferentiationType.TIME)
      then derivative_ptr;

      // else differentiate the equation
      else algorithm
        (diffedEq, new_diffArguments) := differentiateEquation(eq, old_diffArguments, name);
        derivative_ptr := Pointer.create(diffedEq);
        // save the derivative if we derive w.r.t. time
        if new_diffArguments.diffType == DifferentiationType.TIME then
          Pointer.update(eq_ptr, Equation.setDerivative(eq, derivative_ptr));
        end if;
        if not referenceEq(new_diffArguments, old_diffArguments) then
          Pointer.update(diffArguments_ptr, new_diffArguments);
        end if;
      then derivative_ptr;
    end match;
  end differentiateEquationPointer;

  function differentiateEquation
    input output Equation eq;
    input output DifferentiationArguments diffArguments;
    input String name = "";
  algorithm
    if Flags.isSet(Flags.DEBUG_DIFFERENTIATION) and not stringEqual(name, "") then
      print("### debugDifferentiation | " + name + " ###\n");
      print("[BEFORE] " + Equation.toString(eq) + "\n");
    end if;
    (eq, diffArguments) := match eq
      local
        Equation res;
        Expression lhs, rhs;
        ComponentRef lhs_cref, rhs_cref;
        list<Equation> forBody = {};
        IfEquationBody ifBody;
        WhenEquationBody whenBody;
        Pointer<DifferentiationArguments> diffArguments_ptr;
        EquationAttributes attr;
        Algorithm alg;

        UnorderedMap<ComponentRef,ComponentRef> dm;
        ComponentRef lhs_base = ComponentRef.EMPTY();
        ComponentRef seed_base;
        Integer n = 0, iel;
        list<Type.Dimension> dims;
        Expression grad_save, rhs_i, grad_i;
        Boolean collect_save;

      // ToDo: Element source stuff (see old backend)
      case Equation.SCALAR_EQUATION() algorithm
        (lhs, diffArguments) := differentiateExpressionNoCollect(eq.lhs, diffArguments);
        (rhs, diffArguments) := differentiateExpression(eq.rhs, diffArguments);
        attr := differentiateEquationAttributes(eq.attr, diffArguments);
      then (Equation.SCALAR_EQUATION(eq.ty, lhs, rhs, eq.source, attr), diffArguments);

      case Equation.ARRAY_EQUATION() algorithm
        (lhs, diffArguments) := differentiateExpressionNoCollect(eq.lhs, diffArguments);
        // Only do per-element reverse seeding for explicit element-wise array assembly on RHS
        if Util.isSome(diffArguments.adjoint_map) and
          diffArguments.diffType == DifferentiationType.JACOBIAN and
          Expression.isArray(eq.rhs) then

          SOME(dm) := diffArguments.diff_map;

          // this must be a variable cref on the LHS so this should work
          lhs_base := Expression.toCref(eq.lhs);

          // Vector length from equation type
          if Type.isArray(eq.ty) then
            dims := Type.arrayDims(eq.ty);
            if not listEmpty(dims) then
              n := Dimension.size(listHead(dims));
            end if;
          end if;

          if (not ComponentRef.isEmpty(lhs_base)) and UnorderedMap.contains(lhs_base, dm) and n > 0 then
            seed_base := UnorderedMap.getOrFail(lhs_base, dm);

            // Save and prepare flags
            grad_save := diffArguments.current_grad;
            collect_save := diffArguments.collectAdjoints;

            // Accumulate adjoints per element with scalar seeds seed_base[i] on rhs[i]
            for iel in 1:n loop
              // current_grad := $SEED...y[i]
              grad_i := Expression.applySubscripts(
                {Subscript.INDEX(Expression.INTEGER(iel))},
                Expression.fromCref(seed_base),
                true);

              // rhs_i := rhs[i]
              rhs_i := Expression.applySubscripts(
                {Subscript.INDEX(Expression.INTEGER(iel))},
                eq.rhs,
                true);

              diffArguments.current_grad := grad_i;
              diffArguments.collectAdjoints := true;

              // Differentiate rhs element to accumulate into adjoint_map
              (_, diffArguments) := differentiateExpression(rhs_i, diffArguments);
            end for;

            // Restore state
            diffArguments.current_grad := grad_save;
            diffArguments.collectAdjoints := collect_save;

            // Also differentiate the full RHS without collecting (avoid duplicates)
            (rhs, diffArguments) := differentiateExpressionNoCollect(eq.rhs, diffArguments);
          else
            // Fallback: regular vector reverse-mode
            (rhs, diffArguments) := differentiateExpression(eq.rhs, diffArguments);
          end if;
        else
          // Non-explicit RHS (e.g., A*x): let reverse-mode handle vectors/matrices
          (rhs, diffArguments) := differentiateExpression(eq.rhs, diffArguments);
        end if;
        attr := differentiateEquationAttributes(eq.attr, diffArguments);
      then (Equation.ARRAY_EQUATION(eq.ty, lhs, rhs, eq.source, attr, eq.recordSize), diffArguments);

      case Equation.RECORD_EQUATION() algorithm
        (lhs, diffArguments) := differentiateExpressionNoCollect(eq.lhs, diffArguments);
        (rhs, diffArguments) := differentiateExpression(eq.rhs, diffArguments);
        attr := differentiateEquationAttributes(eq.attr, diffArguments);
      then (Equation.RECORD_EQUATION(eq.ty, lhs, rhs, eq.source, attr, eq.recordSize), diffArguments);

      case Equation.IF_EQUATION() algorithm
        (ifBody, diffArguments_ptr) := differentiateIfEquationBody(eq.body, Pointer.create(diffArguments));
        attr := differentiateEquationAttributes(eq.attr, diffArguments);
      then (Equation.IF_EQUATION(eq.size, ifBody, eq.source, attr), Pointer.access(diffArguments_ptr));

      case Equation.FOR_EQUATION() algorithm
        for body_eqn in eq.body loop
          (body_eqn, diffArguments) := differentiateEquation(body_eqn, diffArguments);
          forBody := body_eqn :: forBody;
        end for;
        attr := differentiateEquationAttributes(eq.attr, diffArguments);
      then (Equation.FOR_EQUATION(eq.size, eq.iter, listReverse(forBody), eq.source, attr), diffArguments);

      case Equation.WHEN_EQUATION() algorithm
        (whenBody, diffArguments) := differentiateWhenEquationBody(eq.body, diffArguments);
        attr := differentiateEquationAttributes(eq.attr, diffArguments);
      then (Equation.WHEN_EQUATION(eq.size, whenBody, eq.source, attr), diffArguments);

      case Equation.ALGORITHM() algorithm
        (alg, diffArguments) := differentiateAlgorithm(eq.alg, diffArguments); // may need differentiateAlgorithmAdjoint
      then (Equation.ALGORITHM(eq.size, alg, eq.source, eq.expand, eq.attr), diffArguments);

      else algorithm
        // maybe add failtrace here and allow failing
        Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Equation.toString(eq)});
      then fail();

    end match;

/* ToDo
    record AUX_EQUATION
      "Auxiliary equations are generated when auxiliary variables are generated
      that are known to always be solved in this specific equation. E.G. $CSE
      The variable binding contains the equation, but this equation is also
      allowed to have a body for special cases."
      Pointer<Variable> auxiliary     "Corresponding auxiliary variable";
      Option<Equation> body           "Optional body equation"; // -> Expression
    end AUX_EQUATION;

    record DUMMY_EQUATION
    end DUMMY_EQUATION;

  */
    if Flags.isSet(Flags.DEBUG_DIFFERENTIATION) and not stringEqual(name, "") then
      eq := Equation.simplify(eq, name, "\t");
      print("[AFTER ] " + Equation.toString(eq) + "\n\n");
    else
      eq := Equation.simplify(eq, name);
    end if;
  end differentiateEquation;

  function differentiateIfEquationBody
    input output IfEquationBody body;
    input output Pointer<DifferentiationArguments> diffArguments_ptr;
  protected
    list<Pointer<Equation>> then_eqns;
    IfEquationBody else_if;
  algorithm
    // ToDo: this is a little ugly
    // 1. why are the then_eqns Pointers? no need for that
    // 2. we could just traverse it regularly without creating a pointer for diffArguments
    then_eqns := List.map(body.then_eqns, function differentiateEquationPointer(diffArguments_ptr = diffArguments_ptr, name = ""));
    if isSome(body.else_if) then
      (else_if, diffArguments_ptr) := differentiateIfEquationBody(Util.getOption(body.else_if), diffArguments_ptr);
      body := IfEquationBody.IF_EQUATION_BODY(body.condition, then_eqns, SOME(else_if));
    else
      body := IfEquationBody.IF_EQUATION_BODY(body.condition, then_eqns, NONE());
    end if;
  end differentiateIfEquationBody;

  function differentiateWhenEquationBody
    input output WhenEquationBody body;
    input output DifferentiationArguments diffArguments;
  protected
    list<WhenStatement> when_stmts;
    WhenEquationBody else_when;
  algorithm
    (when_stmts, diffArguments) := List.mapFold(body.when_stmts, function differentiateWhenStatement(), diffArguments);
    if isSome(body.else_when) then
      (else_when, diffArguments) := differentiateWhenEquationBody(Util.getOption(body.else_when), diffArguments);
      body := WhenEquationBody.WHEN_EQUATION_BODY(body.condition, when_stmts, SOME(else_when));
    else
      body := WhenEquationBody.WHEN_EQUATION_BODY(body.condition, when_stmts, NONE());
    end if;
  end differentiateWhenEquationBody;

  function differentiateWhenStatement
    input output WhenStatement stmt;
    input output DifferentiationArguments diffArguments;
  algorithm
    (stmt, diffArguments) := match stmt
      local
        Expression lhs, rhs;
      // Only differentiate assignments
      case WhenStatement.ASSIGN() algorithm
        (lhs, diffArguments) := differentiateExpression(stmt.lhs, diffArguments);
        (rhs, diffArguments) := differentiateExpression(stmt.rhs, diffArguments);
      then (WhenStatement.ASSIGN(lhs, rhs, stmt.source), diffArguments);
      else (stmt, diffArguments);
    end match;
  end differentiateWhenStatement;

  function differentiateExpressionDump
    "wrapper function for differentiation to allow dumping before and afterwards"
    input output Expression exp;
    input output DifferentiationArguments diffArguments;
    input String name = "";
    input String indent = "";
  algorithm
    if Flags.isSet(Flags.DEBUG_DIFFERENTIATION) then
      print(indent + "### debugDifferentiation | " + name + " ###\n");
      print(indent + "[BEFORE] " + Expression.toString(exp) + "\n");
      (exp, diffArguments) := differentiateExpression(exp, diffArguments);
      print(indent + "[AFTER ] " + Expression.toString(exp) + "\n\n");
    else
      (exp, diffArguments) := differentiateExpression(exp, diffArguments);
    end if;
  end differentiateExpressionDump;

  function differentiateExpression
    input output Expression exp;
    input output DifferentiationArguments diffArguments;
  algorithm
    (exp, diffArguments) := match exp
      local
        Expression elem1, elem2, current_grad, gradTrue, gradFalse;
        list<Expression> new_elements = {};
        list<list<Expression>> new_matrix_elements = {};
        array<Expression> arr;
        ComponentRef d_fn;
        Boolean isReverse = Util.isSome(diffArguments.adjoint_map);

      // differentiation of constant expressions results in zero
      case Expression.INTEGER()   then (Expression.INTEGER(0), diffArguments);
      case Expression.REAL()      then (Expression.REAL(0.0), diffArguments);
      // leave boolean and string expressions as is
      case Expression.STRING()    then (exp, diffArguments);
      case Expression.BOOLEAN()   then (exp, diffArguments);

      // differentiate cref
      case Expression.CREF() then differentiateComponentRef(exp, diffArguments);

      // [a, b, c, ...]' = [a', b', c', ...]
      case Expression.ARRAY() algorithm
        (arr, diffArguments) := Array.mapFold(exp.elements, differentiateExpression, diffArguments);
        exp.elements := arr;
      then (exp, diffArguments);

      // |a, b, c|'   |a', b', c'|
      // |d, e, f|  = |d', e', f'|
      // |g, h, i|    |g', h', i'|
      case Expression.MATRIX() algorithm
        for element_lst in exp.elements loop
          new_elements := {};
          for element in element_lst loop
            (element, diffArguments) := differentiateExpression(element, diffArguments);
            new_elements := element :: new_elements;
          end for;
          new_matrix_elements := listReverse(new_elements) :: new_matrix_elements;
        end for;
      then (Expression.MATRIX(listReverse(new_matrix_elements)), diffArguments);

      // (a, b, c, ...)' = (a', b', c', ...)
      case Expression.TUPLE() algorithm
        for element in exp.elements loop
          (element, diffArguments) := differentiateExpression(element, diffArguments);
          new_elements := element :: new_elements;
        end for;
      then (Expression.TUPLE(exp.ty, listReverse(new_elements)), diffArguments);

      // REC(a, b, c, ...)' = REC(a', b', c', ...)
      case Expression.RECORD() algorithm
        for element in exp.elements loop
          (element, diffArguments) := differentiateExpression(element, diffArguments);
          new_elements := element :: new_elements;
        end for;
      then (Expression.RECORD(exp.path, exp.ty, listReverse(new_elements)), diffArguments);

      // e.g. (f(x))' = f'(x) * x' (more rules in differentiateCall)
      case Expression.CALL() then differentiateCall(exp, diffArguments);

      // Forward: (if c then a else b)' = if c then a' else b'
      // Reverse: upstream G is only sent to taken branch:
      //   grad_a = if c then G else 0
      //   grad_b = if c then 0 else G
      // Then recurse with those masked gradients.
      case Expression.IF() algorithm
        if isReverse then
          // Keep original upstream
          current_grad := diffArguments.current_grad;

          // Masked gradients
          gradTrue  := Expression.IF(Expression.typeOf(current_grad), exp.condition, current_grad, Expression.makeZero(Expression.typeOf(current_grad)));
          gradFalse := Expression.IF(Expression.typeOf(current_grad), exp.condition, Expression.makeZero(Expression.typeOf(current_grad)), current_grad);

          // Recurse true branch
          diffArguments.current_grad := gradTrue;
          (elem1, diffArguments) := differentiateExpression(exp.trueBranch, diffArguments);

          // Recurse false branch
          diffArguments.current_grad := gradFalse;
          (elem2, diffArguments) := differentiateExpression(exp.falseBranch, diffArguments);

          // Restore upstream
          diffArguments.current_grad := current_grad;
        else
          (elem1, diffArguments) := differentiateExpression(exp.trueBranch, diffArguments);
          (elem2, diffArguments) := differentiateExpression(exp.falseBranch, diffArguments);
        end if;
      then (Expression.IF(exp.ty, exp.condition, elem1, elem2), diffArguments);

      // e.g. (fg)' = fg' + f'g (more rules in differentiateBinary)
      case Expression.BINARY() then differentiateBinary(exp, diffArguments);

      // e.g. (fgh)' = f'gh + fg'h + fgh' (more rules in differentiateMultary)
      case Expression.MULTARY() then differentiateMultary(exp, diffArguments);

      // (-x)' = -(x')
      case Expression.UNARY() algorithm
        if isReverse then
          current_grad := diffArguments.current_grad;

          // apply same unary operator to current_grad
          diffArguments.current_grad := Expression.UNARY(exp.operator, current_grad);
          (elem1, diffArguments) := differentiateExpression(exp.exp, diffArguments);

          diffArguments.current_grad := current_grad;
        else
          (elem1, diffArguments) := differentiateExpression(exp.exp, diffArguments);
        end if;
      then (Expression.UNARY(exp.operator, elem1), diffArguments);

      // ((Real) x)' = (Real) x'
      case Expression.CAST() algorithm
        (elem1, diffArguments) := differentiateExpression(exp.exp, diffArguments);
      then (Expression.CAST(exp.ty, elem1), diffArguments);

      // BOX(x)' = BOX(x')
      case Expression.BOX() algorithm
        (elem1, diffArguments) := differentiateExpression(exp.exp, diffArguments);
      then (Expression.BOX(elem1), diffArguments);

      // UNBOX(x)' = UNBOX(x')
      case Expression.UNBOX() algorithm
        (elem1, diffArguments) := differentiateExpression(exp.exp, diffArguments);
      then (Expression.UNBOX(elem1, exp.ty), diffArguments);

      // (x(1))' = x'(1)
      case Expression.SUBSCRIPTED_EXP() algorithm
        (elem1, diffArguments) := differentiateExpression(exp.exp, diffArguments);
      then (Expression.SUBSCRIPTED_EXP(elem1, exp.subscripts, exp.ty, exp.split), diffArguments);

      // (..., a_i,...)' = (..., a'_i, ...)
      case Expression.TUPLE_ELEMENT() algorithm
        (elem1, diffArguments) := differentiateExpression(exp.tupleExp, diffArguments);
      then (Expression.TUPLE_ELEMENT(elem1, exp.index, exp.ty), diffArguments);

      // REC(i, ...)' = REC(i', ...)
      case Expression.RECORD_ELEMENT() algorithm
        // check if differentiating for simple cref and if it contains it
        if diffArguments.diffType == DifferentiationType.SIMPLE and not Expression.containsCref(exp.recordExp, diffArguments.diffCref) then
          elem1 := Expression.makeZero(Expression.typeOf(exp));
        else
          (elem1, diffArguments) := differentiateExpression(exp.recordExp, diffArguments);
          elem1 := Expression.RECORD_ELEMENT(elem1, exp.index, exp.fieldName, exp.ty);
        end if;
      then (elem1, diffArguments);

      // differentiate a passed function pointer
      case Expression.PARTIAL_FUNCTION_APPLICATION() algorithm
        d_fn := BVariable.makeFDerVar(exp.fn);
        for element in exp.args loop
          (element, diffArguments) := differentiateExpression(element, diffArguments);
          new_elements := element :: new_elements;
        end for;
      then (Expression.PARTIAL_FUNCTION_APPLICATION(d_fn, listAppend(exp.args, listReverse(new_elements)),
        listAppend(exp.argNames, list(BackendUtil.makeFDerString(name) for name in exp.argNames)), exp.ty), diffArguments);

      // Binary expressions, conditions and placeholders are not differentiated and left as they are
      case Expression.LBINARY()       then (exp, diffArguments);
      case Expression.LUNARY()        then (exp, diffArguments);
      case Expression.RELATION()      then (exp, diffArguments);
      case Expression.SIZE()          then (exp, diffArguments);
      case Expression.RANGE()         then (exp, diffArguments);
      case Expression.END()           then (exp, diffArguments);
      case Expression.EMPTY()         then (exp, diffArguments);
      case Expression.ENUM_LITERAL()  then (exp, diffArguments);
      case Expression.TYPENAME()      then (exp, diffArguments);

      else algorithm
        // maybe add failtrace here and allow failing
        Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp)});
      then fail();
    end match;
  end differentiateExpression;

  function differentiateExpressionNoCollect
    input output Expression expr;
    input output DifferentiationArguments diffArguments;
  protected
    Boolean oldCollect;
  algorithm
    if Util.isSome(diffArguments.adjoint_map) then
      oldCollect := diffArguments.collectAdjoints;
      diffArguments.collectAdjoints := false;
      (expr, diffArguments) := differentiateExpression(expr, diffArguments);
      diffArguments.collectAdjoints := oldCollect;
    else
      (expr, diffArguments) := differentiateExpression(expr, diffArguments);
    end if;
  end differentiateExpressionNoCollect;

  function differentiateComponentRef
    input output Expression exp "Has to be Expression.CREF()";
    input output DifferentiationArguments diffArguments;
  protected
    Pointer<Variable> var_ptr, der_ptr;
    ComponentRef derCref, strippedCref;
  algorithm
    // extract var pointer first to have following code more readable
    var_ptr := match exp
      // function body expressions, empty and wild crefs are not lowered (maybe do it?)
      case _ guard(diffArguments.diffType == DifferentiationType.FUNCTION) then Pointer.create(NBVariable.DUMMY_VARIABLE);
      case Expression.CREF(cref = ComponentRef.EMPTY()) then Pointer.create(NBVariable.DUMMY_VARIABLE);
      case Expression.CREF(cref = ComponentRef.WILD())  then Pointer.create(NBVariable.DUMMY_VARIABLE);
      case Expression.CREF() then BVariable.getVarPointer(exp.cref, sourceInfo());
      else algorithm
        Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp)});
      then fail();
    end match;

    // Debug entry summary
    dbg("[dCREF] exp=" + Expression.toString(exp)
        + " | diffType=" + DifferentiationArguments.diffTypeStr(diffArguments.diffType)
        + " | scalarized=" + boolString(diffArguments.scalarized)
        + " | collectAdjoints=" + boolString(diffArguments.collectAdjoints));
    if Util.isSome(diffArguments.adjoint_map) then
      dbg("[dCREF] current_grad=" + Expression.toString(diffArguments.current_grad));
    end if;

    (exp, diffArguments) := match (exp, diffArguments.diffType, diffArguments.diff_map)
      local
        Expression res, adjExpr;
        UnorderedMap<ComponentRef,ComponentRef> diff_map;
        list<Subscript> expCrefSubscripts;

      // -------------------------------------
      //    EMPTY and WILD crefs do nothing
      // -------------------------------------
      case (Expression.CREF(cref = ComponentRef.EMPTY()), _, _) then (exp, diffArguments);
      case (Expression.CREF(cref = ComponentRef.WILD()), _, _)  then (exp, diffArguments);

      // -------------------------------------
      //    Special rules for Type: FUNCTION
      //    (needs to be first because var_ptr is DUMMY)
      // -------------------------------------

      // Types: (FUNCTION)
      // Any variable that is in the HT will be differentiated accordingly. 0 otherwise
      case (Expression.CREF(), DifferentiationType.FUNCTION, SOME(diff_map)) algorithm
        strippedCref := ComponentRef.stripSubscriptsAll(exp.cref);
        if UnorderedMap.contains(strippedCref, diff_map) then
          // get the derivative and reapply subscripts
          derCref := UnorderedMap.getOrFail(strippedCref, diff_map);
          derCref := ComponentRef.copySubscripts(exp.cref, derCref);
          res     := Expression.fromCref(derCref);
        else
          res     := Expression.makeZero(exp.ty);
        end if;
      then (res, diffArguments);

      // -------------------------------------
      //    Generic Rules
      // -------------------------------------

      // Types: (TIME)
      // differentiate time cref => 1
      case (Expression.CREF(), DifferentiationType.TIME, _)
        guard(ComponentRef.isTime(exp.cref))
      then (Expression.makeOne(exp.ty), diffArguments);

      // Types: not (TIME)
      // differentiate time cref => 0
      case (Expression.CREF(), _, _)
        guard(ComponentRef.isTime(exp.cref))
      then (Expression.makeZero(exp.ty), diffArguments);

      // Types: (ALL)
      // differentiate start cref => 0
      case (Expression.CREF(), _, _)
        guard(BVariable.isStart(var_ptr))
      then (Expression.makeZero(exp.ty), diffArguments);

      // ToDo: Records, Arrays, WILD (?)

      // Types: (SIMPLE)
      //  D(x)/dx => 1
      case (Expression.CREF(), DifferentiationType.SIMPLE, _)
        guard(ComponentRef.isEqual(exp.cref, diffArguments.diffCref))
      then (Expression.makeOne(exp.ty), diffArguments);

      // Types: (SIMPLE)
      // D(y)/dx => 0
      case (Expression.CREF(), DifferentiationType.SIMPLE, _)
      then (Expression.makeZero(exp.ty), diffArguments);

      // Types: (ALL)
      // Known variables, except for top level inputs have a 0-derivative
      case (Expression.CREF(), _, _)
        guard(BVariable.isParamOrConst(var_ptr) and
              not (ComponentRef.isTopLevel(exp.cref) and BVariable.isInput(var_ptr))
              and not BVariable.isOptimizable(var_ptr) /* TODO? */ )
      then (Expression.makeZero(exp.ty), diffArguments);

      // -------------------------------------
      //    Special rules for Type: TIME
      // -------------------------------------

      // Types: (TIME)
      // D(discrete)/d(x) = 0
      case (Expression.CREF(), DifferentiationType.TIME, _)
        guard(BVariable.isDiscrete(var_ptr) or BVariable.isDiscreteState(var_ptr))
      then (Expression.makeZero(exp.ty), diffArguments);

      // Types: (TIME)
      // known derivatives by state order
      case (Expression.CREF(), DifferentiationType.TIME, SOME(diff_map))
        guard(UnorderedMap.contains(ComponentRef.stripSubscriptsAll(exp.cref), diff_map)) algorithm
        // get the derivative and reapply subscripts
        derCref := UnorderedMap.getOrFail(ComponentRef.stripSubscriptsAll(exp.cref), diff_map);
        derCref := ComponentRef.copySubscripts(exp.cref, derCref);
        res     := Expression.fromCref(derCref);
      then (res, diffArguments);

      // Types: (TIME)
      // DUMMY_STATES => DUMMY_DER
      case (Expression.CREF(), DifferentiationType.TIME, _)
        guard(BVariable.isDummyState(var_ptr))
      then (Expression.fromCref(BVariable.getPartnerCref(exp.cref, BVariable.getVarDummyDer)), diffArguments);

      // Types: (TIME)
      // D(x)/dtime --> der(x) --> $DER.x
      // STATE => STATE_DER
      case (Expression.CREF(), DifferentiationType.TIME, _)
        guard(BVariable.isState(var_ptr))
      then (Expression.fromCref(BVariable.getPartnerCref(exp.cref, BVariable.getVarDer)), diffArguments);

      // Types: (TIME)
      // D(y)/dtime --> der(y) --> $DER.y
      // ALGEBRAIC => STATE_DER
      // make y a state and add new STATE_DER
      case (Expression.CREF(), DifferentiationType.TIME, _)
        guard(BVariable.isContinuous(var_ptr, false))
        algorithm
          // create derivative
          (derCref, der_ptr) := BVariable.makeDerVar(exp.cref);
          // add derivative to new_vars
          diffArguments.new_vars := der_ptr :: diffArguments.new_vars;
          // update algebraic variable to be a state
          BVariable.setStateDerivativeVar(var_ptr, der_ptr);
      then (Expression.fromCref(derCref), diffArguments);

      // -------------------------------------
      //    Special rules for Type: JACOBIAN
      // -------------------------------------

      // Types: (JACOBIAN)
      // cref in diff_map => get $SEED or $pDER variable from hash table
      case (Expression.CREF(), DifferentiationType.JACOBIAN, SOME(diff_map))
        guard(diffArguments.scalarized)
      algorithm
        if UnorderedMap.contains(exp.cref, diff_map) then
          res := Expression.fromCref(UnorderedMap.getOrFail(exp.cref, diff_map));

          // Accumulate adjoint contribution: append current_grad to list at key exp.cref.
          if diffArguments.collectAdjoints then
            UnorderedMap.tryAddUpdate(exp.cref, function updateAdjointList(current_grad = diffArguments.current_grad), Util.getOption(diffArguments.adjoint_map));
          end if;
        else
          // Everything that is not in diff_map gets differentiated to zero
          res := Expression.makeZero(exp.ty);
        end if;
      then (res, diffArguments);

      // Types: (JACOBIAN)
      // cref in diff_map => get $SEED or $pDER variable from hash table
      case (Expression.CREF(), DifferentiationType.JACOBIAN, SOME(diff_map))
        guard(not diffArguments.scalarized)
      algorithm
        strippedCref := ComponentRef.stripSubscriptsAll(exp.cref);
        expCrefSubscripts := ComponentRef.subscriptsAllFlat(exp.cref);
        dbg("[dCREF:JAC] cref=" + ComponentRef.toString(exp.cref)
            + " | stripped=" + ComponentRef.toString(strippedCref)
            + " | subs=" + Subscript.toStringList(expCrefSubscripts));
        if UnorderedMap.contains(strippedCref, diff_map) then
          // get the derivative an reapply subscripts
          derCref := UnorderedMap.getOrFail(strippedCref, diff_map);
          dbg("[dCREF:JAC] mapped -> " + ComponentRef.toString(derCref));
          res     := Expression.fromCref(ComponentRef.copySubscripts(exp.cref, derCref));
          dbg("[dCREF:JAC] get variable for derivative cref: " + NBVariable.pointerToString(NBVariable.getVarPointer(derCref, sourceInfo())));
          if diffArguments.collectAdjoints then // if derCref is on the rhs then collect adjoint (collectAdjoints is false when differentiating lhs)
            // Create adjoint expression from subscripts:
            adjExpr := match expCrefSubscripts
              local
                Integer iidx;
                Option<Expression> onehotOpt;
                Option<Expression> multiOpt;
              // Single literal index -> one-hot
              case {Subscript.INDEX(Expression.INTEGER(iidx))}
                algorithm
                  dbg("[dCREF:JAC] adjoint via INDEX[" + intString(iidx) + "]");
                  onehotOpt := buildOneHotVectorAdjoint(derCref, iidx, diffArguments.current_grad);
                then (if Util.isSome(onehotOpt) then Util.getOption(onehotOpt) else diffArguments.current_grad);

              // Single slice/range -> multi-hot scatter
              case {Subscript.SLICE()}
                algorithm
                  dbg("[dCREF:JAC] adjoint via SLICE " + Subscript.toString(listHead(expCrefSubscripts)));
                  multiOpt := buildMultiHotVectorAdjoint(derCref, listHead(expCrefSubscripts), diffArguments.current_grad);
                then (if Util.isSome(multiOpt) then Util.getOption(multiOpt) else diffArguments.current_grad);

              // Whole dimension -> pass upstream as-is
              case {Subscript.WHOLE()}
                then diffArguments.current_grad;

              // Fallback: keep previous behavior
              else diffArguments.current_grad;
            end match;
            dbg("[dCREF:JAC] append adjoint key=" + ComponentRef.toString(derCref)
                + " expr=" + Expression.toString(adjExpr));
            UnorderedMap.tryAddUpdate(derCref, function updateAdjointList(current_grad = adjExpr), Util.getOption(diffArguments.adjoint_map));
          else
            dbg("[dCREF:JAC] collectAdjoints=false, skip append");
          end if;
        else
          res     := Expression.makeZero(exp.ty);
        end if;
      then (res, diffArguments);

      else algorithm
        // maybe add failtrace here and allow failing
        Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp)});
      then fail();

    end match;
  end differentiateComponentRef;

  function differentiateComponentRefNoCollect
    input output Expression exp;
    input output DifferentiationArguments diffArguments;
  protected
    Boolean oldCollect;
  algorithm
    if Util.isSome(diffArguments.adjoint_map) then
      oldCollect := diffArguments.collectAdjoints;
      diffArguments.collectAdjoints := false;
      (exp, diffArguments) := differentiateComponentRef(exp, diffArguments);
      diffArguments.collectAdjoints := oldCollect;
    else
      (exp, diffArguments) := differentiateComponentRef(exp, diffArguments);
    end if;
  end differentiateComponentRefNoCollect;

  function differentiateVariablePointer
    input Pointer<Variable> var_ptr;
    input Pointer<DifferentiationArguments> diffArguments_ptr;
    output Pointer<Variable> diff_ptr;
  protected
    DifferentiationArguments diffArguments = Pointer.access(diffArguments_ptr);
    Variable var = Pointer.access(var_ptr);
    Expression crefExp;
  algorithm
    (crefExp, diffArguments) := differentiateComponentRefNoCollect(Expression.fromCref(var.name), diffArguments);
    diff_ptr := match crefExp
      case Expression.CREF(cref = ComponentRef.EMPTY()) then Pointer.create(NBVariable.DUMMY_VARIABLE);
      case Expression.CREF(cref = ComponentRef.WILD())  then Pointer.create(NBVariable.DUMMY_VARIABLE);
      case Expression.CREF() then BVariable.getVarPointer(crefExp.cref, sourceInfo());
      else algorithm
        Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for " + Variable.toString(var)
          + " because the result is expected to be a variable but turned out to be " + Expression.toString(crefExp) + "."});
      then fail();
    end match;
    Pointer.update(diffArguments_ptr, diffArguments);
  end differentiateVariablePointer;

  function differentiateCall
  "Differentiate builtin function calls
  1. if the function is builtin -> use hardcoded logic
  2. if the function is not builtin -> check if there is a 'fitting' derivative defined.
    - 'fitting' means that all the zeroDerivative annotations have to hold
    2.1 fitting function found -> use it
    2.2 fitting function not found -> differentiate the body of the function
  ToDo: respect the 'order' of the derivative when differentiating!"
    input output Expression exp "Has to be Expression.CALL()";
    input output DifferentiationArguments diffArguments;
  protected
    constant Boolean debug = false;
  algorithm
    if debug then
      print("\nDifferentiate Exp-Call: "+ Expression.toString(exp) + "\n");
    end if;

    (exp, diffArguments) := match exp
      local
        Expression ret, arg;
        Call call, der_call;
        Option<Function> func_opt, der_func_opt;
        list<Function> derivatives;
        Function func, der_func;
        list<Expression> arguments = {};
        Operator addOp, mulOp;
        list<tuple<Expression, InstNode>> arguments_inputs;
        InstNode inp;
        Boolean isCont, isReal, isFunc;
        // interface map. If the map contains a variable it has a zero derivative
        // if the value is "true" it has to be stripped from the interface
        // (it is possible that a variable has a zero derivative, but still appears in the interface)
        UnorderedMap<String, Boolean> interface_map;

      // for array constructors only differentiate the argument
      case ret as Expression.CALL(call = call as Call.TYPED_ARRAY_CONSTRUCTOR()) algorithm
        (arg, diffArguments) := differentiateExpression(call.exp, diffArguments);
        call.exp := arg;
        ret.call := call;
      then (ret, diffArguments);

      // handle reductions
      case Expression.CALL(call = call as Call.TYPED_REDUCTION()) algorithm
        (ret, diffArguments) := differentiateReduction(AbsynUtil.pathString(Function.nameConsiderBuiltin(call.fn)), exp, diffArguments);
      then (ret, diffArguments);

      // builtin functions
      case Expression.CALL(call = call as Call.TYPED_CALL()) guard(Function.isBuiltin(call.fn)) algorithm
        (ret, diffArguments) := differentiateBuiltinCall(AbsynUtil.pathString(Function.nameConsiderBuiltin(call.fn)), exp, diffArguments);
      then (ret, diffArguments);

      // user defined functions
      case Expression.CALL(call = call as Call.TYPED_CALL()) algorithm
        func_opt := UnorderedMap.get(call.fn.path, diffArguments.funcMap);
        if Util.isSome(func_opt) then
          // The function is in the function tree
          SOME(func) := func_opt;

          interface_map := UnorderedMap.new<Boolean>(stringHashDjb2, stringEqual);

          // build interface map to check if a function fits
          // save all inputs that would end up in a zero derivative in a map
          arguments_inputs := List.zip(call.arguments, func.inputs);
          for tpl in arguments_inputs loop
            (arg, inp) := tpl;
            // do not check for continuous if it is for functions (differentiating a function inside a function)
            // crefs are not lowered there! assume it is continuous
            isCont := (diffArguments.diffType == DifferentiationType.FUNCTION) or BackendUtil.isContinuous(arg, false);
            // input type has to be real value or a function pointer
            isReal := Type.isReal(Type.arrayElementType(Expression.typeOf(arg)));
            isFunc := InstNode.isFunction(inp);
            if not (isFunc or (isCont and isReal)) then
              // add to map; if it is not Real also already set to true (always removed from interface)
              UnorderedMap.add(InstNode.name(inp), not (isFunc or isReal), interface_map);
            end if;
          end for;

          // try to get a fitting function from derivatives -> if none is found, differentiate
          der_func_opt := Function.getDerivative(func, interface_map);
          if Util.isSome(der_func_opt) then
            SOME(der_func) := der_func_opt;
          else
            (der_func, diffArguments) := differentiateFunction(func, interface_map, diffArguments);
          end if;

          for tpl in listReverse(arguments_inputs) loop
            (arg, inp) := tpl;
            // only keep the arguments which are not in the map or have value false
            if not UnorderedMap.getOrDefault(InstNode.name(inp), interface_map, false) then
              arguments := arg :: arguments;
            end if;
          end for;

          // differentiate type arguments and append to original ones
          (arguments, diffArguments) := List.mapFold(arguments, differentiateExpression, diffArguments);
          arguments := listAppend(call.arguments, arguments);

          ret := Expression.CALL(Call.makeTypedCall(der_func, arguments, call.var, call.purity));
        else
          // The function is not in the function tree and not builtin -> error
          Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName()
            + " failed because the function is not a builtin function and could not be found in the function tree: "
            + Expression.toString(exp)});
          fail();
        end if;
      then (ret, diffArguments);

      // If the call was not typed correctly by the frontend
      else algorithm
        Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp)});
      then fail();
    end match;

    if debug then
      print("Differentiate-ExpCall-result: " + Expression.toString(exp) + "\n");
    end if;
  end differentiateCall;

  function differentiateReduction
    "This function differentiates reduction expressions with respect to a given variable.
    Also creates and multiplies inner derivatives."
    input String name;
    input output Expression exp;
    input output DifferentiationArguments diffArguments;
  algorithm
    exp := match exp
      local
        Call call;
        Expression arg;

      case Expression.CALL(call = call as Call.TYPED_REDUCTION()) guard(name == "sum") algorithm
        (arg, diffArguments) := differentiateExpression(call.exp, diffArguments);
        call.exp := arg;
        exp.call := call;
      then exp;

      // ToDo: product, min, max

      else algorithm
        Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed because of non-call expression: " + Expression.toString(exp)});
      then fail();
    end match;
  end differentiateReduction;

  function differentiateBuiltinCall
    "This function differentiates built-in call expressions with respect to a given variable.
    Also creates and multiplies inner derivatives."
    input String name;
    input output Expression exp;
    input output DifferentiationArguments diffArguments;
  protected
    // these need to be adapted to size and type of exp
    Operator.SizeClassification sizeClass = NFOperator.SizeClassification.SCALAR;
    Operator addOp = Operator.fromClassification((NFOperator.MathClassification.ADDITION, sizeClass), Type.REAL());
    Operator mulOp = Operator.fromClassification((NFOperator.MathClassification.MULTIPLICATION, sizeClass), Type.REAL());
  algorithm
    exp := match (exp)
      local
        Integer i;
        Expression ret, ret1, ret2, arg1, arg2, arg3, diffArg1, diffArg2, diffArg3, current_grad, cond1, cond2, cond, zero1, zero2, grad_x, grad_y, old_grad;
        list<Expression> rest;
        Type ty;
        DifferentiationType diffType;
        Integer rY, rX;
        Boolean isReverse = Util.isSome(diffArguments.adjoint_map);

        Type elTy;
        // sumG = G + Gᵀ
        Operator addM, subM;
        Expression sumG, triuG;

        // diagG = G .* I(n), I(n) from diagonal(ones(n))
        Integer nExp;
        Expression eyeNN;
        Operator mulEW;
        Expression diagG;

      // d/dz delay(x, delta) = (dt/dz - d delta/dz) * delay(der(x), delta)
      case (Expression.CALL()) guard(name == "delay")
      algorithm
        (arg1, arg2, arg3) := match Call.arguments(exp.call)
          case {arg1, arg2, arg3} then (arg1, arg2, arg3);
          else algorithm
            Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp) + "."});
          then fail();
        end match;
        // if z = t then dt/dz = 1 else dt/dz = 0
        ret1 := Expression.REAL(if diffArguments.diffType == DifferentiationType.TIME then 1.0 else 0.0);
        // d delta/dz
        (ret2, diffArguments) := differentiateExpression(arg2, diffArguments);
        // dt/dz - d delta/dz
        ret2 := SimplifyExp.simplifyDump(Expression.MULTARY({ret1}, {ret2}, addOp), true, getInstanceName());
        if Expression.isZero(ret2) then
          ret := Expression.makeZero(Expression.typeOf(arg1));
        else
          diffType := diffArguments.diffType;
          diffArguments.diffType := DifferentiationType.TIME;
          (ret1, diffArguments) := differentiateExpression(arg1, diffArguments);
          diffArguments.diffType := diffType;
          exp.call := Call.setArguments(exp.call, {ret1, arg2, arg3});
          ret := Expression.MULTARY({ret2, exp}, {}, mulOp);
        end if;
      then ret;

      // SMOOTH
      case (Expression.CALL()) guard(name == "smooth")
      algorithm
        ret := match Call.arguments(exp.call)
          case {arg1 as Expression.INTEGER(i), arg2} guard(i > 0) algorithm
            (ret2, diffArguments) := differentiateExpression(arg2, diffArguments);
            exp.call := Call.setArguments(exp.call, {Expression.INTEGER(i-1), ret2});
          then exp;
          case {arg1 as Expression.INTEGER(i), arg2} algorithm
            (ret2, diffArguments) := differentiateExpression(arg2, diffArguments);
            exp := Expression.CALL(Call.makeTypedCall(
              fn          = NFBuiltinFuncs.NO_EVENT,
              args        = {ret2},
              variability = Expression.variability(ret2),
              purity      = NFPrefixes.Purity.PURE
            ));
          then exp;
          else algorithm
            Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp) + "."});
          then fail();
        end match;
      then ret;

      case (Expression.CALL()) guard(name == "sum")
      algorithm
        arg1 := match Call.arguments(exp.call)
          case {arg1} then arg1;
          else algorithm
            Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp) + "."});
          then fail();
        end match;
        if isReverse then
          current_grad := diffArguments.current_grad;
          // sum is linear -> multiply upstream gradient with ones of the right size
          diffArguments.current_grad := Expression.BINARY(
            Expression.makeOne(Expression.typeOf(arg1)),
            Operator.fromClassification(
              (NFOperator.MathClassification.MULTIPLICATION, NFOperator.SizeClassification.ARRAY_SCALAR),
              Expression.typeOf(arg1)
            ),
            current_grad);
        end if;

        (ret1, diffArguments) := differentiateExpression(arg1, diffArguments);

        if isReverse then
          // restore upstream
          diffArguments.current_grad := current_grad;
        end if;
        exp.call := Call.setArguments(exp.call, {ret1});
      then exp;

      // symmetric(A):
      // Forward: symmetric(dA/dz)
      // Reverse: grad_A = triu(G + Gᵀ) - diag(G)
      case (Expression.CALL()) guard(name == "symmetric")
      algorithm
        arg1 := match Call.arguments(exp.call)
          case {arg1} then arg1;
          else algorithm
            Error.addMessage(Error.INTERNAL_ERROR, {getInstanceName() + " failed for: " + Expression.toString(exp) + "."});
          then fail();
        end match;

        if isReverse then
          current_grad := diffArguments.current_grad;

          // upstream gradient type (matrix)
          ty := Expression.typeOf(current_grad);
          // element type
          elTy := if Type.isArray(ty) then Type.arrayElementType(ty) else ty;
          // matrix dimension (assume square)
          nExp := Dimension.size(listHead(Type.arrayDims(Expression.typeOf(arg1))));

          // element-wise add / mul operators with full matrix type (not element type)
          addM := Operator.fromClassification(
            (NFOperator.MathClassification.ADDITION, NFOperator.SizeClassification.ELEMENT_WISE),
            ty);
          subM := Operator.fromClassification(
            (NFOperator.MathClassification.SUBTRACTION, NFOperator.SizeClassification.ELEMENT_WISE),
            ty);
          mulEW := Operator.fromClassification(
            (NFOperator.MathClassification.MULTIPLICATION, NFOperator.SizeClassification.ELEMENT_WISE),
            ty);

          // sumG = G + Gᵀ   (binary)
          sumG := Expression.BINARY(
            current_grad,
            addM,
            typeTransposeCall(current_grad));

          // triu(sumG) = sumG .* triu(ones(n,n))  (binary)
          triuG := Expression.BINARY(
            sumG,
            mulEW,
            Expression.makeTriuMask(nExp, elTy));

          // I(n)
          eyeNN := Expression.makeIdentityMatrix(nExp, elTy);

          // diagG = G .* I  (binary)
          diagG := Expression.BINARY(
            current_grad,
            mulEW,
            eyeNN);

          // triu(G + Gᵀ) - diag(G)  (binary)
          diffArguments.current_grad := Expression.BINARY(
            triuG,
            subM,
            diagG);
        end if;

        // Forward: symmetric(dA/dz)
        (ret1, diffArguments) := differentiateExpression(arg1, diffArguments);

        if isReverse then
          // restore upstream
          diffArguments.current_grad := current_grad;
        end if;
        exp.call := Call.setArguments(exp.call, {ret1});
      then exp;

      // diagonal(v):
      // Forward: diagonal(dv/dz)
      // Reverse: grad_v = diag(G)  (extract diagonal of upstream matrix)
      case (Expression.CALL()) guard(name == "diagonal")
      algorithm
        arg1 := match Call.arguments(exp.call)
          case {arg1} then arg1;
          else algorithm
            Error.addMessage(Error.INTERNAL_ERROR, {getInstanceName() + " failed for: " + Expression.toString(exp) + "."});
          then fail();
        end match;

        if isReverse then
          current_grad := diffArguments.current_grad;
          // number of elements in v and in diagonal of G
          nExp := Dimension.size(listHead(Type.arrayDims(Expression.typeOf(arg1))));
          // Literal: [ G[1,1], G[2,2], ..., G[n,n] ]
          diffArguments.current_grad := extractDiagonalVector(current_grad, nExp, Expression.typeOf(arg1));
        end if;

        // Forward: diagonal(dv/dz)
        (ret1, diffArguments) := differentiateExpression(arg1, diffArguments);

        if isReverse then
          // Restore upstream and return updated call
          diffArguments.current_grad := current_grad;
        end if;
        exp.call := Call.setArguments(exp.call, {ret1});
      then exp;

      // matrix(A)
      // Forward: matrix(dA/dz)
      // Reverse: let rX = ndims(A), G the upstream matrix:
      //   - if rX < 2: dropLastDimIndex1(G) (2-rX times)
      //   - if rX = 2: G
      //   - if rX > 2: promote(G, rX)
      case (Expression.CALL()) guard(name == "matrix")
      algorithm
        arg1 := match Call.arguments(exp.call)
          case {arg1} then arg1;
          else algorithm
            Error.addMessage(Error.INTERNAL_ERROR, {getInstanceName() + " failed for: " + Expression.toString(exp) + "."});
          then fail();
        end match;

        if isReverse then
          current_grad := diffArguments.current_grad;
          // Rank of input A
          ty := Expression.typeOf(arg1);
          rX := if Type.isArray(ty) then Type.dimensionCount(ty) else 0;

          // Map upstream gradient back to A's shape
          grad_x := current_grad;

          // If A has rank < 2, drop trailing dims by indexing with 1
          if rX < 2 then
            for i in 1:(2 - rX) loop
              grad_x := dropLastDimIndex1(grad_x);
            end for;
          elseif rX > 2 then
            // If A has rank > 2 (with trailing singleton dims), promote G to rank rX
            grad_x := typePromoteCall(grad_x, rX);
          end if;

          // Recurse into A with mapped upstream gradient
          diffArguments.current_grad := grad_x;

          (ret1, diffArguments) := differentiateExpression(arg1, diffArguments);

          // restore upstream
          diffArguments.current_grad := current_grad;
        else
          (ret1, diffArguments) := differentiateExpression(arg1, diffArguments);
        end if;
        // Forward: matrix(dA/dz)
        exp.call := Call.setArguments(exp.call, {ret1});
      then exp;

      // Functions with one argument that differentiate "through"
      // through means that the derivative of the function wrt. its input is equal to the function of derivative of input
      // d/dz f(x) -> f(dx/dz)
      case (Expression.CALL()) guard(List.contains({"pre", "noEvent", "scalar", "vector", "transpose", "skew"}, name, stringEqual))
      algorithm
        arg1 := match Call.arguments(exp.call)
          case {arg1} then arg1;
          else algorithm
            Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp) + "."});
          then fail();
        end match;
        (ret1, diffArguments) := differentiateExpression(arg1, diffArguments);
        exp.call := Call.setArguments(exp.call, {ret1});
      then exp;

      // Functions with two arguments that differentiate "through"
      // df(x,y)/dz = f(dx/dz, dy/dz)
      case (Expression.CALL()) guard(List.contains({"homotopy", "$OMC$inStreamDiv"}, name, stringEqual))
      algorithm
        (arg1, arg2) := match Call.arguments(exp.call)
          case {arg1, arg2} then (arg1, arg2);
          else algorithm
            Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp) + "."});
          then fail();
        end match;
        (ret1, diffArguments) := differentiateExpression(arg1, diffArguments);
        (ret2, diffArguments) := differentiateExpression(arg2, diffArguments);
        exp.call := Call.setArguments(exp.call, {ret1, ret2});
      then exp;

      // d/dz promote(A, n) = promote(dA/dz, n)
      case (Expression.CALL()) guard(name == "promote")
      algorithm
        (arg1, arg2) := match Call.arguments(exp.call)
          case {arg1, arg2} then (arg1, arg2);
          else algorithm
            Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp) + "."});
          then fail();
        end match;
        if isReverse then
          rY := if Type.isArray(Expression.typeOf(exp)) then Type.dimensionCount(Expression.typeOf(exp)) else 0;
          rX := if Type.isArray(Expression.typeOf(arg1)) then Type.dimensionCount(Expression.typeOf(arg1)) else 0;
          current_grad := diffArguments.current_grad;
          old_grad := current_grad;
          for i in 1:max(0, rY - rX) loop
            current_grad := dropLastDimIndex1(current_grad);
          end for;
          diffArguments.current_grad := current_grad;
          (ret1, diffArguments) := differentiateExpression(arg1, diffArguments);
          diffArguments.current_grad := old_grad;
        else
          (ret1, diffArguments) := differentiateExpression(arg1, diffArguments);
        end if;
        exp.call := Call.setArguments(exp.call, {ret1, arg2});
      then exp;

      // d/dz identity(n) = zeros(n, n)
      case (Expression.CALL()) guard(name == "identity")
      algorithm
        // diffArguments.current_grad := Expression.makeZero(Expression.typeOf(exp));?
        arg1 := match Call.arguments(exp.call)
          case {arg1} then arg1;
          else algorithm
            Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp) + "."});
          then fail();
        end match;
      then Expression.CALL(Call.makeTypedCall(
          fn          = NFBuiltinFuncs.FILL_FUNC,
          args        = {Expression.INTEGER(0), arg1, arg1},
          variability = Variability.CONSTANT,
          purity      = NFPrefixes.Purity.PURE
        ));

      // d/dz fill(x, n1, n2, ...) = fill(dx/dz, n1, n2, ...)
      case (Expression.CALL()) guard(name == "fill")
      algorithm
        // only differentiate 1st input
        arg1 :: rest := Call.arguments(exp.call);
        if isReverse then
          rY := if Type.isArray(Expression.typeOf(exp)) then Type.dimensionCount(Expression.typeOf(exp)) else 0;
          rX := if Type.isArray(Expression.typeOf(arg1)) then Type.dimensionCount(Expression.typeOf(arg1)) else 0;
          current_grad := diffArguments.current_grad;
          old_grad := current_grad;
          for i in 1:max(0, rY - rX) loop // reduce over all added dimensions with sum (TODO: change to only sum over added dimensions)
            current_grad := typeSumCall(current_grad); // sum over first (or last?) dimension
          end for;
          diffArguments.current_grad := current_grad;
          (ret1, diffArguments) := differentiateExpression(arg1, diffArguments);
          diffArguments.current_grad := old_grad;
        else
          (ret1, diffArguments) := differentiateExpression(arg1, diffArguments);
        end if;
        exp.call := Call.setArguments(exp.call, ret1 :: rest);
      then exp;

      // SEMI LINEAR
      // d sL(x, m1, m2)/dz = sL(x, dm1/dz, dm2/dz) + dx/dz * (if x >= 0 then m1 else m2)
      case (Expression.CALL()) guard(name == "semiLinear")
      algorithm
        (arg1, arg2, arg3) := match Call.arguments(exp.call)
          case {arg1, arg2, arg3} then (arg1, arg2, arg3);
          else algorithm
            Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp) + "."});
          then fail();
        end match;
        current_grad := diffArguments.current_grad;

        if isReverse then
          cond := Expression.RELATION(
            arg1, // x
            Operator.makeGreaterEq(Expression.typeOf(arg1)),
            Expression.makeZero(Expression.typeOf(arg1)),
            -1);

          grad_x := Expression.IF(
            Expression.typeOf(arg1),
            cond,
            Expression.MULTARY({arg2, current_grad}, {}, mulOp), // d(positive_slope * x)/dx = positive_slope * current_grad
            Expression.MULTARY({arg3, current_grad}, {}, mulOp)  // d(negative_slope * x)/dx = negative_slope * current_grad
          );
          diffArguments.current_grad := grad_x;
        end if;

        // dx/dz, dm1/dz, dm2/dz
        (diffArg1, diffArguments) := differentiateExpression(arg1, diffArguments);
        diffArguments.current_grad := current_grad; // restore upstream
        (diffArg2, diffArguments) := differentiateExpression(arg2, diffArguments);
        (diffArg3, diffArguments) := differentiateExpression(arg3, diffArguments);

        // sL(x, dm1/dz, dm2/dz)
        exp.call := Call.setArguments(exp.call, {arg1, diffArg2, diffArg3});
        ret := exp;

        // only add second part if dx/dz is nonzero
        if not Expression.isZero(diffArg1) then
          ty    := Expression.typeOf(diffArg1);
          // x >= 0
          ret1  := Expression.RELATION(arg1, Operator.makeGreaterEq(ty), Expression.makeZero(ty), -1);
          // if x >= 0 then m1 else m2
          ret1  := Expression.IF(ty, ret1, arg2, arg3);
          // dx/dz * (if x >= 0 then m1 else m2)
          ret2  := Expression.MULTARY({diffArg1, ret1}, {}, mulOp);
          // sL(x, dm1/dz, dm2/dz) + dx/dz * (if x >= 0 then m1 else m2)
          ret   := Expression.MULTARY({ret, ret2}, {}, addOp);
        end if;
      then ret;

      // d/dz min(X) = (dX/dz)[argmin(X)]
      // d/dz max(X) = (dX/dz)[argmax(X)]
      // d/dz min(x,y) = if x < y then dx/dz else dy/dz
      // d/dz max(x,y) = if x > y then dx/dz else dy/dz
      case (Expression.CALL()) guard(name == "min" or name == "max")
      algorithm
        ret := match Call.arguments(exp.call)
          case {arg1} algorithm
            // dX/dz
            (diffArg1, diffArguments) := differentiateExpression(arg1, diffArguments);
            ty := Expression.typeOf(diffArg1);
            if Expression.isZero(diffArg1) then
              // make 0 of reduced type
              ret := Expression.makeZero(Type.arrayElementType(ty));
            else
              ret1 := Expression.CALL(Call.makeTypedCall(
                fn          = if name == "min" then NFBuiltinFuncs.ARG_MIN_ARR_REAL else NFBuiltinFuncs.ARG_MAX_ARR_REAL,
                args        = {arg1},
                variability = Expression.variability(arg1),
                purity      = NFPrefixes.Purity.PURE));
              ret := Expression.applySubscripts({Subscript.INDEX(ret1)}, diffArg1, true);
            end if;
          then ret;

          case {arg1, arg2} algorithm
            if isReverse then
              current_grad := diffArguments.current_grad;
              // Relation: for min use x<y; for max use x>y
              cond1 := Expression.RELATION(
                arg1,
                if name == "min" then Operator.makeLess(Expression.typeOf(arg1))
                                else Operator.makeGreater(Expression.typeOf(arg1)),
                arg2,
                -1);
              cond2 := Expression.RELATION(
                arg2,
                if name == "min" then Operator.makeLess(Expression.typeOf(arg2))
                                else Operator.makeGreater(Expression.typeOf(arg2)),
                arg1,
                -1);

              // Reverse local masks:
              // For min: grad_x = upstream if x<y else 0; grad_y = upstream if x>=y else 0
              // For max: grad_x = upstream if x>y else 0; grad_y = upstream if x<=y else 0
              zero1 := Expression.makeZero(Expression.typeOf(arg1));
              zero2 := Expression.makeZero(Expression.typeOf(arg2));

              grad_x := Expression.IF(
                Expression.typeOf(arg1),
                cond1,
                current_grad,
                zero1);

              grad_y := Expression.IF(
                Expression.typeOf(arg2),
                cond2,
                current_grad,
                zero2);

              // Reverse recurse arg1 with grad_x
              old_grad := diffArguments.current_grad;
              diffArguments.current_grad := grad_x;
              // dx/dz
              (diffArg1, diffArguments) := differentiateExpression(arg1, diffArguments);

              // Reverse recurse arg2 with grad_y
              diffArguments.current_grad := grad_y;
              // dy/dz
              (diffArg2, diffArguments) := differentiateExpression(arg2, diffArguments);

              // Restore upstream
              diffArguments.current_grad := old_grad;
            else
              // Forward: dx/dz and dy/dz
              (diffArg1, diffArguments) := differentiateExpression(arg1, diffArguments);
              (diffArg2, diffArguments) := differentiateExpression(arg2, diffArguments);
            end if;

            ty := Expression.typeOf(diffArg1);
            if Expression.isZero(diffArg1) and Expression.isZero(diffArg2) then
              ret := Expression.makeZero(ty);
            else
              // condition x < y or x > y
              ret1 := Expression.RELATION(arg1, if name == "min" then Operator.makeLess(ty) else Operator.makeGreater(ty), arg2, -1);
              // if condition then dx/dz else dy/dz
              ret := Expression.IF(ty, ret1, diffArg1, diffArg2);
            end if;
          then ret;
          else algorithm
            Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp) + "."});
          then fail();
        end match;
      then ret;

      // Builtin function call with one argument
      // df(x)/dz = df/dx * dx/dz
      case (Expression.CALL()) guard List.hasOneElement(Call.arguments(exp.call))
      algorithm
        arg1 := match Call.arguments(exp.call)
          case {arg1} then arg1;
          else algorithm
            Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp) + "."});
          then fail();
        end match;
        // differentiate the call df/dx
        ret := differentiateBuiltinCall1Arg(name, arg1);
        if not Expression.isZero(ret) then
          current_grad := diffArguments.current_grad;

          diffArguments.current_grad := Expression.MULTARY({current_grad, ret}, {}, mulOp);
          // differentiate the argument (inner derivative) dx/dz
          (diffArg1, diffArguments) := differentiateExpression(arg1, diffArguments);

          diffArguments.current_grad := current_grad;
          ret := Expression.MULTARY({ret, diffArg1}, {}, mulOp);
        end if;
      then ret;

      // Builtin function call with two arguments
      // df(x,y)/dz = df/dx * dx/dz + df/dy * dy/dz
      case (Expression.CALL()) guard(listLength(Call.arguments(exp.call)) == 2)
      algorithm
        (arg1, arg2) := match Call.arguments(exp.call)
          case {arg1, arg2} then (arg1, arg2);
          else algorithm
            Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp) + "."});
          then fail();
        end match;
        // differentiate the call
        (ret1, ret2) := differentiateBuiltinCall2Arg(name, arg1, arg2);             // df/dx and df/dy
        current_grad := diffArguments.current_grad;

        diffArguments.current_grad := Expression.MULTARY({current_grad, ret1}, {}, mulOp);
        (diffArg1, diffArguments) := differentiateExpression(arg1, diffArguments);  // dx/dz

        diffArguments.current_grad := Expression.MULTARY({current_grad, ret2}, {}, mulOp);
        (diffArg2, diffArguments) := differentiateExpression(arg2, diffArguments);  // dy/dz

        diffArguments.current_grad := current_grad;
        ret1 := Expression.MULTARY({ret1, diffArg1}, {}, mulOp);                    // df/dx * dx/dz
        ret2 := Expression.MULTARY({ret2, diffArg2}, {}, mulOp);                    // df/dy * dy/dz
        ret := Expression.MULTARY({ret1,ret2}, {}, addOp);                          // df/dx * dx/dz + df/dy * dy/dz
      then ret;

      // try some simple known cases
      case (Expression.CALL()) algorithm
        ret := match Call.functionNameLast(exp.call)
          case "sample" then Expression.BOOLEAN(false);
          else algorithm
            Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp)});
          then fail();
        end match;
      then ret;

      else algorithm
        Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed because of non-call expression: " + Expression.toString(exp)});
        then fail();
    end match;
  end differentiateBuiltinCall;

  function differentiateBuiltinCall1Arg
    "differentiate a builtin call with one argument."
    input String name;
    input Expression arg;
    output Expression derFuncCall;
  protected
    // these probably need to be adapted to the size and type of arg
    Operator.SizeClassification sizeClass = NFOperator.SizeClassification.SCALAR;
    Operator powOp = Operator.fromClassification((NFOperator.MathClassification.POWER, sizeClass), Type.REAL());
    Operator addOp = Operator.fromClassification((NFOperator.MathClassification.ADDITION, sizeClass), Type.REAL());
    Operator mulOp = Operator.fromClassification((NFOperator.MathClassification.MULTIPLICATION, sizeClass), Type.REAL());
  algorithm
    derFuncCall := match (name)
      local
        Expression ret;

      // all these have integer values and therefore zero derivative
      case ("sign")     then Expression.INTEGER(0);
      case ("ceil")     then Expression.REAL(0.0);
      case ("floor")    then Expression.REAL(0.0);
      case ("integer")  then Expression.INTEGER(0);

      // abs(arg) -> sign(arg)
      case ("abs") then Expression.CAST(
        Expression.typeOf(arg),
        Expression.CALL(Call.makeTypedCall(
          fn          = NFBuiltinFuncs.SIGN,
          args        = {arg},
          variability = Expression.variability(arg),
          purity      = NFPrefixes.Purity.PURE
        )));

      // sqrt(arg) -> 0.5/arg^(0.5)
      case ("sqrt") algorithm
        ret := Expression.BINARY(arg, powOp, Expression.REAL(0.5));       // arg^0.5
        ret := Expression.MULTARY({Expression.REAL(0.5)}, {ret}, mulOp);  // 1/(2*arg^0.5)
      then ret;

      // sin(arg) -> cos(arg)
      case ("sin") then Expression.CALL(Call.makeTypedCall(
          fn          = NFBuiltinFuncs.COS_REAL,
          args        = {arg},
          variability = Expression.variability(arg),
          purity      = NFPrefixes.Purity.PURE
        ));

      // cos(arg) -> -sin(arg)
      case ("cos") then Expression.negate(Expression.CALL(Call.makeTypedCall(
          fn          = NFBuiltinFuncs.SIN_REAL,
          args        = {arg},
          variability = Expression.variability(arg),
          purity      = NFPrefixes.Purity.PURE
        )));

      // tan(arg) -> 1/cos(arg)^2
      // kabdelhak: ToDo - investigate numerical properties: 1+tan(arg)^2 maybe better?
      case ("tan") algorithm
        ret := Expression.CALL(Call.makeTypedCall(
          fn          = NFBuiltinFuncs.COS_REAL,
          args        = {arg},
          variability = Expression.variability(arg),
          purity      = NFPrefixes.Purity.PURE));                         // cos(arg)
        ret := Expression.BINARY(ret, powOp, Expression.REAL(2.0));       // cos(arg)^2
        ret := Expression.MULTARY({Expression.REAL(1.0)}, {ret}, mulOp);  // 1/cos(arg)^2
      then ret;

      // asin(arg) -> 1/sqrt(1-arg^2)
      case ("asin") algorithm
        ret := Expression.BINARY(arg, powOp, Expression.REAL(2.0));       // arg^2
        ret := Expression.MULTARY({Expression.REAL(1.0)}, {ret}, addOp);  // 1-arg^2
        ret := Expression.BINARY(ret, powOp, Expression.REAL(0.5));       // sqrt(1-arg^2)
        ret := Expression.MULTARY({Expression.REAL(1.0)}, {ret}, mulOp);  // 1/sqrt(1-arg^2)
      then ret;

      // acos(arg) -> -1/sqrt(1-arg^2)
      case ("acos") algorithm
        ret := Expression.BINARY(arg, powOp, Expression.REAL(2.0));       // arg^2
        ret := Expression.MULTARY({Expression.REAL(1.0)}, {ret}, addOp);  // 1-arg^2
        ret := Expression.BINARY(ret, powOp, Expression.REAL(0.5));       // sqrt(1-arg^2)
        ret := Expression.MULTARY({Expression.REAL(-1.0)}, {ret}, mulOp); // -1/sqrt(1-arg^2)
      then ret;

      // atan(arg) -> 1/(1+arg^2)
      case ("atan") algorithm
        ret := Expression.BINARY(arg, powOp, Expression.REAL(2.0));       // arg^2
        ret := Expression.MULTARY({Expression.REAL(1.0), ret}, {}, addOp);// 1+arg^2
        ret := Expression.MULTARY({Expression.REAL(1.0)}, {ret}, mulOp);  // 1/(1+arg^2)
      then ret;

      // sinh(arg) -> cosh(arg)
      case ("sinh") then Expression.CALL(Call.makeTypedCall(
          fn          = NFBuiltinFuncs.COSH_REAL,
          args        = {arg},
          variability = Expression.variability(arg),
          purity      = NFPrefixes.Purity.PURE
        ));

      // cosh(arg) -> sinh(arg)
      case ("cosh") then Expression.CALL(Call.makeTypedCall(
          fn          = NFBuiltinFuncs.SINH_REAL,
          args        = {arg},
          variability = Expression.variability(arg),
          purity      = NFPrefixes.Purity.PURE
        ));

      // tanh(arg) -> 1-tanh(arg)^2
      case ("tanh") algorithm
        ret := Expression.CALL(Call.makeTypedCall(
          fn          = NFBuiltinFuncs.TANH_REAL,
          args        = {arg},
          variability = Expression.variability(arg),
          purity      = NFPrefixes.Purity.PURE));                         // tanh(arg)
        ret := Expression.BINARY(ret, powOp, Expression.REAL(2.0));       // tanh(arg)^2
        ret := Expression.MULTARY({Expression.REAL(1.0)}, {ret}, addOp);  // 1-tanh(arg)^2
      then ret;

      // acosh(arg) -> 1/sqrt(arg^2-1)
      case ("acosh") algorithm
        ret := Expression.BINARY(arg, powOp, Expression.REAL(2.0));       // arg^2
        ret := Expression.MULTARY({ret}, {Expression.REAL(1.0)}, addOp);  // arg^2-1
        ret := Expression.BINARY(ret, powOp, Expression.REAL(0.5));       // sqrt(arg^2-1)
        ret := Expression.MULTARY({Expression.REAL(1.0)}, {ret}, mulOp);  // 1/sqrt(arg^2-1)
      then ret;

      // asinh(arg) -> 1/sqrt(arg^2+1)
      case ("asinh") algorithm
        ret := Expression.BINARY(arg, powOp, Expression.REAL(2.0));         // arg^2
        ret := Expression.MULTARY({ret, Expression.REAL(1.0)}, {}, addOp);  // arg^2+1
        ret := Expression.BINARY(ret, powOp, Expression.REAL(0.5));         // sqrt(arg^2+1)
        ret := Expression.MULTARY({Expression.REAL(1.0)}, {ret}, mulOp);    // 1/sqrt(arg^2+1)
      then ret;

      // atanh(arg) -> 1/(1-arg^2)
      case ("atanh") algorithm
        ret := Expression.BINARY(arg, powOp, Expression.REAL(2.0));       // arg^2
        ret := Expression.MULTARY({Expression.REAL(1.0)}, {ret}, addOp);  // 1-arg^2
        ret := Expression.MULTARY({Expression.REAL(1.0)}, {ret}, mulOp);  // 1/(1-arg^2)
      then ret;

      // exp(arg) -> exp(arg)
      case ("exp") then Expression.CALL(Call.makeTypedCall(
          fn          = NFBuiltinFuncs.EXP_REAL,
          args        = {arg},
          variability = Expression.variability(arg),
          purity      = NFPrefixes.Purity.PURE
        ));

      // log(arg) -> 1/arg
      case ("log") then Expression.MULTARY({Expression.REAL(1.0)}, {arg}, mulOp);

      // log10(arg) -> 1/(arg*log(10))
      case ("log10") algorithm
        ret := Expression.CALL(Call.makeTypedCall(
          fn          = NFBuiltinFuncs.LOG_REAL,
          args        = {Expression.REAL(10.0)},
          variability = Variability.CONSTANT,
          purity      = NFPrefixes.Purity.PURE));                             // log(10)
        ret := Expression.MULTARY({Expression.REAL(1.0)}, {arg, ret}, mulOp); // 1/(arg*log(10))
      then ret;

      else algorithm
        Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + name});
      then fail();
    end match;
  end differentiateBuiltinCall1Arg;

  function differentiateBuiltinCall2Arg
    "differentiate a builtin call with two arguments."
    input String name;
    input Expression arg1;
    input Expression arg2;
    output Expression derFuncCall1;
    output Expression derFuncCall2;
  protected
    // these probably need to be adapted to the size and type of arg
    Operator.SizeClassification sizeClass = NFOperator.SizeClassification.SCALAR;
    Operator powOp = Operator.fromClassification((NFOperator.MathClassification.POWER, sizeClass), Type.REAL());
    Operator addOp = Operator.fromClassification((NFOperator.MathClassification.ADDITION, sizeClass), Type.REAL());
    Operator mulOp = Operator.fromClassification((NFOperator.MathClassification.MULTIPLICATION, sizeClass), Type.REAL());
  algorithm
    (derFuncCall1, derFuncCall2) := match (name)
      local
        Expression exp1, exp2, ret1, ret2;

      // div(arg1, arg2) truncates the fractional part of arg1/arg2 so it has discrete values
      // therefore it has zero derivative where it's defined
      case ("div") then (Expression.INTEGER(0), Expression.INTEGER(0));

      // d/darg1 mod(arg1, arg2) -> 1
      // d/darg2 mod(arg1, arg2) -> -floor(arg1/arg2)
      case ("mod") algorithm
        exp2 := Expression.CALL(Call.makeTypedCall(
          fn          = NFBuiltinFuncs.FLOOR,
          args        = {Expression.MULTARY({arg1}, {arg2}, mulOp)},          // arg1/arg2
          variability = Prefixes.variabilityMax(Expression.variability(arg1), Expression.variability(arg2)),
          purity      = NFPrefixes.Purity.PURE
        ));                                                                   // floor(arg1/arg2)
        ret2 := Expression.negate(exp2);                                      // -floor(arg1/arg2)
      then (Expression.REAL(1), ret2);

      // d/darg1 rem(arg1, arg2) -> 1
      // d/darg2 rem(arg1, arg2) -> -div(arg1, arg2)
      case ("rem") algorithm
        exp2 := Expression.CALL(Call.makeTypedCall(
          fn          = NFBuiltinFuncs.DIV_REAL,
          args        = {arg1, arg2},
          variability = Prefixes.variabilityMax(Expression.variability(arg1), Expression.variability(arg2)),
          purity      = NFPrefixes.Purity.PURE
        ));                                                                   // div(arg1, arg2)
        ret2 := Expression.negate(exp2);                                      // -div(arg1, arg2)
      then (Expression.REAL(1), ret2);

      // d/darg1 atan2(arg1, arg2) -> -arg2/(arg1^2+arg2^2)
      // d/darg2 atan2(arg1, arg2) ->  arg1/(arg1^2+arg2^2)
      case ("atan2") algorithm
        exp1 := Expression.BINARY(arg1, powOp, Expression.REAL(2.0));         // arg1^2
        exp2 := Expression.BINARY(arg2, powOp, Expression.REAL(2.0));         // arg2^2
        exp1 := Expression.MULTARY({exp1, exp2}, {}, addOp);                  // arg1^2+arg2^2
        ret1 := Expression.MULTARY({Expression.negate(arg2)}, {exp1}, mulOp); // -arg2/(arg1^2+arg2^2)
        ret2 := Expression.MULTARY({arg1}, {exp1}, mulOp);                    //  arg1/(arg1^2+arg2^2)
      then (ret1, ret2);

      else algorithm
        Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + name});
      then fail();
    end match;
  end differentiateBuiltinCall2Arg;

  function differentiateFunction
    input Function func;
    output Function der_func;
    input UnorderedMap<String, Boolean> interface_map;
    input output DifferentiationArguments diffArguments;
  algorithm
    der_func := match func
      local
        InstNode node;
        Pointer<Class> cls;
        Class new_cls;
        DifferentiationArguments funcDiffArgs;
        UnorderedMap<ComponentRef, ComponentRef> diff_map = UnorderedMap.new<ComponentRef>(ComponentRef.hash, ComponentRef.isEqual);
        list<Algorithm> algorithms;
        FunctionDerivative funcDer;
        Function dummy_func;
        CachedData cachedData;
        String der_func_name;
        list<InstNode> inputs, locals, outputs, local_outputs;
        list<Slot> slots;

      case der_func as Function.FUNCTION(node = node as InstNode.CLASS_NODE(cls = cls)) algorithm
        new_cls := match Pointer.access(cls)
          case new_cls as Class.INSTANCED_CLASS() algorithm
            // prepare outputs that become locals
            local_outputs     := list(InstNode.setComponentDirection(NFPrefixes.Direction.NONE, lout) for lout in der_func.outputs);
            local_outputs     := list(InstNode.protect(lout) for lout in local_outputs);

            // prepare differentiation arguments
            funcDiffArgs          := DifferentiationArguments.default();
            funcDiffArgs.diffType := DifferentiationType.FUNCTION;
            funcDiffArgs.funcMap  := diffArguments.funcMap;
            createInterfaceDerivatives(der_func.inputs, interface_map, diff_map);
            createInterfaceDerivatives(der_func.locals, interface_map, diff_map);
            createInterfaceDerivatives(der_func.outputs, interface_map, diff_map);
            funcDiffArgs.diff_map := SOME(diff_map);

            // differentiate interface arguments
            (inputs, funcDiffArgs)  := differentiateFunctionInterfaceNodes(der_func.inputs, interface_map, diff_map, funcDiffArgs, true);
            (locals, funcDiffArgs)  := differentiateFunctionInterfaceNodes(der_func.locals, interface_map, diff_map, funcDiffArgs, false);
            (outputs, funcDiffArgs) := differentiateFunctionInterfaceNodes(der_func.outputs, interface_map, diff_map, funcDiffArgs, false);

            // update inputs, outputs and locals, add old outputs to locals as they might still be used as temporary variables
            der_func.inputs   := inputs;
            der_func.locals   := List.flatten({der_func.locals, locals, local_outputs});
            der_func.outputs  := outputs;
            // also add the new locals to the class
            new_cls.elements := ClassTree.appendComponentsToFlatTree(locals, new_cls.elements);

            // differentiate slots
            (slots, funcDiffArgs) := createSlotDerivatives(der_func.slots, interface_map, diff_map, funcDiffArgs);
            der_func.slots        := listAppend(der_func.slots, slots);

            // create "fake" function with correct interface to have the interface
            // in the case of recursive differentiation (e.g. function calls itself)
            dummy_func      := func;
            node.cls        := Pointer.create(new_cls);
            der_func_name   := NBVariable.FUNCTION_DERIVATIVE_STR + intString(listLength(func.derivatives));
            node.name       := der_func_name + "." + node.name;
            node.definition := SCodeUtil.setElementName(node.definition, node.name);
            // create "fake" function from new node, update cache to get correct derivative name
            der_func.path   := AbsynUtil.prefixPath(der_func_name, der_func.path);
            cachedData      := CachedData.FUNCTION({der_func}, true, false);
            der_func.node   := InstNode.setFuncCache(node, cachedData);

            // create fake derivative
            funcDer := FunctionDerivative.FUNCTION_DER(
              derivativeFn          = der_func.node,
              derivedFn             = dummy_func.node,
              order                 = Expression.INTEGER(1),
              conditions            = {}, // possibly needs updating
              lowerOrderDerivatives = {}  // possibly needs updating
            );

            // add fake derivative to function tree
            dummy_func.derivatives := funcDer :: dummy_func.derivatives;
            UnorderedMap.add(dummy_func.path, dummy_func, funcDiffArgs.funcMap);

            // differentiate function statements (if there are any. empty for function pointer arguments)
            funcDiffArgs := match new_cls.sections
              local
                Sections sections;
              case sections as Sections.SECTIONS() algorithm
                (algorithms, funcDiffArgs) := List.mapFold(sections.algorithms, differentiateAlgorithm, funcDiffArgs);

                // add them to new node
                sections.algorithms := algorithms;
                new_cls.sections    := sections;
              then funcDiffArgs;
              else funcDiffArgs;
            end match;

            node.cls              := Pointer.create(new_cls);
            cachedData            := CachedData.FUNCTION({der_func}, true, false);
            der_func.node         := InstNode.setFuncCache(node, cachedData);
            der_func.derivatives  := {};

            // save the function tree
            diffArguments.funcMap := funcDiffArgs.funcMap;
          then new_cls;

          else algorithm
            Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for class " + Class.toFlatString(Pointer.access(cls), func.node) + "."});
          then fail();
        end match;

        // add function to function tree
        UnorderedMap.add(der_func.path, der_func, diffArguments.funcMap);
        // add new function as derivative to original function
        funcDer := FunctionDerivative.FUNCTION_DER(
          derivativeFn          = der_func.node,
          derivedFn             = func.node,
          order                 = Expression.INTEGER(1),
          conditions            = {}, // possibly needs updating
          lowerOrderDerivatives = {}  // possibly needs updating
        );
        func.derivatives := List.appendElt(funcDer, func.derivatives);
        UnorderedMap.add(func.path, func, diffArguments.funcMap);
      then der_func;

      else algorithm
        Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for uninstantiated function " + Function.signatureString(func) + "."});
      then fail();
    end match;
    if Flags.isSet(Flags.DEBUG_DIFFERENTIATION) then
      print("\n[BEFORE] " + Function.toFlatString(func) + "\n");
      print("\n[AFTER ] " + Function.toFlatString(der_func) + "\n\n");
    end if;
  end differentiateFunction;

  function differentiateFunctionInterfaceNodes
    "differentiates function interface nodes (inputs, outputs, locals) and
    adds them to the diff_map used for differentiation. Also returns the new
    interface node lists for the differentiated function.
    (outputs only have the differentiated and not the original interface nodes)"
    input output list<InstNode> interface_nodes;
    input UnorderedMap<String, Boolean> interface_map;
    input UnorderedMap<ComponentRef, ComponentRef> diff_map;
    input output DifferentiationArguments diffArgs;
    input Boolean keepOld;
  protected
    list<InstNode> new_nodes;
    InstNode d_node;
  algorithm
    new_nodes := if keepOld then listReverse(interface_nodes) else {};
    for node in interface_nodes loop
      if not UnorderedMap.contains(InstNode.name(node), interface_map) then
        (d_node, diffArgs) := differentiateFunctionInterfaceNode(node, diff_map, diffArgs);
        new_nodes := d_node :: new_nodes;
      end if;
    end for;
    interface_nodes := listReverse(new_nodes);
  end differentiateFunctionInterfaceNodes;

  function differentiateFunctionInterfaceNode
    input InstNode node;
    output InstNode d_node;
    input UnorderedMap<ComponentRef, ComponentRef> diff_map;
    input output DifferentiationArguments diffArgs;
  protected
    ComponentRef cref, diff_cref;
    Component comp;
    Binding binding;
    Function func, d_func;
  algorithm
    cref := ComponentRef.fromNode(node, InstNode.getType(node));
      diff_cref := UnorderedMap.getSafe(cref, diff_map, sourceInfo());
      diff_cref := match diff_cref
        case ComponentRef.CREF(node = d_node as InstNode.COMPONENT_NODE()) algorithm
          // differentiate bindings
          comp := Pointer.access(d_node.component);
          comp := match comp
            case comp as Component.COMPONENT() algorithm
              (binding, diffArgs) := differentiateBinding(comp.binding, diffArgs);
              comp.binding := binding;
            then comp;
            else comp;
          end match;
          d_node.component := Pointer.create(comp);
          diff_cref.node := d_node;
        then diff_cref;
        else diff_cref;
      end match;

      // if the node is a function, its a function pointer argument
      if InstNode.isFunction(node) then
        func := listHead(Function.getCachedFuncs(node));
        (d_func, diffArgs) := differentiateFunction(func, UnorderedMap.new<Boolean>(stringHashDjb2, stringEqual), diffArgs);
      end if;

      d_node := ComponentRef.node(diff_cref);
  end differentiateFunctionInterfaceNode;

  function createInterfaceDerivatives
    input list<InstNode> interface_nodes;
    input UnorderedMap<String, Boolean> interface_map;
    input UnorderedMap<ComponentRef, ComponentRef> diff_map;
  protected
    ComponentRef cref;

    function addCref
      input ComponentRef cref;
      input UnorderedMap<ComponentRef, ComponentRef> diff_map;
    protected
      ComponentRef diff_cref;
      list<ComponentRef> children;
    algorithm
      diff_cref := BVariable.makeFDerVar(cref);
      UnorderedMap.add(cref, diff_cref, diff_map);

      children := ComponentRef.getRecordChildren(cref);
      for child in children loop
        addCref(child, diff_map);
      end for;
    end addCref;
  algorithm
    for node in interface_nodes loop
      if not UnorderedMap.contains(InstNode.name(node), interface_map) then
        cref := ComponentRef.fromNode(node, InstNode.getType(node));
        addCref(cref, diff_map);
      end if;
    end for;
  end createInterfaceDerivatives;

  function createSlotDerivatives
    input list<Slot> slots;
    output list<Slot> new_slots = {};
    input UnorderedMap<String, Boolean> interface_map;
    input UnorderedMap<ComponentRef, ComponentRef> diff_map;
    input output DifferentiationArguments diffArgs;
  protected
    InstNode d_node;
    Integer local_index = listLength(slots) + 1;
  algorithm
    for slot in slots loop
      if not  UnorderedMap.contains(InstNode.name(slot.node), interface_map) then
        (d_node, diffArgs) := differentiateFunctionInterfaceNode(slot.node, diff_map, diffArgs);
        slot.node := d_node;
        slot.index := local_index;
        new_slots := slot :: new_slots;
        local_index := local_index + 1;
      end if;
    end for;
    new_slots := listReverse(new_slots);
  end createSlotDerivatives;

  function resolvePartialDerivatives
    input output Function func;
    input UnorderedMap<Path, Function> funcMap;
  protected
    Function der_func;
    InstNode node;
    Pointer<Class> cls, tmp_cls;
    Class new_cls, wrap_cls;
    Sections sections;
    UnorderedMap<ComponentRef, ComponentRef> diff_map = UnorderedMap.new<ComponentRef>(ComponentRef.hash, ComponentRef.isEqual);
    UnorderedMap<String, Boolean> interface_map;
    DifferentiationArguments diffArgs = DifferentiationArguments.default();
    list<Algorithm> algorithms;
    CachedData cachedData;
    InstNode diffVar;
    ComponentRef diffCref;
    list<InstNode> locals, outputs, local_outputs;
    Boolean changed = false;
  algorithm
    func := match func
      case der_func as Function.FUNCTION(node = InstNode.CLASS_NODE(cls = cls)) algorithm
        wrap_cls := Pointer.access(cls);
        new_cls := match wrap_cls
          case wrap_cls as Class.TYPED_DERIVED(baseClass = node as InstNode.CLASS_NODE(cls = tmp_cls)) algorithm
            new_cls :=  match Pointer.access(tmp_cls)
              case new_cls as Class.INSTANCED_CLASS(sections = sections as Sections.SECTIONS(algorithms = algorithms)) algorithm
                // prepare differentiation arguments
                diffArgs.diffType     := DifferentiationType.FUNCTION;
                diffArgs.funcMap      := funcMap;

                interface_map := UnorderedMap.fromLists(list(InstNode.name(var) for var in der_func.inputs), List.fill(false, listLength(der_func.inputs)), stringHashDjb2, stringEqual);

                // add all differentiated inputs to the interface map
                for var in List.getAtIndexLst(der_func.inputs, der_func.derivedInputs) loop
                  UnorderedMap.remove(InstNode.name(var), interface_map);

                  // prepare outputs that become locals
                  local_outputs     := list(InstNode.setComponentDirection(NFPrefixes.Direction.NONE, node) for node in der_func.outputs);
                  local_outputs     := list(InstNode.protect(node) for node in local_outputs);

                  // differentiate interface arguments
                  createInterfaceDerivatives({var}, interface_map, diff_map);
                  createInterfaceDerivatives(der_func.locals, interface_map, diff_map);
                  createInterfaceDerivatives(der_func.outputs, interface_map, diff_map);
                  diffArgs.diff_map   := SOME(diff_map);

                  (locals, diffArgs)  := differentiateFunctionInterfaceNodes(der_func.locals, interface_map, diff_map, diffArgs, true);
                  (outputs, diffArgs) := differentiateFunctionInterfaceNodes(der_func.outputs, interface_map, diff_map, diffArgs, false);

                  diffCref          := UnorderedMap.getSafe(ComponentRef.fromNode(var, InstNode.getType(var)), diff_map, sourceInfo());
                  der_func.locals   := listAppend(locals, local_outputs);
                  der_func.outputs  := outputs;

                  // differentiate function statements
                  (algorithms, diffArgs) := List.mapFold(algorithms, differentiateAlgorithm, diffArgs);
                  algorithms := Algorithm.mapExpList(algorithms, function Replacements.single(old = Expression.fromCref(diffCref), new = Expression.makeOne(ComponentRef.getSubscriptedType(diffCref))));

                  UnorderedMap.add(InstNode.name(var), false, interface_map);
                end for;

                // add them to new node
                sections.algorithms     := algorithms;
                new_cls.sections        := sections;
                new_cls.ty              := wrap_cls.ty;
                new_cls.restriction     := wrap_cls.restriction;
                node.cls                := Pointer.create(new_cls);
                cachedData              := CachedData.FUNCTION({der_func}, true, false);
                der_func.node           := InstNode.setFuncCache(node, cachedData);
                der_func.derivatives    := {};
                der_func.derivedInputs  := {};

                changed := true;
              then new_cls;

              else wrap_cls;
            end match;
          then new_cls;
          else wrap_cls;
        end match;

        if changed then
          if Flags.isSet(Flags.DEBUG_DIFFERENTIATION) then
            print("\n[BEFORE] " + Function.toFlatString(func) + "\n");
            print("\n[AFTER ] " + Function.toFlatString(der_func) + "\n\n");
          end if;
          UnorderedMap.add(der_func.path, der_func, funcMap);
        end if;
      then der_func;

      else func;
    end match;
  end resolvePartialDerivatives;

  function differentiateAlgorithm
    input output Algorithm alg;
    input output DifferentiationArguments diffArguments;
  protected
    list<list<Statement>> statements;
    list<Statement> statements_flat;
    list<ComponentRef> inputs, outputs;
  algorithm
    (statements, diffArguments) := List.mapFold(alg.statements, differentiateStatement, diffArguments);
    statements_flat := List.flatten(statements);
    (inputs, outputs) := Algorithm.getInputsOutputs(statements_flat);
    alg := Algorithm.ALGORITHM(statements_flat, inputs, outputs, alg.scope, alg.source);
  end differentiateAlgorithm;

  function differentiateStatement
    input Statement stmt;
    output list<Statement> diff_stmts "two statements for 'Real' assignments (diff; original) and else one";
    input output DifferentiationArguments diffArguments;
  algorithm
    diff_stmts := match stmt
      local
        Statement diff_stmt;
        Expression exp, lhs, rhs;
        list<Statement> branch_stmts_flat;
        list<list<Statement>> branch_stmts;
        list<tuple<Expression, list<Statement>>> branches = {};

      // I. differentiate 'Real' assignment and return differentiated and original statement
      case diff_stmt as Statement.ASSIGNMENT() guard(Type.isReal(Type.arrayElementType(Expression.typeOf(diff_stmt.lhs)))) algorithm
        (lhs, diffArguments) := differentiateExpression(diff_stmt.lhs, diffArguments);
        (rhs, diffArguments) := differentiateExpression(diff_stmt.rhs, diffArguments);
        diff_stmt.lhs := lhs;
        diff_stmt.rhs := SimplifyExp.simplifyDump(rhs, true, getInstanceName());
      then {diff_stmt, stmt};

      // II. delegate differentiation to body and only return differentiated statement
      case diff_stmt as Statement.FOR() algorithm
        (branch_stmts, diffArguments) := List.mapFold(diff_stmt.body, differentiateStatement, diffArguments);
        diff_stmt.body := List.flatten(branch_stmts);
      then {diff_stmt};

      case diff_stmt as Statement.WHILE() algorithm
        (branch_stmts, diffArguments) := List.mapFold(diff_stmt.body, differentiateStatement, diffArguments);
        diff_stmt.body := List.flatten(branch_stmts);
      then {diff_stmt};

      case diff_stmt as Statement.FAILURE() algorithm
        (branch_stmts, diffArguments) := List.mapFold(diff_stmt.body, differentiateStatement, diffArguments);
        diff_stmt.body := List.flatten(branch_stmts);
      then {diff_stmt};

      case diff_stmt as Statement.IF() algorithm
        for branch in diff_stmt.branches loop
          (exp, branch_stmts_flat) := branch;
          (branch_stmts, diffArguments) := List.mapFold(branch_stmts_flat, differentiateStatement, diffArguments);
          branches := (exp, List.flatten(branch_stmts)) :: branches;
        end for;
        diff_stmt.branches := listReverse(branches);
      then {diff_stmt};

      case diff_stmt as Statement.WHEN() algorithm
        for branch in diff_stmt.branches loop
          (exp, branch_stmts_flat) := branch;
          (branch_stmts, diffArguments) := List.mapFold(branch_stmts_flat, differentiateStatement, diffArguments);
          branches := (exp, List.flatten(branch_stmts)) :: branches;
        end for;
        diff_stmt.branches := listReverse(branches);
      then {diff_stmt};

      // III. assignments of non-Real are not differentiated, as well as empty statements
      case Statement.ASSIGNMENT()           then {stmt};
      case Statement.FUNCTION_ARRAY_INIT()  then {stmt};
      case Statement.ASSERT()               then {stmt};
      case Statement.TERMINATE()            then {stmt};
      case Statement.NORETCALL()            then {stmt};
      case Statement.RETURN()               then {stmt};
      case Statement.BREAK()                then {stmt};

      else algorithm
        Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for:" + Statement.toString(stmt)});
      then fail();
    end match;
  end differentiateStatement;

  function differentiateBinary
    "Some of this is depcreated because of Expression.MULTARY().
    Will always try to convert to MULTARY whenever possible. (commutativity)"
    input output Expression exp "Has to be Expression.BINARY()";
    input output DifferentiationArguments diffArguments;
  algorithm
    if Flags.isSet(Flags.DEBUG_ADJOINT) then
      print("differentiateBinary: " + Expression.toString(exp) + "\n");
    end if;
    (exp, diffArguments) := match exp
      local
        Expression exp1, exp2, diffExp1, diffExp2, e1, e2, e3, res;
        Operator operator, addOp, mulOp, powOp, divOp;
        Operator.SizeClassification sizeClass, powSizeClass;
        Expression current_grad;
        // Local reverse grads (to assign before recursing)
        Expression grad_exp1, grad_exp2, denom2, numUF;
        Boolean isVec1, isVec2, isMat1, isMat2;
        Type ty1, ty2;
        Integer r1, r2;
        list<Integer> dim1, dim2;
        Boolean isReverse = Util.isSome(diffArguments.adjoint_map);

      // Addition calculations (ADD, ADD_EW, ...)
      // (f + g)' = f' + g'
      // Adjoint rule: ∂(f + g)/∂f = 1, ∂(f + g)/∂g = 1
      // diffArguments.current_grad = ∂Out/∂(f + g) * ∂(f + g)/∂f = current_grad * 1 = current_grad
      case Expression.BINARY(exp1 = exp1, operator = operator, exp2 = exp2)
        guard(Operator.getMathClassification(operator) == NFOperator.MathClassification.ADDITION)
        algorithm
          //current_grad := diffArguments.current_grad;

          //diffArguments.current_grad := current_grad; // not needed, but for clarity
          (diffExp1, diffArguments) := differentiateExpression(exp1, diffArguments);

          //diffArguments.current_grad := current_grad; // not needed, but for clarity
          (diffExp2, diffArguments) := differentiateExpression(exp2, diffArguments);

          //diffArguments.current_grad := current_grad;
      then (Expression.MULTARY({diffExp1, diffExp2}, {}, operator), diffArguments);

      // Subtraction calculations (SUB, SUB_EW, ...)
      // (f - g)' = f' - g'
      // ∂(f - g)/∂f = 1, ∂(f - g)/∂g = -1
      case Expression.BINARY(exp1 = exp1, operator = operator, exp2 = exp2)
        guard(Operator.getMathClassification(operator) == NFOperator.MathClassification.SUBTRACTION)
        algorithm
          current_grad := diffArguments.current_grad;

          // differentiate first argument
          //diffArguments.current_grad := current_grad; // not needed, but for clarity
          (diffExp1, diffArguments) := differentiateExpression(exp1, diffArguments);

          // differentiate second argument
          diffArguments.current_grad := Expression.negate(current_grad);
          (diffExp2, diffArguments) := differentiateExpression(exp2, diffArguments);

          diffArguments.current_grad := current_grad;
          // create addition operator from the size classification of original multiplication operator
          (_, sizeClass) := Operator.classify(operator);
          addOp := Operator.fromClassification((NFOperator.MathClassification.ADDITION, sizeClass), operator.ty);
      then (Expression.MULTARY({diffExp1}, {diffExp2}, addOp), diffArguments);

      // Multiplication (MUL, MUL_EW, ...)
      // (f * g)' =  fg' + f'g
      // ∂(f * g)/∂f = g, ∂(f * g)/∂g = f
      case Expression.BINARY(exp1 = exp1, operator = operator, exp2 = exp2)
        guard(Operator.getMathClassification(operator) == NFOperator.MathClassification.MULTIPLICATION)
      algorithm
        if isReverse then
          // Upstream gradient
          current_grad := diffArguments.current_grad;

          // Type / rank info
          ty1 := Expression.typeOf(exp1);
          ty2 := Expression.typeOf(exp2);
          r1 := if Type.isArray(ty1) then Type.dimensionCount(ty1) else 0;
          r2 := if Type.isArray(ty2) then Type.dimensionCount(ty2) else 0;
          dim1 := if r1 > 0 then NFDimension.sizes(Type.arrayDims(ty1)) else {};
          dim2 := if r2 > 0 then NFDimension.sizes(Type.arrayDims(ty2)) else {};

          isVec1 := (r1 == 1);
          isVec2 := (r2 == 1);
          isMat1 := (r1 == 2);
          isMat2 := (r2 == 2);

          // Original size classification (kept for forward combination)
          (_, sizeClass) := Operator.classify(operator);
          // Decide shape case
          // Inner product
          if isVec1 and isVec2 and sizeClass == NFOperator.SizeClassification.SCALAR then
            grad_exp1 := Expression.BINARY(
              current_grad,
              Operator.fromClassification(
                (NFOperator.MathClassification.MULTIPLICATION, NFOperator.SizeClassification.SCALAR_ARRAY),
                operator.ty),
              exp2); // G * y
            grad_exp2 := Expression.BINARY(
              current_grad,
              Operator.fromClassification(
                (NFOperator.MathClassification.MULTIPLICATION, NFOperator.SizeClassification.SCALAR_ARRAY),
                operator.ty),
              exp1); // G * x
          // outer product
          elseif isMat1 and isMat2 and sizeClass == NFOperator.SizeClassification.MATRIX and listGet(dim1, 1) > 1 and listGet(dim1, 2) == 1 and listGet(dim2, 1) == 1 and listGet(dim2, 2) > 1 then
            grad_exp1 := Expression.BINARY(
              current_grad,
              Operator.fromClassification(
                (NFOperator.MathClassification.MULTIPLICATION, NFOperator.SizeClassification.MATRIX),
                operator.ty),
              exp2); // G * y
            grad_exp2 := Expression.BINARY(
              typeTransposeCall(current_grad),
              Operator.fromClassification(
                (NFOperator.MathClassification.MULTIPLICATION, NFOperator.SizeClassification.MATRIX),
                operator.ty),
              exp1); // G^T * x
          // Matrix * Vector
          elseif isMat1 and isVec2 then
            grad_exp1 := Expression.BINARY(
              current_grad,
              Operator.fromClassification(
                (NFOperator.MathClassification.MULTIPLICATION, NFOperator.SizeClassification.MATRIX),
                operator.ty),
              typeTransposeCall(exp2)); // G * xᵀ
            grad_exp2 := Expression.BINARY(
              typeTransposeCall(exp1),
              Operator.fromClassification(
                (NFOperator.MathClassification.MULTIPLICATION, NFOperator.SizeClassification.MATRIX_VECTOR),
                operator.ty),
              current_grad); // Aᵀ * G
          // Vector * Matrix
          elseif isVec1 and isMat2 then
            // grad w.r.t exp1 (x): B * Gᵀ  -> treat Gᵀ via transpose(current_grad)
            grad_exp1 := Expression.BINARY(
              exp2,
              Operator.fromClassification(
                (NFOperator.MathClassification.MULTIPLICATION, NFOperator.SizeClassification.MATRIX_VECTOR),
                operator.ty),
              typeTransposeCall(current_grad));    // B * Gᵀ  (shape n)
            // grad w.r.t exp2 (B): xᵀ * G
            grad_exp2 := Expression.BINARY(
              typeTransposeCall(exp1),
              Operator.fromClassification(
                (NFOperator.MathClassification.MULTIPLICATION, NFOperator.SizeClassification.MATRIX),
                operator.ty),
              current_grad);                   // xᵀ * G  (outer product)
          // Matrix * Matrix
          elseif isMat1 and isMat2 then
            grad_exp1 := Expression.BINARY(
              current_grad,
              Operator.fromClassification(
                (NFOperator.MathClassification.MULTIPLICATION, NFOperator.SizeClassification.MATRIX),
                operator.ty),
              typeTransposeCall(exp2));              // G * Bᵀ
            grad_exp2 := Expression.BINARY(
              typeTransposeCall(exp1),
              Operator.fromClassification(
                (NFOperator.MathClassification.MULTIPLICATION, NFOperator.SizeClassification.MATRIX),
                operator.ty),
              current_grad);                     // Aᵀ * G
          else
            grad_exp1 := Expression.MULTARY({current_grad, exp2}, {}, makeMulFromOperator(operator));
            grad_exp2 := Expression.MULTARY({current_grad, exp1}, {}, makeMulFromOperator(operator));
          end if;

          // Reverse recurse: exp1
          diffArguments.current_grad := grad_exp1;
          (diffExp1, diffArguments) := differentiateExpression(exp1, diffArguments);
          // Reverse recurse: exp2
          diffArguments.current_grad := grad_exp2;
          (diffExp2, diffArguments) := differentiateExpression(exp2, diffArguments);
          // Restore upstream
          diffArguments.current_grad := current_grad;
        else
          // only forward differentiation
          (diffExp1, diffArguments) := differentiateExpression(exp1, diffArguments);
          (diffExp2, diffArguments) := differentiateExpression(exp2, diffArguments);
        end if;
        // Forward derivative assembly: f*g' + f'*g
        sizeClass := Operator.classifyAddition(operator);
        addOp := Operator.fromClassification(
          (NFOperator.MathClassification.ADDITION, sizeClass),
          operator.ty);
      then (Expression.MULTARY(
              {Expression.BINARY(exp1, operator, diffExp2),
                Expression.BINARY(diffExp1, operator, exp2)},
              {},
              addOp
            ),
            diffArguments);

      // Division (DIV, DIV_EW, ...)
      // (f / g)' = (f'g - fg') / g^2
      case Expression.BINARY(exp1 = exp1, operator = operator, exp2 = exp2)
        guard(Operator.getMathClassification(operator) == NFOperator.MathClassification.DIVISION)
        algorithm
          powSizeClass := NFOperator.SizeClassification.SCALAR;
          powOp := Operator.fromClassification((NFOperator.MathClassification.POWER, powSizeClass), Type.REAL());
          if isReverse then
            current_grad := diffArguments.current_grad; // upstream gradient
            diffArguments.current_grad := Expression.MULTARY({current_grad}, {exp2}, Operator.fromClassification(
              (NFOperator.MathClassification.MULTIPLICATION, if Type.isArray(Expression.typeOf(current_grad)) then NFOperator.SizeClassification.ARRAY_SCALAR else NFOperator.SizeClassification.SCALAR),
              operator.ty)); // z = f/g going into f
          end if;
          (diffExp1, diffArguments) := differentiateExpression(exp1, diffArguments);
          if isReverse then
            // Reverse local grad for denominator g: G_g = - ( (upstream .* f) / g^2 )
            // Build g^2
            denom2 := Expression.BINARY(exp2, powOp, Expression.REAL(2.0));

            // Build numerator = upstream .* f  with proper size classification
            numUF := Expression.BINARY(current_grad, if Type.isArray(Expression.typeOf(exp1)) then Operator.makeScalarProduct(operator.ty) else Operator.fromClassification(
              (NFOperator.MathClassification.MULTIPLICATION, NFOperator.SizeClassification.SCALAR),
              Type.REAL()), exp1);

            // Divide by g^2 (array/scalar-safe)
            divOp := Operator.fromClassification(
              (NFOperator.MathClassification.DIVISION, NFOperator.SizeClassification.SCALAR),
              Type.REAL());
            diffArguments.current_grad := Expression.negate(
              Expression.BINARY(numUF, divOp, denom2));
          end if;
          (diffExp2, diffArguments) := differentiateExpression(exp2, diffArguments);

          if isReverse then
            // Restore upstream
            diffArguments.current_grad := current_grad;
          end if;
          // create subtraction and multiplication operator from the size classification of original division operator
          (_, sizeClass) := Operator.classify(operator);
          // the frontend treats multiplication equally for element and nen elementwise, but pow needs to have the correct operator
          addOp := Operator.fromClassification((NFOperator.MathClassification.ADDITION, sizeClass), operator.ty);
          mulOp := Operator.fromClassification((NFOperator.MathClassification.MULTIPLICATION, sizeClass), operator.ty);
      then (Expression.MULTARY(
              {Expression.MULTARY(
                {Expression.BINARY(exp1, mulOp, diffExp2)},              // fg'
                {Expression.BINARY(diffExp1, mulOp, exp2)},              // - f'g
                addOp
              )},
              {Expression.BINARY(exp2, powOp, Expression.REAL(2.0))},    // / g^2
              mulOp
            ),
            diffArguments);

      // Power (POW, POW_EW, ...) with base zero
      // (0^r)' = 0
      case Expression.BINARY(exp1 = exp1, operator = operator, exp2 = exp2)
        guard((Operator.getMathClassification(operator) == NFOperator.MathClassification.POWER) and
              Expression.isZero(exp1))
      then (Expression.makeZero(operator.ty), diffArguments);

      // Power (POW, POW_EW, ...) general case
      case Expression.BINARY(exp1 = exp1, operator = operator, exp2 = exp2)
        guard((Operator.getMathClassification(operator) == NFOperator.MathClassification.POWER))
        algorithm
          (_, sizeClass) := Operator.classify(operator);
          addOp := Operator.fromClassification((NFOperator.MathClassification.ADDITION, sizeClass), operator.ty);
          current_grad := diffArguments.current_grad; // upstream gradient

          diffArguments.current_grad := Expression.MULTARY({current_grad, exp2, Expression.BINARY(exp1, operator, minusOne(exp2, addOp))}, {}, makeMulFromOperator(operator));
          (diffExp1, diffArguments) := differentiateExpression(exp1, diffArguments);

          diffArguments.current_grad := Expression.MULTARY({current_grad, exp, expLog(exp1)}, {}, makeMulFromOperator(operator));
          (diffExp2, diffArguments) := differentiateExpression(exp2, diffArguments);

          diffArguments.current_grad := current_grad;
          diffExp1 := SimplifyExp.simplifyDump(diffExp1, true, getInstanceName());
          diffExp2 := SimplifyExp.simplifyDump(diffExp2, true, getInstanceName());
          mulOp := Operator.fromClassification((NFOperator.MathClassification.MULTIPLICATION, sizeClass), operator.ty);

          res := match (Expression.isZero(diffExp1), Expression.isZero(diffExp2))
            // Power (POW, POW_EW, ...) with constant exponent and constant base
            // (r1^r2)' = 0
            case (true, true) then Expression.makeZero(operator.ty);
            // Power (POW, POW_EW, ...) with constant exponent
            // (x^r)' = r*(x^(r-1))*x'
            case (false, true) then Expression.MULTARY({exp2, Expression.BINARY(exp1, operator, minusOne(exp2, addOp)), diffExp1}, {}, mulOp);
            // Power (POW, POW_EW, ...) with constant base
            // (r^x)'  = r^x*ln(r)*x'
            case (true, false) then Expression.MULTARY({exp, expLog(exp1), diffExp2}, {}, mulOp);
            // Power (POW, POW_EW, ...) regular case
            // (x^y)' = x^(y-1) * (x*ln(x)*y'+(y*x'))
            else algorithm
              // x^(y-1)
              e1 := Expression.BINARY(exp1, operator, minusOne(exp2, addOp));
              // x * ln(x) * y'
              e2 := Expression.MULTARY({exp1, expLog(exp1), diffExp2}, {}, mulOp);
              // y * x'
              e3 := Expression.MULTARY({exp2, diffExp1}, {}, mulOp);
            then Expression.MULTARY({e1, Expression.MULTARY({e2, e3}, {}, addOp)}, {}, mulOp);
          end match;
      then (res, diffArguments);

      // Logical and Comparing operators => just return as is
      case Expression.BINARY(operator = operator)
        guard((Operator.getMathClassification(operator) == NFOperator.MathClassification.LOGICAL) or
              (Operator.getMathClassification(operator) == NFOperator.MathClassification.RELATION))
      then (exp, diffArguments);

      else algorithm
        // maybe add failtrace here and allow failing
        Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp)});
      then fail();

    end match;
    // simplify?
  end differentiateBinary;

  function differentiateMultary
    "Differentiates a multary expression. Expression.MULTARY()
    Note: these can only contain commutative operators"
    input output Expression exp "Has to be Expression.MULTARY()";
    input output DifferentiationArguments diffArguments;
  protected
    Boolean isReverse = Util.isSome(diffArguments.adjoint_map);
  algorithm
    if Flags.isSet(Flags.DEBUG_ADJOINT) then
      print("differentiateMultary: " + Expression.toString(exp) + "\n");
    end if;
    exp := match exp
      local
        Expression diff_arg, divisor, diff_enumerator, diff_divisor;
        list<Expression> arguments, new_arguments = {};
        list<Expression> inv_arguments, new_inv_arguments = {};
        list<Expression> diff_arguments, diff_inv_arguments;
        Operator operator, addOp, powOp, mulOp, mulEWOp, addEWOp;
        Operator.SizeClassification sizeClass, powSizeClass;
        Expression current_grad, upstream, e_over_f, term, e_over_g, numProd, denomProd;
        List<Expression> add_terms, sub_terms, arg_rest, arg_products;
        Boolean hasArray, hasArrayNum;
        Expression local_grad, localUpF, localUpG;
        Integer i;
        Type powTy;

      // Dash calculations (ADD, SUB, ADD_EW, SUB_EW, ...)
      // NOTE: Multary always contains ADDITION
      // (sum(f_i))' = sum(f_i')
      // e.g. (f + g + h - p - q)' = f' + g' + h' - p' - q'
      // Reverse-mode note:
      //  - If an argument is scalar but at least one other argument is an array,
      //    its local upstream must be sum-reduced to a scalar before recursion.
      case Expression.MULTARY(arguments = arguments, inv_arguments = inv_arguments, operator = operator)
        guard(Operator.getMathClassification(operator) == NFOperator.MathClassification.ADDITION)
        algorithm
          if isReverse then
            // Detect if any term is an array (for mixed scalar/array broadcasting)
            hasArray := List.any(arguments, Expression.hasArrayType) or List.any(inv_arguments, Expression.hasArrayType);
          end if;
          // go over addition arguments
          for arg in listReverse(arguments) loop
            if isReverse then
              current_grad := diffArguments.current_grad;
              // For scalar arg in mixed case: sum-reduce upstream to scalar
              if Expression.isScalar(arg) and hasArray then
                diffArguments.current_grad := typeSumCall(current_grad);
              else
                diffArguments.current_grad := current_grad;
              end if;
            end if;

            (diff_arg, diffArguments) := differentiateExpression(arg, diffArguments);

            if isReverse then
              diffArguments.current_grad := current_grad;
            else
              new_arguments := diff_arg :: new_arguments;
            end if;
          end for;
          // go over subtraction arguments
          for arg in listReverse(inv_arguments) loop
            if isReverse then
              current_grad := diffArguments.current_grad;

              local_grad := Expression.negate(current_grad);
              if Expression.isScalar(arg) and hasArray then
                local_grad := typeSumCall(local_grad);
              end if;
              diffArguments.current_grad := local_grad;
            end if;

            (diff_arg, diffArguments) := differentiateExpression(arg, diffArguments);

            if isReverse then
              diffArguments.current_grad := current_grad;
            else
              new_inv_arguments := diff_arg :: new_inv_arguments;
            end if;
          end for;
      then Expression.MULTARY(new_arguments, new_inv_arguments, operator);

      // Dot calculations (MUL, DIV, MUL_EW, DIV_EW, ...)
      // NOTE: Multary always contains MULTIPLICATION
      // no inverse arguments so single product rule:
      // prod(f_i)) = sum((f_i)' * prod(f_k | k <> i))
      // e.g. (fgh)' = f'gh + fg'h + fgh'
      case Expression.MULTARY(arguments = arguments, inv_arguments = {}, operator = operator)
        guard(Operator.getMathClassification(operator) == NFOperator.MathClassification.MULTIPLICATION)
        algorithm
          // create addition operator
          sizeClass := Operator.classifyAddition(operator);
          addOp := Operator.fromClassification((NFOperator.MathClassification.ADDITION, sizeClass), operator.ty);
          // the adjoint is handled inside here
          (new_arguments, diffArguments) := differentiateMultaryMultiplicationArgs(arguments, diffArguments, operator);
      then Expression.MULTARY(new_arguments, {}, addOp);

      // Dot calculations (MUL, DIV, MUL_EW, DIV_EW, ...)
      // NOTE: Multary always contains MULTIPLICATION
      // (prod(f_i)) / prod(g_j))'
      // makes use of single product rule:
      // prod(f_i)) = sum((f_i)' * prod(f_k | k <> i))
      // e.g. (abc)' = a'bc + ab'c + abc'
      // and binary division rule
      // (f / g)' = (f'g - g'f) / g^2
      // this is implemented like so:
      // E = (prod arguments) / (prod inv_arguments)
      // dE = Σ_i f_i' * (E / f_i) - Σ_j g_j' * (E / g_j)
      // Reverse mode local grads (used via current_grad):
      //   for f_i: G_i = G * (E / f_i)
      //   for g_j: G_j = -G * (E / g_j)
      // Reverse assumptions:
      //  - Broadcasting only happens in the numerator.
      //  - All denominators are scalar.
      //  - If the numerator is an array then the division by the denominator is elementwise.
      // Sum reduction is needed:
      //  - For scalar f_i in numerator if any other numerator factor is an array.
      //  - For denominator g_j (scalar) if numerator is an array.
      case Expression.MULTARY(arguments = arguments, inv_arguments = inv_arguments, operator = operator)
        guard(Operator.getMathClassification(operator) == NFOperator.MathClassification.MULTIPLICATION
              and (not listEmpty(inv_arguments)) and isReverse)
        algorithm
          (_, sizeClass) := Operator.classify(operator);
          // Determine operators
          addOp := Operator.fromClassification(
            (NFOperator.MathClassification.ADDITION, sizeClass),
            operator.ty);
          mulOp := makeMulFromOperator(operator);

          // Use element-wise mul for reverse local upstream assembly to avoid array*scalar miscodegen
          // when upstream and partial products are arrays.
          // We keep forward terms using mulOp as before.
          mulEWOp := Operator.fromClassification(
            (NFOperator.MathClassification.MULTIPLICATION, NFOperator.SizeClassification.ELEMENT_WISE),
            operator.ty);
          addEWOp := Operator.fromClassification(
            (NFOperator.MathClassification.ADDITION, NFOperator.SizeClassification.ELEMENT_WISE),
            operator.ty);

          // Does the numerator contain any arrays?
          hasArrayNum := List.any(arguments, Expression.hasArrayType);

          numProd := Expression.MULTARY(arguments, {}, operator);
          denomProd := Expression.MULTARY(inv_arguments, {}, operator);

          // Forward derivative term accumulator
          upstream := diffArguments.current_grad;
          // Differentiate numerator factors
          i := 1;
          for f in arguments loop
            // Remove first occurrence of f from numerator list using List.deleteMemberOnTrue
            // this may be an issue if f occurs multiple times
            arg_rest := listDelete(arguments, i);
            e_over_f := Expression.MULTARY(arg_rest, {denomProd}, operator);

            // Reverse local upstream for f: G_f = upstream .* (exp / f)
            localUpF := Expression.MULTARY({upstream, e_over_f}, {}, mulEWOp);

            // If f is scalar but numerator has arrays -> sum-reduce to scalar
            if Expression.isScalar(f) and hasArrayNum then
              localUpF := typeSumCall(localUpF);
            end if;

            // Recurse into f with G_f
            diffArguments.current_grad := localUpF;
            (diff_arg, diffArguments) := differentiateExpression(f, diffArguments);

            // Forward term: f' * (exp / f)
            term := Expression.MULTARY({diff_arg, e_over_f}, {}, mulEWOp);
            i := i + 1;
          end for;

          sub_terms := {};
          // Differentiate denominator factors
          i := 1;
          powSizeClass := if Expression.hasArrayType(listHead(inv_arguments)) then NFOperator.SizeClassification.ARRAY_SCALAR else NFOperator.SizeClassification.SCALAR;
          powOp := Operator.fromClassification((NFOperator.MathClassification.POWER, powSizeClass), Type.REAL());
          for g in inv_arguments loop
            arg_rest := listDelete(inv_arguments, i);
            // exp / g : add one more g to denominator list
            e_over_g := Expression.MULTARY({numProd}, g :: inv_arguments, operator);

            // Reverse local upstream for g: G_g = - upstream .* (exp / g)
            localUpG := Expression.negate(Expression.MULTARY({upstream, e_over_g}, {}, mulEWOp));

            // If numerator has arrays -> sum-reduce scalar denominator upstream
            if hasArrayNum then
              localUpG := typeSumCall(localUpG);
            end if;

            diffArguments.current_grad := localUpG;
            (diff_arg, diffArguments) := differentiateExpression(g, diffArguments);

            // Forward term: - g' * (exp / g)
            term := Expression.negate(Expression.MULTARY({diff_arg, e_over_g}, {}, mulEWOp));
            i := i + 1;
          end for;
          // Restore upstream gradient
          diffArguments.current_grad := upstream;
          then (Expression.END());

      case Expression.MULTARY(arguments = arguments, inv_arguments = inv_arguments, operator = operator)
        guard(Operator.getMathClassification(operator) == NFOperator.MathClassification.MULTIPLICATION
              and (not listEmpty(inv_arguments)))
        algorithm
          // the frontend treats multiplication equally for elementwise and non-elementwise, but pow needs to have the correct operator
          if not listEmpty(inv_arguments) and Type.isArray(Expression.typeOf(listHead(inv_arguments))) then
            powSizeClass := NFOperator.SizeClassification.ARRAY_SCALAR;
            powTy := operator.ty;
          else
            powSizeClass := NFOperator.SizeClassification.SCALAR;
            powTy := Type.REAL();
          end if;

          // check if the addition size class has to be element wise
          if not listEmpty(arguments) and Type.isArray(Expression.typeOf(listHead(arguments))) then
            sizeClass := NFOperator.SizeClassification.ELEMENT_WISE;
          else
            (_, sizeClass) := Operator.classify(operator);
          end if;

          addOp := Operator.fromClassification((NFOperator.MathClassification.ADDITION, sizeClass), operator.ty);
          powOp := Operator.fromClassification((NFOperator.MathClassification.POWER, powSizeClass), powTy);
          // f'
          (diff_arguments, diffArguments) := differentiateMultaryMultiplicationArgs(arguments, diffArguments, operator);
          diff_enumerator := Expression.MULTARY(diff_arguments, {}, addOp);
          // g'
          (diff_inv_arguments, diffArguments) := differentiateMultaryMultiplicationArgs(inv_arguments, diffArguments, operator);
          diff_divisor := Expression.MULTARY(diff_inv_arguments, {}, addOp);
          // g
          divisor := Expression.MULTARY(inv_arguments, {}, operator);
      then Expression.MULTARY(
              {Expression.MULTARY(
                {Expression.MULTARY(diff_enumerator :: inv_arguments, {}, operator)},   // f'g
                {Expression.MULTARY(diff_divisor :: arguments, {}, operator)},          // -g'f
                addOp
              )},
              {Expression.BINARY(divisor, powOp, Expression.REAL(2.0))},
              operator
           );

      else algorithm
        // maybe add failtrace here and allow failing
        Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp)});
      then fail();
    end match;
  end differentiateMultary;

  function differentiateMultaryMultiplicationArgs
    "prod_i(f_i)' = sum_i((f_i)' * prod(f_k | k <> i))
    e.g. (fgh)' = f'gh + fg'h + fgh'"
    input list<Expression> arguments;
    output list<Expression> new_arguments = {};
    input output DifferentiationArguments diffArguments;
    input Operator operator;
  protected
    Expression diff_arg, current_grad, localUp, restProd;
    Array<List<Expression>> diff_lists;
    List<Expression> arg_products, restArgs;
    Integer idx = 1;
    Boolean isReverse = Util.isSome(diffArguments.adjoint_map);
    Operator mulEWOp = Operator.fromClassification(
      (NFOperator.MathClassification.MULTIPLICATION, NFOperator.SizeClassification.ELEMENT_WISE),
      operator.ty);
  algorithm
    if isReverse then
      arg_products := Expression.productOfListExceptSelf(arguments, makeMulFromOperator(operator));
    else
      diff_lists := arrayCreate(listLength(arguments), {});
    end if;
    for arg in arguments loop
      if isReverse then
        current_grad := diffArguments.current_grad;

        // product of remaining factors (k <> i)
        restProd := listGet(arg_products, idx);

        // Build local upstream = current_grad .* restProd, but flatten if restProd is also a MULTARY product.
        restArgs := match restProd
          local
            Operator mOp;
            list<Expression> rA;
          case Expression.MULTARY(operator = mOp, arguments = rA)
            guard Operator.getMathClassification(mOp) == NFOperator.MathClassification.MULTIPLICATION
          then rA;
          else {restProd};
        end match;

        localUp := Expression.MULTARY(
          listAppend({current_grad}, restArgs),
          {},
          mulEWOp); // may need to adapt this aswell to scalar when scalar

        // If current argument is scalar but the rest-product is array-shaped,
        // sum-reduce the local upstream to a scalar before recursing.
        if Expression.isScalar(arg) and Expression.hasArrayType(restProd) then
          localUp := typeSumCall(localUp);
        end if;
        diffArguments.current_grad := localUp;
      end if;

      (diff_arg, diffArguments) := differentiateExpression(arg, diffArguments);

      if isReverse then
        diffArguments.current_grad := current_grad;
      else
        for i in 1:arrayLength(diff_lists) loop
          diff_lists[i] := if i == idx then diff_arg :: diff_lists[i] else arg :: diff_lists[i];
        end for;
      end if;
      idx := idx + 1;
    end for;
    if not isReverse then
      for i in arrayLength(diff_lists):-1:1 loop
        new_arguments := Expression.MULTARY(listReverse(diff_lists[i]), {}, operator) :: new_arguments;
      end for;
    end if;
  end differentiateMultaryMultiplicationArgs;

  function differentiateEquationAttributes
    "Differentiates the residual variable for diffType JACOBIAN, if it exists.
    The cref has to be saved in the diff_map for this to work.
    ToDo: needs to be adapted for torn/inner equations"
    input output EquationAttributes attr;
    input DifferentiationArguments diffArguments;
  algorithm
    attr := match (attr, diffArguments)
      local
        Pointer<Variable> residualVar, diffedResidualVar;
        UnorderedMap<ComponentRef,ComponentRef> diff_map;

      case (EquationAttributes.EQUATION_ATTRIBUTES(residualVar = SOME(residualVar)),
         DIFFERENTIATION_ARGUMENTS(diff_map = SOME(diff_map), diffType = DifferentiationType.JACOBIAN))
        guard(UnorderedMap.contains(BVariable.getVarName(residualVar), diff_map))
        algorithm
          diffedResidualVar := BVariable.getVarPointer(UnorderedMap.getOrFail(BVariable.getVarName(residualVar), diff_map), sourceInfo());
          attr.residualVar := SOME(diffedResidualVar);
      then attr;

      else attr;

    end match;
  end differentiateEquationAttributes;

  function differentiateBinding
    input output Binding binding;
    input output DifferentiationArguments diffArgs;
  protected
    Option<Expression> opt_exp;
    Expression exp;
  algorithm
    opt_exp := Binding.getExpOpt(binding);
    if Util.isSome(opt_exp) then
      (exp, diffArgs) := differentiateExpression(Util.getOption(opt_exp), diffArgs);
      binding := Binding.setExp(exp, binding);
    end if;
  end differentiateBinding;

protected
  function minusOne
    input output Expression exp;
    input Operator op;
  algorithm
    exp := match exp
      local
        Real r;
        Integer i;
      case Expression.REAL(value = r)         then Expression.REAL(r - 1.0);
      case Expression.INTEGER(value = i)      then Expression.INTEGER(i - 1);
      else Expression.MULTARY({exp}, {Expression.makeOne(op.ty)}, op);
    end match;
  end minusOne;

  function expLog
    input output Expression exp;
  algorithm
    exp := match exp
      local
        Real r;
        Integer i;
      case Expression.REAL(value = r)     then Expression.REAL(log(r));
      case Expression.INTEGER(value = i)  then Expression.REAL(log(i));
      else Expression.CALL(Call.makeTypedCall(
        fn          = NFBuiltinFuncs.LOG_REAL,
        args        = {exp},
        variability = Expression.variability(exp),
        purity      = NFPrefixes.Purity.PURE
      ));
    end match;
  end expLog;

  function makeMulFromOperator
    input Operator operator;
    output Operator mulOp;
  algorithm
    mulOp := Operator.fromClassification((NFOperator.MathClassification.MULTIPLICATION, Operator.getSizeClassification(operator)), operator.ty);
  end makeMulFromOperator;

  function typeTransposeCall
    "Create a typed builtin transpose(mat) call without expanding mat.
     Returns mat if it is not an array with at least 2 dimensions."
    input Expression mat;
    output Expression tr;
  protected
    Type inTy = Expression.typeOf(mat);
    list<Type.Dimension> dims;
    Type elTy;
    Type resTy;
    NFCall call;
    NFPrefixes.Variability var = Expression.variability(mat);
    NFPrefixes.Purity pur = Expression.purity(mat);
  algorithm
    // Only handle array types
    if not Type.isArray(inTy) then
      tr := mat;
      return;
    end if;

    elTy := Type.arrayElementType(inTy);
    dims := Type.arrayDims(inTy);

    // Need at least 2 dimensions to transpose
    if listLength(dims) < 2 then
      tr := mat;
      return;
    end if;

    // Swap first two dimensions; keep the rest
    resTy := Type.ARRAY(
      elTy,
      listAppend({listGet(dims,2), listGet(dims,1)}, listRest(listRest(dims)))
    );

    call := NFCall.makeTypedCall(NFBuiltinFuncs.TRANSPOSE, {mat}, var, pur, resTy);
    tr := Expression.CALL(call);
  end typeTransposeCall;

    // Helper: build a typed builtin promote(A, n) call that appends (n - ndims(A)) singleton dims.
  function typePromoteCall
    input Expression arr;   // A (scalar or array)
    input Integer n;        // desired rank
    output Expression promoted;
  protected
    Type inTy = Expression.typeOf(arr);
    Type elTy;
    list<Type.Dimension> inDims;
    Integer m, k;
    list<Type.Dimension> ones = {};
    list<Type.Dimension> resDims;
    Type resTy;
    NFCall call;
    NFPrefixes.Variability var = Expression.variability(arr);
    NFPrefixes.Purity pur = Expression.purity(arr);
    NFFunction.Function PROMOTE_FUNC;
  algorithm
    elTy := if Type.isArray(inTy) then Type.arrayElementType(inTy) else inTy;
    inDims := if Type.isArray(inTy) then Type.arrayDims(inTy) else {};
    m := listLength(inDims);

    // Append singleton dims to the right until rank n
    for k in 1:max(0, n - m) loop
      ones := Dimension.fromInteger(1) :: ones;
    end for;
    resDims := List.append_reverse(ones, inDims);
    resTy := if n > 0 then Type.ARRAY(elTy, resDims) else elTy;

    call := NFCall.makeTypedCall(NFBuiltinFuncs.PROMOTE, {arr, Expression.INTEGER(n)}, var, pur, resTy);
    promoted := Expression.CALL(call);
  end typePromoteCall;


  function typeSumCall
    "
    Create a typed builtin sum(A) call without expanding A.
      Semantics:
        - If A is not an array => return A (defensive fallback).
        - If A is an array => return sum over all elements, resulting in a scalar of element type.
    "
    input Expression arr;
    output Expression s;
  protected
    Type inTy = Expression.typeOf(arr);
    list<Type.Dimension> dims;
    Type elTy;
    Type resTy;
    NFCall call;
    NFPrefixes.Variability var = Expression.variability(arr);
    NFPrefixes.Purity pur = Expression.purity(arr);
    NFFunction.Function SUM_FUNC;
  algorithm
    // Not an array: just return expression (sum(x) == x)
    if not Type.isArray(inTy) then
      s := arr;
      return;
    end if;

    elTy := Type.arrayElementType(inTy);
    dims := Type.arrayDims(inTy);
    resTy := elTy; // always reduce to scalar of element type

    call := NFCall.makeTypedCall(NFBuiltinFuncs.SUM, {arr}, var, pur, resTy);
    s := Expression.CALL(call);
  end typeSumCall;

  // Helper: build matrix * vector (or matrix * matrix) MULTARY with a proper mul operator
  function makeMul
    input Expression a;
    input Expression b;
    input Operator.SizeClassification sc;
    input Type ty;
    output Expression res;
  algorithm
    res := Expression.BINARY(
      a,
      Operator.fromClassification((NFOperator.MathClassification.MULTIPLICATION, sc), ty),
      b);
  end makeMul;

  // Drop the last array dimension by indexing it with 1:
  // arr[..., 1]. If arr is not an array, return it unchanged.
  function dropLastDimIndex1
    input Expression arr;
    output Expression res;
  protected
    Type ty = Expression.typeOf(arr);
    list<Type.Dimension> dims;
    Integer m, i;
    list<Subscript> subs = {};
  algorithm
    if not Type.isArray(ty) then
      res := arr; return;
    end if;

    dims := Type.arrayDims(ty);
    m := listLength(dims);
    if m <= 0 then
      res := arr; return;
    end if;

    // Build subscripts: WHOLE for first m-1 dims, INDEX(1) for last
    for i in 1:(m-1) loop
      subs := Subscript.WHOLE() :: subs;
    end for;
    subs := Subscript.INDEX(Expression.INTEGER(1)) :: subs;
    subs := listReverse(subs);

    res := Expression.applySubscripts(subs, arr, true);
  end dropLastDimIndex1;

  // Build vector[n] with elements A[i,i], i=1..n (literal array).
  function extractDiagonalVector
    input Expression A;     // matrix
    input Integer n;
    input Type vecTy;       // vector[n] type
    output Expression v;
  protected
    list<Expression> elems = {};
    Integer i;
  algorithm
    for i in 1:n loop
      elems := Expression.applySubscripts(
        { Subscript.INDEX(Expression.INTEGER(i)), Subscript.INDEX(Expression.INTEGER(i)) },
        A, true) :: elems;
    end for;
    v := Expression.ARRAY(vecTy, listArray(listReverse(elems)), false);
  end extractDiagonalVector;

  function dbg
    input String s;
  algorithm
    if Flags.isSet(Flags.DEBUG_ADJOINT) then
      print(s + "\n");
    end if;
  end dbg;

  function updateAdjointList
    input Option<list<Expression>> oldOpt;
    input Expression current_grad;
    output list<Expression> newList;
  protected
    list<Expression> oldList;
  algorithm
    newList := match oldOpt
      // probably the only case since empty list is used to initialize
      case SOME(oldList) then (current_grad :: oldList);
      else {current_grad};
    end match;
  end updateAdjointList;

  // Build a 1D one-hot array of the same type as derBaseCref:
  // zeros(n) with value placed at index idx.
  function buildOneHotVectorAdjoint
    input ComponentRef derBaseCref;
    input Integer idx;                // 1-based
    input Expression value;           // scalar element to place
    output Option<Expression> onehot; // NONE if sizes unknown or not vector
  protected
    Type arrTy;
    list<Type.Dimension> dims;
    list<Integer> sizes;
    Integer n, i;
    Type elTy;
    list<Expression> elems = {};
  algorithm
    // Array type of the pDER base cref
    arrTy := ComponentRef.getSubscriptedType(derBaseCref);
    if not Type.isArray(arrTy) then
      onehot := NONE(); return;
    end if;

    dims := Type.arrayDims(arrTy);
    if not List.hasOneElement(dims) then
      // Only handle simple vectors here
      onehot := NONE(); return;
    end if;

    sizes := NFDimension.sizes(dims);
    if listEmpty(sizes) then
      onehot := NONE(); return;
    end if;

    n := listHead(sizes);
    elTy := Type.arrayElementType(arrTy);

    // Build [0,0,...,value,...,0]
    for i in 1:n loop
      elems := (if i == idx then value else Expression.makeZero(elTy)) :: elems;
    end for;

    onehot := SOME(Expression.ARRAY(
      arrTy,
      listArray(listReverse(elems)),
      false
    ));
  end buildOneHotVectorAdjoint;

  // Build a multi-hot scatter vector for a SLICE subscript:
  // result = sum_t [onehot(idx_t) * seed_elem_t]
  // Handles:
  //   - WHOLE()                     -> returns seed
  //   - SLICE {i1,i2,...}           -> sum of one-hots; indices must be literal integers
  //   - SLICE range lo[:st]:hi      -> sum over lo, lo+st, ..., hi; lo,st,hi must be literal integers
  function buildMultiHotVectorAdjoint
    input ComponentRef derBaseCref;
    input Subscript sub;        // SLICE or WHOLE
    input Expression seed;      // upstream gradient for the sliced view (scalar or vector)
    output Option<Expression> scatter; // NONE() if not handled
  protected
    Type arrTy;
    Type elTy;
    Operator addOp;
    Boolean seedIsArray;
    Integer m, j, loI, hiI, stI;
    array<Expression> elems = arrayCreate(0, Expression.INTEGER(0));
    Option<Expression> accOpt;
    Option<Expression> ohOpt;
    Expression acc, seedElem, term;
    list<Expression> idxElems;
  algorithm
    arrTy := ComponentRef.getSubscriptedType(derBaseCref);
    elTy  := Type.arrayElementType(arrTy);
    addOp := Operator.fromClassification(
      (NFOperator.MathClassification.ADDITION, NFOperator.SizeClassification.ELEMENT_WISE),
      elTy
    );
    seedIsArray := Type.isArray(Expression.typeOf(seed));

    scatter := match sub
      // case Subscript.SLICE(slice = Expression.ARRAY(elements = elems))
      //     algorithm
      //       m := arrayLength(elems);
      //       if m == 0 then
      //         scatter := SOME(Expression.makeZero(arrTy)); return;
      //       end if;

      //       acc := Expression.makeZero(arrTy);

      //       for j in 1:m loop
      //         // slice index must be a literal integer
      //         if match elems[j] case Expression.INTEGER() then true else false end match then
      //           // pick element seed[j] if seed is a vector, else reuse scalar seed
      //           seedElem := if seedIsArray
      //             then Expression.applySubscripts({Subscript.INDEX(Expression.INTEGER(j))}, seed, true)
      //             else seed;

      //           ohOpt := buildOneHotVectorAdjoint(derBaseCref, Expression.toInteger(elems[j]), seedElem);
      //           if Util.isSome(ohOpt) then
      //             acc := Expression.MULTARY({acc, Util.getOption(ohOpt)}, {}, addOp);
      //           else
      //             scatter := NONE(); return;
      //           end if;
      //         else
      //           scatter := NONE(); return;
      //         end if;
      //       end for;

      //       scatter := SOME(acc);
      //     then scatter;
      // SLICE with range lo:hi (unit step)
      case Subscript.SLICE(slice = Expression.RANGE(
          start = Expression.INTEGER(loI),
          step  = NONE(),
          stop  = Expression.INTEGER(hiI)))
        algorithm
          if hiI < loI then
            scatter := SOME(Expression.makeZero(arrTy)); return;
          end if;

          accOpt := NONE();
          m := hiI - loI + 1;
          for j in 0:(m-1) loop
            ohOpt := buildOneHotVectorAdjoint(
              derBaseCref,
              loI + j,
              if seedIsArray
                then Expression.applySubscripts({Subscript.INDEX(Expression.INTEGER(j+1))}, seed, true)
                else seed
            );
            if Util.isSome(ohOpt) then
              if Util.isSome(accOpt) then
                acc := Util.getOption(accOpt);
                term := Util.getOption(ohOpt);
                accOpt := SOME(Expression.MULTARY({acc, term}, {}, addOp));
              else
                accOpt := ohOpt;
              end if;
            else
              scatter := NONE(); return;
            end if;
          end for;

          scatter := if Util.isSome(accOpt) then accOpt else SOME(Expression.makeZero(arrTy));
        then scatter;

      else NONE();
    end match;
  end buildMultiHotVectorAdjoint;

  annotation(__OpenModelica_Interface="backend");
end NBDifferentiate;
