diff --git a/cmd/train_phonemizer/main.go b/cmd/train_phonemizer/main.go index 3ec7749..ef65360 100644 --- a/cmd/train_phonemizer/main.go +++ b/cmd/train_phonemizer/main.go @@ -9,7 +9,7 @@ import "flag" import "github.com/neurlang/classifier/datasets/phonemizer" import "github.com/neurlang/classifier/layer/majpool2d" import "github.com/neurlang/classifier/datasets" -import "github.com/neurlang/classifier/learning" +import "github.com/neurlang/classifier/learning/avx" import "github.com/neurlang/classifier/net/feedforward" func error_abs(a, b uint32) uint32 { @@ -22,6 +22,7 @@ func error_abs(a, b uint32) uint32 { func main() { cleantsv := flag.String("cleantsv", "", "clean tsv dataset for the language") dstmodel := flag.String("dstmodel", "", "model destination .json.lzw file") + resume := flag.Bool("resume", false, "resume training") flag.Parse() var improved_success_rate = 0 @@ -49,7 +50,7 @@ func main() { net.NewCombiner(majpool2d.MustNew(fanout2, 1, fanout1, 1, fanout2, 1, 1)) net.NewLayer(1, 0) - //net.ReadCompressedWeightsFromFile("output.94.json.t.lzw") + trainWorst := func(worst int) { var tally = new(datasets.Tally) @@ -76,7 +77,7 @@ func main() { wg.Wait() } - var h learning.HyperParameters + var h avx.HyperParameters h.Threads = runtime.NumCPU() h.Factor = 1 // affects the solution size @@ -100,6 +101,9 @@ func main() { h.Name = fmt.Sprint(worst) h.SetLogger("solutions11.txt") + + h.AvxLanes = 16 + h.AvxSkip = 4 fmt.Println("hashtron position:", worst, "(job size:", tally.Len(), ")") @@ -131,17 +135,21 @@ func main() { success := percent * 100 / len(datakeys) println("[success rate]", success, "%", "with", errsum, "errors") - err := net.WriteCompressedWeightsToFile("output." + fmt.Sprint(success) + ".json.t.lzw") - if err != nil { - println(err.Error()) + if dstmodel == nil { + err := net.WriteCompressedWeightsToFile("output." + fmt.Sprint(success) + ".json.t.lzw") + if err != nil { + println(err.Error()) + } } if dstmodel != nil && len(*dstmodel) > 0 && improved_success_rate < success { - improved_success_rate = success - err := net.WriteCompressedWeightsToFile(*dstmodel) - if err != nil { - println(err.Error()) + if improved_success_rate > 0 { + err := net.WriteCompressedWeightsToFile(*dstmodel) + if err != nil { + println(err.Error()) + } } + improved_success_rate = success } if success == 100 { @@ -149,7 +157,9 @@ func main() { os.Exit(0) } } - + if resume != nil && *resume && dstmodel != nil { + net.ReadCompressedWeightsFromFile(*dstmodel) + } for { shuf := net.Shuffle(true) evaluate()