diff --git a/Scripts/H2O/SimpleSAC/__pycache__/model.cpython-39.pyc b/Scripts/H2O/SimpleSAC/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000..8339b47 Binary files /dev/null and b/Scripts/H2O/SimpleSAC/__pycache__/model.cpython-39.pyc differ diff --git a/Scripts/H2O/SimpleSAC/model.py b/Scripts/H2O/SimpleSAC/model.py new file mode 100644 index 0000000..28ce8c6 --- /dev/null +++ b/Scripts/H2O/SimpleSAC/model.py @@ -0,0 +1,230 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import TanhTransform +from torch.distributions import Normal +import math +import torch.autograd as autograd +from torch.nn.utils import spectral_norm +# from copy import deepcopy + + +def extend_and_repeat(tensor, dim, repeat): + # Extend and repeast the tensor along dim axie and repeat it + ones_shape = [1 for _ in range(tensor.ndim + 1)] + ones_shape[dim] = repeat + return torch.unsqueeze(tensor, dim) * tensor.new_ones(ones_shape) + + +def soft_target_update(network, target_network, soft_target_update_rate): + target_network_params = {k: v for k, v in target_network.named_parameters()} + for k, v in network.named_parameters(): + target_network_params[k].data = ( + (1 - soft_target_update_rate) * target_network_params[k].data + + soft_target_update_rate * v.data + ) + + +def multiple_action_q_function(forward): + # Forward the q function with multiple actions on each state, to be used as a decorator + def wrapped(self, observations, actions, is_x_grad=False, **kwargs): + multiple_actions = False + batch_size = observations.shape[0] + if actions.ndim == 3 and observations.ndim == 2: + multiple_actions = True + observations = extend_and_repeat(observations, 1, actions.shape[1]).reshape(-1, observations.shape[-1]) + actions = actions.reshape(-1, actions.shape[-1]) + q_values = forward(self, observations, actions, is_x_grad, **kwargs) + if multiple_actions: + q_values = q_values.reshape(batch_size, -1) + return q_values + return wrapped + + +class FullyConnectedNetwork(nn.Module): + def __init__(self, input_dim, output_dim, arch='256-256', orthogonal_init=False, is_LN=False, is_SN=False): + super().__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.arch = arch + self.orthogonal_init = orthogonal_init + + d = input_dim + modules = [] + hidden_sizes = [int(h) for h in arch.split('-')] + + for hidden_size in hidden_sizes: + fc = nn.Linear(d, hidden_size) + if orthogonal_init: + nn.init.orthogonal_(fc.weight, gain=np.sqrt(2)) + nn.init.constant_(fc.bias, 0.0) + if is_SN: + modules.append(spectral_norm(fc)) + else: + modules.append(fc) + modules.append(nn.ReLU()) + + if is_LN: + ln = nn.LayerNorm(hidden_size) + modules.append(ln) + + d = hidden_size + + last_fc = nn.Linear(d, output_dim) + if orthogonal_init: + nn.init.orthogonal_(last_fc.weight, gain=1e-2) + else: + nn.init.xavier_uniform_(last_fc.weight, gain=1e-2) + nn.init.constant_(last_fc.bias, 0.0) + + if is_SN: + modules.append(spectral_norm(last_fc)) + else: + modules.append(last_fc) + + self.network = nn.Sequential(*modules) + + def forward(self, input_tensor): + return self.network(input_tensor) + + +class ReparameterizedTanhGaussian(nn.Module): + + def __init__(self, log_std_min=-20.0, log_std_max=2.0, no_tanh=False): + super().__init__() + self.log_std_min = log_std_min + self.log_std_max = log_std_max + self.no_tanh = no_tanh + + def log_prob(self, mean, log_std, sample): + log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) + std = torch.exp(log_std) + if self.no_tanh: + action_distribution = Normal(mean, std) + else: + action_distribution = TransformedDistribution( + Normal(mean, std), TanhTransform(cache_size=1) + ) + return torch.sum(action_distribution.log_prob(sample), dim=-1) + + def forward(self, mean, log_std, deterministic=False): + log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) + std = torch.exp(log_std) + + if self.no_tanh: + action_distribution = Normal(mean, std) + else: + action_distribution = TransformedDistribution( + Normal(mean, std), TanhTransform(cache_size=1) + ) + + if deterministic: + action_sample = torch.tanh(mean) + else: + action_sample = action_distribution.rsample() + + log_prob = torch.sum( + action_distribution.log_prob(action_sample), dim=-1 + ) + + return action_sample, log_prob + + +class TanhGaussianPolicy(nn.Module): + + def __init__(self, observation_dim, action_dim, arch='256-256', + log_std_multiplier=1.0, log_std_offset=-1.0, + orthogonal_init=False, no_tanh=False): + super().__init__() + self.observation_dim = observation_dim + self.action_dim = action_dim + self.arch = arch + self.orthogonal_init = orthogonal_init + self.no_tanh = no_tanh + + self.base_network = FullyConnectedNetwork( + observation_dim, 2 * action_dim, arch, orthogonal_init, is_LN=False, is_SN=False + ) + self.log_std_multiplier = Scalar(log_std_multiplier) + self.log_std_offset = Scalar(log_std_offset) + self.tanh_gaussian = ReparameterizedTanhGaussian(no_tanh=no_tanh) + + def log_prob(self, observations, actions): + if actions.ndim == 3: + observations = extend_and_repeat(observations, 1, actions.shape[1]) + base_network_output = self.base_network(observations) + mean, log_std = torch.split(base_network_output, self.action_dim, dim=-1) + log_std = self.log_std_multiplier() * log_std + self.log_std_offset() + return self.tanh_gaussian.log_prob(mean, log_std, actions) + + def forward(self, observations, deterministic=False, repeat=None): + if repeat is not None: + observations = extend_and_repeat(observations, 1, repeat) + assert torch.isnan(observations).sum() == 0, print(observations) + base_network_output = self.base_network(observations) + mean, log_std = torch.split(base_network_output, self.action_dim, dim=-1) + log_std = self.log_std_multiplier() * log_std + self.log_std_offset() + assert torch.isnan(mean).sum() == 0, print(mean) + assert torch.isnan(log_std).sum() == 0, print(log_std) + return self.tanh_gaussian(mean, log_std, deterministic) + + +class SamplerPolicy(object): + + def __init__(self, policy, device): + self.policy = policy + self.device = device + + def __call__(self, observations, deterministic=False): + with torch.no_grad(): + observations = torch.tensor( + observations, dtype=torch.float32, device=self.device + ) + actions, _ = self.policy(observations, deterministic) + actions = actions.cpu().numpy() + return actions + + +class FullyConnectedQFunction(nn.Module): + def __init__(self, observation_dim, action_dim, arch='256-256', orthogonal_init=False, is_LN=False, is_SN=False): + super().__init__() + self.observation_dim = observation_dim + self.action_dim = action_dim + self.arch = arch + self.orthogonal_init = orthogonal_init + self.network = FullyConnectedNetwork( + observation_dim + action_dim, 1, arch, orthogonal_init, is_LN=is_LN, is_SN=is_SN + ) + + @multiple_action_q_function + def forward(self, observations, actions, is_x_grad=False): + input_tensor = torch.cat([observations, actions], dim=-1) + output = self.network(input_tensor) + if is_x_grad: + input_tensor_grad = input_tensor.detach().requires_grad_(True) + output_grad = self.network(input_tensor_grad) + x_grad = autograd.grad(output_grad, input_tensor_grad, retain_graph=True, create_graph=True, + grad_outputs=torch.ones_like(output_grad))[0] + x_grad = torch.norm(x_grad, dim=1) + return torch.squeeze(output, dim=-1), x_grad + else: + return torch.squeeze(output, dim=-1) + + +class Scalar(nn.Module): + def __init__(self, init_value): + super().__init__() + self.constant = nn.Parameter( + torch.tensor(init_value, dtype=torch.float32) + ) + + def forward(self): + return self.constant + + +if __name__ == '__main__': + s = Scalar(10) + print(s) + diff --git a/Scripts/SDM/__pycache__/SDM.cpython-39.pyc b/Scripts/SDM/__pycache__/SDM.cpython-39.pyc index eadc4a9..ec1d1d6 100644 Binary files a/Scripts/SDM/__pycache__/SDM.cpython-39.pyc and b/Scripts/SDM/__pycache__/SDM.cpython-39.pyc differ diff --git a/Scripts/SDM/__pycache__/__init__.cpython-39.pyc b/Scripts/SDM/__pycache__/__init__.cpython-39.pyc index 92e4111..dc7ba00 100644 Binary files a/Scripts/SDM/__pycache__/__init__.cpython-39.pyc and b/Scripts/SDM/__pycache__/__init__.cpython-39.pyc differ diff --git a/Scripts/SDM/__pycache__/leaderUpdate.cpython-39.pyc b/Scripts/SDM/__pycache__/leaderUpdate.cpython-39.pyc index af5fb73..faf264a 100644 Binary files a/Scripts/SDM/__pycache__/leaderUpdate.cpython-39.pyc and b/Scripts/SDM/__pycache__/leaderUpdate.cpython-39.pyc differ diff --git a/Scripts/SDM/__pycache__/objective.cpython-39.pyc b/Scripts/SDM/__pycache__/objective.cpython-39.pyc index f7748e6..118fae4 100644 Binary files a/Scripts/SDM/__pycache__/objective.cpython-39.pyc and b/Scripts/SDM/__pycache__/objective.cpython-39.pyc differ diff --git a/Scripts/SDM/__pycache__/utils.cpython-39.pyc b/Scripts/SDM/__pycache__/utils.cpython-39.pyc index 36e89ae..b646ad3 100644 Binary files a/Scripts/SDM/__pycache__/utils.cpython-39.pyc and b/Scripts/SDM/__pycache__/utils.cpython-39.pyc differ diff --git a/Scripts/SimpleSAC/__pycache__/__init__.cpython-39.pyc b/Scripts/SimpleSAC/__pycache__/__init__.cpython-39.pyc index a9dab07..06e2f5a 100644 Binary files a/Scripts/SimpleSAC/__pycache__/__init__.cpython-39.pyc and b/Scripts/SimpleSAC/__pycache__/__init__.cpython-39.pyc differ diff --git a/Scripts/SimpleSAC/__pycache__/envs.cpython-39.pyc b/Scripts/SimpleSAC/__pycache__/envs.cpython-39.pyc index 0b46756..e8821c2 100644 Binary files a/Scripts/SimpleSAC/__pycache__/envs.cpython-39.pyc and b/Scripts/SimpleSAC/__pycache__/envs.cpython-39.pyc differ diff --git a/Scripts/SimpleSAC/__pycache__/replay_buffer.cpython-39.pyc b/Scripts/SimpleSAC/__pycache__/replay_buffer.cpython-39.pyc index c282d8f..9beda97 100644 Binary files a/Scripts/SimpleSAC/__pycache__/replay_buffer.cpython-39.pyc and b/Scripts/SimpleSAC/__pycache__/replay_buffer.cpython-39.pyc differ diff --git a/Scripts/SimpleSAC/__pycache__/sac.cpython-39.pyc b/Scripts/SimpleSAC/__pycache__/sac.cpython-39.pyc index f4236e7..d6d324f 100644 Binary files a/Scripts/SimpleSAC/__pycache__/sac.cpython-39.pyc and b/Scripts/SimpleSAC/__pycache__/sac.cpython-39.pyc differ diff --git a/Scripts/SimpleSAC/__pycache__/sampler.cpython-39.pyc b/Scripts/SimpleSAC/__pycache__/sampler.cpython-39.pyc index 2dfd3f8..1fca6f2 100644 Binary files a/Scripts/SimpleSAC/__pycache__/sampler.cpython-39.pyc and b/Scripts/SimpleSAC/__pycache__/sampler.cpython-39.pyc differ diff --git a/Scripts/SimpleSAC/ego_policy/__pycache__/__init__.cpython-39.pyc b/Scripts/SimpleSAC/ego_policy/__pycache__/__init__.cpython-39.pyc index 80e65a0..66e3b67 100644 Binary files a/Scripts/SimpleSAC/ego_policy/__pycache__/__init__.cpython-39.pyc and b/Scripts/SimpleSAC/ego_policy/__pycache__/__init__.cpython-39.pyc differ diff --git a/Scripts/SimpleSAC/ego_policy/__pycache__/fvdm.cpython-39.pyc b/Scripts/SimpleSAC/ego_policy/__pycache__/fvdm.cpython-39.pyc index ab11fb3..6623fed 100644 Binary files a/Scripts/SimpleSAC/ego_policy/__pycache__/fvdm.cpython-39.pyc and b/Scripts/SimpleSAC/ego_policy/__pycache__/fvdm.cpython-39.pyc differ diff --git a/Scripts/SimpleSAC/envs.py b/Scripts/SimpleSAC/envs.py index ca72426..86546cf 100644 --- a/Scripts/SimpleSAC/envs.py +++ b/Scripts/SimpleSAC/envs.py @@ -39,6 +39,9 @@ def __init__(self, realdata_path, num_agents = 1, dt = 0.04, sim_horizon = 200, self.ego_policy = ego_policy self.adv_policy = adv_policy + + self.up_cut = [0.6 * 9.8 * self.dt, np.pi / 3 * self.dt] + self.low_cut = [-0.8 * 9.8 * self.dt, -np.pi / 3 * self.dt] self.gui = gui if self.gui: @@ -115,15 +118,7 @@ def reset(self, ego_policy, adv_policy, idx = None): departLane=0, departPos=10.0, departSpeed=self.states[0][2]) - elif self.ego_policy == 'sumo': - traci.vehicle.add(vehID = 'car0', - routeID = 'straight', - typeID = 'AV', - depart = cur_time, - departLane = 0, - departPos = 10.0, - arrivalLane=np.random.randint(0, 3), - departSpeed=self.states[0][2]) + else: traci.vehicle.add(vehID="car0", @@ -231,6 +226,8 @@ def step(self, action_ego, action_adv): angle = -new_state[3] * 180 / np.pi + 90, lane = 0, edgeID = 0) + elif self.ego_policy == "sumo": + ... '''adversary vehicles''' @@ -258,6 +255,8 @@ def step(self, action_ego, action_adv): lane = 0, edgeID = 0) traci.vehicle.setSpeed(vehID = "car" + str(i + 1), speed = new_state[2]) + elif self.ego_policy == "sumo": + ... traci.simulationStep() # time.sleep(0.04) diff --git a/Scripts/SimpleSAC/models/__pycache__/__init__.cpython-39.pyc b/Scripts/SimpleSAC/models/__pycache__/__init__.cpython-39.pyc index d3ab60e..486f8f8 100644 Binary files a/Scripts/SimpleSAC/models/__pycache__/__init__.cpython-39.pyc and b/Scripts/SimpleSAC/models/__pycache__/__init__.cpython-39.pyc differ diff --git a/Scripts/SimpleSAC/models/__pycache__/model.cpython-39.pyc b/Scripts/SimpleSAC/models/__pycache__/model.cpython-39.pyc index 5144524..6b3bb22 100644 Binary files a/Scripts/SimpleSAC/models/__pycache__/model.cpython-39.pyc and b/Scripts/SimpleSAC/models/__pycache__/model.cpython-39.pyc differ diff --git a/Scripts/SimpleSAC/utils/__pycache__/__init__.cpython-39.pyc b/Scripts/SimpleSAC/utils/__pycache__/__init__.cpython-39.pyc index f9273a6..681aa35 100644 Binary files a/Scripts/SimpleSAC/utils/__pycache__/__init__.cpython-39.pyc and b/Scripts/SimpleSAC/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/Scripts/SimpleSAC/utils/__pycache__/car_dis_comput.cpython-39.pyc b/Scripts/SimpleSAC/utils/__pycache__/car_dis_comput.cpython-39.pyc index 3e8f5df..527e51d 100644 Binary files a/Scripts/SimpleSAC/utils/__pycache__/car_dis_comput.cpython-39.pyc and b/Scripts/SimpleSAC/utils/__pycache__/car_dis_comput.cpython-39.pyc differ diff --git a/Scripts/main_SDM_add_rb.py b/Scripts/main_SDM_add_rb.py new file mode 100644 index 0000000..b9b2524 --- /dev/null +++ b/Scripts/main_SDM_add_rb.py @@ -0,0 +1,514 @@ +import argparse +import numpy as np +from utils import define_flags_with_default, WandbLogger, get_user_flags, set_random_seed, Timer, prefix_metrics, Eval, Count_tensors +from datetime import datetime +from SimpleSAC.envs import Env +from SimpleSAC.sampler import StepSampler, TrajSampler +from SimpleSAC.replay_buffer import ReplayBuffer, GradReplayBuffer +from SimpleSAC.sac import SAC +from SimpleSAC.models.model import TanhGaussianPolicy, SamplerPolicy, FullyConnectedQFunction + +from SDM.SDM import SDM +from copy import deepcopy +import os +import absl.app +import absl.flags +import wandb +from viskit.logging import logger, setup_logger +from tqdm import trange +import ipdb +import torch + +parser = argparse.ArgumentParser() +parser.add_argument('--used_wandb', type = str, default = 'False', choices = ['True', 'False']) +parser.add_argument('--ego_policy', type = str, default = 'RL', choices = ['RL', 'uniform', 'sumo', 'fvdm']) +parser.add_argument('--adv_policy', type = str, default = 'sumo', choices = ['RL', 'uniform', 'sumo', 'fvdm']) +parser.add_argument('--num_agents', type = int, default = 5) +parser.add_argument('--r_ego', type = str, default = 'stackelberg', choices = ['r1', 'stackelberg']) +parser.add_argument('--r_adv', type = str, default = 'stackelberg3') +# parser.add_argument('--realdata_path', type = str, default = '../datasets/dataset/r3_dis_25_car_6') +parser.add_argument('--is_save', type = str, default = 'False', choices = ['True', 'False']) +parser.add_argument('--device', type = str, default = 'cuda:0') +parser.add_argument('--seed', type = int, default = 42) +parser.add_argument('--save_model', type=str, default="False", choices=["True", "False"]) +parser.add_argument('--n_adv_policy_update_gap', type = int, default = 5) +parser.add_argument('--n_ego_policy_update_gap', type = int, default = 1) +parser.add_argument('--model_name', type = str, default = 'SPG') +parser.add_argument('--gui', type = str, default = 'False', choices = ['True', 'False']) +parser.add_argument('--policy_arch', type = str, default = '256-256') +parser.add_argument('--qf_arch', type = str, default = '256-256') +parser.add_argument('--batch_size', type = int, default = 64) +parser.add_argument('--reg_scale', type = float, default = 0.2) +parser.add_argument('--n_epochs', type = int, default = 100) +parser.add_argument('--n_loops', type = int, default = 20) +parser.add_argument('--n_rollout_steps_per_epoch', type = int, default = 1000) +parser.add_argument('--n_train_step_per_epoch', type = int, default = 500) +parser.add_argument('--pretrain_ego', type = str, default = 'True', choices = ['True', 'False']) +parser.add_argument('--pretrain_loops', type = int, default = 2) +parser.add_argument('--pretrain_epochs', type = int, default = 100) +parser.add_argument('--pretrain_steps', type = int, default = 500) +parser.add_argument('--load_pretrain_ego', type = str, default = 'False', choices = ['True', 'False']) +parser.add_argument('--pretrain_ego_path', type = str, default = '') +parser.add_argument('--reset_rb', type = str, default = 'False', choices = ['True', 'False']) +parser.add_argument('--replay_buffer_size', type = int, default = 100000) +parser.add_argument('--pretrain_replay_buffer_size', type = int, default = 1000000) +parser.add_argument('--is_SN', type = str, default = 'True', choices = ['True', 'False']) +parser.add_argument('--is_LN', type = str, default = '') +parser.add_argument('--use_auto_alpha', default = 'False', type = str, choices = ['True', 'False']) +parser.add_argument('--backup_entropy', default = 'True', type = str, choices = ['True', 'False']) +parser.add_argument('--num_save', type = int, default = 5) +args = parser.parse_args() +args.is_save = True if args.is_save == 'True' else False + +args.used_wandb = True if args.used_wandb == 'True' else False +# ipdb.set_trace() +args.save_model = True if args.save_model == 'True' else False +args.gui = True if args.gui == 'True' else False +args.pretrain_ego = True if args.pretrain_ego == 'True' else False +args.load_pretrain_ego = True if args.load_pretrain_ego == 'True' else False +args.reset_rb = True if args.reset_rb == 'True' else False +args.is_SN = True if args.is_SN == 'True' else False +args.use_auto_alpha = True if args.use_auto_alpha == 'True' else False +args.backup_entropy = True if args.backup_entropy == 'True' else False + +realdata_paths = os.listdir('../datasets/dataset/') +def extract_last_digit(path): + # This function extracts the last digit from a string and returns it as an integer. + return int(path[-1]) + +# Sort the realdata_paths list based on the last digit in ascending order. +sorted_realdata_paths = sorted(realdata_paths, key=extract_last_digit) +realdata_path = os.path.join('../datasets/dataset', sorted_realdata_paths[args.num_agents - 1]) +# ipdb.set_trace() + + + +FLAGS_DEF = define_flags_with_default( + model_name = args.model_name, + used_wandb = args.used_wandb, + ego_policy = args.ego_policy, + adv_policy = args.adv_policy, + num_agents = args.num_agents, + reg_scale = args.reg_scale, + r_ego = args.r_ego, + r_adv = args.r_adv, + r_adv_replaybuffer = args.r_adv, + realdata_path = realdata_path, + is_save = args.is_save, + device = args.device, + seed = args.seed, + replay_buffer_size = args.replay_buffer_size, + pretrain_replay_buffer_size = args.pretrain_replay_buffer_size, + save_model = args.save_model, + current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S"), + replaybuffer_ratio = 10, + real_residual_ratio = 1.0, + dis_dropout = False, + max_traj_length = 100, + batch_size = args.batch_size, + reward_scale = 1.0, + reward_bias = 0.0, + clip_action = 1.0, + joint_noise_std = 0.0, + policy_arch = args.policy_arch, + qf_arch = args.qf_arch, + orthogonal_init = False, + policy_log_std_multiplier = 1.0, + policy_log_std_offset = -1.0, + # train and evaluate policy + n_epochs = args.n_epochs, + n_loops = args.n_loops, + bc_epochs = 0, + n_rollout_steps_per_epoch = args.n_rollout_steps_per_epoch, + n_train_step_per_epoch = args.n_train_step_per_epoch, + n_adv_policy_update_gap = args.n_adv_policy_update_gap, + n_ego_policy_update_gap = args.n_ego_policy_update_gap, + eval_period = 1, + eval_n_trajs = 20, + logging = WandbLogger.get_default_config(), + gui = args.gui, + pretrain_ego = args.pretrain_ego, + pretrain_epochs = args.pretrain_epochs, + pretrain_loops = args.pretrain_loops, + pretrain_steps = args.pretrain_steps, + load_pretrain_ego = args.load_pretrain_ego, + pretrain_ego_path = args.pretrain_ego_path, + reset_rb = args.reset_rb, + cql_ego = SAC.get_default_config(), + is_SN = args.is_SN, + is_LN = args.is_LN, + use_auto_alpha = args.use_auto_alpha, + backup_entropy = args.backup_entropy, + num_save = args.num_save +) + +def argparse(): + ... + +def get_tensors_on_gpu(device): + tensors_on_gpu = [] + for obj in dir(): + if isinstance(eval(obj), torch.Tensor): + if eval(obj).device == device: + tensors_on_gpu.append(obj) + return tensors_on_gpu + +def main(argv): + + # ipdb.set_trace() + FLAGS = absl.flags.FLAGS + if FLAGS.is_save: + eval_savepath = "output/" + \ + f"bv={FLAGS.num_agents}-{FLAGS.adv_policy}_" \ + f"r-ego={FLAGS.r_ego}_r-adv={FLAGS.r_adv}_" \ + f"seed={FLAGS.seed}_time={FLAGS.current_time}" \ + f"reg_scale={FLAGS.reg_scale}_" \ + f"pretrain_adv={FLAGS.adv_policy}" \ + + "/" + if not os.path.exists('output'): + os.makedirs('output') + if not os.path.exists(eval_savepath): + os.mkdir(eval_savepath) + os.mkdir(eval_savepath + "avcrash") + os.mkdir(eval_savepath + "bvcrash") + os.mkdir(eval_savepath + "avarrive") + os.mkdir(eval_savepath + "models") + else: + eval_savepath = 'None' + + if FLAGS.used_wandb: + variant = get_user_flags(FLAGS, FLAGS_DEF) + wandb_logger = WandbLogger(config=FLAGS.logging, variant=variant, seed = FLAGS.seed) + wandb.run.name = f"{FLAGS.model_name}" \ + f"_Pretrain_Train_Eval_{FLAGS.model_name}" \ + f"bv={FLAGS.num_agents}-{FLAGS.adv_policy}_" \ + f"r-ego={FLAGS.r_ego}_r-adv={FLAGS.r_adv}_" \ + f"reg_scale={FLAGS.reg_scale}_" \ + f"pretrain_adv={FLAGS.adv_policy}_" \ + f"pretrain_epochs={FLAGS.pretrain_epochs}_" \ + f"pretrain_steps={FLAGS.pretrain_steps}_" \ + f"pretrain_rb_size={FLAGS.pretrain_replay_buffer_size}_" \ + f"reset_rb={FLAGS.reset_rb}_" \ + f"n_adv_policy_update_gap={FLAGS.n_adv_policy_update_gap}_" \ + f"n_ego_policy_update_gap={FLAGS.n_ego_policy_update_gap}_" \ + f"is_SN={FLAGS.is_SN}_" \ + f"is_auto_alpha={FLAGS.use_auto_alpha}_" \ + f"backup_entropy={FLAGS.backup_entropy}_" \ + f"seed={FLAGS.seed}_time={FLAGS.current_time}" + setup_logger( + variant=variant, + exp_id=wandb_logger.experiment_id, + seed=FLAGS.seed, + base_log_dir=FLAGS.logging.output_dir, + include_exp_prefix_sub_dir=False + ) + + set_random_seed(FLAGS.seed) + # real_env = Env(realdata_path=FLAGS.realdata_path, num_agents=FLAGS.num_agents, sim_horizon=FLAGS.max_traj_length, + # ego_policy=FLAGS.ego_policy, adv_policy=FLAGS.adv_policy, + # r_ego=FLAGS.r_ego, r_adv=FLAGS.r_adv, sim_seed=FLAGS.seed, gui=FLAGS.gui) + env = Env(realdata_path=FLAGS.realdata_path, num_agents=FLAGS.num_agents, sim_horizon=FLAGS.max_traj_length, + ego_policy=FLAGS.ego_policy, adv_policy=FLAGS.adv_policy, + r_ego=FLAGS.r_ego, r_adv=FLAGS.r_adv, sim_seed=FLAGS.seed, gui=FLAGS.gui) + pretain_env = Env(realdata_path=FLAGS.realdata_path, num_agents=FLAGS.num_agents, sim_horizon=FLAGS.max_traj_length, + ego_policy=FLAGS.ego_policy, adv_policy=FLAGS.adv_policy, + r_ego = 'r1', r_adv = 'r3', sim_seed=FLAGS.seed, gui=FLAGS.gui) + pretrain_sampler = StepSampler(pretain_env, max_traj_length=FLAGS.max_traj_length) + train_sampler = StepSampler(env, max_traj_length=FLAGS.max_traj_length) + eval_sampler = TrajSampler(env, rootsavepath=eval_savepath, max_traj_length=FLAGS.max_traj_length) + + # replay buffer + num_state = env.state_space[0] + num_action_adv = env.action_space_adv[0] + num_action_ego = env.action_space_ego[0] + # ipdb.set_trace() + num_action = num_action_ego + num_action_adv + pretrain_replay_buffer = ReplayBuffer(num_state, num_action_ego, num_action_adv, FLAGS.pretrain_replay_buffer_size, device=FLAGS.device) + replay_buffer = GradReplayBuffer(num_state, num_action_ego, num_action_adv, FLAGS.replay_buffer_size, device=FLAGS.device) \ + if FLAGS.model_name == 'SPG' else ReplayBuffer(num_state, num_action_ego, num_action_adv, FLAGS.replay_buffer_size, device=FLAGS.device) + + # ipdb.set_trace() + + + + + ego_policy = TanhGaussianPolicy( + num_state, + num_action_ego, + arch=FLAGS.policy_arch, + log_std_multiplier=FLAGS.policy_log_std_multiplier, + log_std_offset=FLAGS.policy_log_std_offset, + orthogonal_init=FLAGS.orthogonal_init, + is_SN = False, + is_LN = FLAGS.is_LN + ) + qf1_ego = FullyConnectedQFunction( + num_state, + num_action_ego, + num_action_adv, + arch=FLAGS.qf_arch, + orthogonal_init=FLAGS.orthogonal_init, + is_LN=FLAGS.is_LN, + is_SN=False + ) + target_qf1_ego = deepcopy(qf1_ego) + qf2_ego = FullyConnectedQFunction( + num_state, + num_action_ego, + num_action_adv, + arch=FLAGS.qf_arch, + orthogonal_init=FLAGS.orthogonal_init, + is_LN=FLAGS.is_LN, + is_SN=False + ) + target_qf2_ego = deepcopy(qf2_ego) + sampler_ego_policy = SamplerPolicy(ego_policy, FLAGS.device) + + + + adv_policy = TanhGaussianPolicy( + num_state, + num_action_adv, + arch=FLAGS.policy_arch, + log_std_multiplier=FLAGS.policy_log_std_multiplier, + log_std_offset=FLAGS.policy_log_std_offset, + orthogonal_init=FLAGS.orthogonal_init, + is_LN=FLAGS.is_LN, + is_SN=FLAGS.is_SN + ) + qf1_adv = FullyConnectedQFunction( + num_state, + num_action_ego, + num_action_adv, + arch=FLAGS.qf_arch, + orthogonal_init=FLAGS.orthogonal_init, + is_LN=FLAGS.is_LN, + is_SN=FLAGS.is_SN + ) + target_qf1_adv = deepcopy(qf1_adv) + qf2_adv = FullyConnectedQFunction( + num_state, + num_action_ego, + num_action_adv, + arch=FLAGS.qf_arch, + orthogonal_init=FLAGS.orthogonal_init, + is_LN=FLAGS.is_LN, + is_SN=FLAGS.is_SN + ) + target_qf2_adv = deepcopy(qf1_adv) + sampler_adv_policy = SamplerPolicy(adv_policy, FLAGS.device) + + if FLAGS.model_name == 'SPG': + model = SDM(None, + ego_policy = ego_policy, + adv_policy = adv_policy, + qf1_ego = qf1_ego, + qf2_ego = qf2_ego, + target_qf1_ego = target_qf1_ego, + target_qf2_ego = target_qf2_ego, + qf1_adv = qf1_adv, + qf2_adv = qf2_adv, + target_qf1_adv = target_qf1_adv, + target_qf2_adv = target_qf2_adv, + device = FLAGS.device, + reg_scale = FLAGS.reg_scale, + use_automatic_entropy_tuning = FLAGS.use_auto_alpha, + backup_entropy = FLAGS.backup_entropy,) + else: + return + model.torch_to_device(FLAGS.device) + + if FLAGS.pretrain_ego: + if not FLAGS.load_pretrain_ego: + model_pre_ego = SAC( + FLAGS.cql_ego, + policy = ego_policy, + qf1 = qf1_ego, + qf2 = qf2_ego, + target_qf1 = target_qf1_ego, + target_qf2 = target_qf2_ego + ) + model_pre_ego.torch_to_device(FLAGS.device) + else: + model_pre_ego = torch.load(FLAGS.pretrain_ego_path) + + viskit_metrics = {} + + # TODO: Pretrain Ego Policy on fvdm BV using sac + # TODO: Check bv + # TODO: Check sac + if FLAGS.pretrain_ego: + pretrain_replay_buffer.reset() + sampler_pretrain_ego_policy = SamplerPolicy(model_pre_ego.policy, FLAGS.device) + sampler_pretrain_ego_policy.set_grad(False) + + for i in range(FLAGS.pretrain_loops): + + for epoch in trange(FLAGS.pretrain_epochs): + metrics = {} + pretrain_sampler.env.adv_policy = FLAGS.adv_policy # sumo + pretrain_sampler.env.ego_policy = 'RL' + # while True: + pretrain_sampler.sample( + ego_policy=sampler_pretrain_ego_policy, adv_policy=None, n_steps=FLAGS.n_rollout_steps_per_epoch, + deterministic=False, replay_buffer=pretrain_replay_buffer, + joint_noise_std=FLAGS.joint_noise_std + ) + metrics['epoch'] = epoch + for batch_idx in trange(FLAGS.pretrain_steps): + batch = pretrain_replay_buffer.sample(FLAGS.batch_size) + # if FLAGS.used_wandb: + # wandb_logger.log(prefix_metrics(model_pre_ego.train(batch), 'SAC_Pretrain')) + metrics.update(prefix_metrics(model_pre_ego.train(batch), 'SAC_Pretrain')) + + if FLAGS.used_wandb: + wandb_logger.log(metrics) + # eval + + if epoch == 0 or (epoch + 1) % FLAGS.eval_period == 0: + eval_ego_policy = 'RL' + eval_sampler.env.ego_policy = eval_ego_policy + eval_sampler.env.adv_policy = FLAGS.adv_policy + if adv_policy != 'RL': + s_a = None + else: + s_a = sampler_adv_policy + # ipdb.set_trace() + trajs, _ = eval_sampler.sample( + ego_policy=sampler_pretrain_ego_policy, adv_policy=s_a, + n_trajs=FLAGS.eval_n_trajs, deterministic=True + ) + # TODO: add speed + Eval(metrics, eval_ego_policy, FLAGS.adv_policy, trajs) + + if FLAGS.used_wandb: + wandb_logger.log(metrics) + + # pretrain_replay_buffer.reset() + + if FLAGS.save_model: + torch.save(model_pre_ego, os.path.join(eval_savepath, 'models', 'pretrain_ego.pth')) + if FLAGS.used_wandb: + # wandb_logger.log(metrics) + pre_save_data = {'model_pre_ego': model_pre_ego} + wandb_logger.save_pickle(pre_save_data, 'pre_model.pkl') + sampler_pretrain_ego_policy = deepcopy(sampler_pretrain_ego_policy) # freezing the pretrain policy + # return + # TODO: apply cross learning method + replay_buffer.reset() + # add data of pretrain replay buffer into replay buffer + actions = [(action_ago, action_adv) for action_ago, action_adv in \ + zip(torch.Tensor(pretrain_replay_buffer.action_ego), torch.Tensor(pretrain_replay_buffer.action_adv))] + for i in trange(pretrain_replay_buffer.size): + replay_buffer.append( + pretrain_replay_buffer.state[i], actions[i], + pretrain_replay_buffer.reward[i], pretrain_replay_buffer.next_state[i], + pretrain_replay_buffer.done[i] + ) + print('finished adding') + + if FLAGS.model_name == 'SPG': + sampler_ego_policy.set_grad(True) + sampler_adv_policy.set_grad(True) + freeze_ego = False + freeze_adv = False + # = 1 + print('training start') + # ipdb.set_trace() + for l in range(FLAGS.n_loops): + for epoch in trange(FLAGS.n_epochs): + + '''leader and follower''' + metrics = {} + + # metrics['epoch'] = epoch + # TODO: Train from the mixed data + with Timer() as train_timer: + train_sampler.env.adv_policy = "RL" + train_sampler.env.ego_policy = "RL" + if FLAGS.reset_rb: + replay_buffer.reset() + + train_sampler.sample( + ego_policy=sampler_ego_policy, adv_policy=sampler_adv_policy, n_steps=FLAGS.n_rollout_steps_per_epoch, + deterministic=False, replay_buffer=replay_buffer, + joint_noise_std=FLAGS.joint_noise_std + ) # Sample trajectories from the simulator using the current policy pi_1, pi_2 + + for batch_idx in trange(FLAGS.n_train_step_per_epoch): # at each step of a epoch: + + batch = replay_buffer.sample(FLAGS.batch_size) # Draw actions a_t^1, a_t^2 from their distributions pi_1, pi_2 + + # at the end of each step, train the policy and Q function + freeze_ego = not ((batch_idx % FLAGS.n_ego_policy_update_gap) == 0) + freeze_adv = not ((batch_idx % FLAGS.n_adv_policy_update_gap) == 0) + metrics.update(prefix_metrics(model.train(batch, freeze_ego = freeze_ego, freeze_adv = freeze_adv), FLAGS.model_name)) + if FLAGS.used_wandb: + wandb_logger.log(metrics) + + # TODO: Evaluate in the real world + with Timer() as eval_timer: + if epoch == 0 or (epoch + 1) % FLAGS.eval_period == 0: + # Eval ego policy + for adv_policy in ['sumo', 'fvdm', 'RL']: + # ipdb.set_trace() + eval_ego_policy = 'RL' + eval_sampler.env.ego_policy = eval_ego_policy + eval_sampler.env.adv_policy = adv_policy + if adv_policy != 'RL': + s_a = None + else: + s_a = sampler_adv_policy + # ipdb.set_trace() + trajs, _ = eval_sampler.sample( + ego_policy=sampler_ego_policy, adv_policy=s_a, + n_trajs=FLAGS.eval_n_trajs, deterministic=True + ) + Eval(metrics, eval_ego_policy, adv_policy, trajs) + + # Eval adv policy + for ego_policy in ['sumo', 'fvdm', 'RL']: # this RL ego is pretrained ego + eval_adv_policy = 'RL' + eval_sampler.env.ego_policy = ego_policy + eval_sampler.env.adv_policy = eval_adv_policy + if ego_policy != 'RL': + s_e = None + else: + s_e = sampler_pretrain_ego_policy + ego_policy = 'pretrainedRL' + trajs, _ = eval_sampler.sample( + ego_policy=s_e, adv_policy=sampler_adv_policy, + n_trajs=FLAGS.eval_n_trajs, deterministic=True + ) + Eval(metrics, ego_policy, eval_adv_policy, trajs) + if FLAGS.used_wandb: + wandb_logger.log(metrics) + + # metrics['rollout_time'] = rollout_timer() + metrics['train_time'] = train_timer() + metrics['eval_time'] = eval_timer() + metrics['epoch_time'] = train_timer() + eval_timer() + if FLAGS.used_wandb: + save_data = {FLAGS.model_name: model, + 'variant': variant, 'epoch': epoch} + wandb_logger.save_pickle(save_data, 'model.pkl') + viskit_metrics.update(metrics) + logger.record_dict(viskit_metrics) + logger.dump_tabular(with_prefix=False, with_timestamp=False) + + # save model for matric Eval + if FLAGS.save_model and l % (FLAGS.n_loops / FLAGS.num_save) == 0 or l == FLAGS.n_loops - 1: + # ipdb.set_trace() + torch.save(model, os.path.join(eval_savepath, 'models', f'loop_{l+1}.pth')) + + + if FLAGS.save_model and FLAGS.used_wandb: + save_data = {FLAGS.model_name: model, + 'variant': variant, 'epoch': epoch} + wandb_logger.save_pickle(save_data, 'model.pkl') + if FLAGS.save_model: + torch.save(model, os.path.join(eval_savepath, 'models', 'trained_model.pth')) + +if __name__ == '__main__': + absl.app.run(main) + diff --git a/Scripts/main_SDM_vs_re2h2o.py b/Scripts/main_SDM_vs_re2h2o.py new file mode 100644 index 0000000..1a3e1f5 --- /dev/null +++ b/Scripts/main_SDM_vs_re2h2o.py @@ -0,0 +1,523 @@ +import argparse +import numpy as np +from utils import define_flags_with_default, WandbLogger, get_user_flags, set_random_seed, Timer, prefix_metrics, Eval, Count_tensors +from datetime import datetime +from SimpleSAC.envs import Env +from SimpleSAC.sampler import StepSampler, TrajSampler +from SimpleSAC.replay_buffer import ReplayBuffer, GradReplayBuffer +from SimpleSAC.sac import SAC +from SimpleSAC.models.model import TanhGaussianPolicy, SamplerPolicy, FullyConnectedQFunction + +from SDM.SDM import SDM +from copy import deepcopy +import os +import absl.app +import absl.flags +import wandb +from viskit.logging import logger, setup_logger +from tqdm import trange +import ipdb +import torch + +parser = argparse.ArgumentParser() +parser.add_argument('--used_wandb', type = str, default = 'False', choices = ['True', 'False']) +parser.add_argument('--ego_policy', type = str, default = 'RL', choices = ['RL', 'uniform', 'sumo', 'fvdm']) +parser.add_argument('--adv_policy', type = str, default = 'sumo', choices = ['RL', 'uniform', 'sumo', 'fvdm']) +parser.add_argument('--num_agents', type = int, default = 5) +parser.add_argument('--r_ego', type = str, default = 'stackelberg', choices = ['r1', 'stackelberg']) +parser.add_argument('--r_adv', type = str, default = 'stackelberg3') +# parser.add_argument('--realdata_path', type = str, default = '../datasets/dataset/r3_dis_25_car_6') +parser.add_argument('--is_save', type = str, default = 'False', choices = ['True', 'False']) +parser.add_argument('--device', type = str, default = 'cuda:0') +parser.add_argument('--seed', type = int, default = 42) +parser.add_argument('--save_model', type=str, default="False", choices=["True", "False"]) +parser.add_argument('--n_adv_policy_update_gap', type = int, default = 5) +parser.add_argument('--n_ego_policy_update_gap', type = int, default = 1) +parser.add_argument('--model_name', type = str, default = 'SPG') +parser.add_argument('--gui', type = str, default = 'False', choices = ['True', 'False']) +parser.add_argument('--policy_arch', type = str, default = '256-256') +parser.add_argument('--qf_arch', type = str, default = '256-256') +parser.add_argument('--batch_size', type = int, default = 64) +parser.add_argument('--reg_scale', type = float, default = 0.2) +parser.add_argument('--n_epochs', type = int, default = 100) +parser.add_argument('--n_loops', type = int, default = 20) +parser.add_argument('--n_rollout_steps_per_epoch', type = int, default = 1000) +parser.add_argument('--n_train_step_per_epoch', type = int, default = 500) +parser.add_argument('--pretrain_ego', type = str, default = 'True', choices = ['True', 'False']) +parser.add_argument('--pretrain_loops', type = int, default = 2) +parser.add_argument('--pretrain_epochs', type = int, default = 100) +parser.add_argument('--pretrain_steps', type = int, default = 500) +parser.add_argument('--load_pretrain_ego', type = str, default = 'False', choices = ['True', 'False']) +parser.add_argument('--pretrain_ego_path', type = str, default = '') +parser.add_argument('--reset_rb', type = str, default = 'False', choices = ['True', 'False']) +parser.add_argument('--replay_buffer_size', type = int, default = 100000) +parser.add_argument('--pretrain_replay_buffer_size', type = int, default = 1000000) +parser.add_argument('--is_SN', type = str, default = 'True', choices = ['True', 'False']) +parser.add_argument('--is_LN', type = str, default = '') +parser.add_argument('--use_auto_alpha', default = 'False', type = str, choices = ['True', 'False']) +parser.add_argument('--backup_entropy', default = 'True', type = str, choices = ['True', 'False']) +parser.add_argument('--num_save', type = int, default = 5) +args = parser.parse_args() +args.is_save = True if args.is_save == 'True' else False + +args.used_wandb = True if args.used_wandb == 'True' else False +# ipdb.set_trace() +args.save_model = True if args.save_model == 'True' else False +args.gui = True if args.gui == 'True' else False +args.pretrain_ego = True if args.pretrain_ego == 'True' else False +args.load_pretrain_ego = True if args.load_pretrain_ego == 'True' else False +args.reset_rb = True if args.reset_rb == 'True' else False +args.is_SN = True if args.is_SN == 'True' else False +args.use_auto_alpha = True if args.use_auto_alpha == 'True' else False +args.backup_entropy = True if args.backup_entropy == 'True' else False + +realdata_paths = os.listdir('../datasets/dataset/') +def extract_last_digit(path): + # This function extracts the last digit from a string and returns it as an integer. + return int(path[-1]) + +# Sort the realdata_paths list based on the last digit in ascending order. +sorted_realdata_paths = sorted(realdata_paths, key=extract_last_digit) +realdata_path = os.path.join('../datasets/dataset', sorted_realdata_paths[args.num_agents - 1]) +# ipdb.set_trace() + + + +FLAGS_DEF = define_flags_with_default( + model_name = args.model_name, + used_wandb = args.used_wandb, + ego_policy = args.ego_policy, + adv_policy = args.adv_policy, + num_agents = args.num_agents, + reg_scale = args.reg_scale, + r_ego = args.r_ego, + r_adv = args.r_adv, + r_adv_replaybuffer = args.r_adv, + realdata_path = realdata_path, + is_save = args.is_save, + device = args.device, + seed = args.seed, + replay_buffer_size = args.replay_buffer_size, + pretrain_replay_buffer_size = args.pretrain_replay_buffer_size, + save_model = args.save_model, + current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S"), + replaybuffer_ratio = 10, + real_residual_ratio = 1.0, + dis_dropout = False, + max_traj_length = 100, + batch_size = args.batch_size, + reward_scale = 1.0, + reward_bias = 0.0, + clip_action = 1.0, + joint_noise_std = 0.0, + policy_arch = args.policy_arch, + qf_arch = args.qf_arch, + orthogonal_init = False, + policy_log_std_multiplier = 1.0, + policy_log_std_offset = -1.0, + # train and evaluate policy + n_epochs = args.n_epochs, + n_loops = args.n_loops, + bc_epochs = 0, + n_rollout_steps_per_epoch = args.n_rollout_steps_per_epoch, + n_train_step_per_epoch = args.n_train_step_per_epoch, + n_adv_policy_update_gap = args.n_adv_policy_update_gap, + n_ego_policy_update_gap = args.n_ego_policy_update_gap, + eval_period = 10, + eval_n_trajs = 20, + logging = WandbLogger.get_default_config(), + gui = args.gui, + pretrain_ego = args.pretrain_ego, + pretrain_epochs = args.pretrain_epochs, + pretrain_loops = args.pretrain_loops, + pretrain_steps = args.pretrain_steps, + load_pretrain_ego = args.load_pretrain_ego, + pretrain_ego_path = args.pretrain_ego_path, + reset_rb = args.reset_rb, + cql_ego = SAC.get_default_config(), + is_SN = args.is_SN, + is_LN = args.is_LN, + use_auto_alpha = args.use_auto_alpha, + backup_entropy = args.backup_entropy, + num_save = args.num_save +) + +def argparse(): + ... + +def get_tensors_on_gpu(device): + tensors_on_gpu = [] + for obj in dir(): + if isinstance(eval(obj), torch.Tensor): + if eval(obj).device == device: + tensors_on_gpu.append(obj) + return tensors_on_gpu + +def main(argv): + + # ipdb.set_trace() + FLAGS = absl.flags.FLAGS + if FLAGS.is_save: + eval_savepath = "output/" + \ + f"bv={FLAGS.num_agents}-{FLAGS.adv_policy}_" \ + f"r-ego={FLAGS.r_ego}_r-adv={FLAGS.r_adv}_" \ + f"seed={FLAGS.seed}_time={FLAGS.current_time}" \ + f"reg_scale={FLAGS.reg_scale}_" \ + f"pretrain_adv={FLAGS.adv_policy}" \ + + "/" + if not os.path.exists('output'): + os.makedirs('output') + if not os.path.exists(eval_savepath): + os.mkdir(eval_savepath) + os.mkdir(eval_savepath + "avcrash") + os.mkdir(eval_savepath + "bvcrash") + os.mkdir(eval_savepath + "avarrive") + os.mkdir(eval_savepath + "models") + else: + eval_savepath = 'None' + + if FLAGS.used_wandb: + variant = get_user_flags(FLAGS, FLAGS_DEF) + wandb_logger = WandbLogger(config=FLAGS.logging, variant=variant, seed = FLAGS.seed) + wandb.run.name = f"{FLAGS.model_name}" \ + f"_Pretrain_Train_Eval_{FLAGS.model_name}" \ + f"bv={FLAGS.num_agents}-{FLAGS.adv_policy}_" \ + f"r-ego={FLAGS.r_ego}_r-adv={FLAGS.r_adv}_" \ + f"reg_scale={FLAGS.reg_scale}_" \ + f"pretrain_adv={FLAGS.adv_policy}_" \ + f"pretrain_epochs={FLAGS.pretrain_epochs}_" \ + f"pretrain_steps={FLAGS.pretrain_steps}_" \ + f"pretrain_rb_size={FLAGS.pretrain_replay_buffer_size}_" \ + f"reset_rb={FLAGS.reset_rb}_" \ + f"n_adv_policy_update_gap={FLAGS.n_adv_policy_update_gap}_" \ + f"n_ego_policy_update_gap={FLAGS.n_ego_policy_update_gap}_" \ + f"is_SN={FLAGS.is_SN}_" \ + f"is_auto_alpha={FLAGS.use_auto_alpha}_" \ + f"backup_entropy={FLAGS.backup_entropy}_" \ + f"seed={FLAGS.seed}_time={FLAGS.current_time}" + setup_logger( + variant=variant, + exp_id=wandb_logger.experiment_id, + seed=FLAGS.seed, + base_log_dir=FLAGS.logging.output_dir, + include_exp_prefix_sub_dir=False + ) + + set_random_seed(FLAGS.seed) + # real_env = Env(realdata_path=FLAGS.realdata_path, num_agents=FLAGS.num_agents, sim_horizon=FLAGS.max_traj_length, + # ego_policy=FLAGS.ego_policy, adv_policy=FLAGS.adv_policy, + # r_ego=FLAGS.r_ego, r_adv=FLAGS.r_adv, sim_seed=FLAGS.seed, gui=FLAGS.gui) + env = Env(realdata_path=FLAGS.realdata_path, num_agents=FLAGS.num_agents, sim_horizon=FLAGS.max_traj_length, + ego_policy=FLAGS.ego_policy, adv_policy=FLAGS.adv_policy, + r_ego=FLAGS.r_ego, r_adv=FLAGS.r_adv, sim_seed=FLAGS.seed, gui=FLAGS.gui) + pretain_env = Env(realdata_path=FLAGS.realdata_path, num_agents=FLAGS.num_agents, sim_horizon=FLAGS.max_traj_length, + ego_policy=FLAGS.ego_policy, adv_policy=FLAGS.adv_policy, + r_ego = 'r1', r_adv = 'r3', sim_seed=FLAGS.seed, gui=FLAGS.gui) + pretrain_sampler = StepSampler(pretain_env, max_traj_length=FLAGS.max_traj_length) + train_sampler = StepSampler(env, max_traj_length=FLAGS.max_traj_length) + eval_sampler = TrajSampler(env, rootsavepath=eval_savepath, max_traj_length=FLAGS.max_traj_length) + + # replay buffer + num_state = env.state_space[0] + num_action_adv = env.action_space_adv[0] + num_action_ego = env.action_space_ego[0] + # ipdb.set_trace() + num_action = num_action_ego + num_action_adv + pretrain_replay_buffer = ReplayBuffer(num_state, num_action_ego, num_action_adv, FLAGS.pretrain_replay_buffer_size, device=FLAGS.device) + replay_buffer = GradReplayBuffer(num_state, num_action_ego, num_action_adv, FLAGS.replay_buffer_size, device=FLAGS.device) \ + if FLAGS.model_name == 'SPG' else ReplayBuffer(num_state, num_action_ego, num_action_adv, FLAGS.replay_buffer_size, device=FLAGS.device) + + # ipdb.set_trace() + + + + + ego_policy = TanhGaussianPolicy( + num_state, + num_action_ego, + arch=FLAGS.policy_arch, + log_std_multiplier=FLAGS.policy_log_std_multiplier, + log_std_offset=FLAGS.policy_log_std_offset, + orthogonal_init=FLAGS.orthogonal_init, + is_SN = False, + is_LN = FLAGS.is_LN + ) + qf1_ego = FullyConnectedQFunction( + num_state, + num_action_ego, + num_action_adv, + arch=FLAGS.qf_arch, + orthogonal_init=FLAGS.orthogonal_init, + is_LN=FLAGS.is_LN, + is_SN=False + ) + target_qf1_ego = deepcopy(qf1_ego) + qf2_ego = FullyConnectedQFunction( + num_state, + num_action_ego, + num_action_adv, + arch=FLAGS.qf_arch, + orthogonal_init=FLAGS.orthogonal_init, + is_LN=FLAGS.is_LN, + is_SN=False + ) + target_qf2_ego = deepcopy(qf2_ego) + sampler_ego_policy = SamplerPolicy(ego_policy, FLAGS.device) + + + + adv_policy = TanhGaussianPolicy( + num_state, + num_action_adv, + arch=FLAGS.policy_arch, + log_std_multiplier=FLAGS.policy_log_std_multiplier, + log_std_offset=FLAGS.policy_log_std_offset, + orthogonal_init=FLAGS.orthogonal_init, + is_LN=FLAGS.is_LN, + is_SN=FLAGS.is_SN + ) + qf1_adv = FullyConnectedQFunction( + num_state, + num_action_ego, + num_action_adv, + arch=FLAGS.qf_arch, + orthogonal_init=FLAGS.orthogonal_init, + is_LN=FLAGS.is_LN, + is_SN=FLAGS.is_SN + ) + target_qf1_adv = deepcopy(qf1_adv) + qf2_adv = FullyConnectedQFunction( + num_state, + num_action_ego, + num_action_adv, + arch=FLAGS.qf_arch, + orthogonal_init=FLAGS.orthogonal_init, + is_LN=FLAGS.is_LN, + is_SN=FLAGS.is_SN + ) + target_qf2_adv = deepcopy(qf1_adv) + sampler_adv_policy = SamplerPolicy(adv_policy, FLAGS.device) + + if FLAGS.model_name == 'SPG': + model = SDM(None, + ego_policy = ego_policy, + adv_policy = adv_policy, + qf1_ego = qf1_ego, + qf2_ego = qf2_ego, + target_qf1_ego = target_qf1_ego, + target_qf2_ego = target_qf2_ego, + qf1_adv = qf1_adv, + qf2_adv = qf2_adv, + target_qf1_adv = target_qf1_adv, + target_qf2_adv = target_qf2_adv, + device = FLAGS.device, + reg_scale = FLAGS.reg_scale, + use_automatic_entropy_tuning = FLAGS.use_auto_alpha, + backup_entropy = FLAGS.backup_entropy,) + else: + return + model.torch_to_device(FLAGS.device) + + if FLAGS.pretrain_ego: + if not FLAGS.load_pretrain_ego: + model_pre_ego = SAC( + FLAGS.cql_ego, + policy = ego_policy, + qf1 = qf1_ego, + qf2 = qf2_ego, + target_qf1 = target_qf1_ego, + target_qf2 = target_qf2_ego + ) + model_pre_ego.torch_to_device(FLAGS.device) + else: + model_pre_ego = torch.load(FLAGS.pretrain_ego_path) + + viskit_metrics = {} + + # TODO: Pretrain Ego Policy on fvdm BV using sac + # TODO: Check bv + # TODO: Check sac + if FLAGS.pretrain_ego: + pretrain_replay_buffer.reset() + sampler_pretrain_ego_policy = SamplerPolicy(model_pre_ego.policy, FLAGS.device) + sampler_pretrain_ego_policy.set_grad(False) + + for i in range(FLAGS.pretrain_loops): + + for epoch in trange(FLAGS.pretrain_epochs): + metrics = {} + pretrain_sampler.env.adv_policy = FLAGS.adv_policy # sumo + pretrain_sampler.env.ego_policy = 'RL' + # while True: + pretrain_sampler.sample( + ego_policy=sampler_pretrain_ego_policy, adv_policy=None, n_steps=FLAGS.n_rollout_steps_per_epoch, + deterministic=False, replay_buffer=pretrain_replay_buffer, + joint_noise_std=FLAGS.joint_noise_std + ) + metrics['epoch'] = epoch + for batch_idx in trange(FLAGS.pretrain_steps): + batch = pretrain_replay_buffer.sample(FLAGS.batch_size) + # if FLAGS.used_wandb: + # wandb_logger.log(prefix_metrics(model_pre_ego.train(batch), 'SAC_Pretrain')) + metrics.update(prefix_metrics(model_pre_ego.train(batch), 'SAC_Pretrain')) + + if FLAGS.used_wandb: + wandb_logger.log(metrics) + # eval + + if epoch == 0 or (epoch + 1) % FLAGS.eval_period == 0: + eval_ego_policy = 'RL' + eval_sampler.env.ego_policy = eval_ego_policy + eval_sampler.env.adv_policy = FLAGS.adv_policy + if adv_policy != 'RL': + s_a = None + else: + s_a = sampler_adv_policy + # ipdb.set_trace() + trajs, _ = eval_sampler.sample( + ego_policy=sampler_pretrain_ego_policy, adv_policy=s_a, + n_trajs=FLAGS.eval_n_trajs, deterministic=True + ) + # TODO: add speed + Eval(metrics, eval_ego_policy, FLAGS.adv_policy, trajs) + + if FLAGS.used_wandb: + wandb_logger.log(metrics) + + # pretrain_replay_buffer.reset() + + if FLAGS.save_model: + torch.save(model_pre_ego, os.path.join(eval_savepath, 'models', 'pretrain_ego.pth')) + if FLAGS.used_wandb: + # wandb_logger.log(metrics) + pre_save_data = {'model_pre_ego': model_pre_ego} + wandb_logger.save_pickle(pre_save_data, 'pre_model.pkl') + sampler_pretrain_ego_policy = deepcopy(sampler_pretrain_ego_policy) # freezing the pretrain policy + # return + # TODO: apply cross learning method + replay_buffer.reset() + if FLAGS.model_name == 'SPG': + sampler_ego_policy.set_grad(True) + sampler_adv_policy.set_grad(True) + freeze_ego = False + freeze_adv = False + # = 1 + # ipdb.set_trace() + + # load trained re2h2o model + map_location = { + 'cuda:0': FLAGS.device, + 'cuda:1': FLAGS.device, + 'cuda:2': FLAGS.device, + 'cuda:3': FLAGS.device + } + # model_adv_re2h2o_policy = torch.load('models_re2h2o_bv//BV0_bv=1.pkl', map_location=map_location) + model_adv_re2h2o_policy = torch.load(f'models_re2h2o_bv//BV0_bv={FLAGS.num_agents}.pkl', map_location=map_location) + sampler_adv_re2h2o_policy = SamplerPolicy(model_adv_re2h2o_policy, device=FLAGS.device) + + for l in range(FLAGS.n_loops): + for epoch in trange(FLAGS.n_epochs): + + '''leader and follower''' + metrics = {} + + # metrics['epoch'] = epoch + # TODO: Train from the mixed data + with Timer() as train_timer: + train_sampler.env.adv_policy = "RL" + train_sampler.env.ego_policy = "RL" + if FLAGS.reset_rb: + replay_buffer.reset() + + train_sampler.sample( + ego_policy=sampler_ego_policy, adv_policy=sampler_adv_policy, n_steps=FLAGS.n_rollout_steps_per_epoch, + deterministic=False, replay_buffer=replay_buffer, + joint_noise_std=FLAGS.joint_noise_std + ) # Sample trajectories from the simulator using the current policy pi_1, pi_2 + + for batch_idx in trange(FLAGS.n_train_step_per_epoch): # at each step of a epoch: + + batch = replay_buffer.sample(FLAGS.batch_size) # Draw actions a_t^1, a_t^2 from their distributions pi_1, pi_2 + + # at the end of each step, train the policy and Q function + freeze_ego = not ((batch_idx % FLAGS.n_ego_policy_update_gap) == 0) + freeze_adv = not ((batch_idx % FLAGS.n_adv_policy_update_gap) == 0) + metrics.update(prefix_metrics(model.train(batch, freeze_ego = freeze_ego, freeze_adv = freeze_adv), FLAGS.model_name)) + if FLAGS.used_wandb: + wandb_logger.log(metrics) + + # TODO: Evaluate in the real world + with Timer() as eval_timer: + if epoch == 0 or (epoch + 1) % FLAGS.eval_period == 0: + # Eval ego policy + for adv_policy in ['sumo', 'fvdm', 'RL', 're2h2o']: + # ipdb.set_trace() + eval_ego_policy = 'RL' + eval_sampler.env.ego_policy = eval_ego_policy + if adv_policy == 're2h2o': + eval_sampler.env.adv_policy = 'RL' + else: + eval_sampler.env.adv_policy = adv_policy + if adv_policy == 'RL': + s_a = sampler_adv_policy + elif adv_policy == 're2h2o': + s_a = sampler_adv_re2h2o_policy + else: + s_a = None + # ipdb.set_trace() + trajs, _ = eval_sampler.sample( + ego_policy=sampler_ego_policy, adv_policy=s_a, + n_trajs=FLAGS.eval_n_trajs, deterministic=True + ) + Eval(metrics, eval_ego_policy, adv_policy, trajs) + + # Eval adv policy + for ego_policy in ['sumo', 'fvdm', 'RL']: # this RL ego is pretrained ego + eval_adv_policy = 'RL' + eval_sampler.env.ego_policy = ego_policy + eval_sampler.env.adv_policy = eval_adv_policy + if ego_policy != 'RL': + s_e = None + else: + s_e = sampler_pretrain_ego_policy + ego_policy = 'pretrainedRL' + trajs, _ = eval_sampler.sample( + ego_policy=s_e, adv_policy=sampler_adv_policy, + n_trajs=FLAGS.eval_n_trajs, deterministic=True + ) + Eval(metrics, ego_policy, eval_adv_policy, trajs) + if FLAGS.used_wandb: + wandb_logger.log(metrics) + + # metrics['rollout_time'] = rollout_timer() + metrics['train_time'] = train_timer() + metrics['eval_time'] = eval_timer() + metrics['epoch_time'] = train_timer() + eval_timer() + if FLAGS.used_wandb: + save_data = {FLAGS.model_name: model, + 'variant': variant, 'epoch': epoch} + wandb_logger.save_pickle(save_data, 'model.pkl') + viskit_metrics.update(metrics) + logger.record_dict(viskit_metrics) + logger.dump_tabular(with_prefix=False, with_timestamp=False) + + # save model for matric Eval + if FLAGS.save_model and l % (FLAGS.n_loops / FLAGS.num_save) == 0 or l == FLAGS.n_loops - 1: + # ipdb.set_trace() + torch.save(model, os.path.join(eval_savepath, 'models', f'loop_{l+1}.pth')) + + + if FLAGS.save_model and FLAGS.used_wandb: + save_data = {FLAGS.model_name: model, + 'variant': variant, 'epoch': epoch} + wandb_logger.save_pickle(save_data, 'model.pkl') + if FLAGS.save_model: + torch.save(model, os.path.join(eval_savepath, 'models', 'trained_model.pth')) + +if __name__ == '__main__': + absl.app.run(main) + + + + + diff --git a/Scripts/main_pretrainRL_vs_re2h2o.py b/Scripts/main_pretrainRL_vs_re2h2o.py new file mode 100644 index 0000000..a1a446c --- /dev/null +++ b/Scripts/main_pretrainRL_vs_re2h2o.py @@ -0,0 +1,467 @@ +import argparse +import numpy as np +from utils import define_flags_with_default, WandbLogger, get_user_flags, set_random_seed, Timer, prefix_metrics, Eval, Count_tensors +from datetime import datetime +from SimpleSAC.envs import Env +from SimpleSAC.sampler import StepSampler, TrajSampler +from SimpleSAC.replay_buffer import ReplayBuffer, GradReplayBuffer +from SimpleSAC.sac import SAC +from SimpleSAC.models.model import TanhGaussianPolicy, SamplerPolicy, FullyConnectedQFunction + +from SDM.SDM import SDM +from copy import deepcopy +import os +import absl.app +import absl.flags +import wandb +from viskit.logging import logger, setup_logger +from tqdm import trange +import ipdb +import torch + +parser = argparse.ArgumentParser() +parser.add_argument('--used_wandb', type = str, default = 'False', choices = ['True', 'False']) +parser.add_argument('--ego_policy', type = str, default = 'RL', choices = ['RL', 'uniform', 'sumo', 'fvdm']) +parser.add_argument('--adv_policy', type = str, default = 'sumo', choices = ['RL', 'uniform', 'sumo', 'fvdm']) +parser.add_argument('--num_agents', type = int, default = 5) +parser.add_argument('--r_ego', type = str, default = 'stackelberg', choices = ['r1', 'stackelberg']) +parser.add_argument('--r_adv', type = str, default = 'stackelberg3') +# parser.add_argument('--realdata_path', type = str, default = '../datasets/dataset/r3_dis_25_car_6') +parser.add_argument('--is_save', type = str, default = 'False', choices = ['True', 'False']) +parser.add_argument('--device', type = str, default = 'cuda:0') +parser.add_argument('--seed', type = int, default = 42) +parser.add_argument('--save_model', type=str, default="False", choices=["True", "False"]) +parser.add_argument('--n_adv_policy_update_gap', type = int, default = 5) +parser.add_argument('--n_ego_policy_update_gap', type = int, default = 1) +parser.add_argument('--model_name', type = str, default = 'SPG') +parser.add_argument('--gui', type = str, default = 'False', choices = ['True', 'False']) +parser.add_argument('--policy_arch', type = str, default = '256-256') +parser.add_argument('--qf_arch', type = str, default = '256-256') +parser.add_argument('--batch_size', type = int, default = 64) +parser.add_argument('--reg_scale', type = float, default = 0.2) +parser.add_argument('--n_epochs', type = int, default = 100) +parser.add_argument('--n_loops', type = int, default = 20) +parser.add_argument('--n_rollout_steps_per_epoch', type = int, default = 1000) +parser.add_argument('--n_train_step_per_epoch', type = int, default = 500) +parser.add_argument('--pretrain_ego', type = str, default = 'True', choices = ['True', 'False']) +parser.add_argument('--pretrain_loops', type = int, default = 2) +parser.add_argument('--pretrain_epochs', type = int, default = 100) +parser.add_argument('--pretrain_steps', type = int, default = 500) +parser.add_argument('--load_pretrain_ego', type = str, default = 'False', choices = ['True', 'False']) +parser.add_argument('--pretrain_ego_path', type = str, default = '') +parser.add_argument('--reset_rb', type = str, default = 'False', choices = ['True', 'False']) +parser.add_argument('--replay_buffer_size', type = int, default = 100000) +parser.add_argument('--pretrain_replay_buffer_size', type = int, default = 1000000) +parser.add_argument('--is_SN', type = str, default = 'True', choices = ['True', 'False']) +parser.add_argument('--is_LN', type = str, default = '') +parser.add_argument('--use_auto_alpha', default = 'False', type = str, choices = ['True', 'False']) +parser.add_argument('--backup_entropy', default = 'True', type = str, choices = ['True', 'False']) +parser.add_argument('--num_save', type = int, default = 5) +args = parser.parse_args() +args.is_save = True if args.is_save == 'True' else False + +args.used_wandb = True if args.used_wandb == 'True' else False +# ipdb.set_trace() +args.save_model = True if args.save_model == 'True' else False +args.gui = True if args.gui == 'True' else False +args.pretrain_ego = True if args.pretrain_ego == 'True' else False +args.load_pretrain_ego = True if args.load_pretrain_ego == 'True' else False +args.reset_rb = True if args.reset_rb == 'True' else False +args.is_SN = True if args.is_SN == 'True' else False +args.use_auto_alpha = True if args.use_auto_alpha == 'True' else False +args.backup_entropy = True if args.backup_entropy == 'True' else False + +realdata_paths = os.listdir('../datasets/dataset/') +def extract_last_digit(path): + # This function extracts the last digit from a string and returns it as an integer. + return int(path[-1]) + +# Sort the realdata_paths list based on the last digit in ascending order. +sorted_realdata_paths = sorted(realdata_paths, key=extract_last_digit) +realdata_path = os.path.join('../datasets/dataset', sorted_realdata_paths[args.num_agents - 1]) +# ipdb.set_trace() + + + +FLAGS_DEF = define_flags_with_default( + model_name = args.model_name, + used_wandb = args.used_wandb, + ego_policy = args.ego_policy, + adv_policy = args.adv_policy, + num_agents = args.num_agents, + reg_scale = args.reg_scale, + r_ego = args.r_ego, + r_adv = args.r_adv, + r_adv_replaybuffer = args.r_adv, + realdata_path = realdata_path, + is_save = args.is_save, + device = args.device, + seed = args.seed, + replay_buffer_size = args.replay_buffer_size, + pretrain_replay_buffer_size = args.pretrain_replay_buffer_size, + save_model = args.save_model, + current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S"), + replaybuffer_ratio = 10, + real_residual_ratio = 1.0, + dis_dropout = False, + max_traj_length = 100, + batch_size = args.batch_size, + reward_scale = 1.0, + reward_bias = 0.0, + clip_action = 1.0, + joint_noise_std = 0.0, + policy_arch = args.policy_arch, + qf_arch = args.qf_arch, + orthogonal_init = False, + policy_log_std_multiplier = 1.0, + policy_log_std_offset = -1.0, + # train and evaluate policy + n_epochs = args.n_epochs, + n_loops = args.n_loops, + bc_epochs = 0, + n_rollout_steps_per_epoch = args.n_rollout_steps_per_epoch, + n_train_step_per_epoch = args.n_train_step_per_epoch, + n_adv_policy_update_gap = args.n_adv_policy_update_gap, + n_ego_policy_update_gap = args.n_ego_policy_update_gap, + eval_period = 10, + eval_n_trajs = 20, + logging = WandbLogger.get_default_config(), + gui = args.gui, + pretrain_ego = args.pretrain_ego, + pretrain_epochs = args.pretrain_epochs, + pretrain_loops = args.pretrain_loops, + pretrain_steps = args.pretrain_steps, + load_pretrain_ego = args.load_pretrain_ego, + pretrain_ego_path = args.pretrain_ego_path, + reset_rb = args.reset_rb, + cql_ego = SAC.get_default_config(), + is_SN = args.is_SN, + is_LN = args.is_LN, + use_auto_alpha = args.use_auto_alpha, + backup_entropy = args.backup_entropy, + num_save = args.num_save +) + +def argparse(): + ... + +def get_tensors_on_gpu(device): + tensors_on_gpu = [] + for obj in dir(): + if isinstance(eval(obj), torch.Tensor): + if eval(obj).device == device: + tensors_on_gpu.append(obj) + return tensors_on_gpu + +def main(argv): + + # ipdb.set_trace() + FLAGS = absl.flags.FLAGS + if FLAGS.is_save: + eval_savepath = "output/" + \ + f"bv={FLAGS.num_agents}-{FLAGS.adv_policy}_" \ + f"r-ego={FLAGS.r_ego}_r-adv={FLAGS.r_adv}_" \ + f"seed={FLAGS.seed}_time={FLAGS.current_time}" \ + f"reg_scale={FLAGS.reg_scale}_" \ + f"pretrain_adv={FLAGS.adv_policy}" \ + + "/" + if not os.path.exists('output'): + os.makedirs('output') + if not os.path.exists(eval_savepath): + os.mkdir(eval_savepath) + os.mkdir(eval_savepath + "avcrash") + os.mkdir(eval_savepath + "bvcrash") + os.mkdir(eval_savepath + "avarrive") + os.mkdir(eval_savepath + "models") + else: + eval_savepath = 'None' + + if FLAGS.used_wandb: + variant = get_user_flags(FLAGS, FLAGS_DEF) + wandb_logger = WandbLogger(config=FLAGS.logging, variant=variant, seed = FLAGS.seed) + wandb.run.name = f"{FLAGS.model_name}" \ + f"_Pretrain_Train_Eval_{FLAGS.model_name}" \ + f"bv={FLAGS.num_agents}-{FLAGS.adv_policy}_" \ + f"r-ego={FLAGS.r_ego}_r-adv={FLAGS.r_adv}_" \ + f"reg_scale={FLAGS.reg_scale}_" \ + f"pretrain_adv={FLAGS.adv_policy}_" \ + f"pretrain_epochs={FLAGS.pretrain_epochs}_" \ + f"pretrain_steps={FLAGS.pretrain_steps}_" \ + f"pretrain_rb_size={FLAGS.pretrain_replay_buffer_size}_" \ + f"reset_rb={FLAGS.reset_rb}_" \ + f"n_adv_policy_update_gap={FLAGS.n_adv_policy_update_gap}_" \ + f"n_ego_policy_update_gap={FLAGS.n_ego_policy_update_gap}_" \ + f"is_SN={FLAGS.is_SN}_" \ + f"is_auto_alpha={FLAGS.use_auto_alpha}_" \ + f"backup_entropy={FLAGS.backup_entropy}_" \ + f"seed={FLAGS.seed}_time={FLAGS.current_time}" + setup_logger( + variant=variant, + exp_id=wandb_logger.experiment_id, + seed=FLAGS.seed, + base_log_dir=FLAGS.logging.output_dir, + include_exp_prefix_sub_dir=False + ) + + set_random_seed(FLAGS.seed) + # real_env = Env(realdata_path=FLAGS.realdata_path, num_agents=FLAGS.num_agents, sim_horizon=FLAGS.max_traj_length, + # ego_policy=FLAGS.ego_policy, adv_policy=FLAGS.adv_policy, + # r_ego=FLAGS.r_ego, r_adv=FLAGS.r_adv, sim_seed=FLAGS.seed, gui=FLAGS.gui) + env = Env(realdata_path=FLAGS.realdata_path, num_agents=FLAGS.num_agents, sim_horizon=FLAGS.max_traj_length, + ego_policy=FLAGS.ego_policy, adv_policy=FLAGS.adv_policy, + r_ego=FLAGS.r_ego, r_adv=FLAGS.r_adv, sim_seed=FLAGS.seed, gui=FLAGS.gui) + pretain_env = Env(realdata_path=FLAGS.realdata_path, num_agents=FLAGS.num_agents, sim_horizon=FLAGS.max_traj_length, + ego_policy=FLAGS.ego_policy, adv_policy=FLAGS.adv_policy, + r_ego = 'r1', r_adv = 'r3', sim_seed=FLAGS.seed, gui=FLAGS.gui) + pretrain_sampler = StepSampler(pretain_env, max_traj_length=FLAGS.max_traj_length) + train_sampler = StepSampler(env, max_traj_length=FLAGS.max_traj_length) + eval_sampler = TrajSampler(env, rootsavepath=eval_savepath, max_traj_length=FLAGS.max_traj_length) + + # replay buffer + num_state = env.state_space[0] + num_action_adv = env.action_space_adv[0] + num_action_ego = env.action_space_ego[0] + # ipdb.set_trace() + num_action = num_action_ego + num_action_adv + pretrain_replay_buffer = ReplayBuffer(num_state, num_action_ego, num_action_adv, FLAGS.pretrain_replay_buffer_size, device=FLAGS.device) + replay_buffer = GradReplayBuffer(num_state, num_action_ego, num_action_adv, FLAGS.replay_buffer_size, device=FLAGS.device) \ + if FLAGS.model_name == 'SPG' else ReplayBuffer(num_state, num_action_ego, num_action_adv, FLAGS.replay_buffer_size, device=FLAGS.device) + + # ipdb.set_trace() + + + + + ego_policy = TanhGaussianPolicy( + num_state, + num_action_ego, + arch=FLAGS.policy_arch, + log_std_multiplier=FLAGS.policy_log_std_multiplier, + log_std_offset=FLAGS.policy_log_std_offset, + orthogonal_init=FLAGS.orthogonal_init, + is_SN = False, + is_LN = FLAGS.is_LN + ) + qf1_ego = FullyConnectedQFunction( + num_state, + num_action_ego, + num_action_adv, + arch=FLAGS.qf_arch, + orthogonal_init=FLAGS.orthogonal_init, + is_LN=FLAGS.is_LN, + is_SN=False + ) + target_qf1_ego = deepcopy(qf1_ego) + qf2_ego = FullyConnectedQFunction( + num_state, + num_action_ego, + num_action_adv, + arch=FLAGS.qf_arch, + orthogonal_init=FLAGS.orthogonal_init, + is_LN=FLAGS.is_LN, + is_SN=False + ) + target_qf2_ego = deepcopy(qf2_ego) + sampler_ego_policy = SamplerPolicy(ego_policy, FLAGS.device) + + + + adv_policy = TanhGaussianPolicy( + num_state, + num_action_adv, + arch=FLAGS.policy_arch, + log_std_multiplier=FLAGS.policy_log_std_multiplier, + log_std_offset=FLAGS.policy_log_std_offset, + orthogonal_init=FLAGS.orthogonal_init, + is_LN=FLAGS.is_LN, + is_SN=FLAGS.is_SN + ) + qf1_adv = FullyConnectedQFunction( + num_state, + num_action_ego, + num_action_adv, + arch=FLAGS.qf_arch, + orthogonal_init=FLAGS.orthogonal_init, + is_LN=FLAGS.is_LN, + is_SN=FLAGS.is_SN + ) + target_qf1_adv = deepcopy(qf1_adv) + qf2_adv = FullyConnectedQFunction( + num_state, + num_action_ego, + num_action_adv, + arch=FLAGS.qf_arch, + orthogonal_init=FLAGS.orthogonal_init, + is_LN=FLAGS.is_LN, + is_SN=FLAGS.is_SN + ) + target_qf2_adv = deepcopy(qf1_adv) + sampler_adv_policy = SamplerPolicy(adv_policy, FLAGS.device) + + if FLAGS.model_name == 'SPG': + model = SDM(None, + ego_policy = ego_policy, + adv_policy = adv_policy, + qf1_ego = qf1_ego, + qf2_ego = qf2_ego, + target_qf1_ego = target_qf1_ego, + target_qf2_ego = target_qf2_ego, + qf1_adv = qf1_adv, + qf2_adv = qf2_adv, + target_qf1_adv = target_qf1_adv, + target_qf2_adv = target_qf2_adv, + device = FLAGS.device, + reg_scale = FLAGS.reg_scale, + use_automatic_entropy_tuning = FLAGS.use_auto_alpha, + backup_entropy = FLAGS.backup_entropy,) + else: + return + model.torch_to_device(FLAGS.device) + + if FLAGS.pretrain_ego: + if not FLAGS.load_pretrain_ego: + model_pre_ego = SAC( + FLAGS.cql_ego, + policy = ego_policy, + qf1 = qf1_ego, + qf2 = qf2_ego, + target_qf1 = target_qf1_ego, + target_qf2 = target_qf2_ego + ) + model_pre_ego.torch_to_device(FLAGS.device) + else: + model_pre_ego = torch.load(FLAGS.pretrain_ego_path) + + viskit_metrics = {} + + # TODO: Pretrain Ego Policy on fvdm BV using sac + # TODO: Check bv + # TODO: Check sac + if FLAGS.pretrain_ego: + pretrain_replay_buffer.reset() + sampler_pretrain_ego_policy = SamplerPolicy(model_pre_ego.policy, FLAGS.device) + sampler_pretrain_ego_policy.set_grad(False) + + for i in range(FLAGS.pretrain_loops): + + for epoch in trange(FLAGS.pretrain_epochs): + metrics = {} + pretrain_sampler.env.adv_policy = FLAGS.adv_policy # sumo + pretrain_sampler.env.ego_policy = 'RL' + # while True: + pretrain_sampler.sample( + ego_policy=sampler_pretrain_ego_policy, adv_policy=None, n_steps=FLAGS.n_rollout_steps_per_epoch, + deterministic=False, replay_buffer=pretrain_replay_buffer, + joint_noise_std=FLAGS.joint_noise_std + ) + metrics['epoch'] = epoch + for batch_idx in trange(FLAGS.pretrain_steps): + batch = pretrain_replay_buffer.sample(FLAGS.batch_size) + # if FLAGS.used_wandb: + # wandb_logger.log(prefix_metrics(model_pre_ego.train(batch), 'SAC_Pretrain')) + metrics.update(prefix_metrics(model_pre_ego.train(batch), 'SAC_Pretrain')) + + if FLAGS.used_wandb: + wandb_logger.log(metrics) + # eval + + if epoch == 0 or (epoch + 1) % FLAGS.eval_period == 0: + eval_ego_policy = 'RL' + eval_sampler.env.ego_policy = eval_ego_policy + eval_sampler.env.adv_policy = FLAGS.adv_policy + if adv_policy != 'RL': + s_a = None + else: + s_a = sampler_adv_policy + # ipdb.set_trace() + trajs, _ = eval_sampler.sample( + ego_policy=sampler_pretrain_ego_policy, adv_policy=s_a, + n_trajs=FLAGS.eval_n_trajs, deterministic=True + ) + # TODO: add speed + Eval(metrics, eval_ego_policy, FLAGS.adv_policy, trajs) + + if FLAGS.used_wandb: + wandb_logger.log(metrics) + + # pretrain_replay_buffer.reset() + + if FLAGS.save_model: + torch.save(model_pre_ego, os.path.join(eval_savepath, 'models', 'pretrain_ego.pth')) + if FLAGS.used_wandb: + # wandb_logger.log(metrics) + pre_save_data = {'model_pre_ego': model_pre_ego} + wandb_logger.save_pickle(pre_save_data, 'pre_model.pkl') + sampler_pretrain_ego_policy = deepcopy(sampler_pretrain_ego_policy) # freezing the pretrain policy + # return + # = 1 + # ipdb.set_trace() + + # load trained re2h2o model + map_location = { + 'cuda:0': FLAGS.device, + 'cuda:1': FLAGS.device, + 'cuda:2': FLAGS.device, + 'cuda:3': FLAGS.device + } + # model_adv_re2h2o_policy = torch.load('models_re2h2o_bv//BV0_bv=1.pkl', map_location=map_location) + model_adv_re2h2o_policy = torch.load(f'models_re2h2o_bv//BV0_bv={FLAGS.num_agents}.pkl', map_location=map_location) + sampler_adv_re2h2o_policy = SamplerPolicy(model_adv_re2h2o_policy, device=FLAGS.device) + + for l in range(FLAGS.n_loops): + for epoch in trange(FLAGS.n_epochs): + + '''leader and follower''' + metrics = {} + + # TODO: Evaluate in the real world + with Timer() as eval_timer: + if epoch == 0 or (epoch + 1) % FLAGS.eval_period == 0: + # Eval + eval_sampler.env.ego_policy = 'RL' + eval_sampler.env.adv_policy = 'RL' + s_e = sampler_pretrain_ego_policy + ego_policy = 'pretrainedRL' + adv_policy = 're2h2o' + s_a = sampler_adv_re2h2o_policy + trajs, _ = eval_sampler.sample( + ego_policy=s_e, adv_policy=s_a, + n_trajs=FLAGS.eval_n_trajs, deterministic=True + ) + Eval(metrics, ego_policy, adv_policy, trajs) + + if FLAGS.used_wandb: + wandb_logger.log(metrics) + + # metrics['rollout_time'] = rollout_timer() + # metrics['train_time'] = train_timer() + metrics['eval_time'] = eval_timer() + # metrics['epoch_time'] = train_timer() + eval_timer() + if FLAGS.used_wandb: + save_data = {FLAGS.model_name: model, + 'variant': variant, 'epoch': epoch} + wandb_logger.save_pickle(save_data, 'model.pkl') + viskit_metrics.update(metrics) + logger.record_dict(viskit_metrics) + logger.dump_tabular(with_prefix=False, with_timestamp=False) + + # save model for matric Eval + # if FLAGS.save_model and l % (FLAGS.n_loops / FLAGS.num_save) == 0 or l == FLAGS.n_loops - 1: + # # ipdb.set_trace() + # torch.save(model, os.path.join(eval_savepath, 'models', f'loop_{l+1}.pth')) + + + if FLAGS.save_model and FLAGS.used_wandb: + save_data = {FLAGS.model_name: model, + 'variant': variant, 'epoch': epoch} + wandb_logger.save_pickle(save_data, 'model.pkl') + # if FLAGS.save_model: + # torch.save(model, os.path.join(eval_savepath, 'models', 'trained_model.pth')) + +if __name__ == '__main__': + absl.app.run(main) + + + + + diff --git a/Scripts/main_sumo_vs_re2h2o.py b/Scripts/main_sumo_vs_re2h2o.py new file mode 100644 index 0000000..604489c --- /dev/null +++ b/Scripts/main_sumo_vs_re2h2o.py @@ -0,0 +1,127 @@ +import argparse +import numpy as np +import absl.app +import absl.flags +import torch +import wandb +import os + +from datetime import datetime + +from copy import deepcopy +from tqdm import trange +from viskit.logging import logger, setup_logger +from utils import define_flags_with_default, WandbLogger, get_user_flags, set_random_seed, Timer, prefix_metrics, Eval + + +from SDM.SDM import SDM +from SimpleSAC.sac import SAC +from SimpleSAC.envs import Env +from SimpleSAC.sampler import StepSampler, TrajSampler +from SimpleSAC.replay_buffer import ReplayBuffer, GradReplayBuffer +from SimpleSAC.models.model import TanhGaussianPolicy, SamplerPolicy, FullyConnectedQFunction + + +parser = argparse.ArgumentParser() +parser.add_argument('--used_wandb', type=str, default='False', choices=['True', 'False']) +parser.add_argument('--device', type=str, default='cuda:0') +parser.add_argument('--seed', type=int, default=42) +parser.add_argument('--model_name', type=str, default='sumo_vs_re2h2o') +parser.add_argument('--num_agents', type=int, default=5) +parser.add_argument('--ego_policy', type=str, default='sumo', choices=['sumo']) +parser.add_argument('--adv_policy', type=str, default='re2h2o', choices=['re2h2o', 'sumo']) +parser.add_argument('--r_ego', type = str, default = 'stackelberg', choices = ['r1', 'stackelberg']) +parser.add_argument('--r_adv', type = str, default = 'stackelberg3') +parser.add_argument('--n_epochs', type = int, default = 100) +parser.add_argument('--n_loops', type = int, default = 20) + +args = parser.parse_args() +args.used_wandb = True if args.used_wandb == 'True' else False + +realdata_paths = os.listdir('../datasets/dataset/') +def extract_last_digit(path): + # This function extracts the last digit from a string and returns it as an integer. + return int(path[-1]) + +# Sort the realdata_paths list based on the last digit in ascending order. +sorted_realdata_paths = sorted(realdata_paths, key=extract_last_digit) +realdata_path = os.path.join('../datasets/dataset', sorted_realdata_paths[args.num_agents - 1]) + + +FLAGS_DEF = define_flags_with_default( + model_name=args.model_name, + used_wandb=args.used_wandb, + device=args.device, + seed=args.seed, + num_agents=args.num_agents, + ego_policy=args.ego_policy, + adv_policy=args.adv_policy, + r_ego = args.r_ego, + r_adv = args.r_adv, + realdata_path=realdata_path, + logging = WandbLogger.get_default_config(), + n_epochs = args.n_epochs, + n_loops = args.n_loops, + max_traj_length=100, + eval_n_trajs=20, + current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") +) + + +def main(argv): + FLAGS = absl.flags.FLAGS + + if FLAGS.used_wandb: + variant = get_user_flags(FLAGS, FLAGS_DEF) + wandb_logger = WandbLogger(config=FLAGS.logging, variant=variant, seed = FLAGS.seed) + wandb.run.name = f"{FLAGS.model_name}" \ + f"bv={FLAGS.num_agents}-{FLAGS.adv_policy}_" \ + f"seed={FLAGS.seed}_time={FLAGS.current_time}" + setup_logger( + variant=variant, + exp_id=wandb_logger.experiment_id, + seed=FLAGS.seed, + base_log_dir=FLAGS.logging.output_dir, + include_exp_prefix_sub_dir=False + ) + + set_random_seed(FLAGS.seed) + env = Env(realdata_path=FLAGS.realdata_path, num_agents=FLAGS.num_agents, sim_horizon=FLAGS.max_traj_length, + ego_policy=FLAGS.ego_policy, adv_policy=FLAGS.adv_policy, + r_ego=FLAGS.r_ego, r_adv=FLAGS.r_adv, sim_seed=FLAGS.seed, gui=False) + eval_sampler = TrajSampler(env, rootsavepath='None', max_traj_length=FLAGS.max_traj_length) + + # load trained re2h2o model + map_location = { + 'cuda:0': FLAGS.device, + 'cuda:1': FLAGS.device, + 'cuda:2': FLAGS.device, + 'cuda:3': FLAGS.device + } + model_adv_re2h2o_policy = torch.load('models_re2h2o_bv//BV0_bv=5.pkl', map_location=map_location) + sampler_adv_re2h2o_policy = SamplerPolicy(model_adv_re2h2o_policy, device=FLAGS.device) + + for l in range(FLAGS.n_loops): + for epoch in trange(FLAGS.n_epochs): + metrics = {} + + with Timer() as eval_timer: + ego_policy = FLAGS.ego_policy + adv_policy = FLAGS.adv_policy + eval_sampler.env.ego_policy = ego_policy + eval_sampler.env.adv_policy = 'RL' + s_a = sampler_adv_re2h2o_policy + s_e = None + trajs, _ = eval_sampler.sample( + ego_policy=s_e, adv_policy=s_a, + n_trajs=FLAGS.eval_n_trajs, deterministic=True + ) + Eval(metrics, ego_policy, adv_policy, trajs) + if FLAGS.used_wandb: + wandb_logger.log(metrics) + + metrics['eval_time'] = eval_timer() + metrics['epoch_time'] = eval_timer() + +if __name__ == '__main__': + absl.app.run(main) \ No newline at end of file diff --git a/Scripts/models_re2h2o_bv/BV0_bv=1.pkl b/Scripts/models_re2h2o_bv/BV0_bv=1.pkl new file mode 100644 index 0000000..57030d0 Binary files /dev/null and b/Scripts/models_re2h2o_bv/BV0_bv=1.pkl differ diff --git a/Scripts/models_re2h2o_bv/BV0_bv=2.pkl b/Scripts/models_re2h2o_bv/BV0_bv=2.pkl new file mode 100644 index 0000000..388a091 Binary files /dev/null and b/Scripts/models_re2h2o_bv/BV0_bv=2.pkl differ diff --git a/Scripts/models_re2h2o_bv/BV0_bv=3.pkl b/Scripts/models_re2h2o_bv/BV0_bv=3.pkl new file mode 100644 index 0000000..b122897 Binary files /dev/null and b/Scripts/models_re2h2o_bv/BV0_bv=3.pkl differ diff --git a/Scripts/models_re2h2o_bv/BV0_bv=4.pkl b/Scripts/models_re2h2o_bv/BV0_bv=4.pkl new file mode 100644 index 0000000..5fb565c Binary files /dev/null and b/Scripts/models_re2h2o_bv/BV0_bv=4.pkl differ diff --git a/Scripts/models_re2h2o_bv/BV0_bv=5.pkl b/Scripts/models_re2h2o_bv/BV0_bv=5.pkl new file mode 100644 index 0000000..6a83f53 Binary files /dev/null and b/Scripts/models_re2h2o_bv/BV0_bv=5.pkl differ diff --git a/Scripts/models_re2h2o_bv/BV0_bv=6.pkl b/Scripts/models_re2h2o_bv/BV0_bv=6.pkl new file mode 100644 index 0000000..f73db60 Binary files /dev/null and b/Scripts/models_re2h2o_bv/BV0_bv=6.pkl differ diff --git a/Scripts/models_re2h2o_bv/model_bv_loop0.pkl b/Scripts/models_re2h2o_bv/model_bv_loop0.pkl new file mode 100644 index 0000000..6a83f53 Binary files /dev/null and b/Scripts/models_re2h2o_bv/model_bv_loop0.pkl differ diff --git a/Scripts/utils.py b/Scripts/utils.py index 479fa3c..7411a11 100644 --- a/Scripts/utils.py +++ b/Scripts/utils.py @@ -48,7 +48,7 @@ def get_default_config(updates=None): config.online = True config.prefix = '' config.project = 'SDM' - config.entity = 'ml_cat' + config.entity = 'lyy1912696485' config.output_dir = './experiment_output' config.random_delay = 0.0 config.experiment_id = config_dict.placeholder(str) diff --git a/Scripts/viskit/__pycache__/__init__.cpython-39.pyc b/Scripts/viskit/__pycache__/__init__.cpython-39.pyc index 69736f8..87adbf3 100644 Binary files a/Scripts/viskit/__pycache__/__init__.cpython-39.pyc and b/Scripts/viskit/__pycache__/__init__.cpython-39.pyc differ diff --git a/Scripts/viskit/__pycache__/logging.cpython-39.pyc b/Scripts/viskit/__pycache__/logging.cpython-39.pyc index 96def04..0284b16 100644 Binary files a/Scripts/viskit/__pycache__/logging.cpython-39.pyc and b/Scripts/viskit/__pycache__/logging.cpython-39.pyc differ diff --git a/Scripts/viskit/__pycache__/tabulate.cpython-39.pyc b/Scripts/viskit/__pycache__/tabulate.cpython-39.pyc index 3278b0a..b8ee6e3 100644 Binary files a/Scripts/viskit/__pycache__/tabulate.cpython-39.pyc and b/Scripts/viskit/__pycache__/tabulate.cpython-39.pyc differ