Motivation and description
Description:
The current train.jl script manually manages optimizer state and uses the low-level Optimisers.jl API (Optimisers.setup/Optimisers.update! or explicit params(model) loops). We should refactor it to leverage Flux's unified Flux.setup and Flux.update! API introduced in Flux 0.13+, for a cleaner and more idiomatic training loop.
Current Behavior:
- Optimizers are initialized and updated using the explicit Optimisers.jl API:
opt = ADAM(lr)
ps = params(model)
st = Optimisers.setup(opt, ps)
...
st, _ = Optimisers.update!(st, ps, grads)
- Training loops manually handle gradient collection and state updates.
Desired Behavior:
- Use Flux's unified optimizer API:
opt_state = Flux.setup(Flux.Optimise.Adam(lr), model)
# inside training loop
loss, grads = Flux.withgradient(model) do m
loss_fn(m(x))
end
Flux.update!(opt_state, model, grads[1])
- Simplify
train_discriminator! and train_generator! to use Flux.setup/Flux.update! instead of explicit parameter/state management.
Proposed Changes:
- Replace all explicit
Optimisers.setup and Optimisers.update! calls with Flux.setup and Flux.update! on the Chain models.
- Remove manual
params(model) and state-tracking variables where no longer needed.
Acceptance Criteria:
- Code compiles and runs without deprecation warnings under Flux 0.14+.
- Discriminator and generator training functions use
Flux.setup and Flux.update! exclusively.
- Existing functionality and performance are preserved.
References:
Motivation and description
Description:
The current
train.jlscript manually manages optimizer state and uses the low-level Optimisers.jl API (Optimisers.setup/Optimisers.update!or explicitparams(model)loops). We should refactor it to leverage Flux's unifiedFlux.setupandFlux.update!API introduced in Flux 0.13+, for a cleaner and more idiomatic training loop.Current Behavior:
Desired Behavior:
train_discriminator!andtrain_generator!to useFlux.setup/Flux.update!instead of explicit parameter/state management.Proposed Changes:
Optimisers.setupandOptimisers.update!calls withFlux.setupandFlux.update!on theChainmodels.params(model)and state-tracking variables where no longer needed.Acceptance Criteria:
Flux.setupandFlux.update!exclusively.References: