NForge
Tensor library
Loading...
Searching...
No Matches
tensor_shape.h
1#ifndef TENSOR_SHAPE_H
2#define TENSOR_SHAPE_H
3
4#include <string>
5#include <vector>
6
7#include "nforge/core/tensor.h"
8
9struct TensorLayout;
10
16public:
18 Shape() = default;
19
21 Shape(const std::vector<size_t>& dims);
22
24 Shape(const std::initializer_list<size_t>& dims);
25
27 Shape(const TensorLayout& layout);
28
30 bool operator==(const Tensor::Shape& other) const;
31
33 bool operator!=(const Tensor::Shape& other) const;
34
36 size_t getNumDims() const;
37
39 size_t getNumElements() const;
40
42 size_t getDim(size_t idx) const;
43
45 bool isScalar() const;
46
49 Tensor::Shape operator[](size_t index) const;
50
53 Tensor::Shape operator[](const std::vector<size_t>& position) const;
54
57
59 Tensor::Shape getSlice(size_t start, size_t end) const;
60
62 std::string toString() const;
63
65 std::vector<size_t> toVector() const;
66
68 std::vector<size_t> withoutTrailingOnes() const;
69
72
74 std::vector<size_t> getContiguousStrides() const;
75
76private:
77 std::vector<size_t> m_dimensions;
78};
79
80#endif // TENSOR_SHAPE_H
Definition tensor_shape.h:15
bool isScalar() const
True if the shape is {1}.
Definition tensor_shape.cpp:60
TensorLayout toContiguousLayout() const
Creates a row-major contiguous layout (last dim stride 1).
Definition tensor_shape.cpp:96
std::vector< size_t > getContiguousStrides() const
Returns the row-major strides as a vector.
Definition tensor_shape.cpp:111
Tensor::Shape operator[](size_t index) const
Definition tensor_shape.cpp:27
Tensor::Shape removeLeadingDimension() const
Removes the first dimension. Throws if the shape is empty.
Definition tensor_shape.cpp:62
bool operator!=(const Tensor::Shape &other) const
Negation of operator==.
Definition tensor_shape.cpp:23
size_t getDim(size_t idx) const
Returns the extent of dimension idx.
Definition tensor_shape.cpp:58
std::vector< size_t > toVector() const
Returns the dimension sizes as a vector.
Definition tensor_shape.cpp:85
bool operator==(const Tensor::Shape &other) const
Equality, ignoring trailing ones. So {3, 4, 1} == {3, 4}.
Definition tensor_shape.cpp:19
size_t getNumElements() const
Returns the product of all dimension sizes.
Definition tensor_shape.cpp:47
size_t getNumDims() const
Returns the number of dimensions.
Definition tensor_shape.cpp:25
std::vector< size_t > withoutTrailingOnes() const
Strips trailing 1s, always keeping at least one dimension.
Definition tensor_shape.cpp:87
Tensor::Shape getSlice(size_t start, size_t end) const
Returns shape with dimensions in [start, end).
Definition tensor_shape.cpp:69
Shape()=default
Default constructor. Creates an empty (0-dim) shape.
std::string toString() const
Returns a string like "{ 3 4 5 }".
Definition tensor_shape.cpp:76
Definition tensor_layout.h:15