Skip to content

Commit

Permalink
refactoring linreg
Browse files Browse the repository at this point in the history
  • Loading branch information
campoy committed Aug 21, 2018
1 parent 4a1fae4 commit bdc8b4e
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 53 deletions.
32 changes: 32 additions & 0 deletions 38-linreg-vanilla/linreg/linreg.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Package linreg provides a basic implementation of linear regression
// with gradient descent on two dimensional data.
package linreg

import "fmt"

// LinearRegression runs the requested number of iterations of gradient
// descent and returns the latest approximated coefficients.
func LinearRegression(xs, ys []float64, iterations int, alpha float64) (m, c float64) {
for i := 0; i < iterations; i++ {
cost, dm, dc := Gradient(xs, ys, m, c)
m += -dm * alpha
c += -dc * alpha
if (10 * i % iterations) == 0 {
fmt.Printf("cost(%.2f, %.2f) = %.2f\n", m, c, cost)
}
}

return m, c
}

// Gradient computes the cost function and its gradients.
func Gradient(xs, ys []float64, m, c float64) (cost, dm, dc float64) {
for i := range xs {
d := ys[i] - (xs[i]*m + c)
cost += d * d
dm += -xs[i] * d
dc += -d
}
n := float64(len(xs))
return cost / n, 2 / n * dm, 2 / n * dc
}
72 changes: 19 additions & 53 deletions 38-linreg-vanilla/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,36 +11,32 @@ import (
"gonum.org/v1/plot"
"gonum.org/v1/plot/plotter"
"gonum.org/v1/plot/vg/draw"
)

var iterations int
"github.com/campoy/justforfunc/39-linreg/linreg"
)

func main() {
flag.IntVar(&iterations, "n", 1000, "number of iterations")
iterations := flag.Int("n", 1000, "number of iterations")
flag.Parse()

xys, err := readData("data.txt")
xs, ys, err := readData("data.txt")
if err != nil {
log.Fatalf("could not read data.txt: %v", err)
}
_ = xys

err = plotData("out.png", xys)
err = plotData("out.png", xs, ys, *iterations)
if err != nil {
log.Fatalf("could not plot data: %v", err)
}
}

type xy struct{ x, y float64 }

func readData(path string) (plotter.XYs, error) {
func readData(path string) (xs, ys []float64, err error) {
f, err := os.Open(path)
if err != nil {
return nil, err
return nil, nil, err
}
defer f.Close()

var xys plotter.XYs
s := bufio.NewScanner(f)
for s.Scan() {
var x, y float64
Expand All @@ -49,15 +45,21 @@ func readData(path string) (plotter.XYs, error) {
log.Printf("discarding bad data point %q: %v", s.Text(), err)
continue
}
xys = append(xys, struct{ X, Y float64 }{x, y})
xs = append(xs, x)
ys = append(ys, y)
}
if err := s.Err(); err != nil {
return nil, fmt.Errorf("could not scan: %v", err)
return nil, nil, fmt.Errorf("could not scan: %v", err)
}
return xys, nil
return xs, ys, nil
}

func plotData(path string, xys plotter.XYs) error {
type xyer struct{ xs, ys []float64 }

func (x xyer) Len() int { return len(x.xs) }
func (x xyer) XY(i int) (float64, float64) { return x.xs[i], x.ys[i] }

func plotData(path string, xs, ys []float64, iterations int) error {
f, err := os.Create(path)
if err != nil {
return fmt.Errorf("could not create %s: %v", path, err)
Expand All @@ -69,15 +71,15 @@ func plotData(path string, xys plotter.XYs) error {
}

// create scatter with all data points
s, err := plotter.NewScatter(xys)
s, err := plotter.NewScatter(xyer{xs, ys})
if err != nil {
return fmt.Errorf("could not create scatter: %v", err)
}
s.GlyphStyle.Shape = draw.CrossGlyph{}
s.Color = color.RGBA{R: 255, A: 255}
p.Add(s)

x, c := linearRegression(xys, 0.01)
x, c := linreg.LinearRegression(xs, ys, iterations, 0.01)

// create fake linear regression result
l, err := plotter.NewLine(plotter.XYs{
Expand All @@ -102,39 +104,3 @@ func plotData(path string, xys plotter.XYs) error {
}
return nil
}

func linearRegression(xys plotter.XYs, alpha float64) (m, c float64) {
for i := 0; i < iterations; i++ {
dm, dc := computeGradient(xys, m, c)
m += -dm * alpha
c += -dc * alpha
fmt.Printf("cost(%.2f, %.2f) = %.2f\n", m, c, computeCost(xys, m, c))
}

fmt.Printf("cost(%.2f, %.2f) = %.2f\n", m, c, computeCost(xys, m, c))

return m, c
}

func computeCost(xys plotter.XYs, m, c float64) float64 {
// cost = 1/N * sum((y - (m*x+c))^2)
s := 0.0
for _, xy := range xys {
d := xy.Y - (xy.X*m + c)
s += d * d
}
return s / float64(len(xys))
}

func computeGradient(xys plotter.XYs, m, c float64) (dm, dc float64) {
// cost = 1/N * sum((y - (m*x+c))^2)
// cost/dm = 2/N * sum(-x * (y - (m*x+c)))
// cost/dc = 2/N * sum(-(y - (m*x+c)))
for _, xy := range xys {
d := xy.Y - (xy.X*m + c)
dm += -xy.X * d
dc += -d
}
n := float64(len(xys))
return 2 / n * dm, 2 / n * dc
}

0 comments on commit bdc8b4e

Please sign in to comment.