From bccbefbd50a2ea20bd4b6b26648c49970d441bed Mon Sep 17 00:00:00 2001 From: LYouC <1786686418@qq.com> Date: Thu, 21 Mar 2024 17:58:32 +0800 Subject: [PATCH] fixed val_ratio dose not function as expected in dataset --- diffusion_policy/dataset/blockpush_lowdim_dataset.py | 5 +++-- diffusion_policy/dataset/kitchen_lowdim_dataset.py | 5 +++-- diffusion_policy/dataset/kitchen_mjl_lowdim_dataset.py | 5 +++-- diffusion_policy/dataset/pusht_dataset.py | 5 +++-- diffusion_policy/dataset/pusht_image_dataset.py | 5 +++-- 5 files changed, 15 insertions(+), 10 deletions(-) diff --git a/diffusion_policy/dataset/blockpush_lowdim_dataset.py b/diffusion_policy/dataset/blockpush_lowdim_dataset.py index 86242c224..a9c043f40 100644 --- a/diffusion_policy/dataset/blockpush_lowdim_dataset.py +++ b/diffusion_policy/dataset/blockpush_lowdim_dataset.py @@ -29,6 +29,7 @@ def __init__(self, n_episodes=self.replay_buffer.n_episodes, val_ratio=val_ratio, seed=seed) + self.val_mask = val_mask train_mask = ~val_mask self.sampler = SequenceSampler( replay_buffer=self.replay_buffer, @@ -52,9 +53,9 @@ def get_validation_dataset(self): sequence_length=self.horizon, pad_before=self.pad_before, pad_after=self.pad_after, - episode_mask=~self.train_mask + episode_mask=self.val_mask ) - val_set.train_mask = ~self.train_mask + val_set.train_mask = self.val_mask return val_set def get_normalizer(self, mode='limits', **kwargs): diff --git a/diffusion_policy/dataset/kitchen_lowdim_dataset.py b/diffusion_policy/dataset/kitchen_lowdim_dataset.py index 601e21cb1..f4fa69fd8 100644 --- a/diffusion_policy/dataset/kitchen_lowdim_dataset.py +++ b/diffusion_policy/dataset/kitchen_lowdim_dataset.py @@ -40,6 +40,7 @@ def __init__(self, n_episodes=self.replay_buffer.n_episodes, val_ratio=val_ratio, seed=seed) + self.val_mask = val_mask train_mask = ~val_mask self.sampler = SequenceSampler( replay_buffer=self.replay_buffer, @@ -60,9 +61,9 @@ def get_validation_dataset(self): sequence_length=self.horizon, pad_before=self.pad_before, pad_after=self.pad_after, - episode_mask=~self.train_mask + episode_mask=self.val_mask ) - val_set.train_mask = ~self.train_mask + val_set.train_mask = self.val_mask return val_set def get_normalizer(self, mode='limits', **kwargs): diff --git a/diffusion_policy/dataset/kitchen_mjl_lowdim_dataset.py b/diffusion_policy/dataset/kitchen_mjl_lowdim_dataset.py index e3173818c..36f60d1f9 100644 --- a/diffusion_policy/dataset/kitchen_mjl_lowdim_dataset.py +++ b/diffusion_policy/dataset/kitchen_mjl_lowdim_dataset.py @@ -61,6 +61,7 @@ def __init__(self, n_episodes=self.replay_buffer.n_episodes, val_ratio=val_ratio, seed=seed) + self.val_mask = val_mask train_mask = ~val_mask self.sampler = SequenceSampler( replay_buffer=self.replay_buffer, @@ -81,9 +82,9 @@ def get_validation_dataset(self): sequence_length=self.horizon, pad_before=self.pad_before, pad_after=self.pad_after, - episode_mask=~self.train_mask + episode_mask=self.val_mask ) - val_set.train_mask = ~self.train_mask + val_set.train_mask = self.val_mask return val_set def get_normalizer(self, mode='limits', **kwargs): diff --git a/diffusion_policy/dataset/pusht_dataset.py b/diffusion_policy/dataset/pusht_dataset.py index dc3ec1c81..cca456548 100644 --- a/diffusion_policy/dataset/pusht_dataset.py +++ b/diffusion_policy/dataset/pusht_dataset.py @@ -30,6 +30,7 @@ def __init__(self, n_episodes=self.replay_buffer.n_episodes, val_ratio=val_ratio, seed=seed) + self.val_mask = val_mask train_mask = ~val_mask train_mask = downsample_mask( mask=train_mask, @@ -58,9 +59,9 @@ def get_validation_dataset(self): sequence_length=self.horizon, pad_before=self.pad_before, pad_after=self.pad_after, - episode_mask=~self.train_mask + episode_mask=self.val_mask ) - val_set.train_mask = ~self.train_mask + val_set.train_mask = self.val_mask return val_set def get_normalizer(self, mode='limits', **kwargs): diff --git a/diffusion_policy/dataset/pusht_image_dataset.py b/diffusion_policy/dataset/pusht_image_dataset.py index f096a8f0f..e64896206 100644 --- a/diffusion_policy/dataset/pusht_image_dataset.py +++ b/diffusion_policy/dataset/pusht_image_dataset.py @@ -28,6 +28,7 @@ def __init__(self, n_episodes=self.replay_buffer.n_episodes, val_ratio=val_ratio, seed=seed) + self.val_mask = val_mask; train_mask = ~val_mask train_mask = downsample_mask( mask=train_mask, @@ -52,9 +53,9 @@ def get_validation_dataset(self): sequence_length=self.horizon, pad_before=self.pad_before, pad_after=self.pad_after, - episode_mask=~self.train_mask + episode_mask=self.val_mask ) - val_set.train_mask = ~self.train_mask + val_set.train_mask = self.val_mask return val_set def get_normalizer(self, mode='limits', **kwargs):