Skip to content

Commit 1756dd9

Browse files
committed
Introduce critic network
1 parent 0f0c407 commit 1756dd9

1 file changed

Lines changed: 140 additions & 7 deletions

File tree

src/ppo/main.clj

Lines changed: 140 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
;;
4141
;; In order to use PPO with a simulation environment in Clojure and also in order to get a better understanding of PPO, I dediced to do an implementation of PPO in Clojure.
4242
;;
43-
;; ## Pendulum environment
43+
;; ## Pendulum Environment
4444
;;
4545
;; ![screenshot of pendulum environment](pendulum.png)
4646
;;
@@ -82,7 +82,7 @@
8282
;; Here a pendulum is initialised to be pointing down and with an angular velocity of 0.5.
8383
(setup (/ PI 2) 0.5)
8484

85-
;; ### State updates
85+
;; ### State Updates
8686
;;
8787
;; The angular acceleration due to gravitation is implemented as follows.
8888
(defn pendulum-gravity
@@ -192,7 +192,7 @@
192192
(* velocity-weight (sqr velocity))
193193
(* control-weight (sqr control)))))
194194

195-
;; ### Environment protocol
195+
;; ### Environment Protocol
196196
;;
197197
;; Finally we are able to implement the pendulum as a generic environment.
198198
(defrecord Pendulum [config state]
@@ -269,14 +269,14 @@
269269

270270
;; ![manually controlled pendulum](manual.gif)
271271

272-
;; ## Neural networks
272+
;; ## Neural Networks
273273
;;
274274
;; PPO is a machine learning technique using backpropagation to learn the parameters of two neural networks.
275275
;;
276276
;; * The **actor** network takes an observation as an input and outputs the parameters of a probability distribution for sampling the next action to take.
277277
;; * The **critic** takes an observation as an input and outputs the expected cumulative reward for the current state.
278278
;;
279-
;; ### Pytorch
279+
;; ### Import Pytorch
280280
;;
281281
;; For implementing the neural networks and backpropagation, I am using the Python-Clojure bridge [libpython-clj2](https://github.com/clj-python/libpython-clj) and [Pytorch](https://pytorch.org/).
282282
;; The Pytorch library is quite comprehensive, is free software, and you can find a lot of documentation on how to use it.
@@ -343,11 +343,144 @@
343343
'[torch.optim :as optim]
344344
'[torch.distributions :refer (Beta)])
345345

346+
;; ### Tensor Conversion
347+
;;
348+
;; First we implement a few methods for converting nested Clojure vectors to Pytorch tensors and back.
349+
;;
350+
;; #### Clojure to Pytorch
351+
;;
352+
;; The method `tensor` is for converting a Clojure datatype to a Pytorch tensor.
353+
(defn tensor
354+
"Convert nested vector to tensor"
355+
([data]
356+
(tensor data torch/float32))
357+
([data dtype]
358+
(torch/tensor data :dtype dtype)))
359+
360+
(tensor PI)
361+
(tensor [2.0 3.0 5.0])
362+
(tensor [[1.0 2.0] [3.0 4.0] [5.0 6.0]])
363+
(tensor [1 2 3] torch/long)
364+
365+
;; #### Pytorch to Clojure
366+
;;
367+
;; The next method is for converting a Pytorch tensor back to a Clojure datatype.
368+
(defn tolist
369+
"Convert tensor to nested vector"
370+
[tensor]
371+
(py/->jvm (py. tensor tolist)))
372+
373+
(tolist (tensor [2.0 3.0 5.0]))
374+
(tolist (tensor [[1.0 2.0] [3.0 4.0] [5.0 6.0]]))
375+
376+
;; #### Pytorch scalar to Clojure
377+
;;
378+
;; A tensor with no dimensions can also be converted using `toitem`
379+
(defn toitem
380+
"Convert torch scalar value to float"
381+
[tensor]
382+
(py. tensor item))
383+
384+
(toitem (tensor PI))
385+
386+
;; ### Critic Network
387+
;;
388+
;; The critic network is a fully connected neural network with an input layer of size `observation-size` and two hidden layers of size `hidden-units` with `tanh` activation functions.
389+
;; The critic output is a single value (an estimate for the expected cumulative return achievable by the given observed state.
390+
(def Critic
391+
(py/create-class
392+
"Critic" [nn/Module]
393+
{"__init__"
394+
(py/make-instance-fn
395+
(fn [self observation-size hidden-units]
396+
(py. nn/Module __init__ self)
397+
(py/set-attrs!
398+
self
399+
{"fc1" (nn/Linear observation-size hidden-units)
400+
"fc2" (nn/Linear hidden-units hidden-units)
401+
"fc3" (nn/Linear hidden-units 1)})
402+
nil))
403+
"forward"
404+
(py/make-instance-fn
405+
(fn [self x]
406+
(let [x (py. self fc1 x)
407+
x (torch/tanh x)
408+
x (py. self fc2 x)
409+
x (torch/tanh x)
410+
x (py. self fc3 x)]
411+
(torch/squeeze x -1))))}))
412+
413+
;; When running inference, you need to run the network with gradient accumulation disabled, otherwise gradients get accumulated and can leak into a subsequent training step.
414+
;; In Python this looks like this.
415+
;;
416+
;; ```Python
417+
;; with torch.no_grad():
418+
;; ...
419+
;; ```
420+
;;
421+
;; Here we create a Clojure macro to do the same job.
422+
(defmacro without-gradient
423+
"Execute body without gradient calculation"
424+
[& body]
425+
`(let [no-grad# (torch/no_grad)]
426+
(try
427+
(py. no-grad# ~'__enter__)
428+
~@body
429+
(finally
430+
(py. no-grad# ~'__exit__ nil nil nil)))))
431+
432+
;; Now we can create a network and try it out.
433+
;; Note that the network creates non-zero outputs because Pytorch performs random initialisation of ther weights for us.
434+
(def critic (Critic 3 64))
435+
(without-gradient
436+
(toitem (critic (tensor [-1 0 0]))))
437+
438+
;; We can also create a wrapper for using the neural network with Clojure datatypes.
439+
(defn critic-observation
440+
"Use critic with Clojure datatypes"
441+
[critic]
442+
(fn [observation]
443+
(without-gradient (toitem (critic (tensor observation))))))
444+
445+
((critic-observation critic) [-1 0 0])
446+
447+
;; ### Training
448+
;;
449+
;; Training a neural network is done by defining a loss function.
450+
;; The loss of the network then is calculated for a mini-batch of training data.
451+
;; One can then use Pytorch's backpropagation to compute the gradient of the loss value with respect to every single parameter of the network.
452+
;; The gradient then is used to perform gradient descent steps.
453+
;; A popular gradient descent method is the [Adam optimizer](https://en.wikipedia.org/wiki/Stochastic_gradient_descent#Adam).
454+
455+
;; Here is a wrapper for the Adam optimizer.
456+
(defn adam-optimizer
457+
"Adam optimizer"
458+
[model learning-rate weight-decay]
459+
(optim/Adam (py. model parameters) :lr learning-rate :weight_decay weight-decay))
460+
461+
;; Pytorch also provides the mean square error (MSE) loss function.
462+
(defn mse-loss
463+
"Mean square error cost function"
464+
[]
465+
(nn/MSELoss))
346466

467+
;; A training step can be performed as follows.
468+
(def optimizer (adam-optimizer critic 0.001 0.0))
469+
(def criterion (mse-loss))
470+
(def mini-batch [(tensor [[-1 0 0]]) (tensor [1.0])])
471+
(def prediction (critic (first mini-batch)))
472+
(def loss (criterion prediction (second mini-batch)))
473+
(py. optimizer zero_grad)
474+
(py. loss backward)
475+
(py. optimizer step)
347476

477+
;; As you can see, the output of the network for the observation `[-1 0 0]` is now closer to 1.0.
478+
((critic-observation critic) [-1 0 0])
348479

349-
350-
;; TODO
480+
;; # TODO
481+
;;
482+
;; * neural networks
483+
;; * ppo
351484
;;
352485
;; $\hat{A}_{T-1} = -V(S_{T-1}) + r_{T-1} + \gamma V(S_T)$
353486
;;

0 commit comments

Comments
 (0)