add walsh transform verify tests

This commit is contained in:
Stephen Chang 2016-12-07 16:16:58 -05:00
parent 21c77d7e61
commit f2c6b581fe
4 changed files with 177 additions and 4 deletions

View File

@ -25,6 +25,7 @@
"synthcl3-matrix-synth-tests.rkt" "SynthCL Matrix Mult: synth"
"synthcl3-matrix-verify-tests.rkt" "SynthCL Matrix Mult: verify"
"synthcl3-matrix-verify-buggy-tests.rkt" "SynthCL buggy Matrix Mult: verify"
"synthcl3-walsh-synth-tests.rkt" "SynthCL Walsh Transform: synth")
"synthcl3-walsh-synth-tests.rkt" "SynthCL Walsh Transform: synth"
"synthcl3-walsh-verify-tests.rkt" "SynthCL Walsh Transform: verify")
(do-tests "bv-ref-tests.rkt" "BV SDSL - Hacker's Delight synthesis")

View File

@ -7,6 +7,5 @@
"synthcl3-matrix-synth-tests.rkt" "SynthCL Matrix Mult: synth"
"synthcl3-matrix-verify-tests.rkt" "SynthCL Matrix Mult: verify"
"synthcl3-matrix-verify-buggy-tests.rkt" "SynthCL buggy Matrix Mult: verify"
"synthcl3-walsh-synth-tests.rkt" "SynthCL Walsh Transform: synth")
"synthcl3-walsh-synth-tests.rkt" "SynthCL Walsh Transform: synth"
"synthcl3-walsh-verify-tests.rkt" "SynthCL Walsh Transform: verify")

View File

@ -0,0 +1,147 @@
#lang s-exp "../../../rosette/synthcl3.rkt"
(require "../../rackunit-typechecking.rkt")
; Compute the number of steps for the algorithm,
; assuming that v is a power of 2. See the log2
; algorithm from http://graphics.stanford.edu/~seander/bithacks.html#IntegerLog
(procedure int (steps [int v])
(: int r)
(= r 0)
($= r (<< (!= 0 (& v #xAAAAAAAA)) 0))
($= r (<< (!= 0 (& v #xCCCCCCCC)) 1))
($= r (<< (!= 0 (& v #xF0F0F0F0)) 2))
($= r (<< (!= 0 (& v #xFF00FF00)) 3))
($= r (<< (!= 0 (& v #xFFFF0000)) 4))
r)
; Reference implementation for Fast Walsh Transform. This implementation
; requires the length of the input array to be a power of 2, and it modifies
; the input array in place.
(procedure float* (fwt [float* tArray] [int length])
(for [(: int i in (range 0 (steps length)))]
(: int step)
(= step (<< 1 i))
(for [(: int group in (range 0 step))
(: int pair in (range group length (<< step 1)))]
(: int match)
(: float t1 t2)
(= match (+ pair step))
(= t1 [tArray pair])
(= t2 [tArray match])
(= [tArray pair] (+ t1 t2))
(= [tArray match] (- t1 t2))))
tArray)
; Scalar host for Fast Walsh Transform. This implementation
; requires the length of the input array to be a power of 2. The
; input array is not modified; the output is a new array that holds
; the result of the transform.
(procedure float* (fwtScalarHost [float* input] [int length])
(: cl_context context)
(: cl_command_queue command_queue)
(: cl_program program)
(: cl_kernel kernel)
(: cl_mem tBuffer)
(: float* tArray)
(: int dim global)
(= dim (* length (sizeof float)))
(= global (/ length 2))
(= tArray ((float*) (malloc dim)))
(= context (clCreateContext))
(= command_queue (clCreateCommandQueue context))
(= tBuffer (clCreateBuffer context CL_MEM_READ_WRITE dim))
(= program (clCreateProgramWithSource context "walsh-verify-kernel.rkt"))
(clEnqueueWriteBuffer command_queue tBuffer 0 dim input)
(= kernel (clCreateKernel program "fwtKernel"))
(clSetKernelArg kernel 0 tBuffer)
(for [(: int i in (range 0 (steps length)))]
(: int step)
(= step (<< 1 i))
(clSetKernelArg kernel 1 step)
(clEnqueueNDRangeKernel command_queue kernel 1 NULL (@ global) NULL))
(clEnqueueReadBuffer command_queue tBuffer 0 dim tArray)
tArray)
; Vectorized host for Fast Walsh Transform. This implementation
; requires the length of the input array to be a power of 2. The
; input array is not modified; the output is a new array that holds
; the result of the transform.
(procedure float* (fwtVectorHost [float* input] [int length])
(: cl_context context)
(: cl_command_queue command_queue)
(: cl_program program)
(: cl_mem tBuffer)
(: float* tArray)
(: int dim global n)
(= dim (* length (sizeof float)))
(= global (/ length 2))
(= tArray ((float*) (malloc dim)))
(= context (clCreateContext))
(= command_queue (clCreateCommandQueue context))
(= tBuffer (clCreateBuffer context CL_MEM_READ_WRITE dim))
(= program (clCreateProgramWithSource context "walsh-verify-kernel.rkt"))
(clEnqueueWriteBuffer command_queue tBuffer 0 dim input)
(= n (steps length))
(runKernel command_queue (clCreateKernel program "fwtKernel") tBuffer global 0 (?: (< n 2) n 2))
(if (> n 2)
{ (/= global 4)
(runKernel command_queue (clCreateKernel program "fwtKernel4") tBuffer global 2 n) })
(clEnqueueReadBuffer command_queue tBuffer 0 dim tArray)
tArray)
(procedure void (runKernel [cl_command_queue command_queue] [cl_kernel kernel] [cl_mem tBuffer]
[int global] [int start] [int end])
(clSetKernelArg kernel 0 tBuffer)
(for [(: int i in (range start end))]
(: int step)
(= step (<< 1 i))
(clSetKernelArg kernel 1 step)
(clEnqueueNDRangeKernel command_queue kernel 1 NULL (@ global) NULL)))
; Given two arrays of the same size, checks that they hold the same
; values at each index.
(procedure void (check [int* actual] [int* expected] [int SIZE])
(assert (>= SIZE 0))
(for [(: int i in (range SIZE))]
(assert (== [actual i] [expected i]))))
(procedure void (verify_scalar)
(verify #:forall [(: int logLength in (range 0 7))
(: int length in (range (<< 1 logLength) (+ 1 (<< 1 logLength))))
(: float[length] tArray)]
#:ensure (check (fwtScalarHost tArray length)
(fwt tArray length)
length)))
(procedure void (verify_vector)
(verify #:forall [(: int logLength in (range 0 7))
(: int length in (range (<< 1 logLength) (+ 1 (<< 1 logLength))))
(: float[length] tArray)]
#:ensure (check (fwtVectorHost tArray length)
(fwt tArray length)
length)))
(check-type
(with-output-to-string (λ () (verify_scalar)))
: CString -> "no counterexample found\n")
(check-type
(with-output-to-string (λ () (verify_vector)))
: CString -> "no counterexample found\n")

View File

@ -0,0 +1,26 @@
#lang s-exp "../../../rosette/synthcl3.rkt"
(kernel void (fwtKernel [float* tArray] [int step])
(: int tid group pair match)
(: float t1 t2)
(= tid (get_global_id 0))
(= group (% tid step))
(= pair (+ (* (<< step 1) (/ tid step)) group))
(= match (+ pair step))
(= t1 [tArray pair])
(= t2 [tArray match])
(= [tArray pair] (+ t1 t2))
(= [tArray match] (- t1 t2)))
(kernel void (fwtKernel4 [float4* tArray] [int step])
(: int tid group pair match)
(: float4 t1 t2)
(= tid (get_global_id 0))
(= step (/ step 4))
(= group (% tid step))
(= pair (+ (* (<< step 1) (/ tid step)) group))
(= match (+ pair step))
(= t1 [tArray pair])
(= t2 [tArray match])
(= [tArray pair] (+ t1 t2))
(= [tArray match] (- t1 t2)))