From 9df9ead6038688193782010acdecefbd74cdbc7c Mon Sep 17 00:00:00 2001 From: ssmaddila <siva-sri-prasanna.maddila@inrae.fr> Date: Thu, 25 Apr 2024 11:49:36 +0200 Subject: [PATCH 1/2] Found the bug! When calling compute_single_action in from_checkpoint.py, the intial state must be supplied. Now, when used with an LSTM-based model (like the ones defined in recurrent.py), this is not the case. There is an ancilliary problem as well with seq_lens that "randomly" becomes an empty list; it remains to be seen where to add the logic that resets it. --- examples/rllib_examples/from_checkpoint.py | 2 ++ examples/rllib_examples/recurrent.py | 7 +++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/rllib_examples/from_checkpoint.py b/examples/rllib_examples/from_checkpoint.py index 081f897..c7087e9 100644 --- a/examples/rllib_examples/from_checkpoint.py +++ b/examples/rllib_examples/from_checkpoint.py @@ -42,6 +42,8 @@ def get_actions(algo: Algorithm, obs: dict) -> list: else: timestep = obs[agent]["observations"][0] + # TODO: Check if using LSTM models and then supply initial state. + # Calculate the single actions for this step _temp = algo.compute_single_action( observation=obs[agent], diff --git a/examples/rllib_examples/recurrent.py b/examples/rllib_examples/recurrent.py index 1f04d79..f59f8a7 100644 --- a/examples/rllib_examples/recurrent.py +++ b/examples/rllib_examples/recurrent.py @@ -191,8 +191,8 @@ class LSTMActionMaskingModel(RecurrentNetwork, nn.Module): def get_initial_state(self): return [ - torch.zeros(1, self.lstm_state_size).squeeze(0), - torch.zeros(1, self.lstm_state_size).squeeze(0), + torch.zeros(1, self.lstm_state_size).unsqueeze(0), + torch.zeros(1, self.lstm_state_size).unsqueeze(0), ] def value_function(self): @@ -225,6 +225,9 @@ class LSTMActionMaskingModel(RecurrentNetwork, nn.Module): flat_inputs = torch.cat( [flat_inputs, input_dict["prev_rewards"].unsqueeze(1)], dim=1 ) + if seq_lens is None: + # This happens when compute_single_action is called. + seq_lens = torch.ones([flat_inputs.shape[0], 1]) # Note that max_seq_len != input_dict.max_seq_len != seq_lens.max() # as input_dict may have extra zero-padding beyond seq_lens.max(). -- GitLab From 208d4c8faa9530be43d77dffe4eba9ffe943ee8c Mon Sep 17 00:00:00 2001 From: ssmaddila <siva-sri-prasanna.maddila@inrae.fr> Date: Fri, 26 Apr 2024 09:08:36 +0200 Subject: [PATCH 2/2] Fixed the recurrent LSTMActionMaskingModel for compute_single_action This should only be important when calling compute_single_action (which we do only for from_checkpoint.py). We still need to apply this fix (if needed) for the QmixLSTMModel. --- examples/rllib_examples/recurrent.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/rllib_examples/recurrent.py b/examples/rllib_examples/recurrent.py index f59f8a7..73eb669 100644 --- a/examples/rllib_examples/recurrent.py +++ b/examples/rllib_examples/recurrent.py @@ -191,8 +191,8 @@ class LSTMActionMaskingModel(RecurrentNetwork, nn.Module): def get_initial_state(self): return [ - torch.zeros(1, self.lstm_state_size).unsqueeze(0), - torch.zeros(1, self.lstm_state_size).unsqueeze(0), + torch.zeros(1, self.lstm_state_size).squeeze(0), + torch.zeros(1, self.lstm_state_size).squeeze(0), ] def value_function(self): @@ -225,9 +225,12 @@ class LSTMActionMaskingModel(RecurrentNetwork, nn.Module): flat_inputs = torch.cat( [flat_inputs, input_dict["prev_rewards"].unsqueeze(1)], dim=1 ) - if seq_lens is None: - # This happens when compute_single_action is called. + if not isinstance(seq_lens, torch.Tensor) and not seq_lens: + # This happens when calling compute_single_action seq_lens = torch.ones([flat_inputs.shape[0], 1]) + if not state: + # This also happens when calling compute_single_action + state = [torch.unsqueeze(s, 0) for s in self.get_initial_state()] # Note that max_seq_len != input_dict.max_seq_len != seq_lens.max() # as input_dict may have extra zero-padding beyond seq_lens.max(). -- GitLab