-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathwrappers.py
More file actions
90 lines (70 loc) · 3.47 KB
/
wrappers.py
File metadata and controls
90 lines (70 loc) · 3.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gymnasium
"""
This class implements a wrapper for a standard gymnasium environment, it has two main purposes:
> Fix the environments with only one cost function, returned as float in the info dictionary, in this
case we force the standard form for our approach, i.e., an array of one element for each cost function;
(e.g., info['cost]=0 => info['cost]=[0]).
> It also implements a monitor to avoid problems with the normalization of the cost and reward functions,
it takes track of the non-normalized cost function, useful for the computation of the lambda loss but
also for performance monitoring (the real value is stored in the 'cost_monitor' list).
"""
class MultiCostWrapper( gymnasium.Wrapper ):
"""
Initialization of the wrapper, basically it is a copy of the standard environment
where we overrite the basic function (e.g., reset and step) to perform additional
operations.
"""
def __init__ ( self, env ):
super().__init__(env)
self.reward_monitor = None
self.cost_monitor = None
"""
Override the 'reset' function to reset internal monitoring variables; notice that we must return
the same value returned by the original 'reset' function.
"""
def reset( self, **kwargs ):
# First recall the 'true' reset function of the environment
obs, info = self.env.reset( **kwargs )
# Fix for single cost and initialization of the cost-monitor; this init
# is performed only once after each reset to match the number of cost functions
# of the enviornment. Fix also for the environments without
# the cost (i.e., classical gymnasium like 'CartPole')
if not "cost" in info.keys(): info["cost"] = [0.0]
if type(info["cost"]) is not list: info["cost"] = [info["cost"]]
# Fix also for collision and goal-reached
if not "goal-reached" in info.keys(): info["goal-reached"] = False
if not "collision" in info.keys(): info["collision"] = False
# Reset of the monitoring variables and return the 'true' results
self.reward_monitor = 0
self.cost_monitor = None
return obs, info
"""
Override the 'step' function to update the internal monitors; this method performs also
the 'single cost' fix (i.e., info['cost]=0 => info['cost]=[0]).
"""
def step( self, action):
# First recall the 'true' step function of the environment
# storing the results
obs, reward, terminated, truncated, info = self.env.step(action)
# Fix for single cost and initialization of the cost-monitor; this init
# is performed only once after each reset to match the number of cost functions
# of the enviornment. Fix also for the environments without
# the cost (i.e., classical gymnasium like 'CartPole')
if not "cost" in info.keys(): info["cost"] = [0.0]
if type(info["cost"]) is not list: info["cost"] = [info["cost"]]
if self.cost_monitor is None: self.cost_monitor = [0.0 for _ in info["cost"]]
# Fix also for collision and goal-reached
if not "goal-reached" in info.keys(): info["goal-reached"] = False
if not "collision" in info.keys(): info["collision"] = False
# Update of cost and reward monitors; notice that this operation must be perfomed
# before the normalization steps!
self.reward_monitor += reward
for idx, c in enumerate(info["cost"]): self.cost_monitor[idx] += c
# In the last episode, we returned the results of the monitor as part of
# the info dictionary
info['custom_info'] = {
'tot_reward': self.reward_monitor,
'tot_costs': self.cost_monitor
}
#
return obs, reward, terminated, truncated, info