Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
230 changes: 230 additions & 0 deletions Scripts/H2O/SimpleSAC/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions.transforms import TanhTransform
from torch.distributions import Normal
import math
import torch.autograd as autograd
from torch.nn.utils import spectral_norm
# from copy import deepcopy


def extend_and_repeat(tensor, dim, repeat):
# Extend and repeast the tensor along dim axie and repeat it
ones_shape = [1 for _ in range(tensor.ndim + 1)]
ones_shape[dim] = repeat
return torch.unsqueeze(tensor, dim) * tensor.new_ones(ones_shape)


def soft_target_update(network, target_network, soft_target_update_rate):
target_network_params = {k: v for k, v in target_network.named_parameters()}
for k, v in network.named_parameters():
target_network_params[k].data = (
(1 - soft_target_update_rate) * target_network_params[k].data
+ soft_target_update_rate * v.data
)


def multiple_action_q_function(forward):
# Forward the q function with multiple actions on each state, to be used as a decorator
def wrapped(self, observations, actions, is_x_grad=False, **kwargs):
multiple_actions = False
batch_size = observations.shape[0]
if actions.ndim == 3 and observations.ndim == 2:
multiple_actions = True
observations = extend_and_repeat(observations, 1, actions.shape[1]).reshape(-1, observations.shape[-1])
actions = actions.reshape(-1, actions.shape[-1])
q_values = forward(self, observations, actions, is_x_grad, **kwargs)
if multiple_actions:
q_values = q_values.reshape(batch_size, -1)
return q_values
return wrapped


class FullyConnectedNetwork(nn.Module):
def __init__(self, input_dim, output_dim, arch='256-256', orthogonal_init=False, is_LN=False, is_SN=False):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.arch = arch
self.orthogonal_init = orthogonal_init

d = input_dim
modules = []
hidden_sizes = [int(h) for h in arch.split('-')]

for hidden_size in hidden_sizes:
fc = nn.Linear(d, hidden_size)
if orthogonal_init:
nn.init.orthogonal_(fc.weight, gain=np.sqrt(2))
nn.init.constant_(fc.bias, 0.0)
if is_SN:
modules.append(spectral_norm(fc))
else:
modules.append(fc)
modules.append(nn.ReLU())

if is_LN:
ln = nn.LayerNorm(hidden_size)
modules.append(ln)

d = hidden_size

last_fc = nn.Linear(d, output_dim)
if orthogonal_init:
nn.init.orthogonal_(last_fc.weight, gain=1e-2)
else:
nn.init.xavier_uniform_(last_fc.weight, gain=1e-2)
nn.init.constant_(last_fc.bias, 0.0)

if is_SN:
modules.append(spectral_norm(last_fc))
else:
modules.append(last_fc)

self.network = nn.Sequential(*modules)

def forward(self, input_tensor):
return self.network(input_tensor)


class ReparameterizedTanhGaussian(nn.Module):

def __init__(self, log_std_min=-20.0, log_std_max=2.0, no_tanh=False):
super().__init__()
self.log_std_min = log_std_min
self.log_std_max = log_std_max
self.no_tanh = no_tanh

def log_prob(self, mean, log_std, sample):
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
std = torch.exp(log_std)
if self.no_tanh:
action_distribution = Normal(mean, std)
else:
action_distribution = TransformedDistribution(
Normal(mean, std), TanhTransform(cache_size=1)
)
return torch.sum(action_distribution.log_prob(sample), dim=-1)

def forward(self, mean, log_std, deterministic=False):
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
std = torch.exp(log_std)

if self.no_tanh:
action_distribution = Normal(mean, std)
else:
action_distribution = TransformedDistribution(
Normal(mean, std), TanhTransform(cache_size=1)
)

if deterministic:
action_sample = torch.tanh(mean)
else:
action_sample = action_distribution.rsample()

log_prob = torch.sum(
action_distribution.log_prob(action_sample), dim=-1
)

return action_sample, log_prob


class TanhGaussianPolicy(nn.Module):

def __init__(self, observation_dim, action_dim, arch='256-256',
log_std_multiplier=1.0, log_std_offset=-1.0,
orthogonal_init=False, no_tanh=False):
super().__init__()
self.observation_dim = observation_dim
self.action_dim = action_dim
self.arch = arch
self.orthogonal_init = orthogonal_init
self.no_tanh = no_tanh

self.base_network = FullyConnectedNetwork(
observation_dim, 2 * action_dim, arch, orthogonal_init, is_LN=False, is_SN=False
)
self.log_std_multiplier = Scalar(log_std_multiplier)
self.log_std_offset = Scalar(log_std_offset)
self.tanh_gaussian = ReparameterizedTanhGaussian(no_tanh=no_tanh)

def log_prob(self, observations, actions):
if actions.ndim == 3:
observations = extend_and_repeat(observations, 1, actions.shape[1])
base_network_output = self.base_network(observations)
mean, log_std = torch.split(base_network_output, self.action_dim, dim=-1)
log_std = self.log_std_multiplier() * log_std + self.log_std_offset()
return self.tanh_gaussian.log_prob(mean, log_std, actions)

def forward(self, observations, deterministic=False, repeat=None):
if repeat is not None:
observations = extend_and_repeat(observations, 1, repeat)
assert torch.isnan(observations).sum() == 0, print(observations)
base_network_output = self.base_network(observations)
mean, log_std = torch.split(base_network_output, self.action_dim, dim=-1)
log_std = self.log_std_multiplier() * log_std + self.log_std_offset()
assert torch.isnan(mean).sum() == 0, print(mean)
assert torch.isnan(log_std).sum() == 0, print(log_std)
return self.tanh_gaussian(mean, log_std, deterministic)


class SamplerPolicy(object):

def __init__(self, policy, device):
self.policy = policy
self.device = device

def __call__(self, observations, deterministic=False):
with torch.no_grad():
observations = torch.tensor(
observations, dtype=torch.float32, device=self.device
)
actions, _ = self.policy(observations, deterministic)
actions = actions.cpu().numpy()
return actions


class FullyConnectedQFunction(nn.Module):
def __init__(self, observation_dim, action_dim, arch='256-256', orthogonal_init=False, is_LN=False, is_SN=False):
super().__init__()
self.observation_dim = observation_dim
self.action_dim = action_dim
self.arch = arch
self.orthogonal_init = orthogonal_init
self.network = FullyConnectedNetwork(
observation_dim + action_dim, 1, arch, orthogonal_init, is_LN=is_LN, is_SN=is_SN
)

@multiple_action_q_function
def forward(self, observations, actions, is_x_grad=False):
input_tensor = torch.cat([observations, actions], dim=-1)
output = self.network(input_tensor)
if is_x_grad:
input_tensor_grad = input_tensor.detach().requires_grad_(True)
output_grad = self.network(input_tensor_grad)
x_grad = autograd.grad(output_grad, input_tensor_grad, retain_graph=True, create_graph=True,
grad_outputs=torch.ones_like(output_grad))[0]
x_grad = torch.norm(x_grad, dim=1)
return torch.squeeze(output, dim=-1), x_grad
else:
return torch.squeeze(output, dim=-1)


class Scalar(nn.Module):
def __init__(self, init_value):
super().__init__()
self.constant = nn.Parameter(
torch.tensor(init_value, dtype=torch.float32)
)

def forward(self):
return self.constant


if __name__ == '__main__':
s = Scalar(10)
print(s)

Binary file modified Scripts/SDM/__pycache__/SDM.cpython-39.pyc
Binary file not shown.
Binary file modified Scripts/SDM/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file modified Scripts/SDM/__pycache__/leaderUpdate.cpython-39.pyc
Binary file not shown.
Binary file modified Scripts/SDM/__pycache__/objective.cpython-39.pyc
Binary file not shown.
Binary file modified Scripts/SDM/__pycache__/utils.cpython-39.pyc
Binary file not shown.
Binary file modified Scripts/SimpleSAC/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file modified Scripts/SimpleSAC/__pycache__/envs.cpython-39.pyc
Binary file not shown.
Binary file modified Scripts/SimpleSAC/__pycache__/replay_buffer.cpython-39.pyc
Binary file not shown.
Binary file modified Scripts/SimpleSAC/__pycache__/sac.cpython-39.pyc
Binary file not shown.
Binary file modified Scripts/SimpleSAC/__pycache__/sampler.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file modified Scripts/SimpleSAC/ego_policy/__pycache__/fvdm.cpython-39.pyc
Binary file not shown.
17 changes: 8 additions & 9 deletions Scripts/SimpleSAC/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def __init__(self, realdata_path, num_agents = 1, dt = 0.04, sim_horizon = 200,

self.ego_policy = ego_policy
self.adv_policy = adv_policy

self.up_cut = [0.6 * 9.8 * self.dt, np.pi / 3 * self.dt]
self.low_cut = [-0.8 * 9.8 * self.dt, -np.pi / 3 * self.dt]

self.gui = gui
if self.gui:
Expand Down Expand Up @@ -115,15 +118,7 @@ def reset(self, ego_policy, adv_policy, idx = None):
departLane=0,
departPos=10.0,
departSpeed=self.states[0][2])
elif self.ego_policy == 'sumo':
traci.vehicle.add(vehID = 'car0',
routeID = 'straight',
typeID = 'AV',
depart = cur_time,
departLane = 0,
departPos = 10.0,
arrivalLane=np.random.randint(0, 3),
departSpeed=self.states[0][2])


else:
traci.vehicle.add(vehID="car0",
Expand Down Expand Up @@ -231,6 +226,8 @@ def step(self, action_ego, action_adv):
angle = -new_state[3] * 180 / np.pi + 90,
lane = 0,
edgeID = 0)
elif self.ego_policy == "sumo":
...


'''adversary vehicles'''
Expand Down Expand Up @@ -258,6 +255,8 @@ def step(self, action_ego, action_adv):
lane = 0,
edgeID = 0)
traci.vehicle.setSpeed(vehID = "car" + str(i + 1), speed = new_state[2])
elif self.ego_policy == "sumo":
...

traci.simulationStep()
# time.sleep(0.04)
Expand Down
Binary file modified Scripts/SimpleSAC/models/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file modified Scripts/SimpleSAC/models/__pycache__/model.cpython-39.pyc
Binary file not shown.
Binary file modified Scripts/SimpleSAC/utils/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Loading