#pragma once #include "StarArray.hpp" #include "StarList.hpp" namespace Star { STAR_EXCEPTION(MultiArrayException, StarException); // Multidimensional array class that wraps a vector as a simple contiguous N // dimensional array. Values are stored so that the highest dimension is the // dimension with stride 0, and the lowest dimension has the largest stride. // // Due to usage of std::vector, ElementT = bool means that the user must use // set() and get() rather than operator() template class MultiArray { public: typedef List Storage; typedef ElementT Element; static size_t const Rank = RankN; typedef Array IndexArray; typedef Array SizeArray; typedef typename Storage::iterator iterator; typedef typename Storage::const_iterator const_iterator; typedef Element value_type; MultiArray(); template explicit MultiArray(size_t i, T... rest); explicit MultiArray(SizeArray const& shape); explicit MultiArray(SizeArray const& shape, Element const& c); SizeArray const& size() const; size_t size(size_t dimension) const; void clear(); void resize(SizeArray const& shape); void resize(SizeArray const& shape, Element const& c); template void resize(size_t i, T... rest); void fill(Element const& element); // Does not preserve previous element position, array contents will be // invalid. void setSize(SizeArray const& shape); void setSize(SizeArray const& shape, Element const& c); template void setSize(size_t i, T... rest); Element& operator()(IndexArray const& index); Element const& operator()(IndexArray const& index) const; template Element& operator()(size_t i1, T... rest); template Element const& operator()(size_t i1, T... rest) const; // Throws exception if out of bounds Element& at(IndexArray const& index); Element const& at(IndexArray const& index) const; template Element& at(size_t i1, T... rest); template Element const& at(size_t i1, T... rest) const; // Throws an exception of out of bounds void set(IndexArray const& index, Element element); // Returns default element if out of bounds. Element get(IndexArray const& index, Element def = Element()); // Auto-resizes array if out of bounds void setResize(IndexArray const& index, Element element); // Copy the given array element for element into this array. The shape of // this array must be at least as large in every dimension as the source // array void copy(MultiArray const& source); void copy(MultiArray const& source, IndexArray const& sourceMin, IndexArray const& sourceMax, IndexArray const& targetMin); // op will be called with IndexArray and Element parameters. template void forEach(IndexArray const& min, SizeArray const& size, OpType&& op); template void forEach(IndexArray const& min, SizeArray const& size, OpType&& op) const; // Shortcut for calling forEach on the entire array template void forEach(OpType&& op); template void forEach(OpType&& op) const; template void print(OStream& os) const; // Api for more direct access to elements. size_t count() const; Element const& atIndex(size_t index) const; Element& atIndex(size_t index); Element const* data() const; Element* data(); private: size_t storageIndex(IndexArray const& index) const; template void subPrint(OStream& os, IndexArray index, size_t dim) const; template void subForEach(IndexArray const& min, SizeArray const& size, OpType&& op, IndexArray& index, size_t offset, size_t dim) const; template void subForEach(IndexArray const& min, SizeArray const& size, OpType&& op, IndexArray& index, size_t offset, size_t dim); void subCopy(MultiArray const& source, IndexArray const& sourceMin, IndexArray const& sourceMax, IndexArray const& targetMin, IndexArray& sourceIndex, IndexArray& targetIndex, size_t dim); Storage m_data; SizeArray m_shape; }; typedef MultiArray MultiArray2I; typedef MultiArray MultiArray2S; typedef MultiArray MultiArray2U; typedef MultiArray MultiArray2F; typedef MultiArray MultiArray2D; typedef MultiArray MultiArray3I; typedef MultiArray MultiArray3S; typedef MultiArray MultiArray3U; typedef MultiArray MultiArray3F; typedef MultiArray MultiArray3D; typedef MultiArray MultiArray4I; typedef MultiArray MultiArray4S; typedef MultiArray MultiArray4U; typedef MultiArray MultiArray4F; typedef MultiArray MultiArray4D; template std::ostream& operator<<(std::ostream& os, MultiArray const& array); template MultiArray::MultiArray() { m_shape = SizeArray::filled(0); } template MultiArray::MultiArray(SizeArray const& shape) { setSize(shape); } template MultiArray::MultiArray(SizeArray const& shape, Element const& c) { setSize(shape, c); } template template MultiArray::MultiArray(size_t i, T... rest) { setSize(SizeArray{i, rest...}); } template typename MultiArray::SizeArray const& MultiArray::size() const { return m_shape; } template size_t MultiArray::size(size_t dimension) const { return m_shape[dimension]; } template void MultiArray::clear() { setSize(SizeArray::filled(0)); } template void MultiArray::resize(SizeArray const& shape) { if (m_data.empty()) { setSize(shape); return; } bool equal = true; for (size_t i = 0; i < Rank; ++i) equal = equal && (m_shape[i] == shape[i]); if (equal) return; MultiArray newArray(shape); newArray.copy(*this); std::swap(*this, newArray); } template void MultiArray::resize(SizeArray const& shape, Element const& c) { if (m_data.empty()) { setSize(shape, c); return; } bool equal = true; for (size_t i = 0; i < Rank; ++i) equal = equal && (m_shape[i] == shape[i]); if (equal) return; MultiArray newArray(shape, c); newArray.copy(*this); *this = std::move(newArray); } template template void MultiArray::resize(size_t i, T... rest) { resize(SizeArray{i, rest...}); } template void MultiArray::fill(Element const& element) { std::fill(m_data.begin(), m_data.end(), element); } template void MultiArray::setSize(SizeArray const& shape) { size_t storageSize = 1; for (size_t i = 0; i < Rank; ++i) { m_shape[i] = shape[i]; storageSize *= shape[i]; } m_data.resize(storageSize); } template void MultiArray::setSize(SizeArray const& shape, Element const& c) { size_t storageSize = 1; for (size_t i = 0; i < Rank; ++i) { m_shape[i] = shape[i]; storageSize *= shape[i]; } m_data.resize(storageSize, c); } template template void MultiArray::setSize(size_t i, T... rest) { setSize({i, rest...}); } template Element& MultiArray::operator()(IndexArray const& index) { return m_data[storageIndex(index)]; } template Element const& MultiArray::operator()(IndexArray const& index) const { return m_data[storageIndex(index)]; } template template Element& MultiArray::operator()(size_t i1, T... rest) { return m_data[storageIndex(IndexArray(i1, rest...))]; } template template Element const& MultiArray::operator()(size_t i1, T... rest) const { return m_data[storageIndex(IndexArray(i1, rest...))]; } template Element const& MultiArray::at(IndexArray const& index) const { for (size_t i = Rank; i != 0; --i) { if (index[i - 1] >= m_shape[i - 1]) throw MultiArrayException(strf("Out of bounds on MultiArray::at({})", index)); } return m_data[storageIndex(index)]; } template Element& MultiArray::at(IndexArray const& index) { for (size_t i = Rank; i != 0; --i) { if (index[i - 1] >= m_shape[i - 1]) throw MultiArrayException(strf("Out of bounds on MultiArray::at({})", index)); } return m_data[storageIndex(index)]; } template template Element& MultiArray::at(size_t i1, T... rest) { return at(IndexArray(i1, rest...)); } template template Element const& MultiArray::at(size_t i1, T... rest) const { return at(IndexArray(i1, rest...)); } template void MultiArray::set(IndexArray const& index, Element element) { for (size_t i = Rank; i != 0; --i) { if (index[i - 1] >= m_shape[i - 1]) throw MultiArrayException(strf("Out of bounds on MultiArray::set({})", index)); } m_data[storageIndex(index)] = std::move(element); } template Element MultiArray::get(IndexArray const& index, Element def) { for (size_t i = Rank; i != 0; --i) { if (index[i - 1] >= m_shape[i - 1]) return std::move(def); } return m_data[storageIndex(index)]; } template void MultiArray::setResize(IndexArray const& index, Element element) { SizeArray newShape; for (size_t i = 0; i < Rank; ++i) newShape[i] = std::max(m_shape[i], index[i] + 1); resize(newShape); m_data[storageIndex(index)] = std::move(element); } template void MultiArray::copy(MultiArray const& source) { IndexArray max; for (size_t i = 0; i < Rank; ++i) max[i] = std::min(size(i), source.size(i)); copy(source, IndexArray::filled(0), max, IndexArray::filled(0)); } template void MultiArray::copy(MultiArray const& source, IndexArray const& sourceMin, IndexArray const& sourceMax, IndexArray const& targetMin) { IndexArray sourceIndex; IndexArray targetIndex; subCopy(source, sourceMin, sourceMax, targetMin, sourceIndex, targetIndex, 0); } template template void MultiArray::forEach(IndexArray const& min, SizeArray const& size, OpType&& op) { IndexArray index; subForEach(min, size, std::forward(op), index, 0, 0); } template template void MultiArray::forEach(IndexArray const& min, SizeArray const& size, OpType&& op) const { IndexArray index; subForEach(min, size, std::forward(op), index, 0, 0); } template template void MultiArray::forEach(OpType&& op) { forEach(IndexArray::filled(0), size(), std::forward(op)); } template template void MultiArray::forEach(OpType&& op) const { forEach(IndexArray::filled(0), size(), std::forward(op)); } template template void MultiArray::print(OStream& os) const { subPrint(os, IndexArray(), 0); } template size_t MultiArray::count() const { return m_data.size(); } template Element const& MultiArray::atIndex(size_t index) const { return m_data[index]; } template Element& MultiArray::atIndex(size_t index) { return m_data[index]; } template Element const* MultiArray::data() const { return m_data.ptr(); } template Element* MultiArray::data() { return m_data.ptr(); } template size_t MultiArray::storageIndex(IndexArray const& index) const { size_t loc = index[0]; starAssert(index[0] < m_shape[0]); for (size_t i = 1; i < Rank; ++i) { loc = loc * m_shape[i] + index[i]; starAssert(index[i] < m_shape[i]); } return loc; } template template void MultiArray::subPrint(OStream& os, IndexArray index, size_t dim) const { if (dim == Rank - 1) { for (size_t i = 0; i < m_shape[dim]; ++i) { index[dim] = i; os << m_data[storageIndex(index)] << ' '; } os << std::endl; } else { for (size_t i = 0; i < m_shape[dim]; ++i) { index[dim] = i; subPrint(os, index, dim + 1); } os << std::endl; } } template template void MultiArray::subForEach(IndexArray const& min, SizeArray const& size, OpType&& op, IndexArray& index, size_t offset, size_t dim) { size_t minIndex = min[dim]; size_t maxIndex = minIndex + size[dim]; for (size_t i = minIndex; i < maxIndex; ++i) { index[dim] = i; if (dim == Rank - 1) op(index, m_data[offset + i]); else subForEach(min, size, std::forward(op), index, (offset + i) * m_shape[dim + 1], dim + 1); } } template template void MultiArray::subForEach(IndexArray const& min, SizeArray const& size, OpType&& op, IndexArray& index, size_t offset, size_t dim) const { size_t minIndex = min[dim]; size_t maxIndex = minIndex + size[dim]; for (size_t i = minIndex; i < maxIndex; ++i) { index[dim] = i; if (dim == Rank - 1) op(index, m_data[offset + i]); else subForEach(min, size, std::forward(op), index, (offset + i) * m_shape[dim + 1], dim + 1); } } template void MultiArray::subCopy(MultiArray const& source, IndexArray const& sourceMin, IndexArray const& sourceMax, IndexArray const& targetMin, IndexArray& sourceIndex, IndexArray& targetIndex, size_t dim) { size_t w = sourceMax[dim] - sourceMin[dim]; if (dim < Rank - 1) { for (size_t i = 0; i < w; ++i) { sourceIndex[dim] = i + sourceMin[dim]; targetIndex[dim] = i + targetMin[dim]; subCopy(source, sourceMin, sourceMax, targetMin, sourceIndex, targetIndex, dim + 1); } } else { sourceIndex[dim] = sourceMin[dim]; targetIndex[dim] = targetMin[dim]; size_t sourceStorageStart = source.storageIndex(sourceIndex); size_t targetStorageStart = storageIndex(targetIndex); for (size_t i = 0; i < w; ++i) m_data[targetStorageStart + i] = source.m_data[sourceStorageStart + i]; } } template std::ostream& operator<<(std::ostream& os, MultiArray const& array) { array.print(os); return os; } } template struct fmt::formatter> : ostream_formatter {};