-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
85 lines (74 loc) · 2.72 KB
/
utils.py
File metadata and controls
85 lines (74 loc) · 2.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import argparse
import os
__all__ = [
'MyArgumentParser'
]
def _set_gpu_preallocation(mode: float):
"""GPU memory allocation.
If preallocation is enabled, this makes JAX preallocate ``percent`` of the total GPU memory,
instead of the default 75%. Lowering the amount preallocated can fix OOMs that occur when the JAX program starts.
"""
assert isinstance(mode, float) and 0. <= mode < 1., (
f'GPU memory preallocation must '
f'be in [0., 1.]. But got {mode}.'
)
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = str(mode)
def _set_gpu_device(device_ids):
if isinstance(device_ids, int):
device_ids = str(device_ids)
elif isinstance(device_ids, (tuple, list)):
device_ids = ','.join([str(d) for d in device_ids])
elif isinstance(device_ids, str):
if device_ids == 'none':
device_ids = ''
else:
raise ValueError
os.environ['CUDA_VISIBLE_DEVICES'] = device_ids
class MyArgumentParser(argparse.ArgumentParser):
def __init__(self, *args, gpu_pre_allocate=0.99, **kwargs):
super().__init__(*args, **kwargs)
self.add_argument(
'--devices',
type=str,
default='0',
help='The GPU device ids.'
)
self.add_argument(
"--method",
type=str,
default='bptt',
help="Training method."
)
args, _ = self.parse_known_args()
# device management
_set_gpu_device(args.devices)
_set_gpu_preallocation(gpu_pre_allocate)
# training method
if args.method != 'bptt':
self.add_argument(
"--vjp_method",
type=str,
default='multi-step',
choices=['multi-step', 'single-step'],
)
if args.method != 'diag':
self.add_argument(
"--etrace_decay",
type=float,
default=0.9,
help="The time constant of eligibility trace "
)