package aima.probability.reasoning;

import aima.probability.RandomVariable;
import java.util.Arrays;
import java.util.List;

/* loaded from: input_file:aima/probability/reasoning/HiddenMarkovModel.class */
public class HiddenMarkovModel {
    SensorModel sensorModel;
    TransitionModel transitionModel;
    private RandomVariable priorDistribution;

    public HiddenMarkovModel(RandomVariable randomVariable, TransitionModel transitionModel, SensorModel sensorModel) {
        this.priorDistribution = randomVariable;
        this.transitionModel = transitionModel;
        this.sensorModel = sensorModel;
    }

    public RandomVariable prior() {
        return this.priorDistribution;
    }

    public RandomVariable predict(RandomVariable randomVariable, String str) {
        RandomVariable duplicate = randomVariable.duplicate();
        duplicate.updateFrom(this.transitionModel.asMatrix(str).transpose().times(randomVariable.asMatrix()));
        return duplicate;
    }

    public RandomVariable perceptionUpdate(RandomVariable randomVariable, String str) {
        RandomVariable duplicate = randomVariable.duplicate();
        duplicate.updateFrom(this.sensorModel.asMatrix(str).times(randomVariable.asMatrix()));
        duplicate.normalize();
        return duplicate;
    }

    public RandomVariable forward(RandomVariable randomVariable, String str, String str2) {
        return perceptionUpdate(predict(randomVariable, str), str2);
    }

    public RandomVariable forward(RandomVariable randomVariable, String str) {
        return forward(randomVariable, HmmConstants.DO_NOTHING, str);
    }

    public RandomVariable calculate_next_backward_message(RandomVariable randomVariable, RandomVariable randomVariable2, String str) {
        RandomVariable duplicate = randomVariable2.duplicate();
        duplicate.updateFrom(this.transitionModel.asMatrix().times(this.sensorModel.asMatrix(str).times(randomVariable2.asMatrix())).arrayTimes(randomVariable.asMatrix()));
        duplicate.normalize();
        return duplicate;
    }

    public List<RandomVariable> forward_backward(List<String> list) {
        RandomVariable[] randomVariableArr = new RandomVariable[list.size() + 1];
        RandomVariable createUnitBelief = this.priorDistribution.createUnitBelief();
        RandomVariable[] randomVariableArr2 = new RandomVariable[list.size() + 1];
        randomVariableArr[0] = this.priorDistribution;
        randomVariableArr2[0] = null;
        for (int i = 0; i < list.size(); i++) {
            randomVariableArr[i + 1] = forward(randomVariableArr[i], list.get(i));
        }
        for (int size = list.size(); size > 0; size--) {
            RandomVariable duplicate = this.priorDistribution.duplicate();
            duplicate.updateFrom(randomVariableArr[size].asMatrix().arrayTimes(createUnitBelief.asMatrix()));
            duplicate.normalize();
            randomVariableArr2[size] = duplicate;
            createUnitBelief = calculate_next_backward_message(randomVariableArr[size], createUnitBelief, list.get(size - 1));
        }
        return Arrays.asList(randomVariableArr2);
    }

    public SensorModel sensorModel() {
        return this.sensorModel;
    }

    public TransitionModel transitionModel() {
        return this.transitionModel;
    }
}
