-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmutiple_linear_regration.go
128 lines (104 loc) · 2.57 KB
/
mutiple_linear_regration.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package main
import (
"encoding/csv"
"fmt"
"github.com/sajari/regression"
"log"
"math"
"os"
"strconv"
)
func main() {
// Open the training dataset file.
f, err := os.Open("training.csv")
if err != nil {
log.Fatal(err)
}
defer f.Close()
// Create a new CSV reader reading from the opened file.
reader := csv.NewReader(f)
// Read in all of the CSV records
reader.FieldsPerRecord = 4
trainingData, err := reader.ReadAll()
if err != nil {
log.Fatal(err)
}
// In this case we are going to try and model our Sales
// by the TV and Radio features plus an intercept.
var r regression.Regression
r.SetObserved("Sales")
r.SetVar(0, "TV")
r.SetVar(1, "Radio")
// Loop over the CSV records adding the training data.
for i, record := range trainingData {
// Skip the header.
if i == 0 {
continue
}
// Parse the Sales.
yVal, err := strconv.ParseFloat(record[3], 64)
if err != nil {
log.Fatal(err)
}
// Parse the TV value.
tvVal, err := strconv.ParseFloat(record[0], 64)
if err != nil {
log.Fatal(err)
}
// Parse the Radio value.
radioVal, err := strconv.ParseFloat(record[1], 64)
if err != nil {
log.Fatal(err)
}
// Add these points to the regression value.
r.Train(regression.DataPoint(yVal, []float64{tvVal, radioVal}))
}
// Train/fit the regression model.
r.Run()
// Output the trained model parameters.
fmt.Printf("\nRegression Formula:\n%v\n\n", r.Formula)
// Open the test dataset file.
f, err = os.Open("test.csv")
if err != nil {
log.Fatal(err)
}
defer f.Close()
// Create a CSV reader reading from the opened file.
reader = csv.NewReader(f)
// Read in all of the CSV records
reader.FieldsPerRecord = 4
testData, err := reader.ReadAll()
if err != nil {
log.Fatal(err)
}
// Loop over the test data predicting y and evaluating the prediction
// with the mean absolute error.
var mAE float64
for i, record := range testData {
// Skip the header.
if i == 0 {
continue
}
// Parse the Sales.
yObserved, err := strconv.ParseFloat(record[3], 64)
if err != nil {
log.Fatal(err)
}
// Parse the TV value.
tvVal, err := strconv.ParseFloat(record[0], 64)
if err != nil {
log.Fatal(err)
}
// Parse the Radio value.
radioVal, err := strconv.ParseFloat(record[1], 64)
if err != nil {
log.Fatal(err)
}
// Predict y with our trained model.
yPredicted, err := r.Predict([]float64{tvVal, radioVal})
// Add the to the mean absolute error.
mAE += math.Abs(yObserved-yPredicted) / float64(len(testData))
}
// Output the MAE to standard out.
fmt.Printf("MAE = %0.2f\n\n", mAE)
}