您的位置:首頁 >公共 >

30分鐘吃掉DQN算法

2023-06-22 05:33:35 來源:程序員客棧

表格型方法存儲的狀態(tài)數(shù)量有限,當面對圍棋或機器人控制這類有數(shù)不清的狀態(tài)的環(huán)境時,表格型方法在存儲和查找效率上都受局限,DQN的提出解決了這一局限,使用神經(jīng)網(wǎng)絡來近似替代Q表格。


【資料圖】

本質(zhì)上DQN還是一個Q-learning算法,更新方式一致。為了更好的探索環(huán)境,同樣的也采用epsilon-greedy方法訓練。

在Q-learning的基礎上,DQN提出了兩個技巧使得Q網(wǎng)絡的更新迭代更穩(wěn)定。

經(jīng)驗回放(Experience Replay): 使用一個經(jīng)驗池存儲多條經(jīng)驗s,a,r,s",再從中隨機抽取一批數(shù)據(jù)送去訓練。

固定目標(Fixed Q-Target): 復制一個和原來Q網(wǎng)絡結(jié)構(gòu)一樣的Target-Q網(wǎng)絡,用于計算Q目標值。

公眾號算法美食屋后臺回復關(guān)鍵詞:torchkeras,獲取本文notebook源碼~

不了解強化學習的同學,推薦先閱讀:Q-learning解決懸崖問題

一,準備環(huán)境

gym是一個常用的強化學習測試環(huán)境,可以用make創(chuàng)建環(huán)境。

env具有reset,step,render幾個方法。

倒立擺問題

環(huán)境設計如下:

倒立擺問題環(huán)境的狀態(tài)是無限的,用一個4維的向量表示state.

4個維度分別代表如下含義

cart位置:-2.4 ~ 2.4cart速度:-inf ~ infpole角度:-0.5 ~ 0.5 (radian)pole角速度:-inf ~ inf

智能體設計如下:

智能體的action有兩種,可能的取值2種:

0,向左1,向右

獎勵設計如下:

每維持一個步驟,獎勵+1,到達200個步驟,游戲結(jié)束。

所以最高得分為200分。

倒立擺問題希望訓練一個智能體能夠盡可能地維持倒立擺的平衡。

import?gym?import?numpy?as?np?import?pandas?as?pd?import?timeimport?matplotlibimport?matplotlib.pyplot?as?pltfrom?IPython?import?displayprint("gym.__version__=",gym.__version__)%matplotlib?inline#可視化函數(shù):def?show_state(env,?step,?info=""):????plt.figure(num=10086,dpi=100)????plt.clf()????plt.imshow(env.render())????plt.title("step:?%d?%s"?%?(step,?info))????plt.axis("off")????display.clear_output(wait=True)????display.display(plt.gcf())????plt.close()????env?=?gym.make("CartPole-v1",render_mode="rgb_array")?#?CartPole-v0:?預期最后一次評估總分?>180(最大值是200)env.reset()action_dim?=?env.action_space.n???#?CartPole-v0:?2obs_shape?=?env.observation_space.shape???#?CartPole-v0:?(4,)

gym.__version__= 0.26.2

env.reset()done?=?Falsestep?=?0while?not?done:????????action?=?np.random.randint(0,?1)????state,reward,done,truncated,info?=?env.step(action)????step+=1????print(state,reward)????time.sleep(1.0)????#env.render()?????show_state(env,step=step)????#print("step?{}:?action?{},?state?{},?reward?{},?done?{},?truncated?{},?info?{}".format(\????#????????step,?action,?state,?reward,?done,?truncated,info))????display.clear_output(wait=True)

可以看到,沒有訓練智能體之前,我們采取隨機動作的話,只維持了10步,倒立擺就因為傾斜角度超出范圍而導致游戲結(jié)束。?

二,定義Agent

DQN的核心思想為使用一個神經(jīng)網(wǎng)絡來近似替代Q表格。

Model: 模型結(jié)構(gòu), 負責擬合函數(shù) Q(s,a)。主要實現(xiàn)forward方法。

Agent:智能體,負責學習并和環(huán)境交互, 輸入輸出是numpy.array形式。有sample(單步采樣), predict(單步預測), 有predict_batch(批量預測), compute_loss(計算損失), sync_target(參數(shù)同步)等方法。

import?torch?from?torch?import?nnimport?torch.nn.functional?as?Fimport?copy?class?Model(nn.Module):????def?__init__(self,?obs_dim,?action_dim):????????????????#?3層全連接網(wǎng)絡????????super(Model,?self).__init__()????????self.obs_dim?=?obs_dim????????self.action_dim?=?action_dim?????????self.fc1?=?nn.Linear(obs_dim,32)????????self.fc2?=?nn.Linear(32,16)????????self.fc3?=?nn.Linear(16,action_dim)????def?forward(self,?obs):????????#?輸入state,輸出所有action對應的Q,[Q(s,a1),?Q(s,a2),?Q(s,a3)...]????????x?=?self.fc1(obs)????????x?=?torch.tanh(x)????????x?=?self.fc2(x)????????x?=?torch.tanh(x)????????Q?=?self.fc3(x)????????return?Q????model?=?Model(4,2)model_target?=?copy.deepcopy(model)model.eval()model.forward(torch.tensor([[0.2,0.1,0.2,0.0],[0.3,0.5,0.2,0.6]]))model_target.eval()?model_target.forward(torch.tensor([[0.2,0.1,0.2,0.0],[0.3,0.5,0.2,0.6]]))

tensor([[-0.1148,  0.0068],        [-0.1311,  0.0315]], grad_fn=)

import?torch?from?torch?import?nn?import?copy?class?DQNAgent(nn.Module):????def?__init__(self,?model,?????????gamma=0.9,????????e_greed=0.1,????????e_greed_decrement=0.001????????):????????super().__init__()????????????????self.model?=?model????????self.target_model?=?copy.deepcopy(model)??????????self.gamma?=?gamma?#?reward?的衰減因子,一般取?0.9?到?0.999?不等????????????????self.e_greed?=?e_greed??#?有一定概率隨機選取動作,探索????????self.e_greed_decrement?=?e_greed_decrement??#?隨著訓練逐步收斂,探索的程度慢慢降低????????????????self.global_step?=?0????????self.update_target_steps?=?200?#?每隔200個training?steps再把model的參數(shù)復制到target_model中????????????????????def?forward(self,obs):????????return?self.model(obs)????????@torch.no_grad()????def?predict_batch(self,?obs):????????"""?使用self.model網(wǎng)絡來獲取?[Q(s,a1),Q(s,a2),...]????????"""????????self.model.eval()????????return?self.forward(obs)????????????#單步驟采樣????????def?sample(self,?obs):????????sample?=?np.random.rand()??#?產(chǎn)生0~1之間的小數(shù)????????if?sample?

agent?=?DQNAgent(model,gamma=0.9,e_greed=0.1,?????????????????e_greed_decrement=0.001)?

agent.predict_batch(torch.tensor([[2.0,3.0,4.0,2.0],[1.0,2.0,3.0,4.0]]))

tensor([[-0.1596, -0.0481],        [-0.0927,  0.0318]])

loss?=?agent.compute_loss(torch.tensor([[2.0,3.0,4.0,2.0],[1.0,2.0,3.0,4.0],[1.0,2.0,3.0,4.0]]),??????????torch.tensor([[1],[0],[0]]),??????????torch.tensor([[1.0],[1.0],[1.0]]),?????????torch.tensor([[2.0,3.0,0.4,2.0],[1.0,2.0,3.0,4.0],[1.0,2.0,3.0,4.0]]),?????????torch.tensor(0.9))print(loss)

tensor(0.5757, grad_fn=)

三,訓練Agent

import?randomimport?collectionsimport?numpy?as?npLEARN_FREQ?=?5?#?訓練頻率,不需要每一個step都learn,攢一些新增經(jīng)驗后再learn,提高效率MEMORY_SIZE?=?2048????#?replay?memory的大小,越大越占用內(nèi)存MEMORY_WARMUP_SIZE?=?512??#?replay_memory?里需要預存一些經(jīng)驗數(shù)據(jù),再開啟訓練BATCH_SIZE?=?128???#?每次給agent?learn的數(shù)據(jù)數(shù)量,從replay?memory隨機里sample一批數(shù)據(jù)出來

#經(jīng)驗回放class?ReplayMemory(object):????def?__init__(self,?max_size):????????self.buffer?=?collections.deque(maxlen=max_size)????#?增加一條經(jīng)驗到經(jīng)驗池中????def?append(self,?exp):????????self.buffer.append(exp)????#?從經(jīng)驗池中選取N條經(jīng)驗出來????def?sample(self,?batch_size):????????mini_batch?=?random.sample(self.buffer,?batch_size)????????obs_batch,?action_batch,?reward_batch,?next_obs_batch,?done_batch?=?[],?[],?[],?[],?[]????????for?experience?in?mini_batch:????????????s,?a,?r,?s_p,?done?=?experience????????????obs_batch.append(s)????????????action_batch.append(a)????????????reward_batch.append(r)????????????next_obs_batch.append(s_p)????????????done_batch.append(done)????????return?np.array(obs_batch).astype("float32"),?\????????????np.array(action_batch).astype("int64"),?np.array(reward_batch).astype("float32"),\????????????np.array(next_obs_batch).astype("float32"),?np.array(done_batch).astype("float32")????def?__len__(self):????????return?len(self.buffer)????

from?torch.utils.data?import?IterableDataset,DataLoader??class?MyDataset(IterableDataset):????def?__init__(self,env,agent,rpm,stage="train",size=200):????????self.env?=?env????????self.agent?=?agent?????????self.rpm?=?rpm?if?stage=="train"?else?None????????self.stage?=?stage????????self.size?=?size?????????????def?__iter__(self):????????obs,info?=?self.env.reset()?#?重置環(huán)境,?重新開一局(即開始新的一個episode)????????step?=?0????????batch_reward_true?=?[]?#記錄真實的reward????????while?True:????????????step?+=?1????????????action?=?self.agent.sample(obs)?????????????next_obs,?reward,?done,?_,?_?=?self.env.step(action)?#?與環(huán)境進行一個交互????????????batch_reward_true.append(reward)????????????????????????if?self.stage=="train":????????????????self.rpm.append((obs,?action,?reward,?next_obs,?float(done)))????????????????if?(len(rpm)?>MEMORY_WARMUP_SIZE)?and?(step?%?LEARN_FREQ?==?0):????????????????????#yield?batch_obs,?batch_action,?batch_reward,?batch_next_obs,batch_done????????????????????yield?self.rpm.sample(BATCH_SIZE),sum(batch_reward_true)????????????????????batch_reward_true.clear()????????????????????????else:????????????????obs_batch?=?np.array([obs]).astype("float32")????????????????action_batch?=?np.array([action]).astype("int64")????????????????reward_batch?=?np.array([reward]).astype("float32")????????????????next_obs_batch?=?np.array([next_obs]).astype("float32")????????????????done_batch?=?np.array([float(done)]).astype("float32")????????????????batch_data?=?obs_batch,action_batch,reward_batch,next_obs_batch,done_batch????????????????yield?batch_data,sum(batch_reward_true)????????????????batch_reward_true.clear()????????????????????????????if?self.stage?=="train":????????????????next_action?=?self.agent.sample(next_obs)?#?訓練階段使用探索策略????????????else:????????????????next_action?=?self.agent.predict(next_obs)?#?驗證階段使用模型預測結(jié)果?????????????action?=?next_action????????????obs?=?next_obs???????????????if?done:????????????????if?self.stage=="train"?and?len(self.rpm)

#ReplayMemory預存數(shù)據(jù)while?len(ds_train.rpm)

1347167272511521

def?collate_fn(batch):????samples,rewards?=?[x[0]?for?x?in?batch],[x[-1]?for?x?in?batch]?????samples?=?[torch.from_numpy(np.concatenate([x[j]?for?x?in?samples]))?for?j?in?range(5)]?????rewards?=?torch.from_numpy(np.array([sum(rewards)]).astype("float32"))????return?samples,rewards?dl_train?=?DataLoader(ds_train,batch_size=1,collate_fn=collate_fn)dl_val?=?DataLoader(ds_val,batch_size=1,collate_fn=collate_fn)

for?batch?in?dl_train:????break

import?sys,datetimefrom?tqdm?import?tqdmimport?numpy?as?npfrom?accelerate?import?Acceleratorfrom?torchkeras?import?KerasModelimport?pandas?as?pd?from?copy?import?deepcopyclass?StepRunner:????def?__init__(self,?net,?loss_fn,?accelerator=None,?stage?=?"train",?metrics_dict?=?None,??????????????????optimizer?=?None,?lr_scheduler?=?None?????????????????):????????self.net,self.loss_fn,self.metrics_dict,self.stage?=?net,loss_fn,metrics_dict,stage????????self.optimizer,self.lr_scheduler?=?optimizer,lr_scheduler????????self.accelerator?=?accelerator?if?accelerator?is?not?None?else?Accelerator()????????def?__call__(self,?batch):????????????????samples,reward?=?batch????????#torch_data?=?([torch.from_numpy(x)?for?x?in?batch_data])????????loss?=?self.net.compute_loss(*samples)????????????????#backward()????????if?self.optimizer?is?not?None?and?self.stage=="train":????????????self.accelerator.backward(loss)????????????if?self.accelerator.sync_gradients:????????????????self.accelerator.clip_grad_norm_(self.net.parameters(),?1.0)????????????self.optimizer.step()????????????if?self.lr_scheduler?is?not?None:????????????????self.lr_scheduler.step()????????????self.optimizer.zero_grad()????????????????????????????????????#losses?(or?plain?metric)????????step_losses?=?{self.stage+"_reward":reward.item(),????????????????????????self.stage+"_loss":loss.item()}????????????????#metrics?(stateful?metric)????????step_metrics?=?{}????????if?self.stage=="train":????????????if?self.optimizer?is?not?None:????????????????step_metrics["lr"]?=?self.optimizer.state_dict()["param_groups"][0]["lr"]????????????else:????????????????step_metrics["lr"]?=?0.0????????return?step_losses,step_metrics????class?EpochRunner:????def?__init__(self,steprunner,quiet=False):????????self.steprunner?=?steprunner????????self.stage?=?steprunner.stage????????self.accelerator?=?steprunner.accelerator????????self.net?=?steprunner.net????????self.quiet?=?quiet????????????def?__call__(self,dataloader):????????dataloader.agent?=?self.net?????????n?=?dataloader.size??if?hasattr(dataloader,"size")?else?len(dataloader)????????loop?=?tqdm(enumerate(dataloader,start=1),?????????????????????total=n,????????????????????file=sys.stdout,????????????????????disable=not?self.accelerator.is_local_main_process?or?self.quiet,????????????????????ncols=100???????????????????)????????epoch_losses?=?{}????????for?step,?batch?in?loop:?????????????if?step

keras_model?=?KerasModel(net=?agent,loss_fn=None,????????optimizer=torch.optim.Adam(agent.model.parameters(),lr=1e-2))dfhistory?=?keras_model.fit(train_data?=?dl_train,????val_data=dl_val,????epochs=600,????ckpt_path="checkpoint.pt",????patience=100,????monitor="val_reward",????mode="max",????callbacks=None,????plot=?True,????cpu=True)

四,評估Agent

#?評估?agent,?跑?3?次,總reward求平均def?evaluate(env,?agent,?render=False):????eval_reward?=?[]????for?i?in?range(2):????????obs,info?=?env.reset()????????episode_reward?=?0????????step=0????????while?step<300:????????????action?=?agent.predict(obs)??#?預測動作,只選最優(yōu)動作????????????obs,?reward,?done,?_,?_?=?env.step(action)????????????episode_reward?+=?reward????????????if?render:????????????????show_state(env,step,info="reward="+str(episode_reward))????????????if?done:????????????????break????????????step+=1????????eval_reward.append(episode_reward)????return?np.mean(eval_reward)

#直觀顯示動畫env?=?gym.make("CartPole-v1",render_mode="rgb_array")?evaluate(env,?agent,?render=True)

可以看到,訓練完成之后,我們的agent已經(jīng)變得非常的智能了,能夠維持倒立擺的平衡超過200s。?

288.5

五,保存Agent

torch.save(agent.state_dict(),"dqn_agent.pt")

萬水千山總是情,點個在看行不行???

本文notebook源碼,以及更多有趣范例,可在公眾號算法美食屋后臺回復關(guān)鍵詞:torchkeras,獲取~

標簽: