| | using System; |
| | using System.Collections; |
| | using System.Collections.Generic; |
| | using System.Collections.ObjectModel; |
| | using System.Linq; |
| | using UnityEngine; |
| |
|
| | [Serializable] |
| | public enum Action |
| | { |
| | Up, |
| | Down, |
| | Left, |
| | Right, |
| | None |
| | } |
| | [Serializable] |
| | public class Agent : MonoBehaviour |
| | { |
| | #region Fields |
| | [SerializeField] |
| | private int _step; |
| | [SerializeField] |
| | private int _iteration; |
| | [SerializeField] |
| | private int _currentGridX; |
| | [SerializeField] |
| | private int _currentGridY; |
| | [SerializeField] |
| | private (int,int)? _previousState = null; |
| | [SerializeField] |
| | private Action? _previousAction = null; |
| | [SerializeField] |
| | private float? _previousReward = null; |
| | [SerializeField] |
| | private GUIController _gUIController; |
| | [SerializeField] |
| | [Range(0f, 1f)] |
| | private float _learningRate; |
| | [SerializeField] |
| | [Range(0f, 1f)] |
| | private float _discountingFactor; |
| | |
| | [SerializeField] |
| | private int _mimumumStateActionPairFrequencies; |
| | [SerializeField] |
| | private float _estimatedBestPossibleRewardValue; |
| | [SerializeField] |
| | private Coroutine _waitThenActionCoroutine; |
| | [SerializeField] |
| | private bool _isPause; |
| | [SerializeField] |
| | [Range(0.001f, 30f)] |
| | private float _restTime; |
| | [SerializeField] |
| | private GameObject _roadBlock; |
| | [SerializeField] |
| | private GameObject _Goodies; |
| |
|
| | public int Step { get => _step; set => _step = value; } |
| | public int Iteration { get => _iteration; set => _iteration = value; } |
| | public int CurrentGridX { get => _currentGridX; set => _currentGridX = value; } |
| | public int CurrentGridY { get => _currentGridY; set => _currentGridY = value; } |
| | public (int, int)? PreviousState { get => _previousState; set => _previousState = value; } |
| | public Action? PreviousAction { get => _previousAction; set => _previousAction = value; } |
| | public float? PreviousReward { get => _previousReward; set => _previousReward = value; } |
| | public GUIController GUIController { get => _gUIController; set => _gUIController = value; } |
| | public float LearningRate { get => _learningRate; set => _learningRate = value; } |
| | public float DiscountingFactor { get => _discountingFactor; set => _discountingFactor = value; } |
| | public int MimumumStateActionPairFrequencies { get => _mimumumStateActionPairFrequencies; set => _mimumumStateActionPairFrequencies = value; } |
| | public float EstimatedBestPossibleRewardValue { get => _estimatedBestPossibleRewardValue; set => _estimatedBestPossibleRewardValue = value; } |
| | public Coroutine WaitThenActionCoroutine { get => _waitThenActionCoroutine; set => _waitThenActionCoroutine = value; } |
| | public bool IsPause { get => _isPause; set => _isPause = value; } |
| | public float RestTime { get => _restTime; set => _restTime = value; } |
| | public GameObject RoadBlock { get => _roadBlock; set => _roadBlock = value; } |
| | public GameObject Goodies { get => _Goodies; set => _Goodies = value; } |
| |
|
| | public (int,int) StartState; |
| | public (int,int) FinalState = (7,9); |
| | |
| | public int StartX; |
| | public int StartY; |
| | |
| | public int GrizSizeX; |
| | public int GrizSizeY; |
| |
|
| | public Dictionary<((int,int),Action),float> StateActionPairQValue { get; set; } |
| | |
| | public Dictionary<(int, int), float> StateRewardGrid { get; set; } |
| | public Dictionary<Action, System.Action> ActionDelegatesDictonary { get; set; } |
| | #endregion |
| |
|
| | #region Q_Learning_Agent |
| | private Action Q_Learning_Agent((int,int) currentState, float rewardSignal) |
| | { |
| | UpdateStep(); |
| | if (PreviousState == FinalState) |
| | { |
| | StateActionPairQValue[(PreviousState.Value, Action.None)] = rewardSignal; |
| | } |
| |
|
| | if (PreviousState.HasValue) |
| | { |
| | ((int, int), Action) stateActionPair = (PreviousState.Value, PreviousAction.Value); |
| | |
| | |
| | |
| | |
| | |
| |
|
| | StateActionPairQValue[stateActionPair] += LearningRate * (PreviousReward.Value + (DiscountingFactor * MaxStateActionPairQValue(ref currentState)) - StateActionPairQValue[stateActionPair]); |
| | } |
| | PreviousState = currentState; |
| | PreviousAction = ArgMaxActionExploration(ref currentState); |
| | PreviousReward = rewardSignal; |
| | return PreviousAction.Value; |
| | } |
| |
|
| | |
| | private float MaxStateActionPairQValue(ref (int, int) currentState) |
| | { |
| | if (currentState == FinalState) |
| | return StateActionPairQValue[(currentState, Action.None)]; |
| |
|
| | float max = float.NegativeInfinity; |
| |
|
| | foreach (Action action in SuffledActions()) |
| | { |
| | max = Mathf.Max(StateActionPairQValue[(currentState, action)], max); |
| | } |
| | return max; |
| | } |
| |
|
| | private static Action[] SuffledActions() |
| | { |
| | Action[] actions = new Action[4]; |
| | int i = 0; |
| | foreach (Action action in Enum.GetValues(typeof(Action))) |
| | { |
| | if (action != Action.None) |
| | { |
| | actions[i] = action; |
| | i++; |
| | } |
| | } |
| | System.Random random = new System.Random(); |
| | return actions.OrderBy(_ => random.Next()).ToArray(); |
| | } |
| | #region Conflicts with the wall check and out of bound check |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | #endregion |
| | private Action ArgMaxActionExploration(ref (int, int) currentState) |
| | { |
| | if (currentState == FinalState) |
| | return Action.None; |
| |
|
| | Action argMaxAction = Action.None; |
| | float max = float.NegativeInfinity; |
| |
|
| | foreach (Action action in SuffledActions()) |
| | { |
| | float value = StateActionPairQValue[(currentState, action)]; |
| | if (value >= max) |
| | { |
| | max = value; |
| | argMaxAction = action; |
| | } |
| | } |
| | return argMaxAction; |
| | } |
| | private void Left() |
| | { |
| | transform.position -= new Vector3(1f, 0f, 0f); |
| | CurrentGridX--; |
| | WaitThenActionCoroutine = StartCoroutine(WaitThenAction(RestTime, (CurrentGridX, CurrentGridY))); |
| | } |
| |
|
| | private void Right() |
| | { |
| | transform.position += new Vector3(1f, 0f, 0f); |
| | CurrentGridX++; |
| | WaitThenActionCoroutine = StartCoroutine(WaitThenAction(RestTime, (CurrentGridX, CurrentGridY))); |
| | } |
| |
|
| | private void Up() |
| | { |
| | transform.position += new Vector3(0f, 0f, 1f); |
| | CurrentGridY++; |
| | WaitThenActionCoroutine = StartCoroutine(WaitThenAction(RestTime, (CurrentGridX, CurrentGridY))); |
| | } |
| |
|
| | private void Down() |
| | { |
| | transform.position -= new Vector3(0f, 0f, 1f); |
| | CurrentGridY--; |
| | WaitThenActionCoroutine = StartCoroutine(WaitThenAction(RestTime, (CurrentGridX, CurrentGridY))); |
| | } |
| |
|
| | private void None() |
| | { |
| | ResetAgentToStart(); |
| | UpdateIteration(); |
| | WaitThenActionCoroutine = StartCoroutine(WaitThenAction(RestTime, (CurrentGridX, CurrentGridY))); |
| | } |
| |
|
| | private void ResetAgentToStart() |
| | { |
| | transform.position = new Vector3(StartState.Item1, 1f, StartState.Item2); |
| | CurrentGridX = StartState.Item1; |
| | CurrentGridY = StartState.Item2; |
| | Grid.instance.ClearColors(); |
| | } |
| |
|
| | private IEnumerator WaitThenAction(float waitTime, (int,int) GridCoordinate) |
| | { |
| | |
| | while(IsPause) |
| | { |
| | yield return null; |
| | } |
| | yield return new WaitForSeconds(waitTime); |
| | ActionDelegatesDictonary[Q_Learning_Agent(GridCoordinate, StateRewardGrid[GridCoordinate])](); |
| | } |
| | #endregion |
| |
|
| | #region Unity |
| | private void Start() |
| | { |
| | FinalState = Grid.instance.goalPosition; |
| | |
| | ActionDelegatesDictonary = new Dictionary<Action, System.Action>(); |
| | ActionDelegatesDictonary[Action.Left] = Left; |
| | ActionDelegatesDictonary[Action.Right] = Right; |
| | ActionDelegatesDictonary[Action.Up] = Up; |
| | ActionDelegatesDictonary[Action.Down] = Down; |
| | ActionDelegatesDictonary[Action.None] = None; |
| | StartX = UnityEngine.Random.Range(0, GrizSizeX); |
| | StartY = UnityEngine.Random.Range(0, GrizSizeY); |
| | Initialized(); |
| | } |
| |
|
| | private void Initialized() |
| | { |
| | PreviousAction = null; |
| | PreviousReward = null; |
| | PreviousState = null; |
| | Step = 0; |
| | Iteration = 0; |
| | transform.position = new Vector3(StartX, 1f, StartY); |
| | StartState = (StartX, StartY); |
| | CurrentGridX = StartState.Item1; |
| | CurrentGridY = StartState.Item2; |
| | StateActionPairQValue = new Dictionary<((int, int), Action), float>(); |
| | |
| | StateRewardGrid = new Dictionary<(int, int), float>(); |
| |
|
| | for (int i = 0; i < GrizSizeX; i++) |
| | { |
| | for (int j = 0; j < GrizSizeY; j++) |
| | { |
| | foreach (Action action in Enum.GetValues(typeof(Action))) |
| | { |
| | StateActionPairQValue[((i, j), action)] = 0; |
| | |
| | } |
| | StateRewardGrid[(i, j)] = 0f; |
| | } |
| | } |
| | StateRewardGrid[FinalState] = 100f; |
| |
|
| | for (int i = 0; i < GrizSizeX; i++) |
| | { |
| | for (int j = 0; j < GrizSizeY; j++) |
| | { |
| | if (i != StartState.Item1 && i != FinalState.Item1 && j != StartState.Item2 && j != FinalState.Item2) |
| | { |
| | float random = UnityEngine.Random.Range(0f, 1f); |
| | if (random <= 0.3f) |
| | { |
| | if (random <= 0.2f) |
| | { |
| | Instantiate(RoadBlock, new Vector3(i, 0.5f, j), Quaternion.identity); |
| | if (i + 1 < GrizSizeX) |
| | { |
| | StateActionPairQValue[((i + 1, j), Action.Left)] = float.NegativeInfinity; |
| | } |
| | if (i - 1 >= 0) |
| | { |
| | StateActionPairQValue[((i - 1, j), Action.Right)] = float.NegativeInfinity; |
| | } |
| | if (j + 1 < GrizSizeY) |
| | { |
| | StateActionPairQValue[((i, j + 1), Action.Down)] = float.NegativeInfinity; |
| | } |
| | if (j - 1 >= 0) |
| | { |
| | StateActionPairQValue[((i, j - 1), Action.Up)] = float.NegativeInfinity; |
| | } |
| | } |
| | |
| | |
| | |
| | |
| | |
| | } |
| | } |
| | if (i == 0 || j == 0 || i == GrizSizeX-1 || j == GrizSizeY-1) |
| | { |
| | StateRewardGrid[(i, j)] = 0f; |
| | |
| | if(i == 0) |
| | { |
| | StateActionPairQValue[((i, j), Action.Left)] = float.NegativeInfinity; |
| | } |
| | if(j == 0) |
| | { |
| | StateActionPairQValue[((i, j), Action.Down)] = float.NegativeInfinity; |
| | } |
| | if(i == GrizSizeX-1) |
| | { |
| | StateActionPairQValue[((i, j), Action.Right)] = float.NegativeInfinity; |
| | } |
| | if(j == GrizSizeY-1) |
| | { |
| | StateActionPairQValue[((i, j), Action.Up)] = float.NegativeInfinity; |
| | } |
| | } |
| | } |
| | } |
| | } |
| | private void ReInitialized() |
| | { |
| | PreviousAction = null; |
| | PreviousReward = null; |
| | PreviousState = null; |
| | Step = 0; |
| | Iteration = 0; |
| | transform.position = new Vector3(StartX, 1f, StartY); |
| | StartState = (StartX, StartY); |
| | CurrentGridX = StartState.Item1; |
| | CurrentGridY = StartState.Item2; |
| | |
| |
|
| | for (int i = 0; i < GrizSizeX; i++) |
| | { |
| | for (int j = 0; j < GrizSizeY; j++) |
| | { |
| | foreach (Action action in Enum.GetValues(typeof(Action))) |
| | { |
| | if(!(StateActionPairQValue.ContainsKey(((i, j), action)) && StateActionPairQValue[((i, j), action)] == float.NegativeInfinity)) |
| | { |
| | StateActionPairQValue[((i, j), action)] = 0; |
| | |
| | } |
| | } |
| | } |
| | } |
| | } |
| | private void Update() |
| | { |
| | Grid.instance.UpdateColor(CurrentGridX, CurrentGridY); |
| | } |
| |
|
| | public void StartExploring() |
| | { |
| | UpdateIteration(); |
| | WaitThenActionCoroutine = StartCoroutine(WaitThenAction(1f, StartState)); |
| | } |
| |
|
| | public void Stop() |
| | { |
| | ReInitialized(); |
| | StopCoroutine(WaitThenActionCoroutine); |
| | } |
| |
|
| | private void UpdateStep() |
| | { |
| | Step++; |
| | GUIController?.UpdateStepText(Step.ToString()); |
| | } |
| |
|
| | private void UpdateIteration() |
| | { |
| | Iteration++; |
| | GUIController?.UpdateInterationText(Iteration.ToString()); |
| | } |
| | #endregion |
| | } |