In [None]:
using PGFPlots
pushPGFPlotsPreamble("\\usepgfplotslibrary{fillbetween}")
pushPGFPlotsPreamble("\\definecolor{pastelMagenta}{HTML}{FF48CF}")
pushPGFPlotsPreamble("\\definecolor{pastelPurple}{HTML}{8770FE}")
pushPGFPlotsPreamble("\\definecolor{pastelBlue}{HTML}{1BA1EA}")
pushPGFPlotsPreamble("\\definecolor{pastelSeaGreen}{HTML}{14B57F}")
pushPGFPlotsPreamble("\\definecolor{pastelGreen}{HTML}{3EAA0D}")
pushPGFPlotsPreamble("\\definecolor{pastelOrange}{HTML}{C38D09}")
pushPGFPlotsPreamble("\\definecolor{pastelRed}{HTML}{F5615C}")
pushPGFPlotsPreamble("\\definecolor{darkColor}{HTML}{300A24}")
pushPGFPlotsPreamble("\\tikzset{myarrow/.style={line width = 0.05cm, ->, rounded corners=5mm}}")
pushPGFPlotsPreamble("\\usepackage{amsmath}")
using Test
using Random
using ColorSchemes
pasteljet = ColorMaps.RGBArrayMap(ColorSchemes.viridis, interpolation_levels=500, invert=true);

In [None]:
include("cartpole_actor_critic.jl")

ùí´ = CART_POLE

œÄ = get_stochastic_policy()
U = get_value_function()

n_iters = 500

M = ActorCritic(ùí´, 1000, 128, 1.0, 3.0, 1e-3, 1e-3)
mem = ActorCriticMemory(M, œÄ, U)

opt_œÄ = Adam(0.01)
opt_U = Adam(0.01)

datas = Vector{ActorCriticData}(undef, n_iters)

params_œÄ = Flux.params(œÄ)
params_U = Flux.params(U)

# Keep track of the parameters (of both nets) of the best-performing policy.
# Performance is based on reward Œº-œÉ.
best_score = -Inf
params_œÄ_best = deepcopy(params_œÄ)
params_U_best = deepcopy(params_U)

t_last_print = time()
t_start = time()
for i in 1:n_iters
    
    t_now = time()
    if t_now - t_last_print > 10.0
        println("$i / $(n_iters): $(round(t_now - t_start, digits=3)) sec")
        sleep(0.001)
        t_last_print = t_now
    end
    
    ‚àáœÄ, ‚àáU, data = run_epoch(M, œÄ, U, mem)
    datas[i] = data
    
    # Determine whether to keep best
    score = data.r_togo_Œº1 - data.r_togo_œÉ1
    if score > best_score
        println("$i / $(n_iters): Storing new network with score $(round(score, digits=3)), max depth reached = $(round(mean(data.max_depth_reached), digits=2))")
        sleep(0.001)
        best_score = score
        for (a, b) in zip(params_œÄ_best, params_œÄ)
            a[:] = b
        end
        for (a, b) in zip(params_U_best, params_U)
            a[:] = b
        end
    end

    for (param, grad) in zip(params_œÄ, ‚àáœÄ)
        grad += M.w_reg_œÄ * param
        Flux.update!(opt_œÄ, param, grad)
    end
    for (param, grad) in zip(params_U, ‚àáU)
        grad += M.w_reg_U * param
        Flux.update!(opt_U, param, grad)
    end
end
println("$n_iters / $(n_iters): $(round(time() - t_start, digits=3)) sec")

In [None]:
epochs = collect(1:length(datas))
max_depth_reached_Œº = [mean(data.max_depth_reached) for data in datas]
max_depth_reached_œÉ = [std(data.max_depth_reached) for data in datas]

Axis(
    [
        Plots.Linear(epochs, max_depth_reached_Œº + max_depth_reached_œÉ, style="solid, thick, pastelBlue!40, mark=none"),
        Plots.Linear(epochs, max_depth_reached_Œº,                       style="solid, thick, pastelBlue,    mark=none"),
        Plots.Linear(epochs, max_depth_reached_Œº - max_depth_reached_œÉ, style="solid, thick, pastelBlue!40, mark=none"),
    ],
    xlabel="epoch", ylabel="max depth reached", width="15cm", height="8cm", style="enlarge x limits=0"
)

In [None]:
r_togo_Œº1 = [data.r_togo_Œº1 for data in datas]
r_togo_œÉ1 = [data.r_togo_œÉ1 for data in datas]

Axis(
    [
        Plots.Linear(epochs, r_togo_Œº1 + r_togo_œÉ1, style="solid, thick, pastelSeaGreen!40, mark=none"),
        Plots.Linear(epochs, r_togo_Œº1,             style="solid, thick, pastelSeaGreen,    mark=none"),
        Plots.Linear(epochs, r_togo_Œº1 - r_togo_œÉ1, style="solid, thick, pastelSeaGreen!40, mark=none"),
    ],
    xlabel="epoch", ylabel="reward to go", width="15cm", height="8cm", style="enlarge x limits=0"
)

In [None]:
‚àá_norm_œÄ = [data.‚àá_norm_œÄ for data in datas]
‚àá_norm_U = [data.‚àá_norm_U for data in datas]

Axis(
    [
        Plots.Linear(epochs, ‚àá_norm_œÄ, style="solid, thick, pastelBlue, mark=none", legendentry=L"\|\nabla\pi\|_2"),
        Plots.Linear(epochs, ‚àá_norm_U, style="solid, thick, pastelRed,  mark=none", legendentry=L"\|\nabla U\|_2"),
    ],
    xlabel="epoch", ylabel="gradient norm", width="15cm", height="8cm", style="enlarge x limits=0"
)

In [None]:
w_norm_œÄ = [data.w_norm_œÄ for data in datas]
w_norm_U = [data.w_norm_U for data in datas]

Axis(
    [
        Plots.Linear(epochs, w_norm_œÄ, style="solid, thick, pastelPurple, mark=none", legendentry=L"\|\pi\|_2"),
        Plots.Linear(epochs, w_norm_U, style="solid, thick, pastelOrange,  mark=none", legendentry=L"\|U\|_2"),
    ],
    xlabel="epoch", ylabel="network norm", width="15cm", height="8cm", style="enlarge x limits=0, legend pos=outer north east"
)

In [None]:
r_togo_Œº1s = [data.r_togo_Œº1 for data in datas]
r_togo_œÉ1s = [data.r_togo_œÉ1 for data in datas]

Axis(
    [
        Plots.Linear(epochs, r_togo_Œº1s, style="solid, thick, pastelGreen,      mark=none", legendentry=L"\mu"),
        Plots.Linear(epochs, r_togo_œÉ1s, style="solid, thick, pastelSeaGreen,  mark=none", legendentry=L"\sigma"),
    ],
    xlabel="epoch", ylabel="reward-to-go normalization values", width="15cm", height="8cm", style="enlarge x limits=0"
)

In [None]:
advantage_Œº = [mean(data.advantage) for data in datas]
advantage_œÉ = [std(data.advantage) for data in datas]

Axis(
    [
        Plots.Linear(epochs, advantage_Œº + advantage_œÉ, style="solid, thick, pastelPurple!40, mark=none"),
        Plots.Linear(epochs, advantage_Œº,               style="solid, thick, pastelPurple,    mark=none"),
        Plots.Linear(epochs, advantage_Œº - advantage_œÉ, style="solid, thick, pastelPurple!40, mark=none"),
    ],
    xlabel="epoch", ylabel="advantage", width="15cm", height="8cm", style="enlarge x limits=0"
)

In [None]:
Axis(
    [
        Plots.Linear(epochs, [d.t_elapsed for d in datas], style="solid, thick, pastelRed, mark=none"),
    ],
    xlabel="epoch", ylabel="calcuation time", width="15cm", height="8cm", style="enlarge x limits=0"
)

In [None]:
using Interact

max_depth_reached = rollout(M.ùí´, œÄ, U, mem, M.d)

frames = [draw_state(M.ùí´, s) for s in mem.states]
@manipulate for frame_index in 1:max_depth_reached+1
    frames[frame_index]
end

In [None]:
render_states_to_gif(M.ùí´, mem.states[1:max_depth_reached], "cartpole_varying_spec", 50)