/*
 * Decompiled with CFR 0.152.
 */
package aima.test.learningtest;

import aima.learning.reinforcement.PassiveADPAgent;
import aima.learning.reinforcement.PassiveTDAgent;
import aima.learning.reinforcement.QLearningAgent;
import aima.learning.reinforcement.QTable;
import aima.probability.decision.MDP;
import aima.probability.decision.MDPFactory;
import aima.probability.decision.MDPPerception;
import aima.probability.decision.MDPPolicy;
import aima.probability.decision.MDPUtilityFunction;
import aima.probability.decision.cellworld.CellWorldPosition;
import aima.test.probabilitytest.MockRandomizer;
import aima.util.Pair;
import java.util.Hashtable;
import junit.framework.TestCase;

public class ReinforcementLearningTest
extends TestCase {
    MDP<CellWorldPosition, String> fourByThree;
    MDPPolicy<CellWorldPosition, String> policy;

    @Override
    public void setUp() {
        this.fourByThree = MDPFactory.createFourByThreeMDP();
        this.policy = new MDPPolicy();
        this.policy.setAction(new CellWorldPosition(1, 1), "up");
        this.policy.setAction(new CellWorldPosition(1, 2), "left");
        this.policy.setAction(new CellWorldPosition(1, 3), "left");
        this.policy.setAction(new CellWorldPosition(1, 4), "left");
        this.policy.setAction(new CellWorldPosition(2, 1), "up");
        this.policy.setAction(new CellWorldPosition(2, 3), "up");
        this.policy.setAction(new CellWorldPosition(3, 1), "right");
        this.policy.setAction(new CellWorldPosition(3, 2), "right");
        this.policy.setAction(new CellWorldPosition(3, 3), "right");
    }

    public void testPassiveADPAgent() {
        PassiveADPAgent<CellWorldPosition, String> agent = new PassiveADPAgent<CellWorldPosition, String>(this.fourByThree, this.policy);
        MockRandomizer r = new MockRandomizer(new double[]{0.1, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4, 0.6, 0.5});
        MDPUtilityFunction<CellWorldPosition> uf = null;
        for (int i = 0; i < 100; ++i) {
            agent.executeTrial(r);
            uf = agent.getUtilityFunction();
        }
        ReinforcementLearningTest.assertEquals(0.676, uf.getUtility(new CellWorldPosition(1, 1)), 0.001);
        ReinforcementLearningTest.assertEquals(0.626, uf.getUtility(new CellWorldPosition(1, 2)), 0.001);
        ReinforcementLearningTest.assertEquals(0.573, uf.getUtility(new CellWorldPosition(1, 3)), 0.001);
        ReinforcementLearningTest.assertEquals(0.519, uf.getUtility(new CellWorldPosition(1, 4)), 0.001);
        ReinforcementLearningTest.assertEquals(0.746, uf.getUtility(new CellWorldPosition(2, 1)), 0.001);
        ReinforcementLearningTest.assertEquals(0.865, uf.getUtility(new CellWorldPosition(2, 3)), 0.001);
        ReinforcementLearningTest.assertEquals(0.796, uf.getUtility(new CellWorldPosition(3, 1)), 0.001);
        ReinforcementLearningTest.assertEquals(0.906, uf.getUtility(new CellWorldPosition(3, 3)), 0.001);
        ReinforcementLearningTest.assertEquals(1.0, uf.getUtility(new CellWorldPosition(3, 4)), 0.001);
    }

    public void testPassiveTDAgent() {
        PassiveTDAgent<CellWorldPosition, String> agent = new PassiveTDAgent<CellWorldPosition, String>(this.fourByThree, this.policy);
        MockRandomizer r = new MockRandomizer(new double[]{0.1, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4, 0.6, 0.5});
        MDPUtilityFunction<CellWorldPosition> uf = null;
        for (int i = 0; i < 200; ++i) {
            agent.executeTrial(r);
            uf = agent.getUtilityFunction();
        }
        ReinforcementLearningTest.assertEquals(0.662, uf.getUtility(new CellWorldPosition(1, 1)), 0.001);
        ReinforcementLearningTest.assertEquals(0.61, uf.getUtility(new CellWorldPosition(1, 2)), 0.001);
        ReinforcementLearningTest.assertEquals(0.553, uf.getUtility(new CellWorldPosition(1, 3)), 0.001);
        ReinforcementLearningTest.assertEquals(0.496, uf.getUtility(new CellWorldPosition(1, 4)), 0.001);
        ReinforcementLearningTest.assertEquals(0.735, uf.getUtility(new CellWorldPosition(2, 1)), 0.001);
        ReinforcementLearningTest.assertEquals(0.835, uf.getUtility(new CellWorldPosition(2, 3)), 0.001);
        ReinforcementLearningTest.assertEquals(0.789, uf.getUtility(new CellWorldPosition(3, 1)), 0.001);
        ReinforcementLearningTest.assertEquals(0.889, uf.getUtility(new CellWorldPosition(3, 3)), 0.001);
        ReinforcementLearningTest.assertEquals(1.0, uf.getUtility(new CellWorldPosition(3, 4)), 0.001);
    }

    public void xtestQLearningAgent() {
        QLearningAgent<CellWorldPosition, String> qla = new QLearningAgent<CellWorldPosition, String>(this.fourByThree);
        MockRandomizer r = new MockRandomizer(new double[]{0.1, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4, 0.6, 0.5});
        Hashtable<Pair<CellWorldPosition, String>, Double> q = null;
        QTable<CellWorldPosition, String> qTable = null;
        for (int i = 0; i < 100; ++i) {
            qla.executeTrial(r);
            q = qla.getQ();
            qTable = qla.getQTable();
        }
        System.out.println(qTable);
        System.out.println(qTable.getPolicy());
    }

    public void testFirstStepsOfQLAAgentUnderNormalProbability() {
        QLearningAgent<CellWorldPosition, String> qla = new QLearningAgent<CellWorldPosition, String>(this.fourByThree);
        MockRandomizer alwaysLessThanEightyPercent = new MockRandomizer(new double[]{0.7});
        CellWorldPosition startingPosition = new CellWorldPosition(1, 4);
        String action = qla.decideAction(new MDPPerception<CellWorldPosition>(startingPosition, -0.04));
        ReinforcementLearningTest.assertEquals("left", action);
        ReinforcementLearningTest.assertEquals(0.0, qla.getQTable().getQValue(startingPosition, action));
        qla.execute(action, alwaysLessThanEightyPercent);
        ReinforcementLearningTest.assertEquals(new CellWorldPosition(1, 3), qla.getCurrentState());
        ReinforcementLearningTest.assertEquals(-0.04, qla.getCurrentReward());
        ReinforcementLearningTest.assertEquals(0.0, qla.getQTable().getQValue(startingPosition, action));
        String action2 = qla.decideAction(new MDPPerception<CellWorldPosition>(new CellWorldPosition(1, 3), -0.04));
        ReinforcementLearningTest.assertEquals(-0.04, qla.getQTable().getQValue(startingPosition, action));
    }

    public void testFirstStepsOfQLAAgentWhenFirstStepTerminates() {
        QLearningAgent<CellWorldPosition, String> qla = new QLearningAgent<CellWorldPosition, String>(this.fourByThree);
        CellWorldPosition startingPosition = new CellWorldPosition(1, 4);
        String action = qla.decideAction(new MDPPerception<CellWorldPosition>(startingPosition, -0.04));
        ReinforcementLearningTest.assertEquals("left", action);
        MockRandomizer betweenEightyANdNinetyPercent = new MockRandomizer(new double[]{0.85});
        qla.execute(action, betweenEightyANdNinetyPercent);
        ReinforcementLearningTest.assertEquals(new CellWorldPosition(2, 4), qla.getCurrentState());
        ReinforcementLearningTest.assertEquals(-1.0, qla.getCurrentReward());
        ReinforcementLearningTest.assertEquals(0.0, qla.getQTable().getQValue(startingPosition, action));
        String action2 = qla.decideAction(new MDPPerception<CellWorldPosition>(new CellWorldPosition(2, 4), -1.0));
        ReinforcementLearningTest.assertNull(action2);
        ReinforcementLearningTest.assertEquals(-1.0, qla.getQTable().getQValue(startingPosition, action));
    }
}

