NForge
Tensor library
Loading...
Searching...
No Matches
tensor.h
1#ifndef TENSOR_H
2#define TENSOR_H
3
4#include <cassert>
5#include <iostream>
6#include <memory>
7#include <string>
8#include <vector>
9
11enum class Backend { CPU, CUDA };
12
18class Tensor {
19public:
20 class Impl;
21 class CPUImpl;
22 class CUDAImpl;
23
24 class View;
25 class Shape;
26
27public:
29 Tensor(const Tensor::Shape& shape, Backend backend = Backend::CPU);
30
32 Tensor(const std::initializer_list<size_t>& shape, Backend backend = Backend::CPU);
33
35 Tensor(const Tensor::Shape& shape, float value, Backend backend = Backend::CPU);
36
38 Tensor(const std::initializer_list<size_t>& shape, float value, Backend backend = Backend::CPU);
39
41 Tensor(float value, Backend backend = Backend::CPU);
42
44 Tensor(const Tensor& tensor);
45
48 Tensor(std::unique_ptr<Tensor::Impl> impl, Backend backend = Backend::CPU);
49
51 ~Tensor();
52
54 void to(Backend newBackend);
55
57 void fillAll(float value);
58
60 void fillRand();
61
63 void print() const;
64
66 void print(const std::vector<size_t>& position) const;
67
69 Tensor::Shape getShape() const;
70
72 std::string getBackendString() const;
73
75 inline Backend getBackend() const { return m_backend; }
76
78 std::string getDataString() const;
79
81 size_t getNumElements() const;
82
84 std::vector<float> toVector() const;
85
87 void set(const std::vector<size_t>& position, const Tensor::View& rhs);
88
90 bool compare(const Tensor::View& rhs) const;
91
93 bool compare(const std::vector<size_t>& position, const Tensor::View& rhs) const;
94
96 Tensor operator+(const Tensor::View& rhs) const;
97
99 Tensor operator-(const Tensor::View& rhs) const;
100
102 Tensor operator*(const Tensor::View& rhs) const;
103
105 Tensor operator/(const Tensor::View& rhs) const;
106
108 Tensor operator+(float scalar) const;
109
111 Tensor operator-(float scalar) const;
112
114 Tensor operator*(float scalar) const;
115
117 Tensor operator/(float scalar) const;
118
120 friend Tensor operator+(float scalar, const Tensor& rhs);
121
123 friend Tensor operator-(float scalar, const Tensor& rhs);
124
126 friend Tensor operator*(float scalar, const Tensor& rhs);
127
129 friend Tensor operator/(float scalar, const Tensor& rhs);
130
132 void operator+=(const Tensor::View& rhs);
133
135 void operator-=(const Tensor::View& rhs);
136
138 void operator*=(const Tensor::View& rhs);
139
141 void operator/=(const Tensor::View& rhs);
142
144 Tensor mean(size_t dim = 0) const;
145
147 Tensor sum(size_t dim = 0) const;
148
150 Tensor min(size_t dim = 0) const;
151
153 Tensor max(size_t dim = 0) const;
154
156 Tensor prod(size_t dim = 0) const;
157
159 Tensor norm() const;
160
163 Tensor all(size_t dim = 0) const;
164
167 Tensor any(size_t dim = 0) const;
168
175 Tensor matmul(const Tensor::View& rhs) const;
176
178 Tensor::View operator[](size_t idx) const;
179
182 Tensor::View subsample(std::vector<size_t> strides) const;
183
185 Tensor& operator=(const Tensor& rhs);
186
188 Tensor& operator=(const Tensor::View& rhs);
189
192 Tensor& operator=(float scalar);
193
196 bool isEqual(const Tensor::View& rhs) const;
197
200 bool isNotEqual(const Tensor::View& rhs) const;
201
203 Tensor operator==(const Tensor::View& rhs) const;
204
206 Tensor operator!=(const Tensor::View& rhs) const;
207
209 Tensor operator<(const Tensor::View& rhs) const;
210
212 Tensor operator<=(const Tensor::View& rhs) const;
213
215 Tensor operator>(const Tensor::View& rhs) const;
216
218 Tensor operator>=(const Tensor::View& rhs) const;
219
222 Tensor isClose(const Tensor::View& rhs, float tolerance = 1e-5f) const;
223
224private:
225 Backend m_backend;
226 std::unique_ptr<Impl> m_impl;
227
230 template <typename BinaryOp>
231 Tensor applyBinaryOp(const Tensor::View& rhs, BinaryOp op) const;
232
235 template <typename BinaryOp>
236 void applyInplaceBinaryOp(const Tensor::View& rhs, BinaryOp op);
237
240 template <typename ReductionOp>
241 Tensor applyReduction(size_t dim, ReductionOp op) const;
242};
243
244#endif // TENSOR_H
Definition tensor_impl_CPU.h:14
Definition tensor_impl_CUDA.h:13
Definition tensor_impl.h:15
Definition tensor_shape.h:15
Definition tensor_view.h:12
Definition tensor.h:18
Tensor & operator=(const Tensor &rhs)
Copies data from another tensor.
Definition tensor.cpp:260
Tensor operator!=(const Tensor::View &rhs) const
Elementwise not equal. Returns a tensor of 0.0 / 1.0.
Definition tensor.cpp:291
std::string getDataString() const
Returns a string representation of the underlying data.
Definition tensor.cpp:100
Tensor operator==(const Tensor::View &rhs) const
Elementwise equal. Returns a tensor of 0.0 / 1.0.
Definition tensor.cpp:287
Tensor operator>(const Tensor::View &rhs) const
Elementwise greater than. Returns a tensor of 0.0 / 1.0.
Definition tensor.cpp:303
friend Tensor operator-(float scalar, const Tensor &rhs)
Elementwise subtraction of a tensor from a pure float.
Definition tensor.cpp:176
Tensor::View subsample(std::vector< size_t > strides) const
Definition tensor.cpp:255
void set(const std::vector< size_t > &position, const Tensor::View &rhs)
Replaces the block starting at position with the data from rhs.
Definition tensor.cpp:106
Tensor operator<(const Tensor::View &rhs) const
Elementwise less than. Returns a tensor of 0.0 / 1.0.
Definition tensor.cpp:295
Tensor max(size_t dim=0) const
Reduces dimensions [dim, rank) by taking the maximum. Result shape is shape[0:dim].
Definition tensor.cpp:219
void operator+=(const Tensor::View &rhs)
In-place elementwise addition with a tensor or view.
Definition tensor.cpp:191
friend Tensor operator+(float scalar, const Tensor &rhs)
Elementwise addition of a pure float and a tensor.
Definition tensor.cpp:174
Tensor any(size_t dim=0) const
Definition tensor.cpp:233
~Tensor()
Destructor.
Definition tensor.cpp:51
void fillRand()
Fills all elements with random values in [-1, 1].
Definition tensor.cpp:81
Tensor mean(size_t dim=0) const
Reduces dimensions [dim, rank) by averaging. Result shape is shape[0:dim].
Definition tensor.cpp:208
bool compare(const Tensor::View &rhs) const
Returns true if shape and every element matches rhs.
Definition tensor.cpp:120
Tensor matmul(const Tensor::View &rhs) const
Definition tensor.cpp:235
Tensor::View operator[](size_t idx) const
Indexes into the first dimension, returning a view of the sub-tensor.
Definition tensor.cpp:245
Tensor::Shape getShape() const
Returns the tensor shape.
Definition tensor.cpp:87
size_t getNumElements() const
Returns the total number of elements.
Definition tensor.cpp:102
void operator-=(const Tensor::View &rhs)
In-place elementwise subtraction with a tensor or view.
Definition tensor.cpp:193
void print() const
Prints the tensor to stdout.
Definition tensor.cpp:83
void operator/=(const Tensor::View &rhs)
In-place elementwise division by a tensor or view.
Definition tensor.cpp:197
Tensor min(size_t dim=0) const
Reduces dimensions [dim, rank) by taking the minimum. Result shape is shape[0:dim].
Definition tensor.cpp:217
void fillAll(float value)
Fills all elements with value.
Definition tensor.cpp:79
friend Tensor operator/(float scalar, const Tensor &rhs)
Elementwise division of a pure float by a tensor.
Definition tensor.cpp:180
void to(Backend newBackend)
Transfers data to a different backend. No-op if already on that backend.
Definition tensor.cpp:53
Tensor operator>=(const Tensor::View &rhs) const
Elementwise greater or equal. Returns a tensor of 0.0 / 1.0.
Definition tensor.cpp:307
Tensor sum(size_t dim=0) const
Reduces dimensions [dim, rank) by summation. Result shape is shape[0:dim].
Definition tensor.cpp:215
Backend getBackend() const
Returns the backend enum.
Definition tensor.h:75
bool isEqual(const Tensor::View &rhs) const
Definition tensor.cpp:282
Tensor isClose(const Tensor::View &rhs, float tolerance=1e-5f) const
Definition tensor.cpp:311
friend Tensor operator*(float scalar, const Tensor &rhs)
Elementwise multiplication of a pure float and a tensor.
Definition tensor.cpp:178
Tensor operator<=(const Tensor::View &rhs) const
Elementwise less or equal. Returns a tensor of 0.0 / 1.0.
Definition tensor.cpp:299
void operator*=(const Tensor::View &rhs)
In-place elementwise multiplication with a tensor or view.
Definition tensor.cpp:195
Tensor prod(size_t dim=0) const
Reduces dimensions [dim, rank) by taking the product. Result shape is shape[0:dim].
Definition tensor.cpp:221
Tensor norm() const
L2 norm (scalar tensor equal to sqrt(sum(x^2))).
Definition tensor.cpp:223
std::vector< float > toVector() const
Copies all elements into a flat vector (row-major order).
Definition tensor.cpp:104
std::string getBackendString() const
Returns "CPU" or "CUDA".
Definition tensor.cpp:89
Tensor all(size_t dim=0) const
Definition tensor.cpp:231
bool isNotEqual(const Tensor::View &rhs) const
Definition tensor.cpp:284