NForge
Tensor library
Loading...
Searching...
No Matches
cuda_context.h
1#ifndef CUDA_CONTEXT_H
2#define CUDA_CONTEXT_H
3
4#include <curand.h>
5
6#include "cuda_error.h"
7
9public:
10 static CudaContext& get() {
11 static CudaContext instance;
12 return instance;
13 }
14
15 cudaStream_t stream() const { return m_stream; }
16
17 curandGenerator_t& rng() { return m_gen; }
18
19 CudaContext(const CudaContext&) = delete;
20 CudaContext& operator=(const CudaContext&) = delete;
21 CudaContext(CudaContext&&) = delete;
22 CudaContext& operator=(CudaContext&&) = delete;
23
24private:
25 cudaStream_t m_stream;
26 curandGenerator_t m_gen;
27
28 CudaContext() {
29 CUDA_CHECK(cudaSetDevice(0));
30 CUDA_CHECK(cudaFree(0));
31
32 CUDA_CHECK(cudaStreamCreate(&m_stream));
33
34 CURAND_CHECK(curandCreateGenerator(&m_gen, CURAND_RNG_PSEUDO_DEFAULT));
35 CURAND_CHECK(curandSetStream(m_gen, m_stream));
36
37 CUDA_CHECK(cudaGetLastError());
38 }
39
40 ~CudaContext() {
41 curandDestroyGenerator(m_gen);
42
43 cudaStreamDestroy(m_stream);
44 }
45};
46
47#endif // CUDA_CONTEXT_H
Definition cuda_context.h:8