-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathaf.go
254 lines (241 loc) · 7.93 KB
/
af.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
package wann
import (
"fmt"
"time"
"github.com/dave/jennifer/jen"
"github.com/xyproto/af"
)
// ActivationFunctionIndex is a number that represents a specific activation function
type ActivationFunctionIndex int
const (
// Step is a step. First 0 and then abrubtly up to 1.
Step ActivationFunctionIndex = iota
// Linear is the linear activation function. Gradually from 0 to 1.
Linear
// Sin is the sinoid activation function
Sin
// Gauss is the Gaussian function, with a mean of 0 and a sigma of 1
Gauss
// Tanh is math.Tanh
Tanh
// Sigmoid is the optimized sigmoid function from github.com/xyproto/swish
Sigmoid
// Inv is the inverse linear function
Inv
// Abs is math.Abs
Abs
// ReLU or ReLU is the rectified linear unit, first 0 and then the linear function
ReLU
// Cos is the cosoid (?) activation function
Cos
// Squared increases rapidly
Squared
// Swish is a later invention than ReLU, _|
Swish
// SoftPlus is log(1 + exp(x))
SoftPlus
)
// ActivationFunctions is a collection of activation functions, where the keys are constants that are defined above
// https://github.com/google/brain-tokyo-workshop/blob/master/WANNRelease/WANN/wann_src/ind.py
var ActivationFunctions = map[ActivationFunctionIndex](func(float64) float64){
Step: af.Step, // Unsigned Step Function
Linear: af.Linear, // Linear
Sin: af.Sin, // Sin
Gauss: af.Gaussian01, // Gaussian with mean 0 and sigma 1
Tanh: af.Tanh, // Hyperbolic Tangent (signed?)
Sigmoid: af.Sigmoid, // Sigmoid (unsigned?)
Inv: af.Inv, // Inverse
Abs: af.Abs, // Absolute value
ReLU: af.ReLU, // Rectified linear unit
Cos: af.Cos, // Cosine
Squared: af.Squared, // Squared
Swish: af.Swish, // Swish
SoftPlus: af.SoftPlus, // SoftPlus
}
// ComplexityEstimate is a map for having an estimate of how complex each function is,
// based on a quick benchmark of each function.
// The complexity estimates will vary, depending on the performance.
var ComplexityEstimate = make(map[ActivationFunctionIndex]float64)
func (config *Config) estimateComplexity() {
if config.Verbose {
fmt.Print("Estimating activation function complexity...")
}
startEstimate := time.Now()
resolution := 0.0001
durationMap := make(map[ActivationFunctionIndex]time.Duration)
var maxDuration time.Duration
for i, f := range ActivationFunctions {
start := time.Now()
for x := 0.0; x <= 1.0; x += resolution {
_ = f(x)
}
duration := time.Since(start)
durationMap[ActivationFunctionIndex(i)] = duration
if duration > maxDuration {
maxDuration = duration
}
}
for i := range ActivationFunctions {
// 1.0 means the function took maxDuration
ComplexityEstimate[ActivationFunctionIndex(i)] = float64(durationMap[ActivationFunctionIndex(i)]) / float64(maxDuration)
}
estimateDuration := time.Since(startEstimate)
if config.Verbose {
fmt.Printf(" done. (In %v)\n", estimateDuration)
}
}
// Call runs an activation function with the given float64 value.
// The activation function is chosen by one of the constants above.
func (afi ActivationFunctionIndex) Call(x float64) float64 {
if f, ok := ActivationFunctions[afi]; ok {
return f(x)
}
// Use the linear function by default
return af.Linear(x)
}
// Name returns a name for each activation function
func (afi ActivationFunctionIndex) Name() string {
switch afi {
case Step:
return "Step"
case Linear:
return "Linear"
case Sin:
return "Sinusoid"
case Gauss:
return "Gaussian"
case Tanh:
return "Tanh"
case Sigmoid:
return "Sigmoid"
case Inv:
return "Inverted"
case Abs:
return "Absolute"
case ReLU:
return "ReLU"
case Cos:
return "Cosinusoid"
case Squared:
return "Squared"
case Swish:
return "Swish"
case SoftPlus:
return "SoftPlus"
default:
return "Untitled"
}
}
// goExpression returns the Go expression for this activation function, using the given variable name string as the input variable name
func (afi ActivationFunctionIndex) goExpression(varName string) string {
switch afi {
case Step:
// Using s to not confuse it with the varName
return "func(s float64) float64 { if s >= 0 { return 1 } else { return 0 } }(" + varName + ")"
case Linear:
return varName
case Sin:
return "math.Sin(math.Pi * " + varName + ")"
case Gauss:
return "math.Exp(-(" + varName + " * " + varName + ") / 2.0)"
case Tanh:
return "math.Tanh(" + varName + ")"
case Sigmoid:
return "(1.0 / (1.0 + math.Exp(-" + varName + ")))"
case Inv:
return "-" + varName
case Abs:
return "math.Abs(" + varName + ")"
case ReLU:
// Using r to not confuse it with the varName
return "func(r float64) float64 { if r >= 0 { return r } else { return 0 } }(" + varName + ")"
case Cos:
return "math.Cos(math.Pi * " + varName + ")"
case Squared:
return "(" + varName + " * " + varName + ")"
case Swish:
return "(" + varName + "/ (1.0 + math.Exp(-" + varName + ")))"
case SoftPlus:
return "math.Log(1.0 + math.Exp(" + varName + "))"
default:
return varName
}
}
// String returns the Go expression for this activation function, using "x" as the input variable name
func (afi ActivationFunctionIndex) String() string {
return afi.goExpression("x")
}
// Statement returns the Statement statement for this activation function, using the given inner statement
func (afi ActivationFunctionIndex) Statement(inner *jen.Statement) *jen.Statement {
switch afi {
case Step:
// func(s float64) float64 { if s >= 0 { return 1 } else { return 0 } }(inner)
// Using s to not confuse it with the varName
return jen.Func().Params(jen.Id("s").Id("float64")).Id("float64").Block(
jen.If(jen.Id("s").Op(">=").Id("0")).Block(
jen.Return(jen.Lit(1)),
).Else().Block(
jen.Return(jen.Lit(0)),
),
).Call(inner)
case Cos:
// math.Cos((inner) * math.Pi)
return jen.Qual("math", "Cos").Call(jen.Parens(inner).Op("*").Id("math").Dot("Pi"))
case Sin:
// math.Sin((inner) * math.Pi)
return jen.Qual("math", "Sin").Call(jen.Parens(inner).Op("*").Id("math").Dot("Pi"))
case Gauss:
// return math.Exp(-(math.Pow(inner, 2.0)) / 2.0)
return jen.Qual("math", "Exp").Call(jen.Op("-").Parens(
// Using math.Pow ensures the inner expression is only calculated once, if it's a large expression
//inner.Op("*").Add(inner),
jen.Qual("math", "Pow").Params(
inner,
jen.Lit(2.0),
),
).Op("/").Lit(2.0))
case Tanh:
// math.Tanh(inner)
return jen.Qual("math", "Tanh").Call(inner)
case Sigmoid:
// (1.0 / (1.0 + math.Exp(-(inner))))
return jen.Lit(1.0).Op("/").Parens(jen.Lit(1.0).Op("+").Qual("math", "Exp").Call(jen.Op("-").Parens(inner)))
case Inv:
// -(inner)
return jen.Op("-").Parens(inner)
case Abs:
// math.Abs(inner)
return jen.Qual("math", "Abs").Call(inner)
case ReLU:
//return "func(r float64) float64 { if r >= 0 { return r } else { return 0 } }(" + varName + ")"
// Using r to not confuse it with the varName
return jen.Func().Params(jen.Id("r").Id("float64")).Id("float64").Block(
jen.If(jen.Id("r").Op(">=").Id("0")).Block(
jen.Return(jen.Id("r")),
).Else().Block(
jen.Return(jen.Lit(0)),
),
).Call(inner)
case Squared:
// inner^2
//return inner.Op("*").Add(inner)
// Using math.Pow ensures the inner expression is only calculated once, if it's a large expression
return jen.Qual("math", "Pow").Call(inner, jen.Lit(2.0))
case Swish:
// (inner / (1.0 + math.Exp(-inner)))
return jen.Parens(inner.Op("/").Parens(jen.Lit(1.0).Op("+").Qual("math", "Exp").Call(jen.Op("-").Parens(inner))))
case SoftPlus:
// math.Log(1.0 + math.Exp(inner))
return jen.Qual("math", "Log").Call(jen.Lit(1.0).Op("+").Qual("math", "Exp").Call(inner))
case Linear:
// This is also the default case: (inner)
fallthrough
default:
// (inner)
return jen.Parens(inner)
}
}
// GoRun will first construct the expression using jennifer and then evaluate the result using "go run" and a source file innn /tmp
func (afi ActivationFunctionIndex) GoRun(x float64) (float64, error) {
return RunStatementX(afi.Statement(jen.Id("x")), x)
}