// MIT License // Copyright (c) 2020 Haik Aslanyan // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // The above copyright notice and this permission notice shall be included in all // copies or substantial portions of the Software. // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. import Accelerate final class RCMatrix { //MARK: Properties let rows: Int var columns: Int { data.count / rows } var data: [UInt8] //MARK: Inits init(size: Int) { self.rows = size data = Array(repeating: 0, count: rows * rows) } init(rows: Int, columns: Int) { self.rows = rows data = Array(repeating: 0, count: rows * columns) } init(columns: Int, items: [UInt8]) { guard items.count % columns == 0 else { fatalError("number of columns are not matching") } self.rows = items.count / columns self.data = items } init(rows: Int, items: [UInt8]) { guard items.count % rows == 0 else { fatalError("number of rows are not matching") } self.rows = rows self.data = items } } //MARK: Public methods extension RCMatrix { static func identity(size: Int) -> RCMatrix { let matrix = RCMatrix(size: size) (0.. UInt8 { get { return data[self.rows * row + column] } set { data[self.rows * row + column] = newValue } } func row(for index: Int) -> [UInt8] { let range = (columns * index)..<(columns * (index + 1)) return Array(data[range]) } func column(for index: Int) -> [UInt8] { let indexes = stride(from: 0, to: data.count, by: self.columns).map({$0}) return indexes.map({data[index + $0]}) } func rowsData() -> [[UInt8]] { return stride(from: 0, to: data.count, by: self.columns).map { index in Array(data[index ..< min(index + self.columns, data.count)]) } } func subMatrix(rowMin: Int, columnMin: Int, rowMax: Int, columnMax: Int) -> RCMatrix { let data = rowsData()[rowMin...rowMax].map({$0[columnMin...columnMax]}).flatMap({$0}) return RCMatrix(rows: (rowMin...rowMax).count, items: data) } func reducedMultiply(by matrix: RCMatrix) -> RCMatrix { guard self.columns == matrix.rows else { fatalError("Cannot subract matrices of different dimensions") } var currentData = self.data.map({Float($0)}) var matrixData = matrix.data.map({Float($0)}) var result = [Float](repeating: 0, count: self.rows * matrix.columns) vDSP_mmul(¤tData, 1, &matrixData, 1, &result, 1, UInt(self.rows), UInt(matrix.columns), UInt(self.columns)) return RCMatrix(rows: self.rows, items: result.map({ item in var value = item while value > Float(UInt8.max) { value -= Float(UInt8.max) } return UInt8(value) })) } func invert() { guard rows == columns else { fatalError("Only square matrices can be inverted") } var invertedData = data.map({Double($0)}) var size = __CLPK_integer(self.rows) var pivots = [__CLPK_integer](repeating: 0, count: Int(size)) var workspace = [Double](repeating: 0.0, count: Int(size)) var error = __CLPK_integer(0) withUnsafeMutablePointer(to: &size) { dgetrf_($0, $0, &invertedData, $0, &pivots, &error) dgetri_($0, &invertedData, $0, &pivots, &workspace, $0, &error) } guard error == 0 else { fatalError("error inverting matrix: Error(\(error))") } invertedData.enumerated().forEach { value in self.data[value.offset] = UInt8(value.element) } } } //MARK: CustomStringConvertible extension RCMatrix: CustomStringConvertible { var description: String { var message = "columns: \(columns), rows: \(rows) \n" rowsData().forEach { row in row.forEach({message += " \($0)"}) message += "\n" } return message } }