/**
 * Copyright (c) 2021, 2026 Contributors to the Eclipse Foundation
 *
 * This program and the accompanying materials are made
 * available under the terms of the Eclipse Public License 2.0
 * which is available at https://www.eclipse.org/legal/epl-2.0/
 *
 * SPDX-License-Identifier: EPL-2.0
 */
package org.eclipse.lsat.timing.util

import activity.Move
import activity.SimpleAction
import distributions.CalculationMode
import distributions.Distribution
import distributions.DistributionsAdapter
import expressions.BigDecimalConstant
import expressions.Expression
import java.math.BigDecimal
import java.math.MathContext
import java.math.RoundingMode
import java.util.Collection
import java.util.IdentityHashMap
import java.util.Map
import machine.ActionType
import machine.IResource
import machine.Peripheral
import machine.SetPoint
import org.apache.commons.math3.random.JDKRandomGenerator
import org.apache.commons.math3.random.RandomGenerator
import org.eclipse.lsat.common.graph.directed.editable.Node
import org.eclipse.lsat.common.xtend.annotations.IntermediateProperty
import org.eclipse.lsat.motioncalculator.MotionException
import org.eclipse.lsat.motioncalculator.util.MotionSegmentUtilities
import org.eclipse.lsat.timing.calculator.MotionCalculatorExtension
import setting.PhysicalSettings
import setting.Settings
import timing.Array
import timing.Scalar

import static extension org.eclipse.lsat.common.util.IterableUtil.*

class TimingCalculator implements ITimingCalculator {
    static val MICRO_SECONDS = new MathContext(6, RoundingMode.HALF_UP)
    static val HALF_MICRO_SECOND = 0.5 * Math.pow(10, -MICRO_SECONDS.precision)

    @IntermediateProperty(Array)
    val Integer pointer = -1

    val RandomGenerator distributionRandom = new JDKRandomGenerator(1618033989)
    val MotionCalculatorHelper motionCalculatorHelper;
    val CalculationMode mode
    val boolean synchronizeAxes
    val Map<Node, BigDecimal> nodeTimes
    val Map<Move, Map<SetPoint, BigDecimal>> motionTimes
    val DistributionsAdapter distributions
    val Map<Node, Map<String, MotionData>> motionData

    new(Settings settings, MotionCalculatorExtension motionCalculator) {
        this(settings, motionCalculator, CalculationMode.MEAN, true, true)
    }

    new(Settings settings, MotionCalculatorExtension motionCalculator, CalculationMode mode) {
        this(settings, motionCalculator, mode, true, true)
    }

    new(Settings settings, MotionCalculatorExtension motionCalculator, CalculationMode mode, boolean synchronizeAxes,
        boolean useCache) {
        this.motionCalculatorHelper = new MotionCalculatorHelper(settings, motionCalculator);
        this.mode = mode
        this.synchronizeAxes = synchronizeAxes

        nodeTimes = if(useCache && mode == CalculationMode.MEAN) new IdentityHashMap()
        // Motion does not support distribution/linear yet and calculations are 'expensive', 
        // thus cache is independent of calculation mode
        motionTimes = if(useCache) new IdentityHashMap()
        distributions = DistributionsAdapter.getAdapter(settings);
        motionData = if (useCache) new IdentityHashMap()
        distributions.calculationMode = mode
        distributions.randomGenerator = distributionRandom
    }

    def Settings getSettings() {
        return motionCalculatorHelper.settings
    }

    def MotionCalculatorExtension getMotionCalculator() {
        return motionCalculatorHelper.motionCalculator
    }

    override void reset() {
        _IntermediateProperty_pointer.clear
        if(nodeTimes !== null) nodeTimes.clear
        if(motionTimes !== null) motionTimes.clear
        DistributionsAdapter.resetDistributions(settings.eResource?.resourceSet)
        if (motionData !== null) motionData.clear
    }

    override BigDecimal calculateDuration(Node node) throws MotionException {
        if (null === nodeTimes)
            node.doCalculateDuration
        else
            nodeTimes.computeIfAbsent(node)[doCalculateDuration];
    }

    protected def dispatch BigDecimal doCalculateDuration(Node node) {
        BigDecimal.ZERO
    }

    protected def dispatch BigDecimal doCalculateDuration(SimpleAction action) {
        action.type.doCalculateValue(action.resource, action.peripheral)
    }

    protected def BigDecimal doCalculateValue(ActionType actionType, IResource resource, Peripheral peripheral) {
        getPhysicalSettings(resource, peripheral).timingSettings.get(actionType).doCalculateValue
    }

    protected def dispatch BigDecimal doCalculateValue(Scalar scalar) {
        doCalculateValue(scalar.valueExp)
    }

    protected def dispatch BigDecimal doCalculateValue(Array array) {
        switch (mode) {
            case LINEAIR: {
                array.pointer = (array.pointer + 1) % array.values.size
                doCalculateValue(array.valuesExp.get(array.pointer))
            }
            default : array.valuesExp.map[doCalculateValue].average
        }
    }

    protected def dispatch BigDecimal doCalculateValue(Distribution distribution) {
        distribution.evaluate.normalize
    }

    protected def dispatch BigDecimal doCalculateValue(Expression expression) {
        return expression.evaluate
    }

    protected def dispatch BigDecimal doCalculateDuration(Move move) {
        move.calculateMotionTime.values.max(BigDecimal.ZERO)
    }

    override Map<SetPoint, BigDecimal> calculateMotionTime(Move move) throws MotionException {
        var Map<SetPoint, BigDecimal> motionTime; 
        if(null === motionTimes) {
            motionTime = move.doCalculateMotionTimes.get(move)
        }
        else {
            if (!motionTimes.containsKey(move)) {
                motionTimes.putAll(move.doCalculateMotionTimes)
            }
            motionTime =  motionTimes.get(move);
        }
        val profileName= move.profile.name;
        val motionTimeCopy = newLinkedHashMap()
        val physicalSettings = motionCalculatorHelper.settings.getPhysicalSettings(move.resource,move.peripheral)
        var adjustments = physicalSettings.moveAdjustments.findFirst[it.profile!==null && it.profile.name==profileName]
        if( adjustments === null ){
            adjustments = physicalSettings.moveAdjustments.findFirst[it.profile===null]
        }

        if (adjustments !== null) {
            //the conversion key holds the Expression that should be filled with the calculated value.
            //this expression might be used as a nested expression in the value expression.
            val bdConst =  adjustments.timeDeclaration.expression as BigDecimalConstant
            try {
                for (element : motionTime.entrySet) {
                    //fill the expression (which may be used in the evaluate with the calculated value)
                  	bdConst.value = element.value
                  	//calculate the new value given the set calculated value
                  	motionTimeCopy.put(element.key, adjustments.adjustExpression.evaluate)
                }
            }
            finally {
                //reset the expression value to it's default.
                bdConst.value = BigDecimal.ZERO;
            }
            return motionTimeCopy;
        }
        return motionTime;
    }

    override Map<String, MotionData> calculateMotionData(Node node) throws MotionException {
        if (null === motionData) return node.doCalculateMotionData.get(node)
        
        if (!motionData.containsKey(node)) {
            motionData.putAll(node.doCalculateMotionData)
        }
        return motionData.get(node);
    }
    
    /**
     * Compute the duration of each segment in the (concatenated) move
     */
    protected def Map<Move, ? extends Map<SetPoint, BigDecimal>> doCalculateMotionTimes(Move move) {
        val axesAndSetPoints = if (synchronizeAxes) {
                newHashMap(move.peripheral.type.axes -> move.peripheral.type.setPoints)
            } else {
                move.peripheral.type.setPoints.groupBy[axes];
            }

        val concatenatedMove = motionCalculatorHelper.getConcatenatedMove(move)
        val result = new IdentityHashMap(concatenatedMove.size)
        for (e : axesAndSetPoints.entrySet) {
            val motionSegments = motionCalculatorHelper.createMotionSegments(concatenatedMove, e.key)
            val motionTimes = motionCalculatorHelper.motionCalculator.calculateTimes(motionSegments)
            var startTime = BigDecimal.ZERO
            for (var i = 0; i < motionTimes.size(); i++) {
                val endTime = motionTimes.get(i).normalize
                result.computeIfAbsent(concatenatedMove.get(i)) [
                    new IdentityHashMap(move.peripheral.type.setPoints.size)
                ].fill(e.value, endTime.subtract(startTime))
                startTime = endTime
            }
        }
        return result
    }

    protected def dispatch Map<? extends Node, ? extends Map<String, MotionData>> doCalculateMotionData(Node node) {
        val result = new IdentityHashMap()
        result.put(node, new IdentityHashMap())
        return result
    }

    /**
     * Compute the motion data (position, velocity, acceleration, etc.) over time for the given move.
     */
    protected def dispatch Map<? extends Node, ? extends Map<String, MotionData>> doCalculateMotionData(Move move) {
        val axesAndSetPoints = if (synchronizeAxes) {
            newHashMap(move.peripheral.type.axes -> move.peripheral.type.setPoints)
        } else {
            move.peripheral.type.setPoints.groupBy[axes];
        }
        val concatenatedMove = motionCalculatorHelper.getConcatenatedMove(move)
        val result = new IdentityHashMap()
        for (e : axesAndSetPoints.entrySet) {
            val motionSegments = motionCalculatorHelper.createMotionSegments(concatenatedMove, e.key)
            val motionProfile = MotionSegmentUtilities.getMotionProfiles(motionSegments).iterator().next()
            val parameterNames = motionProfile.getParameters().map[getName].toList
            for (positionInfo : motionCalculatorHelper.motionCalculator.getPositionInfo(motionSegments)) {
                val allData = positionInfo.getAllData()
                var startTime = BigDecimal.ZERO
                for (moveSegment : concatenatedMove) {
                    val startTimeD = startTime.doubleValue
                    val endTimeD = (startTime + calculateMotionTime(moveSegment).filter[key, value | key.name == positionInfo.setPointId].values.get(0)).doubleValue
                    val filteredData = allData
                            .filter[array | array.get(0) >= (startTimeD - HALF_MICRO_SECOND) && array.get(0) <= (endTimeD + HALF_MICRO_SECOND)]
                            .map[array | array.set(0, array.get(0) - startTimeD) array]
                            .toList
                    var motionData = new MotionData()
                    motionData.setTimeData(filteredData)
                    motionData.setParameterNames((asList("Time", "Position") + parameterNames).toList)
                    result.computeIfAbsent(moveSegment) [
                        new IdentityHashMap()
                    ].put(positionInfo.setPointId, motionData)
                    startTime += calculateMotionTime(moveSegment).filter[key, value | key.name == positionInfo.setPointId].values.get(0)
                }
            }
        }
        return result
    }
    
    protected def PhysicalSettings getPhysicalSettings(IResource resource, Peripheral peripheral) throws SpecificationException {
        val physicalSettings = settings.getPhysicalSettings(resource, peripheral)
        if (null === physicalSettings) {
            throw new SpecificationException('''Physical settings not specified for peripheral: «peripheral.fqn»''',
                settings)
        }
        return physicalSettings
    }

    /**
     * Fills the map with the specified value for all specified keys.
     */
    static def <K, V> fill(Map<K, V> map, Collection<? extends K> keys, V value) {
        keys.forEach[map.put(it, value)]
    }

    static def BigDecimal getAverage(Collection<BigDecimal> values) {
        var sum = values.head
        for (value : values.tail) {
            sum += value
        }
        return sum.divide(new BigDecimal(values.size), MICRO_SECONDS)
    }

    static def BigDecimal normalize(Number number) {
        val result = BigDecimal.valueOf(number.doubleValue)
        return result.max(BigDecimal.ZERO).setScale(MICRO_SECONDS.precision, MICRO_SECONDS.roundingMode)
    }
}
