Skip to content

Commit

Permalink
AVX
Browse files Browse the repository at this point in the history
  • Loading branch information
neurlang authored and Your Name committed Oct 16, 2024
1 parent 581baae commit e4e9df6
Show file tree
Hide file tree
Showing 6 changed files with 628 additions and 1 deletion.
5 changes: 4 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,7 @@ go 1.16

//replace gorgonia.org/cu => /home/m2/go/src/example.com/repo.git/cu

require gorgonia.org/cu v0.9.6
require (
github.com/klauspost/cpuid/v2 v2.2.8
gorgonia.org/cu v0.9.6
)
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E
github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q=
github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/cpuid/v2 v2.2.8 h1:+StwCXwm9PdpiEkPyzBXIy+M9KUb4ODm0Zarf1kS5BM=
github.com/klauspost/cpuid/v2 v2.2.8/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
Expand Down Expand Up @@ -121,6 +123,8 @@ golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200909081042-eff7692f9009/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
Expand Down
34 changes: 34 additions & 0 deletions hash/hashvectorized.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package hash

import "github.com/klauspost/cpuid/v2"

func init() {
// Check if the CPU supports AVX512
if cpuid.CPU.Supports(cpuid.AVX512F, cpuid.AVX512DQ) {
HashVectorized = hashAVX512Vectorized
} else {
HashVectorized = hashNotVectorized
}
}

// HashVectorized implement many Neurlang hashes in parallel, using something like AVX-512 or similar
var HashVectorized func(out []uint32, n []uint32, s []uint32, max uint32)

func hashNotVectorized(out []uint32, n []uint32, s []uint32, max uint32) {
for i := range out {
out[i] = Hash(n[i], s[i], max)
}
}
func hashAVX512Vectorized(out []uint32, n []uint32, s []uint32, max uint32) {
hashVectorizedAVX512(&out[0], &n[0], &s[0], max, uint32(len(out)))
// self-checking
//for i := range out {
// var ok = Hash(n[i], s[i], max)
// if out[i] != ok {
// println("result is wrong", i, out[i], ok)
// out[i] = ok
// }
//}
}

func hashVectorizedAVX512(out *uint32, n *uint32, s *uint32, max, length uint32)
130 changes: 130 additions & 0 deletions hash/hashvectorized.s
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#include "textflag.h"

// func hashVectorizedAVX512(out *uint32, n *uint32, s *uint32, max uint32, length uint32)
TEXT ·hashVectorizedAVX512(SB), NOSPLIT, $0-40
MOVQ out+0(FP), DI
MOVQ n+8(FP), SI
MOVQ s+16(FP), DX
MOVL max+24(FP), R8
MOVL len+28(FP), CX

// Preserve length for bounds checking
MOVL CX, R9

// Broadcast max to Z31
VPBROADCASTD R8, Z31

// Check if we have at least 16 elements
CMPQ R9, $16
JL remainder_loop

// Process 16 elements at a time
SHRQ $4, CX
JZ remainder_loop

loop:
// Load 16 elements from n and s
VMOVDQU32 (SI), Z0
VMOVDQU32 (DX), Z1

// m = n - s
VPSUBD Z1, Z0, Z2

// Hashing stage
VPSLLD $2, Z2, Z3
VPXORD Z3, Z2, Z2
VPSLLD $3, Z2, Z3
VPXORD Z3, Z2, Z2
VPSRLD $5, Z2, Z3
VPXORD Z3, Z2, Z2
VPSRLD $7, Z2, Z3
VPXORD Z3, Z2, Z2
VPSLLD $11, Z2, Z3
VPXORD Z3, Z2, Z2
VPSLLD $13, Z2, Z3
VPXORD Z3, Z2, Z2
VPSRLD $17, Z2, Z3
VPXORD Z3, Z2, Z2
VPSLLD $19, Z2, Z3
VPXORD Z3, Z2, Z2

// m += s
VPADDD Z1, Z2, Z2

// Modular reduction: (uint64(m) * uint64(max)) >> 32
// first multiply (even lanes)
VPMULUDQ Z31, Z2, Z3
// prepare odd lanes multiply
VPSRLQ $32, Z3, Z3
VPSRLQ $32, Z2, Z2
// second multiply (odd lanes)
VPMULUDQ Z31, Z2, Z2
// clear wrong lane
VPSRLQ $32, Z2, Z2
VPSLLQ $32, Z2, Z2
// combine odd and even lanes
VPORQ Z2, Z3, Z3

// Store result
VMOVDQU32 Z3, (DI)

ADDQ $64, SI
ADDQ $64, DX
ADDQ $64, DI
SUBQ $16, R9
DECQ CX
JNZ loop

remainder_loop:
CMPQ R9, $0
JE end_loop // Exit if no elements left

MOVL (SI), AX // Load n (scalar)
MOVL (DX), BX // Load s (scalar)
SUBL BX, AX // m = n - s

// Hashing stage: XOR shifts
MOVL AX, R10
SHLL $2, R10
XORL R10, AX
MOVL AX, R10
SHLL $3, R10
XORL R10, AX
MOVL AX, R10
SHRL $5, R10
XORL R10, AX
MOVL AX, R10
SHRL $7, R10
XORL R10, AX
MOVL AX, R10
SHLL $11, R10
XORL R10, AX
MOVL AX, R10
SHLL $13, R10
XORL R10, AX
MOVL AX, R10
SHRL $17, R10
XORL R10, AX
MOVL AX, R10
SHLL $19, R10
XORL R10, AX

// Second mixing stage: Add s
ADDL BX, AX // m += s

// Modular reduction using multiply-shift method
MOVL AX, R11 // Save m in R11
MOVL $0, R10 // Clear upper 32 bits of R10:R11
MOVL R8, AX // Move max to AX
MULL R11 // Multiply m by max, result in EDX:EAX
MOVL DX, (DI) // Store high 32 bits (EDX) to output

ADDQ $4, SI // Move to next n (advance pointer)
ADDQ $4, DX // Move to next s (advance pointer)
ADDQ $4, DI // Move to next out (advance pointer)
DECQ R9 // Decrease remaining element count
JNZ remainder_loop // Continue if remaining elements

end_loop:
VZEROUPPER // Clear upper parts of YMM registers
RET
27 changes: 27 additions & 0 deletions learning/avx/hyperparameters.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package avx

import (
"log"
"os"
)

import "github.com/neurlang/classifier/learning"

// SetLogger sets the output logger file where hashtron golang code programs are written
func (h *HyperParameters) SetLogger(filename string) {
outfile, _ := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
h.l = log.New(outfile, "", 0)
}

type HyperParameters struct {
learning.HyperParameters

AvxLanes uint32 // should be set to 16 for AVX512
AvxSkip uint32 // should be set to 1 to not skip work, >1 to skip some possibly repeated work

l *log.Logger
}

func (h *HyperParameters) H() *learning.HyperParameters {
return &h.HyperParameters
}
Loading

0 comments on commit e4e9df6

Please sign in to comment.