From 9df9ead6038688193782010acdecefbd74cdbc7c Mon Sep 17 00:00:00 2001
From: ssmaddila <siva-sri-prasanna.maddila@inrae.fr>
Date: Thu, 25 Apr 2024 11:49:36 +0200
Subject: [PATCH 1/2] Found the bug!

When calling compute_single_action in from_checkpoint.py, the intial
state must be supplied. Now, when used with an LSTM-based model (like
the ones defined in recurrent.py), this is not the case. There is an
ancilliary problem as well with seq_lens that "randomly" becomes an
empty list; it remains to be seen where to add the logic that resets it.
---
 examples/rllib_examples/from_checkpoint.py | 2 ++
 examples/rllib_examples/recurrent.py       | 7 +++++--
 2 files changed, 7 insertions(+), 2 deletions(-)

diff --git a/examples/rllib_examples/from_checkpoint.py b/examples/rllib_examples/from_checkpoint.py
index 081f897..c7087e9 100644
--- a/examples/rllib_examples/from_checkpoint.py
+++ b/examples/rllib_examples/from_checkpoint.py
@@ -42,6 +42,8 @@ def get_actions(algo: Algorithm, obs: dict) -> list:
         else:
             timestep = obs[agent]["observations"][0]
 
+        # TODO: Check if using LSTM models and then supply initial state.
+
         # Calculate the single actions for this step
         _temp = algo.compute_single_action(
             observation=obs[agent],
diff --git a/examples/rllib_examples/recurrent.py b/examples/rllib_examples/recurrent.py
index 1f04d79..f59f8a7 100644
--- a/examples/rllib_examples/recurrent.py
+++ b/examples/rllib_examples/recurrent.py
@@ -191,8 +191,8 @@ class LSTMActionMaskingModel(RecurrentNetwork, nn.Module):
 
     def get_initial_state(self):
         return [
-            torch.zeros(1, self.lstm_state_size).squeeze(0),
-            torch.zeros(1, self.lstm_state_size).squeeze(0),
+            torch.zeros(1, self.lstm_state_size).unsqueeze(0),
+            torch.zeros(1, self.lstm_state_size).unsqueeze(0),
         ]
 
     def value_function(self):
@@ -225,6 +225,9 @@ class LSTMActionMaskingModel(RecurrentNetwork, nn.Module):
             flat_inputs = torch.cat(
                 [flat_inputs, input_dict["prev_rewards"].unsqueeze(1)], dim=1
             )
+        if seq_lens is None:
+            # This happens when compute_single_action is called.
+            seq_lens = torch.ones([flat_inputs.shape[0], 1])
 
         # Note that max_seq_len != input_dict.max_seq_len != seq_lens.max()
         # as input_dict may have extra zero-padding beyond seq_lens.max().
-- 
GitLab


From 208d4c8faa9530be43d77dffe4eba9ffe943ee8c Mon Sep 17 00:00:00 2001
From: ssmaddila <siva-sri-prasanna.maddila@inrae.fr>
Date: Fri, 26 Apr 2024 09:08:36 +0200
Subject: [PATCH 2/2] Fixed the recurrent LSTMActionMaskingModel for
 compute_single_action

This should only be important when calling compute_single_action (which
we do only for from_checkpoint.py). We still need to apply this fix (if
needed) for the QmixLSTMModel.
---
 examples/rllib_examples/recurrent.py | 11 +++++++----
 1 file changed, 7 insertions(+), 4 deletions(-)

diff --git a/examples/rllib_examples/recurrent.py b/examples/rllib_examples/recurrent.py
index f59f8a7..73eb669 100644
--- a/examples/rllib_examples/recurrent.py
+++ b/examples/rllib_examples/recurrent.py
@@ -191,8 +191,8 @@ class LSTMActionMaskingModel(RecurrentNetwork, nn.Module):
 
     def get_initial_state(self):
         return [
-            torch.zeros(1, self.lstm_state_size).unsqueeze(0),
-            torch.zeros(1, self.lstm_state_size).unsqueeze(0),
+            torch.zeros(1, self.lstm_state_size).squeeze(0),
+            torch.zeros(1, self.lstm_state_size).squeeze(0),
         ]
 
     def value_function(self):
@@ -225,9 +225,12 @@ class LSTMActionMaskingModel(RecurrentNetwork, nn.Module):
             flat_inputs = torch.cat(
                 [flat_inputs, input_dict["prev_rewards"].unsqueeze(1)], dim=1
             )
-        if seq_lens is None:
-            # This happens when compute_single_action is called.
+        if not isinstance(seq_lens, torch.Tensor) and not seq_lens:
+            # This happens when calling compute_single_action
             seq_lens = torch.ones([flat_inputs.shape[0], 1])
+        if not state:
+            # This also happens when calling compute_single_action
+            state = [torch.unsqueeze(s, 0) for s in self.get_initial_state()]
 
         # Note that max_seq_len != input_dict.max_seq_len != seq_lens.max()
         # as input_dict may have extra zero-padding beyond seq_lens.max().
-- 
GitLab