1 module oclcv.convolution;
2 
3 import oclcv.clcore;
4 import core.stdc.stdlib, core.stdc.stdio;
5 import dplug.core.nogc;
6 
7 final class Convolution {
8 public:
9 @nogc nothrow:
10     this(int inputHeight, int inputWidth, int inputDepth, int filterHeight, int filterWidth, CLContext ctx){
11         this.inputHeight = inputHeight;
12         this.inputWidth = inputWidth;
13         this.inputDepth = inputDepth;
14         this.filterHeight = filterHeight;
15         this.filterWidth = filterWidth;
16         
17         if(!initialize(ctx)){
18             printf("Problem initializing the OpenCL kernel %s", __FILE__.ptr);
19             exit(-1);
20         }
21     }
22 
23     ~this(){
24         destroyFree(prog_);
25     }
26 
27     bool initialize(CLContext ctx){
28         import std.conv : to;
29         if(!ctx)
30             return false;
31         context_ = ctx;
32         
33         prog_ = mallocNew!CLProgram(CTKernel.KCONV, context_);
34         _kernel = prog_.getKernel("convolution");
35         
36         return true;
37     }
38 
39     CLBuffer run(CLBuffer d_src, CLBuffer d_filter){
40         import std.algorithm.searching : canFind;
41         debug _assert(d_src.metaData.dataType == FLOAT, "Input type must be ubyte"); 
42         debug _assert([1,2,3].canFind(d_src.metaData.numberOfChannels), "Input's channel count must be 1,2, or 3");
43 
44         CLBuffer d_out = mallocNew!CLBuffer(context_, BufferMeta(FLOAT, inputHeight, inputWidth, inputDepth));
45 
46         _kernel.setArgs(d_src, d_filter, d_out, inputWidth, inputHeight, inputDepth, filterWidth, filterHeight);
47         
48         _conv();
49 
50         return d_out;
51     }
52 
53     void _conv(){
54         import std.algorithm.comparison : max;
55         if(inputDepth == 3){
56             _kernel.launch(0, GridDim((inputWidth + 16 - 1)/16, (inputHeight + 16 - 1)/16, 3), BlockDim(16,16));
57         } else
58             _kernel.launch(0, GridDim((inputWidth + 16 - 1)/16, (inputHeight + 16 - 1)/16), BlockDim(16,16));
59         context_.finish(0);
60     }
61 
62 private:
63     int inputHeight;
64     int inputWidth;
65     int inputDepth;
66     int filterHeight;
67     int filterWidth;
68     int filterDepth;
69     
70     CLContext context_;
71     CLProgram prog_;
72 
73     CLKernel _kernel;
74 }