Skip to content

Commit

Permalink
AVX for phonemizer
Browse files Browse the repository at this point in the history
  • Loading branch information
neurlang authored and Your Name committed Oct 16, 2024
1 parent e4e9df6 commit 28fc6bc
Showing 1 changed file with 21 additions and 11 deletions.
32 changes: 21 additions & 11 deletions cmd/train_phonemizer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import "flag"
import "github.com/neurlang/classifier/datasets/phonemizer"
import "github.com/neurlang/classifier/layer/majpool2d"
import "github.com/neurlang/classifier/datasets"
import "github.com/neurlang/classifier/learning"
import "github.com/neurlang/classifier/learning/avx"
import "github.com/neurlang/classifier/net/feedforward"

func error_abs(a, b uint32) uint32 {
Expand All @@ -22,6 +22,7 @@ func error_abs(a, b uint32) uint32 {
func main() {
cleantsv := flag.String("cleantsv", "", "clean tsv dataset for the language")
dstmodel := flag.String("dstmodel", "", "model destination .json.lzw file")
resume := flag.Bool("resume", false, "resume training")
flag.Parse()

var improved_success_rate = 0
Expand Down Expand Up @@ -49,7 +50,7 @@ func main() {
net.NewCombiner(majpool2d.MustNew(fanout2, 1, fanout1, 1, fanout2, 1, 1))
net.NewLayer(1, 0)

//net.ReadCompressedWeightsFromFile("output.94.json.t.lzw")


trainWorst := func(worst int) {
var tally = new(datasets.Tally)
Expand All @@ -76,7 +77,7 @@ func main() {
wg.Wait()
}

var h learning.HyperParameters
var h avx.HyperParameters
h.Threads = runtime.NumCPU()
h.Factor = 1 // affects the solution size

Expand All @@ -100,6 +101,9 @@ func main() {

h.Name = fmt.Sprint(worst)
h.SetLogger("solutions11.txt")

h.AvxLanes = 16
h.AvxSkip = 4

fmt.Println("hashtron position:", worst, "(job size:", tally.Len(), ")")

Expand Down Expand Up @@ -131,25 +135,31 @@ func main() {
success := percent * 100 / len(datakeys)
println("[success rate]", success, "%", "with", errsum, "errors")

err := net.WriteCompressedWeightsToFile("output." + fmt.Sprint(success) + ".json.t.lzw")
if err != nil {
println(err.Error())
if dstmodel == nil {
err := net.WriteCompressedWeightsToFile("output." + fmt.Sprint(success) + ".json.t.lzw")
if err != nil {
println(err.Error())
}
}

if dstmodel != nil && len(*dstmodel) > 0 && improved_success_rate < success {
improved_success_rate = success
err := net.WriteCompressedWeightsToFile(*dstmodel)
if err != nil {
println(err.Error())
if improved_success_rate > 0 {
err := net.WriteCompressedWeightsToFile(*dstmodel)
if err != nil {
println(err.Error())
}
}
improved_success_rate = success
}

if success == 100 {
println("Max accuracy or wrong data. Exiting")
os.Exit(0)
}
}

if resume != nil && *resume && dstmodel != nil {
net.ReadCompressedWeightsFromFile(*dstmodel)
}
for {
shuf := net.Shuffle(true)
evaluate()
Expand Down

0 comments on commit 28fc6bc

Please sign in to comment.