-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathbackend.py
More file actions
31 lines (22 loc) · 929 Bytes
/
Copy pathbackend.py
File metadata and controls
31 lines (22 loc) · 929 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import tensorflow as tf
def createFunc_helper():
return ActorFunc()
class NNBackend:
def __init__(self):
self._func = None
def runAction(self):
return;
def runOptimizer(self):
return;
class TFBackend(NNBackend):
def __init__(self, name):
tf.set_random_seed(0)
def init(self):
tf_config = tf.ConfigProto(inter_op_parallelism_threads=1, intra_op_parallelism_threads=1)
self._sess = tf.Session(config=tf_config)
self._sess.__enter__() # equivalent to `with sess:`
tf.global_variables_initializer().run() #pylint: disable=E1101
def runAction(self, ops, sy_ob_no, params):
return self._sess.run(ops, feed_dict={sy_ob_no: params})
def runOptimizer(self, updateops, sy_ob_no, sy_ac_na, sy_adv_n, ob_no, ac_na, adv_n):
return self._sess.run(updateops, feed_dict={sy_ob_no: ob_no, sy_ac_na: ac_na, sy_adv_n: adv_n})