Skip to content

Commit

Permalink
update bert example with Optimisers.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
chengchingwen committed Feb 23, 2023
1 parent 205221b commit f2c5fd1
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 27 deletions.
2 changes: 2 additions & 0 deletions example/BERT/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
[deps]
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Transformers = "21ca0261-441d-5938-ace7-c90938fde4d4"
WordTokenizers = "796a5d58-b03d-544a-977e-18100b691f6e"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
18 changes: 9 additions & 9 deletions example/BERT/cola/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ using Transformers.Datasets: GLUE

using Flux
using Flux.Losses
using Flux: pullback, params
import Flux.Optimise: update!
using Zygote
import Optimisers

const Epoch = 2
const Batch = 4
Expand All @@ -26,8 +26,8 @@ const bertenc = hgf"bert-base-uncased:tokenizer"

const bert_model = todevice(_bert_model)

const ps = params(bert_model)
const opt = ADAM(1e-6)
const opt_rule = Optimisers.Adam(1e-6)
const opt = Optimisers.setup(opt_rule, bert_model)

function acc(p, label)
pred = Flux.onecold(p)
Expand All @@ -52,14 +52,14 @@ function train!()
al = zero(Float64)
while (batch = get_batch(datas, Batch)) !== nothing
input = todevice(preprocess(batch::Vector{Vector{String}}))
(l, p), back = pullback(ps) do
loss(bert_model, input)
(l, p), back = Zygote.pullback(bert_model) do model
loss(model, input)
end
a = acc(p, input.label)
al += a
grad = back((Flux.Zygote.sensitivity(l), nothing))
i+=1
update!(opt, ps, grad)
(grad,) = back((Zygote.sensitivity(l), nothing))
i += 1
Optimisers.update!(opt, bert_model, grad)
mod1(i, 16) == 1 && @info "training" loss=l accuracy=al/i
end

Expand Down
18 changes: 9 additions & 9 deletions example/BERT/mnli/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ using Transformers.Datasets: GLUE

using Flux
using Flux.Losses
using Flux: pullback, params
import Flux.Optimise: update!
using Zygote
import Optimisers

const Epoch = 2
const Batch = 4
Expand All @@ -30,8 +30,8 @@ const bertenc = load_tokenizer("bert-base-uncased"; config = bert_config)

const bert_model = todevice(_bert_model)

const ps = params(bert_model)
const opt = ADAM(1e-6)
const opt_rule = Optimisers.Adam(1e-6)
const opt = Optimisers.setup(opt_rule, bert_model)

function acc(p, label)
pred = Flux.onecold(p)
Expand All @@ -56,14 +56,14 @@ function train!()
al = zero(Float64)
while (batch = get_batch(datas, Batch)) !== nothing
input = todevice(preprocess(batch::Vector{Vector{String}}))
(l, p), back = pullback(ps) do
loss(bert_model, input)
(l, p), back = Zygote.pullback(bert_model) do model
loss(model, input)
end
a = acc(p, input.label)
al += a
grad = back((Flux.Zygote.sensitivity(l), nothing))
i+=1
update!(opt, ps, grad)
(grad,) = back((Zygote.sensitivity(l), nothing))
i += 1
Optimisers.update!(opt, bert_model, grad)
mod1(i, 16) == 1 && @info "training" loss=l accuracy=al/i
end

Expand Down
18 changes: 9 additions & 9 deletions example/BERT/mrpc/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ using Transformers.Datasets: GLUE

using Flux
using Flux.Losses
using Flux: pullback, params
import Flux.Optimise: update!
using Zygote
import Optimisers

const Epoch = 2
const Batch = 4
Expand All @@ -30,8 +30,8 @@ const bertenc = load_tokenizer("bert-base-uncased"; config = bert_config)

const bert_model = todevice(_bert_model)

const ps = params(bert_model)
const opt = ADAM(1e-6)
const opt_rule = Optimisers.Adam(1e-6)
const opt = Optimisers.setup(opt_rule, bert_model)

function acc(p, label)
pred = Flux.onecold(p)
Expand All @@ -56,14 +56,14 @@ function train!()
al = zero(Float64)
while (batch = get_batch(datas, Batch)) !== nothing
input = todevice(preprocess(batch))
(l, p), back = pullback(ps) do
loss(bert_model, input)
(l, p), back = Zygote.pullback(bert_model) do model
loss(model, input)
end
a = acc(p, input.label)
al += a
grad = back((Flux.Zygote.sensitivity(l), nothing))
i+=1
update!(opt, ps, grad)
(grad,) = back((Zygote.sensitivity(l), nothing))
i += 1
Optimisers.update!(opt, bert_model, grad)
mod1(i, 16) == 1 && @info "training" loss=l accuracy=al/i
end

Expand Down

0 comments on commit f2c5fd1

Please sign in to comment.