From f2c6b581fe289163c3497790dc195a8d0129350c Mon Sep 17 00:00:00 2001 From: Stephen Chang Date: Wed, 7 Dec 2016 16:16:58 -0500 Subject: [PATCH] add walsh transform verify tests --- .../rosette3/run-all-rosette-tests-script.rkt | 3 +- .../rosette3/run-all-synthcl-tests.rkt | 5 +- .../rosette3/synthcl3-walsh-verify-tests.rkt | 147 ++++++++++++++++++ .../rosette/rosette3/walsh-verify-kernel.rkt | 26 ++++ 4 files changed, 177 insertions(+), 4 deletions(-) create mode 100644 turnstile/examples/tests/rosette/rosette3/synthcl3-walsh-verify-tests.rkt create mode 100644 turnstile/examples/tests/rosette/rosette3/walsh-verify-kernel.rkt diff --git a/turnstile/examples/tests/rosette/rosette3/run-all-rosette-tests-script.rkt b/turnstile/examples/tests/rosette/rosette3/run-all-rosette-tests-script.rkt index 324c3a9..65c21c4 100644 --- a/turnstile/examples/tests/rosette/rosette3/run-all-rosette-tests-script.rkt +++ b/turnstile/examples/tests/rosette/rosette3/run-all-rosette-tests-script.rkt @@ -25,6 +25,7 @@ "synthcl3-matrix-synth-tests.rkt" "SynthCL Matrix Mult: synth" "synthcl3-matrix-verify-tests.rkt" "SynthCL Matrix Mult: verify" "synthcl3-matrix-verify-buggy-tests.rkt" "SynthCL buggy Matrix Mult: verify" - "synthcl3-walsh-synth-tests.rkt" "SynthCL Walsh Transform: synth") + "synthcl3-walsh-synth-tests.rkt" "SynthCL Walsh Transform: synth" + "synthcl3-walsh-verify-tests.rkt" "SynthCL Walsh Transform: verify") (do-tests "bv-ref-tests.rkt" "BV SDSL - Hacker's Delight synthesis") diff --git a/turnstile/examples/tests/rosette/rosette3/run-all-synthcl-tests.rkt b/turnstile/examples/tests/rosette/rosette3/run-all-synthcl-tests.rkt index 54713d7..9c52fd9 100644 --- a/turnstile/examples/tests/rosette/rosette3/run-all-synthcl-tests.rkt +++ b/turnstile/examples/tests/rosette/rosette3/run-all-synthcl-tests.rkt @@ -7,6 +7,5 @@ "synthcl3-matrix-synth-tests.rkt" "SynthCL Matrix Mult: synth" "synthcl3-matrix-verify-tests.rkt" "SynthCL Matrix Mult: verify" "synthcl3-matrix-verify-buggy-tests.rkt" "SynthCL buggy Matrix Mult: verify" - "synthcl3-walsh-synth-tests.rkt" "SynthCL Walsh Transform: synth") - - + "synthcl3-walsh-synth-tests.rkt" "SynthCL Walsh Transform: synth" + "synthcl3-walsh-verify-tests.rkt" "SynthCL Walsh Transform: verify") diff --git a/turnstile/examples/tests/rosette/rosette3/synthcl3-walsh-verify-tests.rkt b/turnstile/examples/tests/rosette/rosette3/synthcl3-walsh-verify-tests.rkt new file mode 100644 index 0000000..2f73413 --- /dev/null +++ b/turnstile/examples/tests/rosette/rosette3/synthcl3-walsh-verify-tests.rkt @@ -0,0 +1,147 @@ +#lang s-exp "../../../rosette/synthcl3.rkt" +(require "../../rackunit-typechecking.rkt") +; Compute the number of steps for the algorithm, +; assuming that v is a power of 2. See the log2 +; algorithm from http://graphics.stanford.edu/~seander/bithacks.html#IntegerLog +(procedure int (steps [int v]) + (: int r) + (= r 0) + ($= r (<< (!= 0 (& v #xAAAAAAAA)) 0)) + ($= r (<< (!= 0 (& v #xCCCCCCCC)) 1)) + ($= r (<< (!= 0 (& v #xF0F0F0F0)) 2)) + ($= r (<< (!= 0 (& v #xFF00FF00)) 3)) + ($= r (<< (!= 0 (& v #xFFFF0000)) 4)) + r) + +; Reference implementation for Fast Walsh Transform. This implementation +; requires the length of the input array to be a power of 2, and it modifies +; the input array in place. +(procedure float* (fwt [float* tArray] [int length]) + (for [(: int i in (range 0 (steps length)))] + (: int step) + (= step (<< 1 i)) + (for [(: int group in (range 0 step)) + (: int pair in (range group length (<< step 1)))] + (: int match) + (: float t1 t2) + (= match (+ pair step)) + (= t1 [tArray pair]) + (= t2 [tArray match]) + (= [tArray pair] (+ t1 t2)) + (= [tArray match] (- t1 t2)))) + tArray) + +; Scalar host for Fast Walsh Transform. This implementation +; requires the length of the input array to be a power of 2. The +; input array is not modified; the output is a new array that holds +; the result of the transform. +(procedure float* (fwtScalarHost [float* input] [int length]) + (: cl_context context) + (: cl_command_queue command_queue) + (: cl_program program) + (: cl_kernel kernel) + (: cl_mem tBuffer) + (: float* tArray) + (: int dim global) + + (= dim (* length (sizeof float))) + (= global (/ length 2)) + + (= tArray ((float*) (malloc dim))) + + (= context (clCreateContext)) + + (= command_queue (clCreateCommandQueue context)) + + (= tBuffer (clCreateBuffer context CL_MEM_READ_WRITE dim)) + (= program (clCreateProgramWithSource context "walsh-verify-kernel.rkt")) + + (clEnqueueWriteBuffer command_queue tBuffer 0 dim input) + + (= kernel (clCreateKernel program "fwtKernel")) + (clSetKernelArg kernel 0 tBuffer) + + (for [(: int i in (range 0 (steps length)))] + (: int step) + (= step (<< 1 i)) + (clSetKernelArg kernel 1 step) + (clEnqueueNDRangeKernel command_queue kernel 1 NULL (@ global) NULL)) + + (clEnqueueReadBuffer command_queue tBuffer 0 dim tArray) + tArray) + +; Vectorized host for Fast Walsh Transform. This implementation +; requires the length of the input array to be a power of 2. The +; input array is not modified; the output is a new array that holds +; the result of the transform. +(procedure float* (fwtVectorHost [float* input] [int length]) + (: cl_context context) + (: cl_command_queue command_queue) + (: cl_program program) + (: cl_mem tBuffer) + (: float* tArray) + (: int dim global n) + + (= dim (* length (sizeof float))) + (= global (/ length 2)) + + (= tArray ((float*) (malloc dim))) + + (= context (clCreateContext)) + + (= command_queue (clCreateCommandQueue context)) + + (= tBuffer (clCreateBuffer context CL_MEM_READ_WRITE dim)) + (= program (clCreateProgramWithSource context "walsh-verify-kernel.rkt")) + + (clEnqueueWriteBuffer command_queue tBuffer 0 dim input) + + (= n (steps length)) + + (runKernel command_queue (clCreateKernel program "fwtKernel") tBuffer global 0 (?: (< n 2) n 2)) + (if (> n 2) + { (/= global 4) + (runKernel command_queue (clCreateKernel program "fwtKernel4") tBuffer global 2 n) }) + + (clEnqueueReadBuffer command_queue tBuffer 0 dim tArray) + tArray) + +(procedure void (runKernel [cl_command_queue command_queue] [cl_kernel kernel] [cl_mem tBuffer] + [int global] [int start] [int end]) + (clSetKernelArg kernel 0 tBuffer) + (for [(: int i in (range start end))] + (: int step) + (= step (<< 1 i)) + (clSetKernelArg kernel 1 step) + (clEnqueueNDRangeKernel command_queue kernel 1 NULL (@ global) NULL))) + +; Given two arrays of the same size, checks that they hold the same +; values at each index. +(procedure void (check [int* actual] [int* expected] [int SIZE]) + (assert (>= SIZE 0)) + (for [(: int i in (range SIZE))] + (assert (== [actual i] [expected i])))) + +(procedure void (verify_scalar) + (verify #:forall [(: int logLength in (range 0 7)) + (: int length in (range (<< 1 logLength) (+ 1 (<< 1 logLength)))) + (: float[length] tArray)] + #:ensure (check (fwtScalarHost tArray length) + (fwt tArray length) + length))) + +(procedure void (verify_vector) + (verify #:forall [(: int logLength in (range 0 7)) + (: int length in (range (<< 1 logLength) (+ 1 (<< 1 logLength)))) + (: float[length] tArray)] + #:ensure (check (fwtVectorHost tArray length) + (fwt tArray length) + length))) + + +(check-type + (with-output-to-string (λ () (verify_scalar))) + : CString -> "no counterexample found\n") +(check-type + (with-output-to-string (λ () (verify_vector))) + : CString -> "no counterexample found\n") diff --git a/turnstile/examples/tests/rosette/rosette3/walsh-verify-kernel.rkt b/turnstile/examples/tests/rosette/rosette3/walsh-verify-kernel.rkt new file mode 100644 index 0000000..b66b607 --- /dev/null +++ b/turnstile/examples/tests/rosette/rosette3/walsh-verify-kernel.rkt @@ -0,0 +1,26 @@ +#lang s-exp "../../../rosette/synthcl3.rkt" + +(kernel void (fwtKernel [float* tArray] [int step]) + (: int tid group pair match) + (: float t1 t2) + (= tid (get_global_id 0)) + (= group (% tid step)) + (= pair (+ (* (<< step 1) (/ tid step)) group)) + (= match (+ pair step)) + (= t1 [tArray pair]) + (= t2 [tArray match]) + (= [tArray pair] (+ t1 t2)) + (= [tArray match] (- t1 t2))) + +(kernel void (fwtKernel4 [float4* tArray] [int step]) + (: int tid group pair match) + (: float4 t1 t2) + (= tid (get_global_id 0)) + (= step (/ step 4)) + (= group (% tid step)) + (= pair (+ (* (<< step 1) (/ tid step)) group)) + (= match (+ pair step)) + (= t1 [tArray pair]) + (= t2 [tArray match]) + (= [tArray pair] (+ t1 t2)) + (= [tArray match] (- t1 t2)))