Skip to content

Refactor DCGAN training script to use Flux.setup/update! unified API #411

Description

@josemanuel22

Motivation and description

Description:
The current train.jl script manually manages optimizer state and uses the low-level Optimisers.jl API (Optimisers.setup/Optimisers.update! or explicit params(model) loops). We should refactor it to leverage Flux's unified Flux.setup and Flux.update! API introduced in Flux 0.13+, for a cleaner and more idiomatic training loop.

Current Behavior:

  • Optimizers are initialized and updated using the explicit Optimisers.jl API:
    opt = ADAM(lr)
    ps  = params(model)
    st  = Optimisers.setup(opt, ps)
    ...
    st, _ = Optimisers.update!(st, ps, grads)
  • Training loops manually handle gradient collection and state updates.

Desired Behavior:

  • Use Flux's unified optimizer API:
    opt_state = Flux.setup(Flux.Optimise.Adam(lr), model)
    
    # inside training loop
    loss, grads = Flux.withgradient(model) do m
      loss_fn(m(x))
    end
    Flux.update!(opt_state, model, grads[1])
  • Simplify train_discriminator! and train_generator! to use Flux.setup/Flux.update! instead of explicit parameter/state management.

Proposed Changes:

  1. Replace all explicit Optimisers.setup and Optimisers.update! calls with Flux.setup and Flux.update! on the Chain models.
  2. Remove manual params(model) and state-tracking variables where no longer needed.

Acceptance Criteria:

  • Code compiles and runs without deprecation warnings under Flux 0.14+.
  • Discriminator and generator training functions use Flux.setup and Flux.update! exclusively.
  • Existing functionality and performance are preserved.

References:

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions