咪免直播高品质美女在线视频互动社区_咪免直播官方版_咪免直播直播视频在线观看免费版下载

您的位置:首頁 > 軟件教程 > 教程 > 強化學習筆記之【ACE:Off-PolicyActor-CriticwithCausality-AwareEntropyRegularization】

強化學習筆記之【ACE:Off-PolicyActor-CriticwithCausality-AwareEntropyRegularization】

來源:好特整理 | 時間:2024-10-18 09:46:01 | 閱讀:193 |  標簽: a T CTO AWA Ri rop Pyre S C ICY Causality AR   | 分享到:

2024年ICML文章,ACE:Off-PolicyActor-CriticwithCausality-AwareEntropyRegularization精讀

強化學習筆記之【ACE:Off-PolicyActor-CriticwithCausality-AwareEntropyRegularization】


前言:


該論文是清華項目組組內博士師兄寫的文章,項目主頁為 ACE (ace-rl.github.io) ,于2024年7月發(fā)表在ICML期刊

因為最近組內(其實只有我)需要從零開始做一個相關項目,前面的幾篇文章都是鋪墊

本文章為強化學習筆記第5篇

本文初編輯于2024.10.5,好像是這個時間,忘記了,前后寫了兩個多星期

CSDN主頁: https://blog.csdn.net/rvdgdsva

博客園主頁: https://www.cnblogs.com/hassle

博客園本文鏈接:


論文一覽

這篇強化學習論文主要介紹了一個名為 ACE 的算法,完整名稱為 Off-Policy Actor-Critic with Causality-Aware Entropy Regularization ,它通過引入因果關系分析和因果熵正則化來解決現(xiàn)有模型在不同動作維度上的不平等探索問題,旨在改進強化學習【注釋1】中探索效率和樣本效率的問題,特別是在高維度連續(xù)控制任務中的表現(xiàn)。

【注釋1】: 強化學習入門這一篇就夠了


論文摘要

在policy【注釋2】學習過程中,不同原始行為的不同意義被先前的model-free RL 算法所忽視。利用這一見解,我們探索了不同行動維度和獎勵之間的因果關系,以評估訓練過程中各種原始行為的重要性。我們引入了一個因果關系感知熵【注釋3】項(causality-aware entropy term),它可以有效地識別并優(yōu)先考慮具有高潛在影響的行為,以實現(xiàn)高效的探索。此外,為了防止過度關注特定的原始行為,我們分析了梯度休眠現(xiàn)象(gradientdormancyphenomenon),并引入了休眠引導的重置機制,以進一步增強我們方法的有效性。與無模型RL基線相比,我們提出的算法 ACE :Off-policy A ctor-criticwith C ausality-aware E ntropyregularization。在跨越7個域的29種不同連續(xù)控制任務中顯示出實質性的性能優(yōu)勢,這強調了我們方法的有效性、多功能性和高效的樣本效率。 基準測試結果和視頻可在https://ace-rl.github.io/上獲得。

【注釋2】: 強化學習算法中on-policy和off-policy

【注釋3】: 最大熵 RL:從Soft Q-Learning到SAC - 知乎


論文主要貢獻:

【1】 因果關系分析 :通過引入因果政策-獎勵結構模型,評估不同動作維度(即原始行為)對獎勵的影響大。ǚQ為“因果權重”)。這些權重反映了每個動作維度在不同學習階段的相對重要性。

作出上述改進的原因是:考慮一個簡單的例子,一個機械手最初應該學習放下手臂并抓住物體,然后將注意力轉移到學習手臂朝著最終目標的運動方向上。因此,在策略學習的不同階段強調對最重要的原始行為的探索是 至關重要的。在探索過程中刻意關注各種原始行為,可以加速智能體在每個階段對基本原始行為的學習,從而提高掌握完整運動任務的效率。

此處可供學習的資料:

【2】 因果熵正則化 :在最大熵強化學習框架的基礎上(如SAC算法),加入了 因果加權的熵正則化項 。與傳統(tǒng)熵正則化不同,這一項根據(jù)各個原始行為的因果權重動態(tài)調整,強化對重要行為的探索,減少對不重要行為的探索。

作出上述改進的原因是:論文引入了一個因果策略-獎勵結構模型來計算行動空間上的因果權重(causal weights),因果權重會引導agent進行更有效的探索, 鼓勵對因果權重較大的動作維度進行探索,表明對獎勵的重要性更大,并減少對因果權重較小的行為維度的探 索。一般的最大熵目標缺乏對不同學習階段原始行為之間區(qū)別的重要性的認識,可能導致低效的探索。為了解決這一限制,論文引入了一個由因果權重加權的策略熵作為因果關系感知的熵最大化目標,有效地加強了對重要原始行為的探索,并導致了更有效的探索。

此處可供學習的資料:

【3】 梯度“休眠”現(xiàn)象(Gradient Dormancy) :論文觀察到,模型訓練時有些梯度會在某些階段不活躍(即“休眠”)。為了防止模型過度關注某些原始行為,論文引入了 梯度休眠導向的重置機制 。該機制通過周期性地對模型進行擾動(reset),避免模型陷入局部最優(yōu),促進更廣泛的探索。

作出上述改進的原因是:該機制通過一個由梯度休眠程度決定的因素間歇性地干擾智能體的神經(jīng)網(wǎng)絡。將因果關系感知探索與這種新穎的重置機制相結合,旨在促進更高效、更有效的探索,最終提高智能體的整體性能。

通過在多個連續(xù)控制任務中的實驗,ACE 展示出了顯著優(yōu)于主流強化學習算法(如SAC、TD3)的表現(xiàn):

  • 29個不同的連續(xù)控制任務 :包括 Meta-World(12個任務)、DMControl(5個任務)、Dexterous Hand(3個任務)和其他稀疏獎勵任務(6個任務)。
  • 實驗結果 表明,ACE 在所有任務中都達到了更好的樣本效率和更高的最終性能。例如,在復雜的稀疏獎勵場景中,ACE 憑借其因果權重引導的探索策略,顯著超越了 SAC 和 TD3 等現(xiàn)有算法。

論文中的對比實驗圖表顯示了 ACE 在多種任務下的顯著優(yōu)勢,尤其是在 稀疏獎勵和高維度任務 中,ACE 憑借其探索效率的提升,能更快達到最優(yōu)策略。


論文代碼框架

在ACE原論文的第21頁,這玩意兒應該寫在正篇的,害的我看了好久的代碼去排流程

不過說實話這偽代碼有夠簡潔的,代碼多少有點糊成一坨了

這是一個強化學習(RL)算法的框架,具體是一個結合因果推斷(Causal Discovery)的離策略(Off-policy)Actor-Critic方法。下面是對每個模塊及其參數(shù)的說明:

1. 初始化模塊

  • Q網(wǎng)絡 ( \(Q_\phi\) ) :用于估計動作價值,(\phi) 是權重參數(shù)。
  • 策略網(wǎng)絡 ( $\pi_\theta $) :用于生成動作策略,(\theta) 是其權重。
  • 重放緩沖區(qū) ($ D$ ) :存儲環(huán)境交互的數(shù)據(jù),以便進行采樣。
  • 局部緩沖區(qū) ( $D_c $) :存儲因果發(fā)現(xiàn)所需的局部數(shù)據(jù)。
  • 因果權重矩陣 ($ B_{a \rightarrow r|s} $) :用于捕捉動作與獎勵之間的因果關系。
  • 擾動因子 ( \(f\) ) :用于對策略進行微小擾動,增加探索。

2. 因果發(fā)現(xiàn)模塊

  • 每 ( $$I$$ ) 步更新
    • 樣本采樣 :從局部緩沖區(qū) ( \(D_c\) ) 中抽樣 ( \(N_c\) ) 條轉移。
    • 更新因果權重矩陣 :調整 ($ B_{a \rightarrow r|s}$ ),用于反映當前策略和獎勵之間的因果關系。

3. 策略優(yōu)化模塊

  • 每個梯度步驟
    • 樣本采樣 :從重放緩沖區(qū) ( \(D\) ) 中抽樣 ($ N$ ) 條轉移。
    • 計算因果意識熵 ( \(H_c(\pi(\cdot|s))\) ) :衡量在給定狀態(tài)下策略的隨機性和確定性,用于修改策略。
    • 目標 Q 值計算 :更新目標 Q 值,用于訓練 Q 網(wǎng)絡。
    • 更新 Q 網(wǎng)絡 :減少預測的 Q 值與目標 Q 值之間的誤差。
    • 更新策略網(wǎng)絡 :最大化當前狀態(tài)下的 Q 值,以提高收益。

4. 重置機制模塊

  • 每個重置間隔
    • 計算梯度主導度 ( $\beta_\gamma $) :用來量化策略更新的影響程度。
    • 初始化隨機網(wǎng)絡 :為新的策略更新準備初始權重 ( $\phi_i $)。
    • 軟重置策略和 Q 網(wǎng)絡 :根據(jù)因果權重進行平滑更新,幫助實現(xiàn)更穩(wěn)定的優(yōu)化。
    • 重置策略和 Q 優(yōu)化器 :在重置時清空狀態(tài),以便進行新的學習過程。

論文源代碼主干

源代碼上千行呢,這里只是貼上main_casual里面的部分代碼,并且刪掉了很大一部分代碼以便理清程序脈絡

def train_loop(config, msg = "default"):
    # Agent
    agent = ACE_agent(env.observation_space.shape[0], env.action_space, config)

    memory = ReplayMemory(config.replay_size, config.seed)
    local_buffer = ReplayMemory(config.causal_sample_size, config.seed)

    for i_episode in itertools.count(1):
        done = False

        state = env.reset()
        while not done:
            if config.start_steps > total_numsteps:
                action = env.action_space.sample()  # Sample random action
            else:
                action = agent.select_action(state)  # Sample action from policy

            if len(memory) > config.batch_size:
                for i in range(config.updates_per_step):
                    #* Update parameters of causal weight
                    if (total_numsteps % config.causal_sample_interval == 0) and (len(local_buffer)>=config.causal_sample_size):
                        causal_weight, causal_computing_time = get_sa2r_weight(env, local_buffer, agent, sample_size=config.causal_sample_size, causal_method='DirectLiNGAM')
                        print("Current Causal Weight is: ",causal_weight)
                        
                    dormant_metrics = {}
                    # Update parameters of all the networks
                    critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha, q_sac, dormant_metrics = agent.update_parameters(memory, causal_weight,config.batch_size, updates)

                    updates += 1
            next_state, reward, done, info = env.step(action) # Step
            total_numsteps += 1
            episode_steps += 1
            episode_reward += reward

            #* Ignore the "done" signal if it comes from hitting the time horizon.
            if '_max_episode_steps' in dir(env):  
                mask = 1 if episode_steps == env._max_episode_steps else float(not done)
            elif 'max_path_length' in dir(env):
                mask = 1 if episode_steps == env.max_path_length else float(not done)
            else: 
                mask = 1 if episode_steps == 1000 else float(not done)

            memory.push(state, action, reward, next_state, mask) # Append transition to memory
            local_buffer.push(state, action, reward, next_state, mask) # Append transition to local_buffer
            state = next_state

        if total_numsteps > config.num_steps:
            break

        # test agent
        if i_episode % config.eval_interval == 0 and config.eval is True:
            eval_reward_list = []
            for _  in range(config.eval_episodes):
                state = env.reset()
                episode_reward = []
                done = False
                while not done:
                    action = agent.select_action(state, evaluate=True)
                    next_state, reward, done, info = env.step(action)
                    state = next_state
                    episode_reward.append(reward)
                eval_reward_list.append(sum(episode_reward))

            avg_reward = np.average(eval_reward_list)
          
    env.close() 

代碼流程解釋

  1. 初始化 :

    • 通過配置文件 config 設置環(huán)境和隨機種子。
    • 使用 ACE_agent 初始化強化學習智能體,該智能體會在后續(xù)過程中學習如何在環(huán)境中行動。
    • 創(chuàng)建存儲結果和檢查點的目錄,確保訓練過程中的配置和因果權重會被記錄下來。
    • 初始化了兩個重放緩沖區(qū): memory 用于存儲所有的歷史數(shù)據(jù), local_buffer 則用于因果權重的更新。
  2. 主訓練循環(huán) :

    • 采樣動作 :如果總步數(shù)較小,則從環(huán)境中隨機采樣動作,否則從策略中選擇動作。通過這種方式,確保早期探索和后期利用。

    • 更新因果權重 :在特定間隔內,從局部緩沖區(qū)中采樣數(shù)據(jù),通過 get_sa2r_weight 函數(shù)使用DirectLiNGAM算法計算從動作到獎勵的因果權重。這個權重會作為額外信息,幫助智能體優(yōu)化策略。

    • 更新網(wǎng)絡參數(shù) :當 memory 中的數(shù)據(jù)足夠多時,開始通過采樣更新Q網(wǎng)絡和策略網(wǎng)絡,使用計算出的因果權重來修正損失函數(shù)。

    • 記錄與保存模型 :每隔一定的步數(shù),算法會測試當前策略的性能,記錄并比較獎勵是否超過歷史最佳值,如果是,則保存模型的檢查點。

    • 使用 wandb 記錄訓練過程中的指標,例如損失函數(shù)、獎勵和因果權重的計算時間,這些信息可以幫助調試和分析訓練過程。


論文模塊代碼及實現(xiàn)

因果發(fā)現(xiàn)模塊

因果發(fā)現(xiàn)模塊 主要通過 get_sa2r_weight 函數(shù)實現(xiàn),并且與 DirectLiNGAM 模型結合,負責計算因果權重。具體代碼在訓練循環(huán)中如下:

causal_weight, causal_computing_time = get_sa2r_weight(env, local_buffer, agent, sample_size=config.causal_sample_size, causal_method='DirectLiNGAM')

在這個代碼段, get_sa2r_weight 函數(shù)會基于當前環(huán)境、樣本數(shù)據(jù)( local_buffer )和因果模型(這里使用的是 DirectLiNGAM ),計算與行動相關的因果權重( causal_weight )。這些權重會影響后續(xù)的策略優(yōu)化和參數(shù)更新。關鍵邏輯包括:

  1. 采樣間隔 :因果發(fā)現(xiàn)是在 total_numsteps % config.causal_sample_interval == 0 時觸發(fā),確保只在指定的步數(shù)間隔內計算因果權重,避免每一步都進行因果計算,減輕計算負擔。
  2. 局部緩沖區(qū) local_buffer 中存儲了足夠的樣本( config.causal_sample_size ),這些樣本用于因果關系的發(fā)現(xiàn)。
  3. 因果方法 DirectLiNGAM 是選擇的因果模型,用于從狀態(tài)、行動和獎勵之間推導出因果關系。

因果權重計算完成后,程序會將這些權重應用到策略優(yōu)化中,并且記錄權重及計算時間等信息。

def get_sa2r_weight(env, memory, agent, sample_size=5000, causal_method='DirectLiNGAM'):
    ······
    return weight, model._running_time

這個代碼的核心是利用DirectLiNGAM模型計算給定狀態(tài)、動作和獎勵之間的因果權重。接下來,用LaTeX公式詳細表述計算因果權重的過程:

  1. 數(shù)據(jù)預處理
    將從 memory 中采樣的 states (狀態(tài))、 actions (動作)和 rewards (獎勵)進行拼接,構建輸入數(shù)據(jù)矩陣 \(X_{\text{ori}}\)

    其中, \(S\) 代表狀態(tài), \(A\) 代表動作, \(R\) 代表獎勵。接著,構建數(shù)據(jù)框 \(X\) 來進行因果分析。

  2. 因果模型擬合

    X_ori 轉換為 X 是為了利用 pandas 數(shù)據(jù)框的便利性和靈活性

    使用 DirectLiNGAM 模型對矩陣 \(X\) 進行擬合,得到因果關系的鄰接矩陣 \(A_{\text{model}}\)

    該鄰接矩陣表示狀態(tài)、動作、獎勵之間的因果結構,特別是從動作到獎勵的影響關系。

  3. 提取動作對獎勵的因果權重
    通過鄰接矩陣提取動作對獎勵的因果權重 \(w_{\text{r}}\) ,該權重從鄰接矩陣的最后一行中選擇與動作對應的元素:

    其中, \(d_s\) 是狀態(tài)的維度, \(d_a\) 是動作的維度。

  4. 因果權重的歸一化
    對因果權重 \(w_{\text{r}}\) 進行Softmax歸一化,確保它們的總和為1:

  5. 調整權重的尺度
    最后,因果權重根據(jù)動作的數(shù)量進行縮放:

最終輸出的權重 \(w\) 表示每個動作對獎勵的因果影響,經(jīng)過歸一化和縮放處理,可以用于進一步的策略調整或分析。

策略優(yōu)化模塊

以下是對函數(shù)工作原理的逐步解釋:

策略優(yōu)化模塊 主要由 agent.update_parameters 函數(shù)實現(xiàn)。 agent.update_parameters 這個函數(shù)的主要目的是在強化學習中更新策略 ( policy ) 和價值網(wǎng)絡(critic)的參數(shù),以提升智能體的性能。這個函數(shù)實現(xiàn)了一個基于軟演員評論家(SAC, Soft Actor-Critic)的更新機制,并且加入了因果權重與"休眠"神經(jīng)元(dormant neurons)的處理,以提高模型的魯棒性和穩(wěn)定性。

critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha, q_sac, dormant_metrics = agent.update_parameters(memory, causal_weight, config.batch_size, updates)

通過 agent.update_parameters 函數(shù),程序會更新以下幾個部分:

  1. Critic網(wǎng)絡(價值網(wǎng)絡) critic_1_loss critic_2_loss 分別是兩個 Critic 網(wǎng)絡的損失,用于評估當前策略的價值。
  2. Policy網(wǎng)絡(策略網(wǎng)絡) policy_loss 表示策略網(wǎng)絡的損失,用于優(yōu)化 agent 的行動選擇。
  3. Entropy損失 ent_loss 用來調節(jié)策略的隨機性,幫助 agent 在探索和利用之間找到平衡。
  4. Alpha :表示自適應的熵系數(shù),用于調整探索與利用之間的權衡。

這些參數(shù)的更新在每次訓練循環(huán)中被調用,并使用 wandb.log 記錄損失和其他相關的訓練數(shù)據(jù)。

update_parameters ACE_agent 類中的一個關鍵函數(shù),用于根據(jù)經(jīng)驗回放緩沖區(qū)中的樣本數(shù)據(jù)來更新模型的參數(shù)。下面是對其工作原理的詳細解釋:

1. 采樣經(jīng)驗數(shù)據(jù)

首先,函數(shù)從 memory 中采樣一批樣本( state_batch 、 action_batch 、 reward_batch 、 next_state_batch 、 mask_batch ),其中包括狀態(tài)、動作、獎勵、下一個狀態(tài)以及掩碼,用于表示是否為終止狀態(tài)。

state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)
  • state_batch :當前的狀態(tài)。
  • action_batch :在當前狀態(tài)下執(zhí)行的動作。
  • reward_batch :執(zhí)行該動作后獲得的獎勵。
  • next_state_batch :執(zhí)行動作后到達的下一個狀態(tài)。
  • mask_batch :掩碼,用于表示是否為終止狀態(tài)(1 表示非終止,0 表示終止)。

2. 計算目標 Q 值

利用當前策略(policy)網(wǎng)絡,采樣下一個狀態(tài)的動作 next_state_action 和其對應的概率分布對數(shù) next_state_log_pi 。然后利用目標 Q 網(wǎng)絡 critic_target 估計下一時刻的最小 Q 值,并結合獎勵和折扣因子 \(\gamma\) 計算下一個 Q 值:

with torch.no_grad():
    next_state_action, next_state_log_pi, _ = self.policy.sample(next_state_batch, causal_weight)
    qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action)
    min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi
    next_q_value = reward_batch + mask_batch * self.gamma * (min_qf_next_target)
  • 通過策略網(wǎng)絡 self.policy 為下一個狀態(tài) next_state_batch 采樣動作 next_state_action 和相應的策略熵 next_state_log_pi

  • 使用目標 Q 網(wǎng)絡計算 qf1_next_target qf2_next_target ,并取兩者的最小值來減少估計偏差。

  • 最終使用貝爾曼方程計算 next_q_value ,即當前的獎勵加上折扣因子 \(\gamma\) 乘以下一個狀態(tài)的 Q 值。

  • 這里, \(\alpha\) 是熵項的權重,用于平衡探索和利用的權衡,而 mask_batch 是為了處理終止狀態(tài)的情況。

    使用無偏估計來計算目標 Q 值。通過目標網(wǎng)絡 ( critic_target ) 計算出下一個狀態(tài)和動作的 Q 值,并使用獎勵和掩碼更新當前 Q 值

3. 更新 Q 網(wǎng)絡

接著,使用當前 Q 網(wǎng)絡 critic 估計當前狀態(tài)和動作下的 Q 值 \(Q_1\) \(Q_2\) ,并計算它們與目標 Q 值的均方誤差損失:

最終 Q 網(wǎng)絡的總損失是兩個 Q 網(wǎng)絡損失之和:

然后,通過反向傳播 qf_loss 來更新 Q 網(wǎng)絡的參數(shù)。

qf1, qf2 = self.critic(state_batch, action_batch)
qf1_loss = F.mse_loss(qf1, next_q_value)
qf2_loss = F.mse_loss(qf2, next_q_value)
qf_loss = qf1_loss + qf2_loss

self.critic_optim.zero_grad()
qf_loss.backward()
self.critic_optim.step()
  • qf1 qf2 是兩個 Q 網(wǎng)絡的輸出,用于減少正向估計偏差。
  • 損失函數(shù)是 Q 值的均方誤差(MSE), qf1_loss qf2_loss 分別計算兩個 Q 網(wǎng)絡的誤差,最后將兩者相加為總的 Q 損失 qf_loss
  • 通過 self.critic_optim 優(yōu)化器對損失進行反向傳播和參數(shù)更新。

4. 策略網(wǎng)絡更新

每隔若干步(通過 target_update_interval 控制),開始更新策略網(wǎng)絡 policy 。首先,重新采樣當前狀態(tài)下的策略 \(\pi(a|s)\) ,并計算 Q 值和熵權重下的策略損失:

這個損失通過反向傳播更新策略網(wǎng)絡。

if updates % self.target_update_interval == 0:
    pi, log_pi, _ = self.policy.sample(state_batch, causal_weight)
    qf1_pi, qf2_pi = self.critic(state_batch, pi)
    min_qf_pi = torch.min(qf1_pi, qf2_pi)
    policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()

    self.policy_optim.zero_grad()
    policy_loss.backward()
    self.policy_optim.step()
  • 通過策略網(wǎng)絡對當前狀態(tài) state_batch 進行采樣,得到動作 pi 及其對應的策略熵 log_pi 。
  • 計算策略損失 policy_loss ,即 \(\alpha\) 倍的策略熵減去最小的 Q 值。
  • 通過 self.policy_optim 優(yōu)化器對策略損失進行反向傳播和參數(shù)更新。

5. 自適應熵調節(jié)

如果開啟了自動熵項調整( automatic_entropy_tuning ),則會進一步更新熵項 \(\alpha\) 的損失:

并通過梯度下降更新 \(\alpha\)

如果 automatic_entropy_tuning 為真,則會更新熵項:

if self.automatic_entropy_tuning:
    alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
    self.alpha_optim.zero_grad()
    alpha_loss.backward()
    self.alpha_optim.step()
    self.alpha = self.log_alpha.exp()
    alpha_tlogs = self.alpha.clone()
else:
    alpha_loss = torch.tensor(0.).to(self.device)
    alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs
  • 通過計算 alpha_loss 更新 self.alpha ,調整策略的探索-利用平衡。

6. 返回值

  • qf1_loss , qf2_loss : 兩個 Q 網(wǎng)絡的損失
  • policy_loss : 策略網(wǎng)絡的損失
  • alpha_loss : 熵權重的損失
  • alpha_tlogs : 用于日志記錄的熵權重
  • next_q_value : 平均下一個 Q 值
  • dormant_metrics : 休眠神經(jīng)元的相關度量

重置機制模塊

重置機制模塊在代碼中主要體現(xiàn)在 update_parameters 函數(shù)中,并通過 梯度主導度 (dominant metrics) 和 擾動函數(shù) (perturbation functions) 實現(xiàn)對策略網(wǎng)絡和 Q 網(wǎng)絡的重置。

重置邏輯

函數(shù)根據(jù)設定的 reset_interval 判斷是否需要對策略網(wǎng)絡和 Q 網(wǎng)絡進行擾動和重置。這里使用了"休眠"神經(jīng)元的概念,即一些在梯度更新中影響較小的神經(jīng)元,可能會被調整或重置。

函數(shù)計算了休眠度量 dormant_metrics 和因果權重差異 causal_diff ,通過擾動因子 perturb_factor 來決定是否對網(wǎng)絡進行部分或全部的擾動與重置。

重置機制模塊的原理

重置機制主要由以下部分組成:

1. 計算梯度主導度 ( $\beta_\gamma $)

在更新策略時,計算 主導梯度 ,即某些特定神經(jīng)元或參數(shù)在更新中主導作用的比率。代碼中通過調用 cal_dormant_grad(self.policy, type='policy', percentage=0.05) 實現(xiàn)這一計算,代表提取出 5% 的主導梯度來作為判斷因子。

dormant_metrics = cal_dormant_grad(self.policy, type='policy', percentage=0.05)

根據(jù)主導度 ($ \beta_\gamma$ ) 和權重 ($ w$ ),可以得到因果效應的差異。代碼里用 causal_diff 來表示因果差異:

2. 軟重置策略和 Q 網(wǎng)絡

軟重置機制通過平滑更新策略網(wǎng)絡和 Q 網(wǎng)絡,避免過大的權重更新導致的網(wǎng)絡不穩(wěn)定。這在代碼中由 soft_update 實現(xiàn):

soft_update(self.critic_target, self.critic, self.tau)

具體來說,軟更新的公式為:

其中,( \(\tau\) ) 是一個較小的常數(shù),通常介于 ( [0, 1] ) 之間,確保目標網(wǎng)絡的更新是緩慢的,以提高學習的穩(wěn)定性。

3. 策略和 Q 優(yōu)化器的重置
4. 重置機制模塊的應用

每當經(jīng)過一定的重置間隔時,判斷是否需要擾動策略和 Q 網(wǎng)絡。通過調用 perturb() dormant_perturb() 實現(xiàn)對網(wǎng)絡的擾動(perturbation)。擾動因子由梯度主導度和因果差異共同決定。

策略與 Q 網(wǎng)絡的擾動會在以下兩種情況下發(fā)生:

a. 重置間隔達成時

代碼中每當更新次數(shù) updates 達到設定的重置間隔 self.reset_interval ,并且 updates > 5000 時,才會觸發(fā)策略與 Q 網(wǎng)絡的重置邏輯。這是為了確保擾動不是頻繁發(fā)生,而是在經(jīng)過一段較長的訓練時間后進行。

具體判斷條件:

if updates % self.reset_interval == 0 and updates > 5000:
b. 主導梯度或因果效應差異滿足條件時

在達到了重置間隔后,首先會計算 梯度主導度 因果效應的差異 。這可以通過計算因果差異 causal_diff 或梯度主導度 dormant_metrics['policy_grad_dormant_ratio'] 來決定是否需要擾動。

  • 梯度主導度 計算方式通過 cal_dormant_grad() 函數(shù)實現(xiàn),如果梯度主導度較低,意味著網(wǎng)絡中的某些神經(jīng)元更新幅度過小,則需要對網(wǎng)絡進行擾動。

  • 因果效應差異 通過計算 causal_diff = np.max(causal_weight) - np.min(causal_weight) 得到,如果差異過大,則可能需要重置。

然后根據(jù)這些值通過擾動因子 factor 進行判斷:

factor = perturb_factor(dormant_metrics['policy_grad_dormant_ratio'])

如果擾動因子 ( \(\text{factor} < 1\) ),網(wǎng)絡會進行擾動:

if factor < 1:
    if self.reset == 'reset' or self.reset == 'causal_reset':
        perturb(self.policy, self.policy_optim, factor)
        perturb(self.critic, self.critic_optim, factor)
        perturb(self.critic_target, self.critic_optim, factor)
c. 總結
  • 更新次數(shù)達到設定的重置間隔 ,且經(jīng)過了一定時間的訓練( updates > 5000 )。
  • 梯度主導度 較低或 因果效應差異 過大,導致計算出的擾動因子 ( $\text{factor} < 1 $)。

這兩種條件同時滿足時,策略和 Q 網(wǎng)絡將被擾動或重置。

擾動因子的計算

在這段代碼中, factor 是基于網(wǎng)絡中梯度主導度或者因果效應差異計算出來的擾動因子。擾動因子通過函數(shù) perturb_factor() 進行計算,該函數(shù)會根據(jù)神經(jīng)元的梯度主導度( dormant_ratio )或因果效應差異( causal_diff )來調整 factor 的大小。

擾動因子(factor)

擾動因子 factor 的計算公式如下:

其中:

  • ( \(\text{dormant\_ratio}\) ) 是網(wǎng)絡中梯度主導度,即表示有多少神經(jīng)元的梯度變化較小(或者接近零),處于“休眠”狀態(tài)。

  • ( \(\text{min\_perturb\_factor}\) ) 是最小擾動因子值,代碼中設定為 0.2 。

  • ( \(\text{max\_perturb\_factor}\) ) 是最大擾動因子值,代碼中設定為 0.9 。

  • dormant_ratio :

    • 表示網(wǎng)絡中處于“休眠狀態(tài)”的梯度比例。這個比例通常通過計算神經(jīng)網(wǎng)絡中梯度幅度小于某個閾值的神經(jīng)元數(shù)量來獲得。 dormant_ratio 越大,表示越多神經(jīng)元的梯度變化很小,說明網(wǎng)絡更新不充分,需要擾動。
  • max_perturb_factor :

    • 最大擾動因子值,用來限制擾動因子的上限,代碼中設定為 0.9,意味著最大擾動幅度不會超過 90%。
  • min_perturb_factor :

    • 最小擾動因子值,用來限制擾動因子的下限,代碼中設定為 0.2,意味著即使休眠神經(jīng)元比例很低,擾動幅度也不會小于 20%。

在計算因果效應的部分,擾動因子 factor 還會根據(jù)因果效應差異 causal_diff 來調整。 causal_diff 是通過計算因果效應的最大值與最小值的差異來獲得的:

計算出的 causal_diff 會影響 causal_factor ,并進一步對 factor 進行調整:

組合擾動因子的公式

最后,如果選擇了因果重置( causal_reset ),擾動因子將使用因果差異計算出的 causal_factor 進行二次調整:

綜上所述, factor 的最終值是由梯度主導度或因果效應差異來控制的,當休眠神經(jīng)元比例較大或因果效應差異較大時, factor 會減小,導致網(wǎng)絡進行擾動。

評估代碼

這段代碼主要實現(xiàn)了在強化學習(RL)訓練過程中,定期評估智能體(agent)的性能,并在某些條件下保存最佳模型的檢查點。我們可以分段解釋該代碼:

1. 定期評估條件

if i_episode % config.eval_interval == 0 and config.eval is True:

這部分代碼用于判斷是否應該執(zhí)行智能體的評估。條件為:

  • i_episode % config.eval_interval == 0 :表示每隔 config.eval_interval 個訓練回合( i_episode 是當前回合數(shù))進行一次評估。
  • config.eval is True :確保 eval 設置為 True ,也就是說,評估功能開啟。

如果滿足這兩個條件,代碼將開始執(zhí)行評估操作。

2. 初始化評估列表

eval_reward_list = []

用于存儲每個評估回合(episode)的累計獎勵,以便之后計算平均獎勵。

3. 進行評估

for _ in range(config.eval_episodes):

評估階段將運行多個回合(由 config.eval_episodes 指定的回合數(shù)),以獲得智能體的表現(xiàn)。

3.1 回合初始化
state = env.reset()
episode_reward = []
done = False
  • env.reset() :重置環(huán)境,獲得初始狀態(tài) state 。
  • episode_reward :初始化一個列表,用于存儲當前回合中智能體獲得的所有獎勵。
  • done = False :用 done 來跟蹤當前回合是否結束。
3.2 執(zhí)行智能體動作
while not done:
    action = agent.select_action(state, evaluate=True)
    next_state, reward, done, info = env.step(action)
    state = next_state
    episode_reward.append(reward)
  • 動作選擇 agent.select_action(state, evaluate=True) 在評估模式下根據(jù)當前狀態(tài) state 選擇動作。 evaluate=True 表示該選擇是在評估模式下,通常意味著探索行為被關閉(即不進行隨機探索,而是選擇最優(yōu)動作)。

  • 環(huán)境反饋 next_state, reward, done, info = env.step(action) 通過執(zhí)行動作 action ,環(huán)境返回下一個狀態(tài) next_state ,當前獎勵 reward ,回合是否結束的標志 done ,以及附加信息 info 。

  • 狀態(tài)更新 :當前狀態(tài)被更新為 next_state ,并將獲得的獎勵 reward 存儲在 episode_reward 列表中。

循環(huán)持續(xù),直到回合結束(即 done == True )。

3.3 存儲回合獎勵
eval_reward_list.append(sum(episode_reward))

當前回合結束后,累計獎勵( sum(episode_reward) )被添加到 eval_reward_list ,用于后續(xù)計算平均獎勵。

4. 計算平均獎勵

avg_reward = np.average(eval_reward_list)

在所有評估回合結束后,計算 eval_reward_list 的平均值 avg_reward 。這是當前評估階段智能體的表現(xiàn)指標。

5. 保存最佳模型

if config.save_checkpoint:
    if avg_reward >= best_reward:
        best_reward = avg_reward
        agent.save_checkpoint(checkpoint_path, 'best')
  • 如果 config.save_checkpoint True ,則表示需要檢查是否保存模型。
  • 通過判斷 avg_reward 是否超過了之前的最佳獎勵 best_reward ,如果是,則更新 best_reward ,并保存當前模型的檢查點。
agent.save_checkpoint(checkpoint_path, 'best')

這行代碼會將智能體的狀態(tài)保存到指定的路徑 checkpoint_path ,并標記為 "best" ,表示這是性能最佳的模型。

論文復現(xiàn)結果

咳咳,可以發(fā)現(xiàn)程序只記錄了 0~1000 的數(shù)據(jù),從 1001 開始的每一個數(shù)據(jù)都顯示報錯所以被舍棄掉了。

后面重新下載了github代碼包,發(fā)生了同樣的報錯信息

報錯信息是:你在 X+1 輪次中嘗試記載 X 輪次中的信息,所以這個數(shù)據(jù)被舍棄掉了

大概是主程序哪里有問題吧,我自己也沒調 bug

不過這個項目結題了,主要負責這個項目的博士師兄也畢業(yè)了,也不好說些什么(雖然我有他微信),至少論文里面的模塊挺有用的。ㄊ謩踊

小編推薦閱讀

好特網(wǎng)發(fā)布此文僅為傳遞信息,不代表好特網(wǎng)認同期限觀點或證實其描述。

a 1.0
a 1.0
類型:休閑益智  運營狀態(tài):正式運營  語言:中文   

游戲攻略

游戲禮包

游戲視頻

游戲下載

游戲活動

《alittletotheleft》官網(wǎng)正版是一款備受歡迎的休閑益智整理游戲。玩家的任務是對日常生活中的各種雜亂物
AWA 1.40
AWA 1.40
類型:休閑益智  運營狀態(tài):未知  語言:中文   

游戲攻略

游戲禮包

游戲視頻

游戲下載

游戲活動

《AWA》安卓版是由開發(fā)商MentalLab研發(fā)的一款帶有奇幻色彩的神秘迷宮冒險游戲,華麗而精美的游戲界面,讓

相關視頻攻略

更多

掃二維碼進入好特網(wǎng)手機版本!

掃二維碼進入好特網(wǎng)微信公眾號!

本站所有軟件,都由網(wǎng)友上傳,如有侵犯你的版權,請發(fā)郵件[email protected]

湘ICP備2022002427號-10 湘公網(wǎng)安備:43070202000427號© 2013~2025 haote.com 好特網(wǎng)