├── .gitignore ├── LICENSE ├── README.MD ├── arithmetic.go ├── arithmetic_test.go ├── arraysetops.go ├── compare_opt.go ├── compare_opt_test.go ├── condition_opt.go ├── distrubution.go ├── error_set.go ├── go.mod ├── go.sum ├── index_opt.go ├── internal ├── LICENSE.md ├── arithmetic_amd64.go ├── arithmetic_amd64.s ├── arithmetic_nasm.go ├── boolOps_amd64.go ├── boolOps_amd64.s ├── boolOps_nasm.go ├── matmul_amd64.go ├── matmul_amd64.s └── matmul_nasm.go ├── logical_opt.go ├── numeric_arrb.go ├── numeric_arrb_test.go ├── numeric_arrf.go ├── numeric_arrf_test.go ├── shape.go ├── shape_test.go ├── stats.go ├── stats_test.go ├── utils.go └── utils_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | blas/lib/* 2 | blas/include/* 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.MD: -------------------------------------------------------------------------------- 1 | 最近在提升可用性,如果您发现了bug,请一定告知。 2 | 3 | ### type Arrf 4 | func Abs(b *Arrf) *Arrf 5 | func Acos(b *Arrf) *Arrf 6 | func Acosh(b *Arrf) *Arrf 7 | func Add(a, b *Arrf) *Arrf 8 | func Arange(vals ...float64) (a *Arrf) 9 | func ArgMax(a *Arrf, axis ...int) *Arrf 10 | func ArgMin(a *Arrf, axis ...int) *Arrf 11 | func Array(data []float64, shape ...int) (a *Arrf) 12 | func Asin(b *Arrf) *Arrf 13 | func Asinh(b *Arrf) *Arrf 14 | func Atan(b *Arrf) *Arrf 15 | func Atanh(b *Arrf) *Arrf 16 | func Ceil(b *Arrf) *Arrf 17 | func CopySign(a, b *Arrf) *Arrf 18 | func Cos(b *Arrf) *Arrf 19 | func Cosh(b *Arrf) *Arrf 20 | func Div(a, b *Arrf) *Arrf 21 | func Empty(shape ...int) (a *Arrf) 22 | func EmptyLike(a *Arrf) *Arrf 23 | func Exp(b *Arrf) *Arrf 24 | func Eye(n int) *Arrf 25 | func Floor(b *Arrf) *Arrf 26 | func Full(fullValue float64, shape ...int) *Arrf 27 | func Identity(n int) *Arrf 28 | func Log(b *Arrf) *Arrf 29 | func Log10(b *Arrf) *Arrf 30 | func Log1p(b *Arrf) *Arrf 31 | func Log2(b *Arrf) *Arrf 32 | func Max(a *Arrf, axis ...int) *Arrf 33 | func Maximum(a, b *Arrf) *Arrf 34 | func Mean(a *Arrf, axis ...int) *Arrf 35 | func Min(a *Arrf, axis ...int) *Arrf 36 | func Minimum(a, b *Arrf) *Arrf 37 | func Mod(a, b *Arrf) *Arrf 38 | func Modf(b *Arrf) (*Arrf, *Arrf) 39 | func Mul(a, b *Arrf) *Arrf 40 | func Ones(shape ...int) *Arrf 41 | func OnesLike(a *Arrf) *Arrf 42 | func Pow(a, b *Arrf) *Arrf 43 | func Round(b *Arrf, places int) *Arrf 44 | func Sign(b *Arrf) *Arrf 45 | func Sin(b *Arrf) *Arrf 46 | func Sinh(b *Arrf) *Arrf 47 | func Sort(a *Arrf, axis ...int) *Arrf 48 | func Sqrt(b *Arrf) *Arrf 49 | func Square(b *Arrf) *Arrf 50 | func Std(a *Arrf, axis ...int) *Arrf 51 | func Sub(a, b *Arrf) *Arrf 52 | func Sum(a *Arrf, axis ...int) *Arrf 53 | func Tan(b *Arrf) *Arrf 54 | func Tanh(b *Arrf) *Arrf 55 | func Var(a *Arrf, axis ...int) *Arrf 56 | func Where(cond *Arrb, tv, fv interface{}) *Arrf 57 | func Zeros(shape ...int) *Arrf 58 | func ZerosLike(a *Arrf) *Arrf 59 | func (a *Arrf) Add(b *Arrf) *Arrf 60 | func (a *Arrf) AddC(b float64) *Arrf 61 | func (a *Arrf) ArgMax(axis ...int) *Arrf 62 | func (a *Arrf) ArgMin(axis ...int) *Arrf 63 | func (a *Arrf) At(index ...int) float64 64 | func (a *Arrf) Copy() *Arrf 65 | func (a *Arrf) Count(axis ...int) int 66 | func (a *Arrf) Div(b *Arrf) *Arrf 67 | func (a *Arrf) DivC(b float64) *Arrf 68 | func (a *Arrf) DotProd(b *Arrf) float64 69 | func (a *Arrf) Equal(b *Arrf) *Arrb 70 | func (a *Arrf) Flatten() *Arrf 71 | func (a *Arrf) Get(index ...int) float64 72 | func (a *Arrf) Greater(b *Arrf) *Arrb 73 | func (a *Arrf) GreaterEqual(b *Arrf) *Arrb 74 | func (a *Arrf) Index(ranges ...Range) *Arrf 75 | func (a *Arrf) Less(b *Arrf) *Arrb 76 | func (a *Arrf) LessEqual(b *Arrf) *Arrb 77 | func (a *Arrf) MatProd(b *Arrf) *Arrf 78 | func (a *Arrf) Max(axis ...int) *Arrf 79 | func (a *Arrf) Mean(axis ...int) *Arrf 80 | func (a *Arrf) Min(axis ...int) *Arrf 81 | func (a *Arrf) Mul(b *Arrf) *Arrf 82 | func (a *Arrf) MulC(b float64) *Arrf 83 | func (a *Arrf) Ndims() int 84 | func (a *Arrf) NotEqual(b *Arrf) *Arrb 85 | func (a *Arrf) Reshape(shape ...int) *Arrf 86 | func (a *Arrf) Set(val float64, index ...int) *Arrf 87 | func (a *Arrf) Sort(axis ...int) *Arrf 88 | func (a *Arrf) Std(axis ...int) *Arrf 89 | func (a *Arrf) String() (s string) 90 | func (a *Arrf) Sub(b *Arrf) *Arrf 91 | func (a *Arrf) SubC(b float64) *Arrf 92 | func (a *Arrf) Sum(axis ...int) *Arrf 93 | func (a *Arrf) Transpose(axes ...int) *Arrf 94 | func (a *Arrf) Values() []float64 95 | func (a *Arrf) Var(axis ...int) *Arrf 96 | 97 | ### type Arrb 98 | func ArrayB(data []bool, shape ...int) (a *Arrb) 99 | func EmptyB(shape ...int) (a *Arrb) 100 | func Equal(a, b *Arrf) *Arrb 101 | func FullB(value bool, shape ...int) *Arrb 102 | func Greater(a, b *Arrf) *Arrb 103 | func GreaterEqual(a, b *Arrf) *Arrb 104 | func IsFinit(b *Arrf) *Arrb 105 | func IsInf(b *Arrf) *Arrb 106 | func IsNaN(b *Arrf) *Arrb 107 | func Less(a, b *Arrf) *Arrb 108 | func LessEqual(a, b *Arrf) *Arrb 109 | func LogicalAnd(a, b *Arrb) *Arrb 110 | func LogicalNot(a *Arrb) *Arrb 111 | func LogicalOr(a, b *Arrb) *Arrb 112 | func NotEqual(a, b *Arrf) *Arrb 113 | func (ab *Arrb) All() bool 114 | func (ab *Arrb) Any() bool 115 | func (a *Arrb) LogicalAnd(b *Arrb) *Arrb 116 | func (a *Arrb) LogicalNot() *Arrb 117 | func (a *Arrb) LogicalOr(b *Arrb) *Arrb 118 | func (a *Arrb) String() (s string) 119 | func (a *Arrb) Sum() int 120 | 121 | ### Variables 122 | func ContainsFloat64(s []float64, e float64) bool 123 | func Hargmax(ln int, data []float64) 124 | func Hargmin(ln int, data []float64) 125 | func Hmax(ln int, data []float64) 126 | func Hmin(ln int, data []float64) 127 | func Hsort(ln int, data []float64) 128 | func ProductIntSlice(slice []int) int 129 | func ReverseIntSlice(slice []int) []int 130 | func Roundf(val float64, places int) float64 131 | func Vargmax(ln int, a []float64) 132 | func Vargmin(ln int, a []float64) 133 | func Vmax(a, b []float64) 134 | func Vmin(a, b []float64) 135 | func Vsort(ln int, a []float64) -------------------------------------------------------------------------------- /arithmetic.go: -------------------------------------------------------------------------------- 1 | package arrgo 2 | 3 | import ( 4 | "math" 5 | 6 | asm "github.com/ledao/arrgo/internal" 7 | ) 8 | 9 | //多维数组和标量相加,结果为新的多维数组,不修改原数组。 10 | func (a *Arrf) AddC(b float64) *Arrf { 11 | if a == nil || a.Size() == 0 { 12 | panic(SHAPE_ERROR) 13 | } 14 | ta := a.Copy() 15 | asm.AddC(b, ta.data) 16 | return ta 17 | } 18 | 19 | //两个多维数组相加,结果为新的多维数组,不修改原数组。 20 | //加法过程中间会发生广播,对矩阵运算有极大帮助。 21 | //fixme : by ledao 广播机制会进行额外的运算,对于简单的场景最好有判断,避免广播。 22 | func (a *Arrf) Add(b *Arrf) *Arrf { 23 | if a.SameShapeTo(b) { 24 | var ta = a.Copy() 25 | asm.Add(ta.data, b.data) 26 | return ta 27 | } 28 | var ta, tb, err = Boardcast(a, b) 29 | if err != nil { 30 | panic(err) 31 | } 32 | return ta.Add(tb) 33 | } 34 | 35 | //多维数组和标量相减,结果为新的多维数组,不修改原数组。 36 | func (a *Arrf) SubC(b float64) *Arrf { 37 | ta := a.Copy() 38 | asm.SubtrC(b, ta.data) 39 | return ta 40 | } 41 | 42 | //两个多维数组相减,结果为新的多维数组,不修改原数组。 43 | //减法过程中间会发生广播,对矩阵运算有极大帮助。 44 | //fixme : by ledao 广播机制会进行额外的运算,对于简单的场景最好有判断,避免广播。 45 | func (a *Arrf) Sub(b *Arrf) *Arrf { 46 | if a.SameShapeTo(b) { 47 | var ta = a.Copy() 48 | asm.Subtr(ta.data, b.data) 49 | return ta 50 | } 51 | var ta, tb, err = Boardcast(a, b) 52 | if err != nil { 53 | panic(err) 54 | } 55 | return ta.Sub(tb) 56 | } 57 | 58 | func (a *Arrf) MulC(b float64) *Arrf { 59 | ta := a.Copy() 60 | asm.MultC(b, ta.data) 61 | return ta 62 | } 63 | 64 | func (a *Arrf) Mul(b *Arrf) *Arrf { 65 | if a.SameShapeTo(b) { 66 | var ta = a.Copy() 67 | asm.Mult(ta.data, b.data) 68 | return ta 69 | } 70 | var ta, tb, err = Boardcast(a, b) 71 | if err != nil { 72 | panic(err) 73 | } 74 | return ta.Mul(tb) 75 | } 76 | 77 | func (a *Arrf) DivC(b float64) *Arrf { 78 | ta := a.Copy() 79 | asm.DivC(b, ta.data) 80 | return ta 81 | } 82 | 83 | func (a *Arrf) Div(b *Arrf) *Arrf { 84 | if a.SameShapeTo(b) { 85 | var ta = a.Copy() 86 | asm.Div(ta.data, b.data) 87 | return ta 88 | } 89 | var ta, tb, err = Boardcast(a, b) 90 | if err != nil { 91 | panic(err) 92 | } 93 | return ta.Div(tb) 94 | } 95 | 96 | func (a *Arrf) DotProd(b *Arrf) float64 { 97 | switch { 98 | case a.Ndims() == 1 && b.Ndims() == 1 && a.Length() == b.Length(): 99 | return asm.DotProd(a.data, b.data) 100 | } 101 | panic(SHAPE_ERROR) 102 | } 103 | 104 | func (a *Arrf) MatProd(b *Arrf) *Arrf { 105 | switch { 106 | case a.Ndims() == 2 && b.Ndims() == 2 && a.shape[1] == b.shape[0]: 107 | ret := Zeros(a.shape[0], b.shape[1]) 108 | for i := 0; i < a.shape[0]; i++ { 109 | for j := 0; j < a.shape[1]; j++ { 110 | ret.Set(a.Index(Range{i, i + 1}).DotProd(b.Index(Range{0, b.shape[0]}, Range{j, j + 1})), i, j) 111 | } 112 | } 113 | return ret 114 | } 115 | panic(SHAPE_ERROR) 116 | } 117 | 118 | func Abs(b *Arrf) *Arrf { 119 | tb := b.Copy() 120 | for i, v := range tb.data { 121 | tb.data[i] = math.Abs(v) 122 | } 123 | return tb 124 | } 125 | 126 | func Sqrt(b *Arrf) *Arrf { 127 | tb := b.Copy() 128 | for i, v := range tb.data { 129 | tb.data[i] = math.Sqrt(v) 130 | } 131 | return tb 132 | } 133 | 134 | func Square(b *Arrf) *Arrf { 135 | var tb = b.Copy() 136 | for i, v := range tb.data { 137 | tb.data[i] = math.Pow(v, 2) 138 | } 139 | return tb 140 | } 141 | 142 | func Exp(b *Arrf) *Arrf { 143 | var tb = b.Copy() 144 | for i, v := range tb.data { 145 | tb.data[i] = math.Exp(v) 146 | } 147 | return tb 148 | } 149 | 150 | func Log(b *Arrf) *Arrf { 151 | var tb = b.Copy() 152 | for i, v := range tb.data { 153 | tb.data[i] = math.Log(v) 154 | } 155 | return tb 156 | } 157 | 158 | func Log10(b *Arrf) *Arrf { 159 | var tb = b.Copy() 160 | for i, v := range tb.data { 161 | tb.data[i] = math.Log10(v) 162 | } 163 | return tb 164 | } 165 | 166 | func Log2(b *Arrf) *Arrf { 167 | var tb = b.Copy() 168 | for i, v := range tb.data { 169 | tb.data[i] = math.Log2(v) 170 | } 171 | return tb 172 | } 173 | 174 | func Log1p(b *Arrf) *Arrf { 175 | var tb = b.Copy() 176 | for i, v := range tb.data { 177 | tb.data[i] = math.Log1p(v) 178 | } 179 | return tb 180 | } 181 | 182 | func Sign(b *Arrf) *Arrf { 183 | var tb = b.Copy() 184 | var sign float64 = 0 185 | for i, v := range tb.data { 186 | if v > 0 { 187 | sign = 1 188 | } else if v < 0 { 189 | sign = -1 190 | } 191 | tb.data[i] = sign 192 | } 193 | return tb 194 | } 195 | 196 | func Ceil(b *Arrf) *Arrf { 197 | var tb = b.Copy() 198 | for i, v := range tb.data { 199 | tb.data[i] = math.Ceil(v) 200 | } 201 | return tb 202 | } 203 | 204 | func Floor(b *Arrf) *Arrf { 205 | var tb = b.Copy() 206 | for i, v := range tb.data { 207 | tb.data[i] = math.Floor(v) 208 | } 209 | return tb 210 | } 211 | 212 | func Round(b *Arrf, places int) *Arrf { 213 | var tb = b.Copy() 214 | for i, v := range tb.data { 215 | tb.data[i] = Roundf(v, places) 216 | } 217 | return tb 218 | } 219 | 220 | func Modf(b *Arrf) (*Arrf, *Arrf) { 221 | var tb = b.Copy() 222 | var tbFrac = b.Copy() 223 | for i, v := range tb.data { 224 | r, f := math.Modf(v) 225 | tb.data[i] = r 226 | tbFrac.data[i] = f 227 | } 228 | return tb, tbFrac 229 | } 230 | 231 | func IsNaN(b *Arrf) *Arrb { 232 | var tb = EmptyB(b.shape...) 233 | for i, v := range b.data { 234 | tb.data[i] = math.IsNaN(v) 235 | } 236 | return tb 237 | } 238 | 239 | func IsInf(b *Arrf) *Arrb { 240 | var tb = EmptyB(b.shape...) 241 | for i, v := range b.data { 242 | tb.data[i] = math.IsInf(v, 0) 243 | } 244 | return tb 245 | } 246 | 247 | func IsFinit(b *Arrf) *Arrb { 248 | var tb = EmptyB(b.shape...) 249 | for i, v := range b.data { 250 | tb.data[i] = !math.IsInf(v, 0) 251 | } 252 | return tb 253 | } 254 | 255 | func Cos(b *Arrf) *Arrf { 256 | var tb = b.Copy() 257 | for i, v := range tb.data { 258 | tb.data[i] = math.Cos(v) 259 | } 260 | return tb 261 | } 262 | 263 | func Cosh(b *Arrf) *Arrf { 264 | var tb = b.Copy() 265 | for i, v := range tb.data { 266 | tb.data[i] = math.Cosh(v) 267 | } 268 | return tb 269 | } 270 | 271 | func Acos(b *Arrf) *Arrf { 272 | var tb = b.Copy() 273 | for i, v := range tb.data { 274 | tb.data[i] = math.Acos(v) 275 | } 276 | return tb 277 | } 278 | 279 | func Acosh(b *Arrf) *Arrf { 280 | var tb = b.Copy() 281 | for i, v := range tb.data { 282 | tb.data[i] = math.Acosh(v) 283 | } 284 | return tb 285 | } 286 | 287 | func Sin(b *Arrf) *Arrf { 288 | var tb = b.Copy() 289 | for i, v := range tb.data { 290 | tb.data[i] = math.Sin(v) 291 | } 292 | return tb 293 | } 294 | 295 | func Sinh(b *Arrf) *Arrf { 296 | var tb = b.Copy() 297 | for i, v := range tb.data { 298 | tb.data[i] = math.Sinh(v) 299 | } 300 | return tb 301 | } 302 | 303 | func Asin(b *Arrf) *Arrf { 304 | var tb = b.Copy() 305 | for i, v := range tb.data { 306 | tb.data[i] = math.Asin(v) 307 | } 308 | return tb 309 | } 310 | 311 | func Asinh(b *Arrf) *Arrf { 312 | var tb = b.Copy() 313 | for i, v := range tb.data { 314 | tb.data[i] = math.Asinh(v) 315 | } 316 | return tb 317 | } 318 | 319 | func Tan(b *Arrf) *Arrf { 320 | var tb = b.Copy() 321 | for i, v := range tb.data { 322 | tb.data[i] = math.Tan(v) 323 | } 324 | return tb 325 | } 326 | 327 | func Tanh(b *Arrf) *Arrf { 328 | var tb = b.Copy() 329 | for i, v := range tb.data { 330 | tb.data[i] = math.Tanh(v) 331 | } 332 | return tb 333 | } 334 | 335 | func Atan(b *Arrf) *Arrf { 336 | var tb = b.Copy() 337 | for i, v := range tb.data { 338 | tb.data[i] = math.Atan(v) 339 | } 340 | return tb 341 | } 342 | 343 | func Atanh(b *Arrf) *Arrf { 344 | var tb = b.Copy() 345 | for i, v := range tb.data { 346 | tb.data[i] = math.Atanh(v) 347 | } 348 | return tb 349 | } 350 | 351 | func Add(a, b *Arrf) *Arrf { 352 | return a.Add(b) 353 | } 354 | 355 | func Sub(a, b *Arrf) *Arrf { 356 | return a.Sub(b) 357 | } 358 | 359 | func Mul(a, b *Arrf) *Arrf { 360 | return a.Mul(b) 361 | } 362 | 363 | func Div(a, b *Arrf) *Arrf { 364 | return a.Div(b) 365 | } 366 | 367 | func Pow(a, b *Arrf) *Arrf { 368 | var t = ZerosLike(a) 369 | for i, v := range a.data { 370 | t.data[i] = math.Pow(v, b.data[i]) 371 | } 372 | return t 373 | } 374 | 375 | func Maximum(a, b *Arrf) *Arrf { 376 | var t = a.Copy() 377 | for i, v := range t.data { 378 | if v < b.data[i] { 379 | v = b.data[i] 380 | } 381 | t.data[i] = v 382 | } 383 | return t 384 | } 385 | 386 | func Minimum(a, b *Arrf) *Arrf { 387 | var t = a.Copy() 388 | for i, v := range t.data { 389 | if v > b.data[i] { 390 | v = b.data[i] 391 | } 392 | t.data[i] = v 393 | } 394 | return t 395 | } 396 | 397 | func Mod(a, b *Arrf) *Arrf { 398 | var t = a.Copy() 399 | for i, v := range t.data { 400 | t.data[i] = math.Mod(v, b.data[i]) 401 | } 402 | return t 403 | } 404 | 405 | func CopySign(a, b *Arrf) *Arrf { 406 | ta := Abs(a) 407 | sign := Sign(b) 408 | return ta.Mul(sign) 409 | } 410 | 411 | func Boardcast(a, b *Arrf) (*Arrf, *Arrf, error) { 412 | if a.Ndims() < b.Ndims() { 413 | return nil, nil, SHAPE_ERROR 414 | } 415 | var bNewShape []int 416 | if a.Ndims() == b.Ndims() { 417 | bNewShape = b.shape 418 | } else { 419 | bNewShape = make([]int, len(a.shape)) 420 | for i := range bNewShape { 421 | bNewShape[i] = 1 422 | } 423 | copy(bNewShape[len(a.shape)-len(b.shape):], b.shape) 424 | } 425 | 426 | var aChangeIndex = make([]int, 0) 427 | var aChangeNum = make([]int, 0) 428 | var bChangeIndex = make([]int, 0) 429 | var bChangeNum = make([]int, 0) 430 | for i := range bNewShape { 431 | if a.shape[i] == bNewShape[i] { 432 | continue 433 | } else if a.shape[i] == 1 { 434 | aChangeIndex = append(aChangeIndex, i) 435 | aChangeNum = append(aChangeNum, bNewShape[i]) 436 | } else if bNewShape[i] == 1 { 437 | bChangeIndex = append(bChangeIndex, i) 438 | bChangeNum = append(bChangeNum, a.shape[i]) 439 | } else { 440 | return nil, nil, SHAPE_ERROR 441 | } 442 | } 443 | 444 | var aNew, bNew *Arrf 445 | if len(aChangeNum) == 0 { 446 | aNew = a 447 | } else { 448 | var baseNum = a.Length() 449 | var expandTimes = ProductIntSlice(aChangeNum) 450 | var expandData = make([]float64, baseNum*expandTimes) 451 | for i := 0; i < expandTimes; i++ { 452 | copy(expandData[i*baseNum:(i+1)*baseNum], a.data) 453 | } 454 | var newPos = make([]int, len(aChangeIndex), len(a.shape)) 455 | var expandShape = make([]int, len(aChangeNum), len(a.shape)) 456 | copy(newPos, aChangeIndex) 457 | copy(expandShape, aChangeNum) 458 | for i := range a.shape { 459 | if !ContainsInt(aChangeIndex, i) { 460 | newPos = append(newPos, i) 461 | expandShape = append(expandShape, a.shape[i]) 462 | } 463 | } 464 | aNew = Array(expandData, expandShape...).Transpose(newPos...) 465 | } 466 | 467 | if len(bChangeNum) == 0 { 468 | bNew = b 469 | } else { 470 | var baseNum = b.Length() 471 | var expandTimes = ProductIntSlice(bChangeNum) 472 | var expandData = make([]float64, baseNum*expandTimes) 473 | for i := 0; i < expandTimes; i++ { 474 | copy(expandData[i*baseNum:(i+1)*baseNum], b.data) 475 | } 476 | var newPos = make([]int, len(bChangeIndex), len(bNewShape)) 477 | var expandShape = make([]int, len(bChangeNum), len(bNewShape)) 478 | copy(newPos, bChangeIndex) 479 | copy(expandShape, bChangeNum) 480 | for i := range bNewShape { 481 | if !ContainsInt(bChangeIndex, i) { 482 | newPos = append(newPos, i) 483 | expandShape = append(expandShape, bNewShape[i]) 484 | } 485 | } 486 | bNew = Array(expandData, expandShape...).Transpose(newPos...) 487 | } 488 | 489 | return aNew, bNew, nil 490 | } 491 | -------------------------------------------------------------------------------- /arithmetic_test.go: -------------------------------------------------------------------------------- 1 | package arrgo 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestArrf_AddC(t *testing.T) { 8 | arr := Arange(0, 10, 2) 9 | add := arr.AddC(2) 10 | if !add.Equal(Array([]float64{2, 4, 6, 8, 10})).AllTrues() { 11 | t.Error("Expected [2,4,6,8,10], got ", add) 12 | } 13 | } 14 | 15 | //测试nil 16 | func TestArrf_AddC_SHAPEERROR(t *testing.T) { 17 | var arr *Arrf = nil 18 | 19 | defer func() { 20 | var rec = recover() 21 | if rec != SHAPE_ERROR { 22 | t.Error("Expected SHAPE ERROR, got ", rec) 23 | } 24 | }() 25 | arr.AddC(10) 26 | } 27 | 28 | //测试空array 29 | func TestArrf_AddC_SHAPEERROR2(t *testing.T) { 30 | var arr *Arrf = Array([]float64{}) 31 | 32 | defer func() { 33 | var rec = recover() 34 | if rec != SHAPE_ERROR { 35 | t.Error("Expected SHAPE ERROR, got ", rec) 36 | } 37 | }() 38 | arr.AddC(10) 39 | } 40 | 41 | func TestArrf_Add(t *testing.T) { 42 | var a = Array([]float64{1, 2, 3, 4, 5, 6}, 2, 3) 43 | var b = Array([]float64{6, 5, 4, 3, 2, 1}, 2, 3) 44 | var c = a.Add(b) 45 | if !c.Equal(Fill(7, 2, 3)).AllTrues() { 46 | t.Error("Expected [[7,7,7],[7,7,7]], got ", c) 47 | } 48 | } 49 | 50 | //func TestArrf_Add_NilException(t *testing.T) { 51 | // var a = Array([]float64{1,2,3,4,5,6}, 2, 3) 52 | // 53 | // defer func(){ 54 | // var rec = recover() 55 | // if rec != SHAPE_ERROR { 56 | // t.Error("Expected SHAPE ERROR, got ", rec) 57 | // } 58 | // }() 59 | // a.Add(nil) 60 | //} 61 | 62 | func TestArrf_Add_NDimException(t *testing.T) { 63 | var a = Array([]float64{1, 2, 3, 4, 5, 6}) 64 | var b = Array([]float64{1, 2, 3}, 3, 1) 65 | defer func() { 66 | var rec = recover() 67 | if rec != SHAPE_ERROR { 68 | t.Error("Expected SHAPE ERROR, got ", rec) 69 | } 70 | }() 71 | a.Add(b) 72 | } 73 | 74 | func BenchmarkDotProd(b *testing.B) { 75 | a := Array([]float64{1, 2, 3}) 76 | c := Array([]float64{4, 5, 6}) 77 | for i := 0; i < 100; i++ { 78 | a.DotProd(c) 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /arraysetops.go: -------------------------------------------------------------------------------- 1 | package arrgo 2 | 3 | import ( 4 | "sort" 5 | ) 6 | 7 | //Find the unique elements of an array. 8 | // 9 | //Returns the sorted unique elements of an array. There are three optional 10 | //outputs in addition to the unique elements: the indices of the input array 11 | //that give the unique values, the indices of the unique array that 12 | //reconstruct the input array, and the number of times each unique value 13 | //comes up in the input array. 14 | // 15 | //Parameters 16 | //---------- 17 | //ar : array_like 18 | //Input array. This will be flattened if it is not already 1-D. 19 | //return_index : bool, optional 20 | //If True, also return the indices of `ar` that result in the unique 21 | //array. 22 | //return_inverse : bool, optional 23 | //If True, also return the indices of the unique array that can be used 24 | //to reconstruct `ar`. 25 | //return_counts : bool, optional 26 | //If True, also return the number of times each unique value comes up 27 | //in `ar`. 28 | // 29 | //.. versionadded:: 1.9.0 30 | // 31 | //Returns 32 | //------- 33 | //unique : ndarray 34 | //The sorted unique values. 35 | //unique_indices : ndarray, optional 36 | //The indices of the first occurrences of the unique values in the 37 | //(flattened) original array. Only provided if `return_index` is True. 38 | //unique_inverse : ndarray, optional 39 | //The indices to reconstruct the (flattened) original array from the 40 | //unique array. Only provided if `return_inverse` is True. 41 | //unique_counts : ndarray, optional 42 | //The number of times each of the unique values comes up in the 43 | //original array. Only provided if `return_counts` is True. 44 | func Unique(a *Arrf) *Arrf { 45 | uniques := make([]float64, 0, a.Length()) 46 | for _, v := range a.Values() { 47 | if !ContainsFloat64(uniques, v) { 48 | uniques = append(uniques, v) 49 | } 50 | } 51 | sort.Float64s(uniques) 52 | return Array(uniques) 53 | } 54 | 55 | //Find the intersection of two arrays. 56 | // Return the sorted, unique values that are in both of the input arrays. 57 | // Parameters 58 | // ---------- 59 | // ar1, ar2 : array_like 60 | // Input arrays. 61 | // assume_unique : bool 62 | // If True, the input arrays are both assumed to be unique, which 63 | // can speed up the calculation. Default is False. 64 | // Returns 65 | // ------- 66 | // intersect1d : ndarray 67 | // Sorted 1D array of common and unique elements. 68 | //func Intersect1d(a, b *Arrf) *Arrf { 69 | // ar1 := Unique(a) 70 | // ar2 := Unique(b) 71 | // 72 | //} 73 | -------------------------------------------------------------------------------- /compare_opt.go: -------------------------------------------------------------------------------- 1 | package arrgo 2 | 3 | import "fmt" 4 | 5 | func (a *Arrf) Greater(b *Arrf) *Arrb { 6 | if len(a.data) == 0 || len(b.data) == 0 { 7 | panic(EMPTY_ARRAY_ERROR) 8 | } 9 | var t = EmptyB(a.shape...) 10 | for i, v := range a.data { 11 | t.data[i] = v > b.data[i] 12 | } 13 | return t 14 | } 15 | 16 | func (a *Arrf) GreaterEqual(b *Arrf) *Arrb { 17 | if len(a.data) == 0 || len(b.data) == 0 { 18 | panic(EMPTY_ARRAY_ERROR) 19 | } 20 | var t = EmptyB(a.shape...) 21 | for i, v := range a.data { 22 | t.data[i] = v >= b.data[i] 23 | } 24 | return t 25 | } 26 | 27 | func (a *Arrf) Less(b *Arrf) *Arrb { 28 | if len(a.data) == 0 || len(b.data) == 0 { 29 | panic(EMPTY_ARRAY_ERROR) 30 | } 31 | var t = EmptyB(a.shape...) 32 | for i, v := range a.data { 33 | t.data[i] = v < b.data[i] 34 | } 35 | return t 36 | } 37 | 38 | func (a *Arrf) LessEqual(b *Arrf) *Arrb { 39 | if len(a.data) == 0 || len(b.data) == 0 { 40 | panic(EMPTY_ARRAY_ERROR) 41 | } 42 | var t = EmptyB(a.shape...) 43 | for i, v := range a.data { 44 | t.data[i] = v <= b.data[i] 45 | } 46 | return t 47 | } 48 | 49 | //判断两个Array相对位置的元素是否相同,返回Arrb。 50 | //如果两个Array任一为空,或者形状不同,则抛出异常。 51 | func (a *Arrf) Equal(b *Arrf) *Arrb { 52 | if len(a.data) == 0 || len(b.data) == 0 { 53 | fmt.Println("empty array.") 54 | panic(EMPTY_ARRAY_ERROR) 55 | } 56 | if !SameIntSlice(a.shape, b.shape) { 57 | fmt.Println("shape not same.") 58 | panic(SHAPE_ERROR) 59 | } 60 | var t = EmptyB(a.shape...) 61 | for i, v := range a.data { 62 | t.data[i] = v == b.data[i] 63 | } 64 | return t 65 | } 66 | 67 | func (a *Arrf) NotEqual(b *Arrf) *Arrb { 68 | if len(a.data) == 0 || len(b.data) == 0 { 69 | panic(EMPTY_ARRAY_ERROR) 70 | } 71 | var t = EmptyB(a.shape...) 72 | for i, v := range a.data { 73 | t.data[i] = v != b.data[i] 74 | } 75 | return t 76 | } 77 | func Greater(a, b *Arrf) *Arrb { 78 | return a.Greater(b) 79 | } 80 | 81 | func GreaterEqual(a,b *Arrf) *Arrb { 82 | return a.GreaterEqual(b) 83 | } 84 | 85 | func Less(a, b *Arrf) *Arrb { 86 | return a.Less(b) 87 | } 88 | 89 | func LessEqual(a, b *Arrf) *Arrb { 90 | return a.LessEqual(b) 91 | } 92 | 93 | func Equal(a, b *Arrf) *Arrb { 94 | return a.Equal(b) 95 | } 96 | 97 | func NotEqual(a, b *Arrf) *Arrb { 98 | return a.NotEqual(b) 99 | } 100 | 101 | func (a *Arrf) Sort(axis ...int) *Arrf { 102 | ax := -1 103 | if len(axis) == 0 { 104 | ax = a.Ndims() - 1 105 | } else { 106 | ax = axis[0] 107 | } 108 | 109 | axisShape, axisSt, axis1St := a.shape[ax], a.strides[ax], a.strides[ax + 1] 110 | if axis1St == 1 { 111 | Hsort(axisSt, a.data) 112 | } else { 113 | Vsort(axis1St, a.data[0:axisShape * axis1St]) 114 | } 115 | 116 | return a 117 | } 118 | 119 | func Sort(a *Arrf, axis ...int) *Arrf { 120 | return a.Copy().Sort(axis...) 121 | } 122 | 123 | func (a *Arrf) Size() int { 124 | return ProductIntSlice(a.shape) 125 | } -------------------------------------------------------------------------------- /compare_opt_test.go: -------------------------------------------------------------------------------- 1 | package arrgo 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestEqualEmptyArrayException(t *testing.T) { 8 | a := Array(nil) 9 | b := Array(nil) 10 | defer func() { 11 | r := recover() 12 | if r != EMPTY_ARRAY_ERROR { 13 | t.Error("Expected EMPTY_ARRAY_ERROR, got ", r) 14 | } 15 | }() 16 | a.Equal(b) 17 | } 18 | 19 | func TestEqualShapeException(t *testing.T) { 20 | a := Array(nil, 3,4) 21 | b := Array(nil, 1,2) 22 | defer func() { 23 | r := recover() 24 | if r != SHAPE_ERROR { 25 | t.Error("Expected SHAPE_ERROR, got ", r) 26 | } 27 | }() 28 | a.Equal(b) 29 | } 30 | 31 | func TestEqual(t *testing.T) { 32 | a := Array([]float64{1,2,3}, ) 33 | b := Array([]float64{1,2,4}, ) 34 | 35 | var compares = a.Equal(b) 36 | if compares.data[2] != false { 37 | t.Error("Expected [true, true, false], got ", compares) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /condition_opt.go: -------------------------------------------------------------------------------- 1 | package arrgo 2 | 3 | func Where(cond *Arrb, tv, fv interface{}) *Arrf { 4 | t := Zeros(cond.shape...) 5 | for i, v := range cond.data { 6 | if v { 7 | switch tv.(type) { 8 | case float64: 9 | t.data[i] = tv.(float64) 10 | case float32: 11 | t.data[i] = float64(tv.(float32)) 12 | case int: 13 | t.data[i] = float64(tv.(int)) 14 | case *Arrf: 15 | t.data[i] = tv.(*Arrf).data[i] 16 | default: 17 | panic(TYPE_ERROR) 18 | } 19 | } else { 20 | switch fv.(type) { 21 | case float64: 22 | t.data[i] = fv.(float64) 23 | case float32: 24 | t.data[i] = float64(fv.(float32)) 25 | case int: 26 | t.data[i] = float64(fv.(int)) 27 | case *Arrf: 28 | t.data[i] = fv.(*Arrf).data[i] 29 | default: 30 | panic(TYPE_ERROR) 31 | } 32 | } 33 | } 34 | return t 35 | } 36 | -------------------------------------------------------------------------------- /distrubution.go: -------------------------------------------------------------------------------- 1 | package arrgo 2 | 3 | import ( 4 | "time" 5 | "math/rand" 6 | ) 7 | 8 | var( 9 | r = rand.New(rand.NewSource(time.Now().UnixNano())) 10 | ) 11 | 12 | func Seed(seed int64) { 13 | r.Seed(seed) 14 | } 15 | 16 | //Return a random matrix with data from the "standard normal" distribution. 17 | // 18 | //`randn` generates a matrix filled with random floats sampled from a 19 | //univariate "normal" (Gaussian) distribution of mean 0 and variance 1. 20 | // 21 | //Parameters 22 | //---------- 23 | //\\*args : Arguments 24 | //Shape of the output. 25 | //If given as N integers, each integer specifies the size of one 26 | //dimension. If given as a tuple, this tuple gives the complete shape. 27 | // 28 | //Returns 29 | //------- 30 | //Z : matrix of floats 31 | //A matrix of floating-point samples drawn from the standard normal 32 | //distribution. 33 | func Randn(shape ...int) *Arrf { 34 | a := Zeros(shape...) 35 | for i := range a.Values() { 36 | a.Values()[i] = r.NormFloat64() 37 | } 38 | 39 | return a 40 | } 41 | -------------------------------------------------------------------------------- /error_set.go: -------------------------------------------------------------------------------- 1 | package arrgo 2 | 3 | import "fmt" 4 | 5 | var ( 6 | INDEX_ERROR error = fmt.Errorf("INDEX ERROR") 7 | SHAPE_ERROR error = fmt.Errorf("SHAPE ERROR") 8 | DIMENTION_ERROR error = fmt.Errorf("DIMENTION ERROR") 9 | TYPE_ERROR error = fmt.Errorf("DATA TYPE ERROR") 10 | EMPTY_ARRAY_ERROR error = fmt.Errorf("EMPTY ARRAY ERROR") 11 | PARAMETER_ERROR error = fmt.Errorf("PARAMETER ERROR") 12 | 13 | UNIMPLEMENT_ERROR error = fmt.Errorf("UNIMPLEMENT ERROR") 14 | ) 15 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/ledao/arrgo 2 | 3 | 4 | go 1.13 5 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/ledao/arrgo v0.0.1 h1:zCBUq/7eOMgFwefX8/BaqoHtIBaU3fdzamzM7tF0QJ0= 2 | github.com/ledao/arrgo v0.0.1/go.mod h1:47TkIWuZtHjXESB9ij9tHbmL3htKDj7VMzb//gSWYd0= 3 | -------------------------------------------------------------------------------- /index_opt.go: -------------------------------------------------------------------------------- 1 | package arrgo 2 | 3 | type Range struct { 4 | Start, Stop int 5 | } 6 | 7 | func (a *Arrf) Index(ranges ...Range) *Arrf { 8 | var ndim = len(a.shape) 9 | totalRanges := make([]Range, ndim) 10 | copy(totalRanges, ranges) 11 | if len(ranges) < ndim { 12 | for i := len(ranges); i < ndim; i++ { 13 | totalRanges[i] = Range{Start:0, Stop:a.shape[i]} 14 | } 15 | } 16 | 17 | shape := make([]int, ndim) 18 | for i := range shape { 19 | shape[i] = totalRanges[i].Stop - totalRanges[i].Start 20 | } 21 | 22 | b := Zeros(shape...) 23 | 24 | totalCount := 1 25 | for i := 0; i < ndim; i++ { 26 | totalCount *= shape[i] 27 | } 28 | 29 | counterSrc := make([]int, ndim) 30 | counterDst := make([]int, ndim) 31 | for i := range counterSrc { 32 | counterSrc[i] = totalRanges[i].Start 33 | counterDst[i] = counterSrc[i] - totalRanges[i].Start 34 | } 35 | 36 | for index := 0; index < totalCount; index++ { 37 | var v = a.At(counterSrc...) 38 | b.Set(v, counterDst...) 39 | counterSrc[ndim-1]++ 40 | counterDst[ndim-1] = counterSrc[ndim-1] - totalRanges[ndim-1].Start 41 | var j = ndim -1 42 | for{ 43 | if j > 0 && counterSrc[j] == totalRanges[j].Stop { 44 | counterSrc[j-1] ++ 45 | counterSrc[j] = totalRanges[j].Start 46 | counterDst[j-1] = counterSrc[j-1] - totalRanges[j-1].Start 47 | counterDst[j] = counterSrc[j] - totalRanges[j].Start 48 | j-- 49 | } else { 50 | break 51 | } 52 | } 53 | } 54 | 55 | return b 56 | } 57 | -------------------------------------------------------------------------------- /internal/LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016, Chad Kunde 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of numgo nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /internal/arithmetic_amd64.go: -------------------------------------------------------------------------------- 1 | //+build !noasm,!appengine 2 | 3 | package asm 4 | 5 | var ( 6 | Sse3Supt, AvxSupt, Avx2Supt, FmaSupt bool 7 | ) 8 | 9 | func init() { 10 | initasm() 11 | } 12 | 13 | func initasm() 14 | 15 | func AddC(c float64, d []float64) 16 | 17 | func SubtrC(c float64, d []float64) 18 | 19 | func MultC(c float64, d []float64) 20 | 21 | func DivC(c float64, d []float64) 22 | 23 | func Add(a, b []float64) 24 | 25 | func Vadd(a, b []float64) 26 | 27 | func Hadd(st uint64, a []float64) 28 | 29 | func Subtr(a, b []float64) 30 | 31 | func Mult(a, b []float64) 32 | 33 | func Div(a, b []float64) 34 | 35 | func Fma12(a float64, x, b []float64) 36 | 37 | func Fma21(a float64, x, b []float64) 38 | -------------------------------------------------------------------------------- /internal/arithmetic_amd64.s: -------------------------------------------------------------------------------- 1 | // +build !noasm !appengine 2 | 3 | #define NOSPLIT 7 4 | 5 | // func initasm()(a,a2 bool) 6 | // pulled from runtime/asm_amd64.s 7 | TEXT ·initasm(SB), NOSPLIT, $0 8 | MOVQ $1, R15 9 | 10 | MOVQ $1, AX 11 | CPUID 12 | 13 | ANDL $0x1, CX 14 | CMPL CX, $0x1 15 | CMOVQEQ R15, R9 16 | MOVB R9, ·Sse3Supt(SB) 17 | XORQ R9, R9 18 | 19 | MOVQ $1, AX 20 | CPUID 21 | ANDL $0x18001000, CX 22 | CMPL CX, $0x18001000 23 | CMOVQEQ R15, R9 24 | MOVB R9, ·FmaSupt(SB) // set numgo·fmaSupt 25 | XORQ R9, R9 26 | 27 | ANDL $0x18000000, CX 28 | CMPL CX, $0x18000000 29 | JNE noavx 30 | 31 | // For XGETBV, OSXSAVE bit is required and sufficient 32 | MOVQ $0, CX 33 | 34 | // Check for FMA capability 35 | // XGETBV 36 | BYTE $0x0F; BYTE $0x01; BYTE $0xD0 37 | 38 | ANDL $6, AX 39 | CMPL AX, $6 // Check for OS support of YMM registers 40 | JNE noavx 41 | MOVB $1, ·AvxSupt(SB) // set numgo·avxSupt 42 | 43 | // Check for AVX2 capability 44 | MOVL $7, AX 45 | XORQ CX, CX 46 | CPUID 47 | ANDL $0x20, BX // check for AVX2 bit 48 | CMPL BX, $0x20 49 | CMOVQEQ R15, R9 50 | MOVB R9, ·Avx2Supt(SB) // set numgo·avx2Supt 51 | XORQ R9, R9 52 | RET 53 | 54 | noavx: 55 | MOVB $0, ·FmaSupt(SB) // set numgo·fmaSupt 56 | MOVB $0, ·AvxSupt(SB) // set numgo·avxSupt 57 | MOVB $0, ·Avx2Supt(SB) // set numgo·avx2Supt 58 | RET 59 | 60 | // func AddC(c float64, d []float64) 61 | TEXT ·AddC(SB), NOSPLIT, $0 62 | // data ptr 63 | MOVQ d+8(FP), R10 64 | 65 | // n = data len 66 | MOVQ d_len+16(FP), SI 67 | 68 | // zero len return 69 | CMPQ SI, $0 70 | JE ACEND 71 | 72 | // check tail 73 | SUBQ $4, SI 74 | JL ACTAIL 75 | 76 | // avx support test 77 | LEAQ c+0(FP), R9 78 | CMPB ·AvxSupt(SB), $1 79 | JE AVX_AC 80 | CMPB ·Avx2Supt(SB), $1 81 | JE AVX2_AC 82 | 83 | // load multiplier 84 | MOVSD (R9), X0 85 | SHUFPD $0, X0, X0 86 | 87 | ACLOOP: // Unrolled x2 d[i]|d[i+1] += c 88 | MOVUPD 0(R10), X1 89 | MOVUPD 16(R10), X2 90 | ADDPD X0, X1 91 | ADDPD X0, X2 92 | MOVUPD X1, 0(R10) 93 | MOVUPD X2, 16(R10) 94 | ADDQ $32, R10 95 | SUBQ $4, SI 96 | JGE ACLOOP 97 | JMP ACTAIL 98 | 99 | // NEED AVX INSTRUCTION CODING FOR THIS TO WORK 100 | AVX2_AC: // Until AVX2 is known 101 | AVX_AC: 102 | // VBROADCASTD (R9), Y0 103 | BYTE $0xC4; BYTE $0xC2; BYTE $0x7D; BYTE $0x19; BYTE $0x01 104 | 105 | AVX_ACLOOP: 106 | // VADDPD (R10),Y0,Y1 107 | BYTE $0xC4; BYTE $0xC1; BYTE $0x7D; BYTE $0x58; BYTE $0x0A 108 | 109 | // VMOVDQA Y1, (R10) 110 | BYTE $0xC4; BYTE $0xC1; BYTE $0x7E; BYTE $0x7F; BYTE $0x0A 111 | 112 | ADDQ $32, R10 113 | SUBQ $4, SI 114 | JGE AVX_ACLOOP 115 | //VZEROUPPER 116 | BYTE $0xC5; BYTE $0xF8; BYTE $0x77 117 | 118 | ACTAIL: // Catch len % 4 == 0 119 | ADDQ $4, SI 120 | JE ACEND 121 | 122 | ACTL: // Calc the last values individually d[i] += c 123 | MOVSD 0(R10), X1 124 | ADDSD X0, X1 125 | MOVSD X1, 0(R10) 126 | ADDQ $8, R10 127 | SUBQ $1, SI 128 | JG ACTL 129 | 130 | ACEND: 131 | RET 132 | 133 | // func subtrC(c float64, d []float64) 134 | TEXT ·SubtrC(SB), NOSPLIT, $0 135 | // data ptr 136 | MOVQ d+8(FP), R10 137 | 138 | // n = data len 139 | MOVQ d_len+16(FP), SI 140 | 141 | // zero len return 142 | CMPQ SI, $0 143 | JE SCEND 144 | 145 | // check tail 146 | SUBQ $4, SI 147 | JL SCTAIL 148 | 149 | // load multiplier 150 | MOVSD c+0(FP), X0 151 | SHUFPD $0, X0, X0 152 | 153 | SCLOOP: // load d[i] | d[i+1] 154 | MOVUPD 0(R10), X1 155 | MOVUPD 16(R10), X2 156 | SUBPD X0, X1 157 | SUBPD X0, X2 158 | MOVUPD X1, 0(R10) 159 | MOVUPD X2, 16(R10) 160 | ADDQ $32, R10 161 | SUBQ $4, SI 162 | JGE SCLOOP 163 | 164 | SCTAIL: 165 | ADDQ $4, SI 166 | JE SCEND 167 | 168 | SCTL: 169 | MOVSD 0(R10), X1 170 | SUBSD X0, X1 171 | MOVSD X1, 0(R10) 172 | ADDQ $8, R10 173 | SUBQ $1, SI 174 | JG SCTL 175 | 176 | SCEND: 177 | RET 178 | 179 | // func multC(c float64, d []float64) 180 | TEXT ·MultC(SB), NOSPLIT, $0 181 | MOVQ d_base+8(FP), R10 182 | MOVQ d_len+16(FP), SI 183 | 184 | // zero len return 185 | CMPQ SI, $0 186 | JE MCEND 187 | SUBQ $4, SI 188 | JL MCTAIL 189 | 190 | // load multiplier 191 | MOVSD c+0(FP), X0 192 | SHUFPD $0, X0, X0 193 | 194 | MCLOOP: // load d[i] | d[i+1] 195 | MOVUPD 0(R10), X1 196 | MOVUPD 16(R10), X2 197 | MULPD X0, X1 198 | MULPD X0, X2 199 | MOVUPD X1, 0(R10) 200 | MOVUPD X2, 16(R10) 201 | ADDQ $32, R10 202 | SUBQ $4, SI 203 | JGE MCLOOP 204 | 205 | MCTAIL: 206 | ADDQ $4, SI 207 | JE MCEND 208 | 209 | MCTL: 210 | MOVSD 0(R10), X1 211 | MULSD X0, X1 212 | MOVSD X1, 0(R10) 213 | ADDQ $8, R10 214 | SUBQ $1, SI 215 | JG MCTL 216 | 217 | MCEND: 218 | RET 219 | 220 | // func divC(c float64, d []float64) 221 | TEXT ·DivC(SB), NOSPLIT, $0 222 | // data ptr 223 | MOVQ d+8(FP), R10 224 | 225 | // n = data len 226 | MOVQ d_len+16(FP), SI 227 | 228 | // zero len return 229 | CMPQ SI, $0 230 | JE DCEND 231 | 232 | // check tail 233 | SUBQ $4, SI 234 | JL DCTAIL 235 | 236 | // load multiplier 237 | MOVSD c+0(FP), X0 238 | SHUFPD $0, X0, X0 239 | 240 | DCLOOP: // load d[i] | d[i+1] 241 | MOVUPD 0(R10), X1 242 | MOVUPD 16(R10), X2 243 | DIVPD X0, X1 244 | DIVPD X0, X2 245 | MOVUPD X1, 0(R10) 246 | MOVUPD X2, 16(R10) 247 | ADDQ $32, R10 248 | SUBQ $4, SI 249 | JGE DCLOOP 250 | 251 | DCTAIL: 252 | ADDQ $4, SI 253 | JE DCEND 254 | 255 | DCTL: 256 | MOVSD 0(R10), X1 257 | DIVSD X0, X1 258 | MOVSD X1, 0(R10) 259 | ADDQ $8, R10 260 | SUBQ $1, SI 261 | JG DCTL 262 | 263 | DCEND: 264 | RET 265 | 266 | // func add(a,b []float64) 267 | TEXT ·Add(SB), NOSPLIT, $0 268 | // a data ptr 269 | MOVQ a_base+0(FP), R8 270 | 271 | // a len 272 | MOVQ a_len+8(FP), SI 273 | 274 | // b data ptr 275 | MOVQ b_base+24(FP), R9 276 | MOVQ R9, R10 277 | 278 | // b len 279 | MOVQ b_len+32(FP), DI 280 | MOVQ DI, R11 281 | 282 | // zero len return 283 | CMPQ SI, $0 284 | JE AEND 285 | 286 | // check tail 287 | SUBQ $2, SI 288 | JL ATAIL 289 | 290 | ALD: 291 | CMPQ DI, $1 292 | JE ALT 293 | SUBQ $2, DI 294 | JGE ALO 295 | MOVQ R10, R9 296 | MOVQ R11, DI 297 | SUBQ $2, DI 298 | 299 | ALO: 300 | MOVUPD (R9), X1 301 | ADDQ $16, R9 302 | JMP ALOOP 303 | 304 | ALT: 305 | MOVLPD (R9), X1 306 | MOVQ R10, R9 307 | MOVQ R11, DI 308 | MOVHPD (R9), X1 309 | SUBQ $1, DI 310 | ADDQ $8, R9 311 | 312 | ALOOP: 313 | MOVUPD (R8), X0 314 | ADDPD X1, X0 315 | MOVUPD X0, (R8) 316 | ADDQ $16, R8 317 | SUBQ $2, SI 318 | JGE ALD 319 | 320 | ATAIL: 321 | ADDQ $2, SI 322 | JE AEND 323 | 324 | ATL: 325 | MOVSD (R8), X0 326 | MOVSD (R9), X1 327 | ADDSD X1, X0 328 | MOVSD X0, (R8) 329 | ADDQ $8, R8 330 | ADDQ $8, R9 331 | SUBQ $1, SI 332 | JG ATL 333 | 334 | AEND: 335 | RET 336 | 337 | // func vadd(a,b[]float64) 338 | // req: len(a) == len(b) 339 | TEXT ·Vadd(SB), NOSPLIT, $0 340 | // a data ptr 341 | MOVQ a_base+0(FP), R8 342 | 343 | // a len 344 | MOVQ a_len+8(FP), SI 345 | 346 | // b data ptr 347 | MOVQ b_base+24(FP), R9 348 | 349 | // zero len return 350 | CMPQ SI, $0 351 | JE vadd_exit 352 | 353 | // check tail 354 | SUBQ $8, SI 355 | JL vadd_tail 356 | 357 | // AVX vs SSE 358 | CMPB ·AvxSupt(SB), $1 359 | JE vadd_avx_loop 360 | 361 | vadd_loop: 362 | MOVUPD (R9), X1 363 | MOVUPD 16(R9), X3 364 | MOVUPD 32(R9), X5 365 | MOVUPD 48(R9), X7 366 | 367 | MOVUPD (R8), X0 368 | ADDPD X1, X0 369 | MOVUPD 16(R8), X2 370 | ADDPD X3, X2 371 | MOVUPD 32(R8), X4 372 | ADDPD X5, X4 373 | MOVUPD 48(R8), X6 374 | ADDPD X7, X6 375 | 376 | MOVUPD X0, (R8) 377 | MOVUPD X2, 16(R8) 378 | MOVUPD X4, 32(R8) 379 | MOVUPD X6, 48(R8) 380 | ADDQ $64, R8 381 | ADDQ $64, R9 382 | SUBQ $8, SI 383 | JGE vadd_loop 384 | 385 | vadd_tail: 386 | ADDQ $8, SI 387 | JE vadd_exit 388 | 389 | vadd_tail_loop: 390 | MOVSD (R8), X15 391 | MOVSD (R9), X14 392 | ADDSD X14, X15 393 | MOVSD X15, (R8) 394 | ADDQ $8, R8 395 | ADDQ $8, R9 396 | SUBQ $1, SI 397 | JGE vadd_tail_loop 398 | JMP vadd_exit 399 | 400 | vadd_avx_loop: 401 | //VMOVDQU (R9), Y0 402 | BYTE $0xC4; BYTE $0xC1; BYTE $0x7E; BYTE $0x6F; BYTE $0x01 403 | //VMOVDQU 32(R9), Y1 404 | BYTE $0xC4; BYTE $0xC1; BYTE $0x7E; BYTE $0x6F; BYTE $0x49; BYTE $0x20 405 | 406 | // VADDPD (R8),Y0,Y0 407 | BYTE $0xC4; BYTE $0xC1; BYTE $0x7D; BYTE $0x58; BYTE $0x00 408 | // VADDPD 32(R10),Y1,Y1 409 | BYTE $0xC4; BYTE $0xC1; BYTE $0x75; BYTE $0x58; BYTE $0x48; BYTE $0x20 410 | 411 | //VMOVDQA Y0, (R8) 412 | BYTE $0xC4; BYTE $0xC1; BYTE $0x7E; BYTE $0x7F; BYTE $0x00 413 | //VMOVDQA Y1, 32(R8) 414 | BYTE $0xC4; BYTE $0xC1; BYTE $0x7E; BYTE $0x7F; BYTE $0x48; BYTE $0x20 415 | 416 | 417 | ADDQ $64, R8 418 | ADDQ $64, R9 419 | SUBQ $8, SI 420 | JGE vadd_avx_loop 421 | //VZEROUPPER 422 | BYTE $0xC5; BYTE $0xF8; BYTE $0x77 423 | ADDQ $8, SI 424 | JE vadd_exit 425 | JMP vadd_tail_loop 426 | 427 | vadd_exit: 428 | RET 429 | 430 | // func hadd(st uint64, a []float64) 431 | // req: len(a) == len(b) 432 | TEXT ·Hadd(SB), NOSPLIT, $0 433 | // a data ptr 434 | MOVQ a_base+8(FP), R8 435 | MOVQ R8, R9 436 | 437 | // a len 438 | MOVQ a_len+16(FP), SI 439 | MOVQ st+0(FP), CX 440 | MOVQ CX, DI 441 | ANDQ $1, DI 442 | 443 | 444 | CMPQ CX, $1 445 | JE hadd_exit 446 | CMPQ SI, $0 447 | JE hadd_exit 448 | CMPQ CX, $8 449 | JG hadd_big_stride 450 | CMPB ·Sse3Supt(SB), $1 451 | JE hadd_sse3_head 452 | 453 | hadd_big_stride: 454 | // AVX vs SSE 455 | CMPB ·AvxSupt(SB), $1 456 | //JE hadd_avx_head 457 | CMPB ·Sse3Supt(SB), $1 458 | JE hadd_sse3_head 459 | hadd_head: 460 | PXOR X0, X0 461 | MOVQ CX, DI 462 | SUBQ $1, DI 463 | hadd_loop: 464 | ADDPD (R8), X0 465 | ADDQ $16, R8 466 | SUBQ $2, DI 467 | JG hadd_loop 468 | JZ hadd_tail 469 | MOVAPD X0, X1 470 | UNPCKHPD X1, X0 471 | ADDPD X1,X0 472 | MOVQ X0, (R9) 473 | ADDQ $8, R9 474 | SUBQ CX, SI 475 | JG hadd_head 476 | JMP hadd_exit 477 | hadd_tail: 478 | ADDSD (R8), X0 479 | MOVAPD X0, X1 480 | UNPCKHPD X1, X0 481 | ADDPD X1,X0 482 | MOVQ X0, (R9) 483 | ADDQ $8, R9 484 | SUBQ CX, SI 485 | JZ hadd_exit 486 | MOVQ 8(R8), X0 487 | MOVQ CX, DI 488 | SUBQ $2, DI 489 | ADDQ $16, R8 490 | JMP hadd_loop 491 | hadd_sse3_head: 492 | PXOR X0, X0 493 | MOVQ CX, DI 494 | SUBQ $1, DI 495 | hadd_sse3_loop: 496 | ADDPD (R8), X0 497 | ADDQ $16, R8 498 | SUBQ $2, DI 499 | JG hadd_sse3_loop 500 | JZ hadd_sse3_tail 501 | BYTE $0x66; BYTE $0x0F; BYTE $0x7C; BYTE $0xC0 502 | // HADDPD X0, X0 //Added in 1.6 503 | MOVQ X0, (R9) 504 | ADDQ $8, R9 505 | SUBQ CX, SI 506 | JG hadd_sse3_head 507 | JMP hadd_exit 508 | hadd_sse3_tail: 509 | ADDSD (R8), X0 510 | BYTE $0x66; BYTE $0x0F; BYTE $0x7C; BYTE $0xC0 511 | // HADDPD X0, X0 //Added in 1.6 512 | MOVQ X0, (R9) 513 | ADDQ $8, R9 514 | SUBQ CX, SI 515 | JZ hadd_exit 516 | MOVQ 8(R8), X0 517 | MOVQ CX, DI 518 | SUBQ $2, DI 519 | ADDQ $16, R8 520 | JMP hadd_sse3_loop 521 | hadd_exit: 522 | RET 523 | 524 | 525 | // func subtr(a,b []float64) 526 | TEXT ·Subtr(SB), NOSPLIT, $0 527 | // a data ptr 528 | MOVQ a_base+0(FP), R8 529 | 530 | // a len 531 | MOVQ a_len+8(FP), SI 532 | 533 | // b data ptr 534 | MOVQ b_base+24(FP), R9 535 | MOVQ R9, R10 536 | 537 | // b len 538 | MOVQ b_len+32(FP), DI 539 | MOVQ DI, R11 540 | 541 | // zero len return 542 | MOVQ $0, AX 543 | CMPQ AX, SI 544 | JE SEND 545 | 546 | // check tail 547 | SUBQ $2, SI 548 | JL STAIL 549 | 550 | SLD: 551 | SUBQ $1, DI 552 | JE SLT 553 | SUBQ $1, DI 554 | JGE SLO 555 | MOVQ R10, R9 556 | MOVQ R11, DI 557 | SUBQ $2, DI 558 | 559 | SLO: 560 | MOVUPD 0(R9), X1 561 | ADDQ $16, R9 562 | JMP SLOOP 563 | 564 | SLT: 565 | MOVLPD 0(R9), X1 566 | MOVQ R10, R9 567 | MOVQ R11, DI 568 | MOVHPD 0(R9), X1 569 | SUBQ $1, DI 570 | ADDQ $8, R9 571 | 572 | SLOOP: 573 | MOVUPD 0(R8), X0 574 | SUBPD X1, X0 575 | MOVUPD X0, 0(R8) 576 | ADDQ $16, R8 577 | SUBQ $2, SI 578 | JGE SLD 579 | 580 | STAIL: 581 | ADDQ $2, SI 582 | JE SEND 583 | 584 | STL: 585 | MOVSD 0(R8), X0 586 | MOVSD 0(R9), X1 587 | SUBSD X1, X0 588 | MOVSD X0, 0(R8) 589 | ADDQ $8, R8 590 | ADDQ $8, R9 591 | SUBQ $1, SI 592 | JG STL 593 | 594 | SEND: 595 | RET 596 | 597 | // func mult(a,b []float64) 598 | TEXT ·Mult(SB), NOSPLIT, $0 599 | // a data ptr 600 | MOVQ a_base+0(FP), R8 601 | 602 | // a len 603 | MOVQ a_len+8(FP), SI 604 | 605 | // b data ptr 606 | MOVQ b_base+24(FP), R9 607 | MOVQ R9, R10 608 | 609 | // b len 610 | MOVQ b_len+32(FP), DI 611 | MOVQ DI, R11 612 | 613 | // zero len return 614 | MOVQ $0, AX 615 | CMPQ AX, SI 616 | JE MEND 617 | 618 | // check tail 619 | SUBQ $2, SI 620 | JL MTAIL 621 | 622 | MLD: 623 | SUBQ $1, DI 624 | JE MLT 625 | SUBQ $1, DI 626 | JGE MLO 627 | MOVQ R10, R9 628 | MOVQ R11, DI 629 | SUBQ $2, DI 630 | 631 | MLO: 632 | MOVUPD 0(R9), X1 633 | ADDQ $16, R9 634 | JMP MLOOP 635 | 636 | MLT: 637 | MOVLPD 0(R9), X1 638 | MOVQ R10, R9 639 | MOVQ R11, DI 640 | MOVHPD 0(R9), X1 641 | SUBQ $1, DI 642 | ADDQ $8, R9 643 | 644 | MLOOP: 645 | MOVUPD 0(R8), X0 646 | MULPD X1, X0 647 | MOVUPD X0, 0(R8) 648 | ADDQ $16, R8 649 | SUBQ $2, SI 650 | JGE MLD 651 | 652 | MTAIL: 653 | ADDQ $2, SI 654 | JE MEND 655 | 656 | MTL: 657 | MOVSD 0(R8), X0 658 | MOVSD 0(R9), X1 659 | MULSD X1, X0 660 | MOVSD X0, 0(R8) 661 | ADDQ $8, R8 662 | ADDQ $8, R9 663 | SUBQ $1, SI 664 | JG MTL 665 | 666 | MEND: 667 | RET 668 | 669 | // func div(a,b []float64) 670 | TEXT ·Div(SB), NOSPLIT, $0 671 | // a data ptr 672 | MOVQ a_base+0(FP), R8 673 | 674 | // a len 675 | MOVQ a_len+8(FP), SI 676 | 677 | // b data ptr 678 | MOVQ b_base+24(FP), R9 679 | MOVQ R9, R10 680 | 681 | // b len 682 | MOVQ b_len+32(FP), DI 683 | MOVQ DI, R11 684 | 685 | // zero len return 686 | MOVQ $0, AX 687 | CMPQ AX, SI 688 | JE DEND 689 | 690 | // check tail 691 | SUBQ $2, SI 692 | JL DTAIL 693 | 694 | DLD: 695 | SUBQ $1, DI 696 | JE DLT 697 | SUBQ $1, DI 698 | JGE DLO 699 | MOVQ R10, R9 700 | MOVQ R11, DI 701 | SUBQ $2, DI 702 | 703 | DLO: 704 | MOVUPD 0(R9), X1 705 | ADDQ $16, R9 706 | JMP DLOOP 707 | DLT: 708 | MOVLPD 0(R9), X1 709 | MOVQ R10, R9 710 | MOVQ R11, DI 711 | MOVHPD 0(R9), X1 712 | SUBQ $1, DI 713 | ADDQ $8, R9 714 | 715 | DLOOP: 716 | MOVUPD 0(R8), X0 717 | DIVPD X1, X0 718 | MOVUPD X0, 0(R8) 719 | ADDQ $16, R8 720 | SUBQ $2, SI 721 | JGE DLD 722 | 723 | DTAIL: 724 | ADDQ $2, SI 725 | JE DEND 726 | DTL: 727 | MOVSD 0(R8), X0 728 | MOVSD 0(R9), X1 729 | DIVSD X1, X0 730 | MOVSD X0, 0(R8) 731 | ADDQ $8, R8 732 | ADDQ $8, R9 733 | SUBQ $1, SI 734 | JG DTL 735 | 736 | DEND: 737 | RET 738 | 739 | // func fma12(a float64, x,b []float64) 740 | // x[i] = a*x[i]+b[i] 741 | TEXT ·Fma12(SB), NOSPLIT, $0 742 | // a ptr 743 | MOVSD a+0(FP), X2 744 | SHUFPD $0, X2, X2 745 | 746 | // x data ptr 747 | MOVQ x_base+8(FP), R8 748 | 749 | // x len 750 | MOVQ x_len+16(FP), SI 751 | 752 | // b data ptr 753 | MOVQ b_base+32(FP), R9 754 | MOVQ R9, R10 755 | 756 | // b len 757 | MOVQ b_len+40(FP), DI 758 | MOVQ DI, R11 759 | 760 | // zero len return 761 | CMPQ SI, $0 762 | JE F12END 763 | 764 | // check tail 765 | SUBQ $2, SI 766 | JL F12TAIL 767 | 768 | F12LD: 769 | CMPQ DI, $1 770 | JE F12LT 771 | SUBQ $2, DI 772 | JGE F12LO 773 | MOVQ R10, R9 774 | MOVQ R11, DI 775 | SUBQ $2, DI 776 | 777 | F12LO: 778 | MOVUPD (R9), X1 779 | ADDQ $16, R9 780 | JMP F12LOOP 781 | 782 | F12LT: 783 | MOVLPD (R9), X1 784 | MOVQ R10, R9 785 | MOVQ R11, DI 786 | MOVHPD (R9), X1 787 | SUBQ $1, DI 788 | ADDQ $8, R9 789 | 790 | F12LOOP: 791 | MOVUPD (R8), X0 792 | MULPD X2, X0 793 | ADDPD X1, X0 794 | MOVUPD X0, (R8) 795 | ADDQ $16, R8 796 | SUBQ $2, SI 797 | JGE F12LD 798 | JMP F12TAIL 799 | 800 | F12LDF: 801 | CMPQ DI, $1 802 | JE F12LTF 803 | SUBQ $2, DI 804 | JGE F12LOF 805 | MOVQ R10, R9 806 | MOVQ R11, DI 807 | SUBQ $2, DI 808 | 809 | F12LOF: 810 | MOVUPD (R9), X1 811 | ADDQ $16, R9 812 | JMP F12LOOPF 813 | 814 | F12LTF: 815 | MOVLPD (R9), X1 816 | MOVQ R10, R9 817 | MOVQ R11, DI 818 | MOVHPD (R9), X1 819 | SUBQ $1, DI 820 | ADDQ $8, R9 821 | 822 | F12LOOPF: 823 | MOVUPD (R8), X0 824 | 825 | // VMFADD213PD X0, X1, X2 826 | BYTE $0xC4; BYTE $0xE2; BYTE $0xF1; BYTE $0x98; BYTE $0xC2 827 | MOVUPD X0, (R8) 828 | ADDQ $16, R8 829 | SUBQ $2, SI 830 | JGE F12LDF 831 | 832 | F12TAIL: 833 | ADDQ $2, SI 834 | JE F12END 835 | 836 | F12TL: 837 | MOVSD (R8), X0 838 | MOVSD (R9), X1 839 | MULPD X2, X0 840 | ADDPD X1, X0 841 | MOVSD X0, (R8) 842 | ADDQ $8, R8 843 | ADDQ $8, R9 844 | SUBQ $1, SI 845 | JG F12TL 846 | 847 | F12END: 848 | RET 849 | 850 | // func fma21(a float64, x,b []float64) 851 | // x[i] = x[i]*b[i]+a 852 | TEXT ·Fma21(SB), NOSPLIT, $0 853 | // a ptr 854 | MOVSD a+0(FP), X2 855 | SHUFPD $0, X2, X2 856 | 857 | // x data ptr 858 | MOVQ x_base+8(FP), R8 859 | 860 | // x len 861 | MOVQ x_len+16(FP), SI 862 | 863 | // b data ptr 864 | MOVQ b_base+32(FP), R9 865 | MOVQ R9, R10 866 | 867 | // b len 868 | MOVQ b_len+40(FP), DI 869 | MOVQ DI, R11 870 | 871 | // zero len return 872 | CMPQ SI, $0 873 | JE F21END 874 | 875 | // check tail 876 | SUBQ $2, SI 877 | JL F21TAIL 878 | 879 | F21LD: 880 | CMPQ DI, $1 881 | JE F21LT 882 | SUBQ $2, DI 883 | JGE F21LO 884 | MOVQ R10, R9 885 | MOVQ R11, DI 886 | SUBQ $2, DI 887 | 888 | F21LO: 889 | MOVUPD (R9), X1 890 | ADDQ $16, R9 891 | JMP F21LOOP 892 | 893 | F21LT: 894 | MOVLPD (R9), X1 895 | MOVQ R10, R9 896 | MOVQ R11, DI 897 | MOVHPD (R9), X1 898 | SUBQ $1, DI 899 | ADDQ $8, R9 900 | 901 | F21LOOP: 902 | MOVUPD (R8), X0 903 | MULPD X1, X0 904 | ADDPD X2, X0 905 | MOVUPD X0, (R8) 906 | ADDQ $16, R8 907 | SUBQ $2, SI 908 | JGE F21LD 909 | JMP F21TAIL 910 | 911 | F21LDF: 912 | CMPQ DI, $1 913 | JE F21LTF 914 | SUBQ $2, DI 915 | JGE F21LOF 916 | MOVQ R10, R9 917 | MOVQ R11, DI 918 | SUBQ $2, DI 919 | 920 | F21LOF: 921 | MOVUPD (R9), X1 922 | ADDQ $16, R9 923 | JMP F21LOOPF 924 | 925 | F21LTF: 926 | MOVLPD (R9), X1 927 | MOVQ R10, R9 928 | MOVQ R11, DI 929 | MOVHPD (R9), X1 930 | SUBQ $1, DI 931 | ADDQ $8, R9 932 | 933 | F21LOOPF: 934 | MOVUPD (R8), X0 935 | 936 | // VMFADD213PD X0, X1, X2 937 | BYTE $0xC4; BYTE $0xE2; BYTE $0xF1; BYTE $0xA8; BYTE $0xC2 938 | MOVUPD X0, (R8) 939 | ADDQ $16, R8 940 | SUBQ $2, SI 941 | JGE F21LDF 942 | 943 | F21TAIL: 944 | ADDQ $2, SI 945 | JE F21END 946 | 947 | F21TL: 948 | MOVSD (R8), X0 949 | MOVSD (R9), X1 950 | MULPD X1, X0 951 | ADDPD X2, X0 952 | MOVSD X0, (R8) 953 | ADDQ $8, R8 954 | ADDQ $8, R9 955 | SUBQ $1, SI 956 | JG F21TL 957 | 958 | F21END: 959 | RET 960 | -------------------------------------------------------------------------------- /internal/arithmetic_nasm.go: -------------------------------------------------------------------------------- 1 | //+build !amd64 noasm appengine 2 | 3 | package asm 4 | 5 | var ( 6 | Sse3Supt, AvxSupt, Avx2Supt, FmaSupt bool 7 | ) 8 | 9 | func initasm() { 10 | } 11 | 12 | func AddC(c float64, d []float64) { 13 | for i := range d { 14 | d[i] += c 15 | } 16 | } 17 | 18 | func SubtrC(c float64, d []float64) { 19 | for i := range d { 20 | d[i] -= c 21 | } 22 | } 23 | 24 | func MultC(c float64, d []float64) { 25 | for i := range d { 26 | d[i] *= c 27 | } 28 | } 29 | 30 | func DivC(c float64, d []float64) { 31 | for i := range d { 32 | d[i] /= c 33 | } 34 | } 35 | 36 | func Add(a, b []float64) { 37 | lna, lnb := len(a), len(b) 38 | for i, j := 0, 0; i < lna; i, j = i+1, j+1 { 39 | if j >= lnb { 40 | j = 0 41 | } 42 | a[i] += b[j] 43 | } 44 | } 45 | 46 | func Vadd(a, b []float64) { 47 | for i := range a { 48 | a[i] += b[i] 49 | } 50 | } 51 | 52 | func Hadd(st uint64, a []float64) { 53 | ln := uint64(len(a)) 54 | for k := uint64(0); k < ln/st; k++ { 55 | a[k] = a[k*st] 56 | for i := uint64(1); i < st; i++ { 57 | a[k] += a[k*st+i] 58 | } 59 | } 60 | } 61 | 62 | func Subtr(a, b []float64) { 63 | lna, lnb := len(a), len(b) 64 | for i, j := 0, 0; i < lna; i, j = i+1, j+1 { 65 | if j >= lnb { 66 | j = 0 67 | } 68 | a[i] -= b[j] 69 | } 70 | } 71 | 72 | func Mult(a, b []float64) { 73 | lna, lnb := len(a), len(b) 74 | for i, j := 0, 0; i < lna; i, j = i+1, j+1 { 75 | if j >= lnb { 76 | j = 0 77 | } 78 | a[i] *= b[j] 79 | } 80 | } 81 | 82 | func Div(a, b []float64) { 83 | lna, lnb := len(a), len(b) 84 | for i, j := 0, 0; i < lna; i, j = i+1, j+1 { 85 | if j >= lnb { 86 | j = 0 87 | } 88 | a[i] /= b[j] 89 | } 90 | } 91 | 92 | func Fma12(a float64, x, b []float64) { 93 | lnx, lnb := len(x), len(b) 94 | for i, j := 0, 0; i < lnx; i, j = i+1, j+1 { 95 | if j >= lnb { 96 | j = 0 97 | } 98 | x[i] = a*x[i] + b[j] 99 | } 100 | } 101 | 102 | func Fma21(a float64, x, b []float64) { 103 | lnx, lnb := len(x), len(b) 104 | for i, j := 0, 0; i < lnx; i, j = i+1, j+1 { 105 | if j >= lnb { 106 | j = 0 107 | } 108 | x[i] = x[i]*b[j] + a 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /internal/boolOps_amd64.go: -------------------------------------------------------------------------------- 1 | //+build !noasm,!appengine 2 | 3 | package asm 4 | 5 | func findBool(vals []bool, find bool) (flg bool) 6 | -------------------------------------------------------------------------------- /internal/boolOps_amd64.s: -------------------------------------------------------------------------------- 1 | //+build !noasm,!appengine 2 | 3 | #define NOSPLIT 7 4 | 5 | TEXT ·findBool(SB), NOSPLIT, $0 6 | MOVQ vals_base+0(FP), R8 7 | MOVQ vals_len+8(FP), SI 8 | MOVB find+24(FP), R10 9 | PXOR X0,X0 10 | 11 | CMPQ SI, $0 12 | JE failed 13 | 14 | CMPB R10, $0 15 | JE loop 16 | 17 | MOVQ R10, X0 18 | PUNPCKLBW X0, X0 19 | PSHUFLW $0, X0, X0 20 | PUNPCKLQDQ X0, X0 21 | loop: 22 | MOVOU X0, X1 23 | PCMPEQB (R8), X1 24 | PMOVMSKB X1, R9 25 | BSFL R9,R9 26 | JNZ fnd 27 | ADDQ $16, R8 28 | SUBQ $16, SI 29 | JG loop 30 | JMP failed 31 | fnd: 32 | CMPQ R9, SI 33 | JG failed 34 | MOVB $1, flg+32(FP) 35 | RET 36 | failed: 37 | MOVB $0, flg+32(FP) 38 | RET 39 | 40 | 41 | -------------------------------------------------------------------------------- /internal/boolOps_nasm.go: -------------------------------------------------------------------------------- 1 | //+build !amd64 noasm appengine 2 | 3 | package asm 4 | 5 | func findBool(vals []bool, find bool) (flg bool) { 6 | for _, v := range vals { 7 | if v == find { 8 | return true 9 | } 10 | } 11 | return false 12 | } 13 | -------------------------------------------------------------------------------- /internal/matmul_amd64.go: -------------------------------------------------------------------------------- 1 | //+build amd64,!noasm,!appengine 2 | 3 | package asm 4 | 5 | func DotProd(a, b []float64) float64 6 | -------------------------------------------------------------------------------- /internal/matmul_amd64.s: -------------------------------------------------------------------------------- 1 | // +build !noasm !appengine 2 | 3 | #define NOSPLIT 7 4 | 5 | // func dotProd(a,b []float64) (float64) 6 | TEXT ·DotProd(SB), NOSPLIT, $0 7 | // a data ptr 8 | MOVQ a_base+0(FP), R8 9 | MOVQ a_len+8(FP), SI 10 | MOVQ b_base+24(FP), R9 11 | XORQ DI, DI 12 | PXOR X0, X0 13 | 14 | // zero len return 15 | CMPQ SI, $0 16 | JE dotp_end 17 | 18 | // check tail 19 | SUBQ $2, SI 20 | JL dotp_tail 21 | 22 | CMPB ·FmaSupt(SB), $1 23 | JE dotp_fma_loop 24 | 25 | dotp_loop: 26 | MOVOU (R8)(DI*8), X1 27 | MULPD (R9)(DI*8), X1 28 | ADDPD X1, X0 29 | ADDQ $2, DI 30 | CMPQ DI, SI 31 | JLE dotp_loop 32 | dotp_tail: 33 | ADDQ $1, SI 34 | CMPQ DI, SI 35 | JNE dotp_end 36 | MOVSD (R8)(DI*8), X1 37 | MULSD (R9)(DI*8), X1 38 | ADDSD X1, X0 39 | JMP dotp_end 40 | 41 | dotp_fma_loop: 42 | MOVOU (R8)(DI*8), X1 43 | // VMFADD231PD X1, (R9)(DI*8), X0 (x0 += x1*(R9) 44 | BYTE $0xC4; BYTE $0xC2; BYTE $0xF1; BYTE $0xB8; BYTE $0x04; BYTE $0xF9 45 | ADDQ $2, DI 46 | CMPQ DI, SI 47 | JLE dotp_fma_loop 48 | dotp_fma_tail: 49 | ADDQ $1, SI 50 | CMPQ DI, SI 51 | JNE dotp_end 52 | MOVSD (R8)(DI*8), X1 53 | // VMFADD231SD X1, (R9)(DI*8), X0 (x0 += x1*x2) 54 | BYTE $0xC4; BYTE $0xC2; BYTE $0xF1; BYTE $0xB9; BYTE $0x04; BYTE $0xF9 55 | dotp_end: 56 | CMPB ·Sse3Supt(SB), $1 57 | JE dotp_sse3 58 | MOVAPD X0, X1 59 | UNPCKHPD X1, X0 60 | ADDPD X1, X0 61 | MOVSD X0, ret+48(FP) 62 | RET 63 | dotp_sse3: 64 | BYTE $0x66; BYTE $0x0F; BYTE $0x7C; BYTE $0xC0 65 | //HADDPD X0, X0 //Added in 1.6 66 | MOVSD X0, ret+48(FP) 67 | RET 68 | -------------------------------------------------------------------------------- /internal/matmul_nasm.go: -------------------------------------------------------------------------------- 1 | //+build !amd64 noasm appengine 2 | 3 | package asm 4 | 5 | func DotProd(a, b []float64) float64 { 6 | var ret float64 7 | for i := range a { 8 | ret += a[i] * b[i] 9 | } 10 | return ret 11 | } 12 | -------------------------------------------------------------------------------- /logical_opt.go: -------------------------------------------------------------------------------- 1 | 2 | package arrgo 3 | 4 | 5 | func (a *Arrb) LogicalAnd(b *Arrb) *Arrb { 6 | var t = EmptyB(a.shape...) 7 | for i, v := range a.data { 8 | t.data[i] = v && b.data[i] 9 | } 10 | return t 11 | } 12 | 13 | func (a *Arrb) LogicalOr(b *Arrb) *Arrb { 14 | var t = EmptyB(a.shape...) 15 | for i, v := range a.data { 16 | t.data[i] = v || b.data[i] 17 | } 18 | return t 19 | } 20 | 21 | func (a *Arrb) LogicalNot() *Arrb { 22 | var t = EmptyB(a.shape...) 23 | for i, v := range a.data { 24 | t.data[i] = !v 25 | } 26 | return t 27 | } 28 | 29 | func LogicalAnd(a, b *Arrb) *Arrb { 30 | var t = EmptyB(a.shape...) 31 | for i, v := range a.data { 32 | t.data[i] = v && b.data[i] 33 | } 34 | return t 35 | } 36 | 37 | func LogicalOr(a, b *Arrb) *Arrb { 38 | var t = EmptyB(a.shape...) 39 | for i, v := range a.data { 40 | t.data[i] = v || b.data[i] 41 | } 42 | return t 43 | } 44 | 45 | func LogicalNot(a *Arrb) *Arrb { 46 | var t = EmptyB(a.shape...) 47 | for i, v := range a.data { 48 | t.data[i] = !v 49 | } 50 | return t 51 | } -------------------------------------------------------------------------------- /numeric_arrb.go: -------------------------------------------------------------------------------- 1 | package arrgo 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | type Arrb struct { 9 | shape []int 10 | strides []int 11 | data []bool 12 | } 13 | 14 | //通过[]bool,形状来创建多维数组。 15 | //输入参数1:data []bool,以·C· 顺序存储,作为多维数组的输入数据,内部复制一份新的internalData,不改变data。 16 | //输入参数2:shape ...int,指定多维数组的形状,多维,类似numpy中的shape。 17 | // 如果某一个(仅支持一个维度)维度为负数,则根据len(data)推断该维度的大小。 18 | //情况1:如果不指定shape,而且data为nil,则创建一个空的*Arrb。 19 | //情况2:如果不指定shape,而且data不为nil,则创建一个len(data)大小的一维*Arrb。 20 | //情况3:如果指定shape,而且data不为nil,则根据data大小创建多维数组,如果len(data)不等于shape,或者len(data)不能整除shape,抛出异常。 21 | //情况4:如果指定shape,而且data为nil,则创建shape大小的全为false的多维数组。 22 | func ArrayB(data []bool, shape ...int) *Arrb { 23 | if len(shape) == 0 && data == nil { 24 | return &Arrb{ 25 | shape: []int{0}, 26 | strides: []int{0, 1}, 27 | data: []bool{}, 28 | } 29 | } 30 | 31 | if len(shape) == 0 && data != nil { 32 | internalData := make([]bool, len(data)) //复制data,不影响输入的值。 33 | copy(internalData, data) 34 | return &Arrb{ 35 | shape: []int{len(data)}, 36 | strides: []int{len(data), 1}, 37 | data: internalData, 38 | } 39 | } 40 | 41 | if data == nil { 42 | for _, v := range shape { 43 | if v <= 0 { 44 | fmt.Println("shape should be positive when data is nill") 45 | panic(SHAPE_ERROR) 46 | } 47 | } 48 | length := ProductIntSlice(shape) 49 | internalShape := make([]int, len(shape)) 50 | copy(internalShape, shape) 51 | strides := make([]int, len(shape)+1) 52 | strides[len(shape)] = 1 53 | for i := len(shape) - 1; i >= 0; i-- { 54 | strides[i] = strides[i+1] * internalShape[i] 55 | } 56 | 57 | return &Arrb{ 58 | shape: internalShape, 59 | strides: strides, 60 | data: make([]bool, length), 61 | } 62 | } 63 | 64 | var dataLength = len(data) 65 | negativeIndex := -1 66 | internalShape := make([]int, len(shape)) 67 | copy(internalShape, shape) 68 | for k, v := range shape { 69 | if v < 0 { 70 | if negativeIndex < 0 { 71 | negativeIndex = k 72 | internalShape[k] = 1 73 | } else { 74 | fmt.Println("shape can only have one negative demention.") 75 | panic(SHAPE_ERROR) 76 | } 77 | } 78 | } 79 | shapeLength := ProductIntSlice(internalShape) 80 | 81 | if dataLength < shapeLength { 82 | fmt.Println("data length is shorter than shape length.") 83 | panic(SHAPE_ERROR) 84 | } 85 | if (dataLength % shapeLength) != 0 { 86 | fmt.Println("data length cannot divided by shape length") 87 | panic(SHAPE_ERROR) 88 | } 89 | 90 | if negativeIndex >= 0 { 91 | internalShape[negativeIndex] = dataLength / shapeLength 92 | } 93 | 94 | strides := make([]int, len(internalShape)+1) 95 | strides[len(internalShape)] = 1 96 | for i := len(internalShape) - 1; i >= 0; i-- { 97 | strides[i] = strides[i+1] * internalShape[i] 98 | } 99 | 100 | internalData := make([]bool, len(data)) 101 | copy(internalData, data) 102 | 103 | return &Arrb{ 104 | shape: internalShape, 105 | strides: strides, 106 | data: internalData, 107 | } 108 | } 109 | 110 | //创建shape形状的多维布尔数组,全部填充为fillvalue。 111 | //必须指定shape,否则抛出异常。 112 | func FillB(fullValue bool, shape ...int) *Arrb { 113 | if len(shape) == 0 { 114 | fmt.Println("shape is empty!") 115 | panic(SHAPE_ERROR) 116 | } 117 | arr := ArrayB(nil, shape...) 118 | for i := range arr.data { 119 | arr.data[i] = fullValue 120 | } 121 | 122 | return arr 123 | } 124 | 125 | //创建全为false,形状位shape的多维布尔数组 126 | func EmptyB(shape ...int) (a *Arrb) { 127 | a = FillB(false, shape...) 128 | return 129 | } 130 | 131 | func (a *Arrb) String() (s string) { 132 | switch { 133 | case a == nil: 134 | return "" 135 | case a.shape == nil || a.strides == nil || a.data == nil: 136 | return "" 137 | case a.strides[0] == 0: 138 | return "[]" 139 | } 140 | 141 | stride := a.strides[len(a.strides)-2] 142 | for i, k := 0, 0; i+stride <= len(a.data); i, k = i+stride, k+1 { 143 | 144 | t := "" 145 | for j, v := range a.strides { 146 | if i%v == 0 && j < len(a.strides)-2 { 147 | t += "[" 148 | } 149 | } 150 | 151 | s += strings.Repeat(" ", len(a.shape)-len(t)-1) + t 152 | s += fmt.Sprint(a.data[i : i+stride]) 153 | 154 | t = "" 155 | for j, v := range a.strides { 156 | if (i+stride)%v == 0 && j < len(a.strides)-2 { 157 | t += "]" 158 | } 159 | } 160 | 161 | s += t + strings.Repeat(" ", len(a.shape)-len(t)-1) 162 | if i+stride != len(a.data) { 163 | s += "\n" 164 | if len(t) > 0 { 165 | s += "\n" 166 | } 167 | } 168 | } 169 | return 170 | } 171 | 172 | //如果多维布尔数组元素都为真,返回true,否则返回false。 173 | func (ab *Arrb) AllTrues() bool { 174 | if len(ab.data) == 0 { 175 | return false 176 | } 177 | for _, v := range ab.data { 178 | if v == false { 179 | return false 180 | } 181 | } 182 | return true 183 | } 184 | 185 | //如果多维布尔数组元素都为假,返回false,否则返回true。 186 | func (ab *Arrb) AnyTrue() bool { 187 | if len(ab.data) == 0 { 188 | return false 189 | } 190 | for _, v := range ab.data { 191 | if v == true { 192 | return true 193 | } 194 | } 195 | return false 196 | } 197 | 198 | //返回多维数组中真值的个数。 199 | func (a *Arrb) Sum() int { 200 | sum := 0 201 | for _, v := range a.data { 202 | if v { 203 | sum++ 204 | } 205 | } 206 | return sum 207 | } -------------------------------------------------------------------------------- /numeric_arrb_test.go: -------------------------------------------------------------------------------- 1 | package arrgo 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | ) 7 | 8 | func TestArrayBCond1(t *testing.T) { 9 | arr := ArrayB(nil) 10 | if SameBoolSlice(arr.data, []bool{}) != true { 11 | t.Error("ArrayB data should be []bool{}, got ", arr.data) 12 | } 13 | if SameIntSlice(arr.shape, []int{0}) != true { 14 | t.Error("ArrayB shape should be []int{0}, got ", arr.shape) 15 | } 16 | if SameIntSlice(arr.strides, []int{0, 1}) != true { 17 | t.Error("ArrayB strides should be []int{0, 1}, got ", arr.shape) 18 | } 19 | } 20 | 21 | func TestArrayBCond2(t *testing.T) { 22 | arr := ArrayB([]bool{true, true, true}) 23 | if SameBoolSlice(arr.data, []bool{true, true, true}) != true { 24 | t.Error("ArrayB data should be []bool{true, true, true}, got ", arr.data) 25 | } 26 | if SameIntSlice(arr.shape, []int{3}) != true { 27 | t.Error("ArrayB shape should be []int{3}, got ", arr.shape) 28 | } 29 | if SameIntSlice(arr.strides, []int{3, 1}) != true { 30 | t.Error("ArrayB strides should be []int{3, 1}, got ", arr.shape) 31 | } 32 | } 33 | 34 | func TestArrayBCond3ExceptionTwoNegtiveDims(t *testing.T) { 35 | defer func() { 36 | r := recover() 37 | if r != SHAPE_ERROR { 38 | t.Error("Exepcted shape error, got ", r) 39 | } 40 | }() 41 | 42 | ArrayB([]bool{true, true, true}, -1, -1, 4) 43 | } 44 | 45 | func TestArrayBCond3ExceptionLengError(t *testing.T) { 46 | defer func() { 47 | r := recover() 48 | if r != SHAPE_ERROR { 49 | t.Error("Exepcted shape error, got ", r) 50 | } 51 | }() 52 | 53 | ArrayB([]bool{true, true, true}, 3, 4, 5) 54 | } 55 | 56 | func TestArrayBCond3ExceptionDivError(t *testing.T) { 57 | defer func() { 58 | r := recover() 59 | if r != SHAPE_ERROR { 60 | t.Error("Exepcted shape error, got ", r) 61 | } 62 | }() 63 | 64 | ArrayB([]bool{true, true, true, true}, -1, 3) 65 | } 66 | 67 | func TestArrayBCond3(t *testing.T) { 68 | arr := ArrayB([]bool{true, true, true, true}, 2, 2) 69 | if !SameIntSlice(arr.shape, []int{2, 2}) { 70 | t.Error("Expected [true, true, true, true], got ", arr.shape) 71 | } 72 | if !SameIntSlice(arr.strides, []int{4, 2, 1}) { 73 | t.Error("Expected [4,2,1], got", arr.strides) 74 | } 75 | if !SameBoolSlice(arr.data, []bool{true, true, true, true}) { 76 | t.Error("Expected [true, true, true, true], got ", arr.data) 77 | } 78 | 79 | arr = ArrayB([]bool{true, true, true, true}, 2, -1) 80 | if !SameIntSlice(arr.shape, []int{2, 2}) { 81 | t.Error("Expected [2, 2], got ", arr.shape) 82 | } 83 | if !SameIntSlice(arr.strides, []int{4, 2, 1}) { 84 | t.Error("Expected [4,2,1], got", arr.strides) 85 | } 86 | if !SameBoolSlice(arr.data, []bool{true, true, true, true}) { 87 | t.Error("Expected [true, true, true, true], got ", arr.data) 88 | } 89 | } 90 | 91 | func TestArrayBCond4(t *testing.T) { 92 | arr := ArrayB(nil, 2, 3) 93 | if SameBoolSlice(arr.data, []bool{false, false, false, false, false, false}) != true { 94 | t.Error("ArrayB data should be []bool{false, false, false, false, false, false}, got ", arr.data) 95 | } 96 | if SameIntSlice(arr.shape, []int{2, 3}) != true { 97 | t.Error("ArrayB shape should be []int{2, 3}, got ", arr.shape) 98 | } 99 | if SameIntSlice(arr.strides, []int{6, 3, 1}) != true { 100 | t.Error("ArrayB strides should be []int{6, 3, 1}, got ", arr.shape) 101 | } 102 | 103 | defer func() { 104 | err := recover() 105 | if err != SHAPE_ERROR { 106 | t.Error("should panic shape error, got ", err) 107 | } 108 | }() 109 | 110 | ArrayB(nil, -1, 2, 3) 111 | } 112 | 113 | func TestFillB(t *testing.T) { 114 | arr := FillB(true, 3) 115 | 116 | if !SameIntSlice(arr.shape, []int{3}) { 117 | t.Errorf("Expected [3], got %v", arr.shape) 118 | } 119 | 120 | if !SameIntSlice(arr.strides, []int{3, 1}) { 121 | t.Errorf("Expected [3, 1], got %v", arr.strides) 122 | } 123 | 124 | if !SameBoolSlice(arr.data, []bool{true, true, true}) { 125 | t.Errorf("Expected [true, true, true], got %v", arr.data) 126 | } 127 | } 128 | 129 | func TestFillBException(t *testing.T) { 130 | defer func() { 131 | r := recover() 132 | 133 | if r != SHAPE_ERROR { 134 | t.Errorf("Expected SHAPE_ERROR, got %v", r) 135 | } 136 | }() 137 | 138 | FillB(true) 139 | } 140 | 141 | func TestEmptyB(t *testing.T) { 142 | arr := EmptyB(3) 143 | if !SameBoolSlice(arr.data, []bool{false, false, false}) { 144 | t.Errorf("Expected [false, false, false], got %v", arr.data) 145 | } 146 | } 147 | 148 | func TestArrb_AllTrues(t *testing.T) { 149 | arr := ArrayB([]bool{true, true}) 150 | if arr.AllTrues() != true { 151 | t.Errorf("Expected true, got %t", arr.AllTrues()) 152 | } 153 | 154 | arr = ArrayB([]bool{true, false}) 155 | if arr.AllTrues() != false { 156 | t.Errorf("EXepcted false, got %t", arr.AllTrues()) 157 | } 158 | } 159 | 160 | func TestArrb_AnyTrue(t *testing.T) { 161 | arr := ArrayB([]bool{true, true}) 162 | if arr.AnyTrue() != true { 163 | t.Errorf("Expected true, got %t", arr.AnyTrue()) 164 | } 165 | 166 | arr = ArrayB([]bool{true, false}) 167 | if arr.AnyTrue() != true { 168 | t.Errorf("EXepcted true, got %t", arr.AnyTrue()) 169 | } 170 | 171 | arr = ArrayB([]bool{false, false}) 172 | if arr.AnyTrue() != false { 173 | t.Errorf("EXepcted false, got %t", arr.AnyTrue()) 174 | } 175 | } 176 | 177 | func TestArrb_String(t *testing.T) { 178 | var arr *Arrb 179 | 180 | if arr.String() != "" { 181 | t.Errorf("Expected , git %s", arr.String()) 182 | } 183 | 184 | arr = EmptyB(2) 185 | arr.shape = nil 186 | if arr.String() != "" { 187 | t.Errorf("Expected , git %s", arr.String()) 188 | } 189 | 190 | arr = EmptyB(2) 191 | arr.strides = make([]int, 2) 192 | if arr.String() != "[]" { 193 | t.Errorf("Expected [], got %s", arr.String()) 194 | } 195 | 196 | arr = ArrayB([]bool{true, false}, 2, 1) 197 | if strings.Replace(arr.String(), "\n", ":", -1) != "[[true] : [false]]" { 198 | t.Errorf("Expected [[true]\n[false]], got %s", arr.String()) 199 | } 200 | } 201 | 202 | func TestArrb_Sum(t *testing.T) { 203 | arr := ArrayB([]bool{true, true}) 204 | if arr.Sum() != 2 { 205 | t.Errorf("Expected 2, got %d", arr.Sum()) 206 | } 207 | 208 | arr = ArrayB([]bool{true, false}) 209 | if arr.Sum() != 1 { 210 | t.Errorf("Expected 1, got %d", arr.Sum()) 211 | } 212 | 213 | arr = ArrayB([]bool{false, false}) 214 | if arr.Sum() != 0 { 215 | t.Errorf("Expected 0, got %d", arr.Sum()) 216 | } 217 | } 218 | -------------------------------------------------------------------------------- /numeric_arrf.go: -------------------------------------------------------------------------------- 1 | package arrgo 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "math" 7 | ) 8 | 9 | type Arrf struct { 10 | shape []int 11 | strides []int 12 | data []float64 13 | } 14 | 15 | //通过[]float64,形状来创建多维数组。 16 | //输入参数1:data []float64,以·C· 顺序存储,作为多维数组的输入数据,内部复制一份新的internalData,不改变data。 17 | //输入参数2:shape ...int,指定多维数组的形状,多维,类似numpy中的shape。 18 | // 如果某一个(仅支持一个维度)维度为负数,则根据len(data)推断该维度的大小。 19 | //情况1:如果不指定shape,而且data为nil,则创建一个空的*Arrf。 20 | //情况2:如果不指定shape,而且data不为nil,则创建一个len(data)大小的一维*Arrf。 21 | //情况3:如果指定shape,而且data不为nil,则根据data大小创建多维数组,如果len(data)不等于shape,或者len(data)不能整除shape,抛出异常。 22 | //情况4:如果指定shape,而且data为nil,则创建shape大小的全为0.0的多维数组。 23 | func Array(data []float64, shape ...int) *Arrf { 24 | if len(shape) == 0 && data == nil { 25 | return &Arrf{ 26 | shape: []int{0}, 27 | strides: []int{0, 1}, 28 | data: []float64{}, 29 | } 30 | } 31 | 32 | if len(shape) == 0 && data != nil { 33 | internalData := make([]float64, len(data)) //复制data,不影响输入的值。 34 | copy(internalData, data) 35 | return &Arrf{ 36 | shape: []int{len(data)}, 37 | strides: []int{len(data), 1}, 38 | data: internalData, 39 | } 40 | } 41 | 42 | if data == nil { 43 | for _, v := range shape { 44 | if v <= 0 { 45 | fmt.Println("shape should be positive when data is nill") 46 | panic(SHAPE_ERROR) 47 | } 48 | } 49 | length := ProductIntSlice(shape) 50 | internalShape := make([]int, len(shape)) 51 | copy(internalShape, shape) 52 | strides := make([]int, len(shape)+1) 53 | strides[len(shape)] = 1 54 | for i := len(shape) - 1; i >= 0; i-- { 55 | strides[i] = strides[i+1] * internalShape[i] 56 | } 57 | 58 | return &Arrf{ 59 | shape: internalShape, 60 | strides: strides, 61 | data: make([]float64, length), 62 | } 63 | } 64 | 65 | var dataLength = len(data) 66 | negativeIndex := -1 67 | internalShape := make([]int, len(shape)) 68 | copy(internalShape, shape) 69 | for k, v := range shape { 70 | if v < 0 { 71 | if negativeIndex < 0 { 72 | negativeIndex = k 73 | internalShape[k] = 1 74 | } else { 75 | fmt.Println("shape can only have one negative demention.") 76 | panic(SHAPE_ERROR) 77 | } 78 | } 79 | } 80 | shapeLength := ProductIntSlice(internalShape) 81 | 82 | if dataLength < shapeLength { 83 | fmt.Println("data length is shorter than shape length.") 84 | panic(SHAPE_ERROR) 85 | } 86 | if (dataLength % shapeLength) != 0 { 87 | fmt.Println("data length cannot divided by shape length") 88 | panic(SHAPE_ERROR) 89 | } 90 | 91 | if negativeIndex >= 0 { 92 | internalShape[negativeIndex] = dataLength / shapeLength 93 | } 94 | 95 | strides := make([]int, len(internalShape)+1) 96 | strides[len(internalShape)] = 1 97 | for i := len(internalShape) - 1; i >= 0; i-- { 98 | strides[i] = strides[i+1] * internalShape[i] 99 | } 100 | 101 | internalData := make([]float64, len(data)) 102 | copy(internalData, data) 103 | 104 | return &Arrf{ 105 | shape: internalShape, 106 | strides: strides, 107 | data: internalData, 108 | } 109 | } 110 | 111 | // 通过指定起始、终止和步进量来创建一维Array。 112 | // 输入参数: vals,可以有三种情况,详见下面描述。 113 | // 情况1:Arange(stop): 以0开始的序列,创建Array [0, 0+(-)1, ..., stop),不包括stop,stop符号决定升降序。 114 | // 情况2:Arange(start, stop):创建Array [start, start +(-)1, ..., stop),如果start小于start则递增,否则递减。 115 | // 情况3:Arange(start, stop, step):创建Array [start, start + step, ..., stop),step符号决定升降序。 116 | // 输入参数多于三个的都会被忽略。 117 | // 输出序列为“整型数”序列。 118 | func Arange(vals ...int) *Arrf { 119 | var start, stop, step int = 0, 0, 1 120 | 121 | switch len(vals) { 122 | case 0: 123 | fmt.Println("range function should have range") 124 | panic(PARAMETER_ERROR) 125 | case 1: 126 | if vals[0] <= 0 { 127 | step = -1 128 | stop = vals[0] + 1 129 | } else { 130 | stop = vals[0] - 1 131 | } 132 | case 2: 133 | if vals[1] < vals[0] { 134 | step = -1 135 | stop = vals[1] + 1 136 | } else { 137 | stop = vals[1] - 1 138 | } 139 | start = vals[0] 140 | default: 141 | if vals[1] < vals[0] { 142 | if vals[2] >= 0 { 143 | fmt.Println("increment should be negative.") 144 | panic(PARAMETER_ERROR) 145 | } 146 | stop = vals[1] + 1 147 | } else { 148 | if vals[2] <= 0 { 149 | fmt.Println("increment should be positive.") 150 | panic(PARAMETER_ERROR) 151 | } 152 | stop = vals[1] - 1 153 | } 154 | start, step = vals[0], vals[2] 155 | } 156 | 157 | a := Array(nil, int(math.Abs(float64((stop-start)/step)))+1) 158 | for i, v := 0, start; i < len(a.data); i, v = i+1, v+step { 159 | a.data[i] = float64(v) 160 | } 161 | return a 162 | } 163 | 164 | //判断Arrf是否为空数组。 165 | //如果内部的data长度为0或者为nil,返回true,否则位false。 166 | func (a *Arrf) IsEmpty() bool { 167 | return len(a.data) == 0 || a.data == nil 168 | } 169 | 170 | //创建shape形状的多维数组,全部填充为fillvalue。 171 | //必须指定shape,否则抛出异常。 172 | func Fill(fillValue float64, shape ...int) *Arrf { 173 | if len(shape) == 0 { 174 | fmt.Println("shape is empty!") 175 | panic(SHAPE_ERROR) 176 | } 177 | arr := Array(nil, shape...) 178 | for i := range arr.data { 179 | arr.data[i] = fillValue 180 | } 181 | 182 | return arr 183 | } 184 | 185 | //根据shape创建全为1.0的多维数组。 186 | func Ones(shape ...int) *Arrf { 187 | return Fill(1, shape...) 188 | } 189 | 190 | //根据输入的多维数组的形状创建全1的多维数组。 191 | func OnesLike(a *Arrf) *Arrf { 192 | return Ones(a.shape...) 193 | } 194 | 195 | //根据shape创建全为0的多维数组。 196 | func Zeros(shape ...int) *Arrf { 197 | return Fill(0, shape...) 198 | } 199 | 200 | //根据输入的多维数组的形状创建全0的多维数组。 201 | func ZerosLike(a *Arrf) *Arrf { 202 | return Zeros(a.shape...) 203 | } 204 | 205 | // String Satisfies the Stringer interface for fmt package 206 | func (a *Arrf) String() (s string) { 207 | switch { 208 | case a == nil: 209 | return "" 210 | case a.data == nil || a.shape == nil || a.strides == nil: 211 | return "" 212 | case a.strides[0] == 0: 213 | return "[]" 214 | case len(a.shape) == 1: 215 | return fmt.Sprint(a.data) 216 | //strs := make([]string, len(a.data)) 217 | //for i := range a.data { 218 | // strs[i] = string(strconv.FormatFloat(a.data[i], 'f', -1, 64)) 219 | // 220 | //} 221 | //return strings.Join(strs, ", ") 222 | } 223 | 224 | stride := a.shape[len(a.shape)-1] 225 | 226 | for i, k := 0, 0; i+stride <= len(a.data); i, k = i+stride, k+1 { 227 | 228 | t := "" 229 | for j, v := range a.strides { 230 | if i%v == 0 && j < len(a.strides)-2 { 231 | t += "[" 232 | } 233 | } 234 | 235 | s += strings.Repeat(" ", len(a.shape)-len(t)-1) + t 236 | s += fmt.Sprint(a.data[i: i+stride]) 237 | 238 | t = "" 239 | for j, v := range a.strides { 240 | if (i+stride)%v == 0 && j < len(a.strides)-2 { 241 | t += "]" 242 | } 243 | } 244 | 245 | s += t + strings.Repeat(" ", len(a.shape)-len(t)-1) 246 | if i+stride != len(a.data) { 247 | s += "\n" 248 | if len(t) > 0 { 249 | s += "\n" 250 | } 251 | } 252 | } 253 | return 254 | } 255 | 256 | //获取index指定位置的元素。 257 | //index必须在shape规定的范围内,否则会抛出异常。 258 | //index的长度必须小于等于维度的个数,否则会抛出异常。 259 | //如果index的个数小于维度个数,则会取后面的第一个值。 260 | func (a *Arrf) At(index ...int) float64 { 261 | idx := a.valIndex(index...) 262 | return a.data[idx] 263 | } 264 | 265 | //详见At函数。 266 | func (a *Arrf) Get(index ...int) float64 { 267 | return a.At(index...) 268 | } 269 | 270 | //At函数的内部实现,返回index指定的元素在切片中的位置,如果有错误,则返回error。 271 | func (a *Arrf) valIndex(index ...int) int { 272 | idx := 0 273 | if len(index) > len(a.shape) { 274 | fmt.Println("index len should not longer than shape.") 275 | panic(INDEX_ERROR) 276 | } 277 | for i, v := range index { 278 | if v >= a.shape[i] || v < 0 { 279 | fmt.Println("index value out of range.") 280 | panic(INDEX_ERROR) 281 | } 282 | idx += v * a.strides[i+1] 283 | } 284 | return idx 285 | } 286 | 287 | //获取多维数组元素的个数。 288 | func (a *Arrf) Length() int { 289 | return len(a.data) 290 | } 291 | 292 | //创建一个n X n 的2维单位矩阵(数组)。 293 | func Eye(n int) *Arrf { 294 | arr := Zeros(n, n) 295 | for i := 0; i < n; i++ { 296 | arr.Set(1, i, i) 297 | } 298 | return arr 299 | } 300 | 301 | //Eye的另一种称呼,详见Eye函数。 302 | func Identity(n int) *Arrf { 303 | return Eye(n) 304 | } 305 | 306 | //指定位置的元素被新值替换。 307 | //如果index的超出范围则会抛出异常。 308 | //返回当前数组的指引,方便后续的连续操作。 309 | func (a *Arrf) Set(value float64, index ...int) *Arrf { 310 | idx := a.valIndex(index...) 311 | 312 | a.data[idx] = value 313 | return a 314 | } 315 | 316 | //返回多维数组的内部数组元素。 317 | //对返回值的操作会影响多维数组,一定谨慎操作。 318 | func (a *Arrf) Values() []float64 { 319 | return a.data 320 | } 321 | 322 | //根据[start, stop]指定的区间,创建包含num个元素的一维数组。 323 | func Linspace(start, stop float64, num int) *Arrf { 324 | var data = make([]float64, num) 325 | var startF, stopF = start, stop 326 | if startF <= stopF { 327 | var step = (stopF - startF) / (float64(num - 1.0)) 328 | for i := range data { 329 | data[i] = startF + float64(i)*step 330 | } 331 | return Array(data, num) 332 | } else { 333 | var step = (startF - stopF) / (float64(num - 1.0)) 334 | for i := range data { 335 | data[i] = startF - float64(i)*step 336 | } 337 | return Array(data, num) 338 | } 339 | } 340 | 341 | //复制一个形状一样,但是数据被深度复制的多维数组。 342 | func (a *Arrf) Copy() *Arrf { 343 | b := ZerosLike(a) 344 | copy(b.data, a.data) 345 | return b 346 | } 347 | 348 | //返回多维数组的维度数目。 349 | func (a *Arrf) Ndims() int { 350 | return len(a.shape) 351 | } 352 | 353 | //Returns ta view of the array with axes transposed. 354 | //根据指定的轴顺序,生成一个新的调整后的多维数组。 355 | //如果是1维数组,则没有任何变化。 356 | //如果是2维数组,则行列交换。 357 | //如果是n维数组,则根据指定的顺序调整,生成新的多维数组。 358 | //输入参数1:如果不指定输入参数,则轴顺序全部反序;如果指定参数则个数必须和轴个数相同,否则抛出异常。 359 | //fixme 这里的实现效率不高,后面有时间需要提升一下。 360 | func (a *Arrf) Transpose(axes ...int) *Arrf { 361 | var n = a.Ndims() 362 | var permutation []int 363 | var nShape []int 364 | 365 | switch len(axes) { 366 | case 0: 367 | permutation = make([]int, n) 368 | nShape = make([]int, n) 369 | for i := range permutation { 370 | permutation[i] = n - i 371 | } 372 | for i := 0; i < n; i++ { 373 | permutation[i] = n - 1 - i 374 | nShape[i] = a.shape[permutation[i]] 375 | } 376 | 377 | case n: 378 | permutation = axes 379 | nShape = make([]int, n) 380 | for i := range nShape { 381 | nShape[i] = a.shape[permutation[i]] 382 | } 383 | 384 | default: 385 | fmt.Println("axis number wrong.") 386 | panic(DIMENTION_ERROR) 387 | } 388 | 389 | var totalIndexSize = 1 390 | for i := range a.shape { 391 | totalIndexSize *= a.shape[i] 392 | } 393 | 394 | var indexsSrc = make([][]int, totalIndexSize) 395 | var indexsDst = make([][]int, totalIndexSize) 396 | 397 | var b = Zeros(nShape...) 398 | var index = make([]int, n) 399 | for i := 0; i < totalIndexSize; i++ { 400 | tindexSrc := make([]int, n) 401 | copy(tindexSrc, index) 402 | indexsSrc[i] = tindexSrc 403 | var tindexDst = make([]int, n) 404 | for j := range tindexDst { 405 | tindexDst[j] = index[permutation[j]] 406 | } 407 | indexsDst[i] = tindexDst 408 | 409 | var j = n - 1 410 | index[j]++ 411 | for { 412 | if j > 0 && index[j] >= a.shape[j] { 413 | index[j-1]++ 414 | index[j] = 0 415 | j-- 416 | } else { 417 | break 418 | } 419 | } 420 | } 421 | for i := range indexsSrc { 422 | b.Set(a.Get(indexsSrc[i]...), indexsDst[i]...) 423 | } 424 | return b 425 | } 426 | -------------------------------------------------------------------------------- /numeric_arrf_test.go: -------------------------------------------------------------------------------- 1 | package arrgo 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | ) 7 | 8 | func TestArrayCond1(t *testing.T) { 9 | arr := Array(nil) 10 | if SameFloat64Slice(arr.data, []float64{}) != true { 11 | t.Error("array data should be []float64{}, got ", arr.data) 12 | } 13 | if SameIntSlice(arr.shape, []int{0}) != true { 14 | t.Error("array shape should be []int{0}, got ", arr.shape) 15 | } 16 | if SameIntSlice(arr.strides, []int{0, 1}) != true { 17 | t.Error("array strides should be []int{0, 1}, got ", arr.shape) 18 | } 19 | } 20 | 21 | func TestArrayCond2(t *testing.T) { 22 | arr := Array([]float64{1, 2, 3}) 23 | if SameFloat64Slice(arr.data, []float64{1, 2, 3}) != true { 24 | t.Error("array data should be []float64{1,2,3}, got ", arr.data) 25 | } 26 | if SameIntSlice(arr.shape, []int{3}) != true { 27 | t.Error("array shape should be []int{3}, got ", arr.shape) 28 | } 29 | if SameIntSlice(arr.strides, []int{3, 1}) != true { 30 | t.Error("array strides should be []int{3, 1}, got ", arr.shape) 31 | } 32 | } 33 | 34 | func TestArrayCond3ExceptionTwoNegtiveDims(t *testing.T) { 35 | defer func() { 36 | r := recover() 37 | if r != SHAPE_ERROR { 38 | t.Error("Exepcted shape error, got ", r) 39 | } 40 | }() 41 | 42 | Array([]float64{1, 2, 3, 4}, -1, -1, 4) 43 | } 44 | 45 | func TestArrayCond3ExceptionLengError(t *testing.T) { 46 | defer func() { 47 | r := recover() 48 | if r != SHAPE_ERROR { 49 | t.Error("Exepcted shape error, got ", r) 50 | } 51 | }() 52 | 53 | Array([]float64{1, 2, 3, 4}, 3, 4, 5) 54 | } 55 | 56 | func TestArrayCond3ExceptionDivError(t *testing.T) { 57 | defer func() { 58 | r := recover() 59 | if r != SHAPE_ERROR { 60 | t.Error("Exepcted shape error, got ", r) 61 | } 62 | }() 63 | 64 | Array([]float64{1, 2, 3, 4}, -1, 3) 65 | } 66 | 67 | func TestArrayCond3(t *testing.T) { 68 | arr := Array([]float64{1, 2, 3, 4}, 2, 2) 69 | if !SameIntSlice(arr.shape, []int{2, 2}) { 70 | t.Error("Expected [2, 2], got ", arr.shape) 71 | } 72 | if !SameIntSlice(arr.strides, []int{4, 2, 1}) { 73 | t.Error("Expected [4,2,1], got", arr.strides) 74 | } 75 | if !SameFloat64Slice(arr.data, []float64{1, 2, 3, 4}) { 76 | t.Error("Expected [1,2,3,4], got ", arr.data) 77 | } 78 | 79 | arr = Array([]float64{1, 2, 3, 4}, 2, -1) 80 | if !SameIntSlice(arr.shape, []int{2, 2}) { 81 | t.Error("Expected [2, 2], got ", arr.shape) 82 | } 83 | if !SameIntSlice(arr.strides, []int{4, 2, 1}) { 84 | t.Error("Expected [4,2,1], got", arr.strides) 85 | } 86 | if !SameFloat64Slice(arr.data, []float64{1, 2, 3, 4}) { 87 | t.Error("Expected [1,2,3,4], got ", arr.data) 88 | } 89 | } 90 | 91 | func TestArrayCond4(t *testing.T) { 92 | arr := Array(nil, 2, 3) 93 | if SameFloat64Slice(arr.data, []float64{0, 0, 0, 0, 0, 0}) != true { 94 | t.Error("array data should be []float64{0, 0, 0, 0, 0, 0}, got ", arr.data) 95 | } 96 | if SameIntSlice(arr.shape, []int{2, 3}) != true { 97 | t.Error("array shape should be []int{2, 3}, got ", arr.shape) 98 | } 99 | if SameIntSlice(arr.strides, []int{6, 3, 1}) != true { 100 | t.Error("array strides should be []int{6, 3, 1}, got ", arr.shape) 101 | } 102 | 103 | defer func() { 104 | err := recover() 105 | if err != SHAPE_ERROR { 106 | t.Error("should panic shape error, got ", err) 107 | } 108 | }() 109 | 110 | Array(nil, -1, 2, 3) 111 | } 112 | 113 | func TestArange(t *testing.T) { 114 | a1 := Arange(3) 115 | if !a1.Equal(Array([]float64{0, 1, 2})).AllTrues() { 116 | t.Error("Expected [0, 1, 2], got ", a1) 117 | } 118 | 119 | a1 = Arange(-3) 120 | if !a1.Equal(Array([]float64{0, -1, -2})).AllTrues() { 121 | t.Error("Expected [0, -1, -2], got ", a1) 122 | } 123 | 124 | a1 = Arange(1, 3) 125 | if !a1.Equal(Array([]float64{1, 2})).AllTrues() { 126 | t.Error("Expected [1,2], got ", a1) 127 | } 128 | 129 | a1 = Arange(-1, 2) 130 | if !a1.Equal(Array([]float64{-1, 0, 1})).AllTrues() { 131 | t.Error("Expected [-1, 0, 1], got ", a1) 132 | } 133 | 134 | a1 = Arange(2, -1) 135 | if !a1.Equal(Array([]float64{2, 1, 0})).AllTrues() { 136 | t.Error("Expected [2, 1, 0], got ", a1) 137 | } 138 | 139 | a1 = Arange(1, 4, 2) 140 | if !a1.Equal(Array([]float64{1, 3})).AllTrues() { 141 | t.Error("Expected [1, 3], got ", a1) 142 | } 143 | 144 | a1 = Arange(4, -1, -2) 145 | if !a1.Equal(Array([]float64{4, 2, 0})).AllTrues() { 146 | t.Error("Expected [4, 2, 0], got ", a1) 147 | } 148 | } 149 | 150 | func TestArangeIncrementExpection1(t *testing.T) { 151 | defer func() { 152 | r := recover() 153 | if r != PARAMETER_ERROR { 154 | t.Error("Expected PARAMTER ERROR, got ", r) 155 | } 156 | }() 157 | 158 | Arange(1, 3, -2) 159 | } 160 | 161 | func TestArangeIncrementExpection2(t *testing.T) { 162 | defer func() { 163 | r := recover() 164 | if r != PARAMETER_ERROR { 165 | t.Error("Expected PARAMTER ERROR, got ", r) 166 | } 167 | }() 168 | 169 | Arange(3, 1, 1) 170 | } 171 | 172 | func TestArangeNullParameterException(t *testing.T) { 173 | defer func() { 174 | r := recover() 175 | if r != PARAMETER_ERROR { 176 | t.Error("Expected PARAMETER ERROR, got ", r) 177 | } 178 | }() 179 | 180 | Arange() 181 | } 182 | 183 | func TestArrf_IsEmpty(t *testing.T) { 184 | empty := Array(nil) 185 | 186 | if empty.IsEmpty() != true { 187 | t.Error("Expected empty arra") 188 | } 189 | 190 | empty.data = make([]float64, 0) 191 | 192 | if empty.IsEmpty() != true { 193 | t.Error("Expected empty arra") 194 | } 195 | } 196 | 197 | func TestFill(t *testing.T) { 198 | arr := Fill(1.0, 3) 199 | 200 | if !SameIntSlice(arr.shape, []int{3}) { 201 | t.Error("Expected [3], got ", arr.shape) 202 | } 203 | 204 | if !SameIntSlice(arr.strides, []int{3, 1}) { 205 | t.Error("Expected [3, 1], got ", arr.strides) 206 | } 207 | 208 | if !SameFloat64Slice(arr.data, []float64{1.0, 1.0, 1.0}) { 209 | t.Error("Expected [1.0, 1.0, 1.0], got ", arr.data) 210 | } 211 | } 212 | 213 | func TestFillException(t *testing.T) { 214 | defer func() { 215 | r := recover() 216 | 217 | if r != SHAPE_ERROR { 218 | t.Error("Expected SHAPE_ERROR, got ", r) 219 | } 220 | }() 221 | 222 | Fill(1.0) 223 | } 224 | 225 | func TestOnes(t *testing.T) { 226 | arr := Ones(3) 227 | 228 | if !SameIntSlice(arr.shape, []int{3}) { 229 | t.Error("Expected [3], got ", arr.shape) 230 | } 231 | 232 | if !SameIntSlice(arr.strides, []int{3, 1}) { 233 | t.Error("Expected [3, 1], got ", arr.strides) 234 | } 235 | 236 | if !SameFloat64Slice(arr.data, []float64{1, 1, 1}) { 237 | t.Error("Expected [1, 1, 1], got ", arr.data) 238 | } 239 | } 240 | 241 | func TestOnesLike(t *testing.T) { 242 | originalArr := Ones(3) 243 | arr := OnesLike(originalArr) 244 | 245 | if !SameIntSlice(arr.shape, []int{3}) { 246 | t.Error("Expected [3], got ", arr.shape) 247 | } 248 | 249 | if !SameIntSlice(arr.strides, []int{3, 1}) { 250 | t.Error("Expected [3, 1], got ", arr.strides) 251 | } 252 | 253 | if !SameFloat64Slice(arr.data, []float64{1, 1, 1}) { 254 | t.Error("Expected [1, 1, 1], got ", arr.data) 255 | } 256 | } 257 | 258 | func TestZeros(t *testing.T) { 259 | arr := Zeros(3) 260 | 261 | if !SameIntSlice(arr.shape, []int{3}) { 262 | t.Error("Expected [3], got ", arr.shape) 263 | } 264 | 265 | if !SameIntSlice(arr.strides, []int{3, 1}) { 266 | t.Error("Expected [3, 1], got ", arr.strides) 267 | } 268 | 269 | if !SameFloat64Slice(arr.data, []float64{0, 0, 0}) { 270 | t.Error("Expected [0,0,0], got ", arr.data) 271 | } 272 | } 273 | 274 | func TestZerosLike(t *testing.T) { 275 | orignalArr := Zeros(3) 276 | arr := ZerosLike(orignalArr) 277 | 278 | if !SameIntSlice(arr.shape, []int{3}) { 279 | t.Error("Expected [3], got ", arr.shape) 280 | } 281 | 282 | if !SameIntSlice(arr.strides, []int{3, 1}) { 283 | t.Error("Expected [3, 1], got ", arr.strides) 284 | } 285 | 286 | if !SameFloat64Slice(arr.data, []float64{0, 0, 0}) { 287 | t.Error("Expected [0,0,0], got ", arr.data) 288 | } 289 | } 290 | 291 | func TestArrf_At(t *testing.T) { 292 | arr := Array([]float64{1, 2, 3, 4, 5, 6}, 2, 3) 293 | 294 | if arr.At(0, 1) != 2.0 { 295 | t.Error("Expected 2.0, got ", arr.At(1, 0)) 296 | } 297 | 298 | if arr.At(0) != 1.0 { 299 | t.Error("Expected 1.0, got ", arr.At(0)) 300 | } 301 | 302 | if arr.At(1) != 4.0 { 303 | t.Error("Expected 4.0, got ", arr.At(1.0)) 304 | } 305 | } 306 | 307 | func TestArrf_AtLongIndexException(t *testing.T) { 308 | arr := Array([]float64{1, 2, 3, 4, 5, 6}, 2, 3) 309 | 310 | defer func() { 311 | r := recover() 312 | if r != INDEX_ERROR { 313 | t.Error("Expected INDEX_ERROR, got ", r) 314 | } 315 | }() 316 | 317 | arr.At(0, 0, 1) 318 | } 319 | 320 | func TestArrf_AtIndexOutofRangeException(t *testing.T) { 321 | arr := Array([]float64{1, 2, 3, 4, 5, 6}, 2, 3) 322 | 323 | defer func() { 324 | r := recover() 325 | if r != INDEX_ERROR { 326 | t.Error("Expected INDEX_ERROR, got ", r) 327 | } 328 | }() 329 | 330 | arr.At(2, 0) 331 | } 332 | 333 | func TestArrf_ValIndex(t *testing.T) { 334 | arr := Array([]float64{1, 2, 3, 4, 5, 6}, 2, 3) 335 | 336 | index := arr.valIndex(0, 1) 337 | if index != 1 { 338 | t.Error("Expected 1, got ", index) 339 | } 340 | 341 | index = arr.valIndex(0) 342 | if index != 0 { 343 | t.Error("Expected 0, got ", index) 344 | } 345 | 346 | index = arr.valIndex(1) 347 | if index != 3 { 348 | t.Error("Expected 3, got ", index) 349 | } 350 | } 351 | 352 | func TestArrf_ValIndexExpection1(t *testing.T) { 353 | arr := Array([]float64{1, 2, 3, 4, 5, 6}, 2, 3) 354 | defer func() { 355 | r := recover() 356 | if r != INDEX_ERROR { 357 | t.Error("Expected INDEX_ERROR, got ", r) 358 | } 359 | }() 360 | arr.valIndex(0, 1, 0) 361 | } 362 | 363 | func TestArrf_ValIndexExpection2(t *testing.T) { 364 | arr := Array([]float64{1, 2, 3, 4, 5, 6}, 2, 3) 365 | defer func() { 366 | r := recover() 367 | if r != INDEX_ERROR { 368 | t.Error("Expected INDEX_ERROR, got ", r) 369 | } 370 | }() 371 | arr.valIndex(2) 372 | } 373 | 374 | func TestArrf_Length(t *testing.T) { 375 | arr := Array(nil, 2, 3) 376 | 377 | if arr.Length() != 6 { 378 | t.Error("Expected 6, got ", arr.Length()) 379 | } 380 | } 381 | 382 | func TestEye(t *testing.T) { 383 | arr := Eye(2) 384 | 385 | if !arr.Equal(Array([]float64{1, 0, 0, 1}, 2, 2)).AllTrues() { 386 | t.Error("Expected [1, 0, 0, 1], got ", arr) 387 | } 388 | 389 | defer func() { 390 | r := recover() 391 | if r != SHAPE_ERROR { 392 | t.Error("Expected SHAPE_ERROR, got ", r) 393 | } 394 | }() 395 | 396 | Eye(0) 397 | } 398 | 399 | func TestIdentity(t *testing.T) { 400 | arr := Identity(2) 401 | 402 | if !arr.Equal(Array([]float64{1, 0, 0, 1}, 2, 2)).AllTrues() { 403 | t.Error("Expected [1, 0, 0, 1], got ", arr) 404 | } 405 | 406 | defer func() { 407 | r := recover() 408 | if r != SHAPE_ERROR { 409 | t.Error("Expected SHAPE_ERROR, got ", r) 410 | } 411 | }() 412 | 413 | Eye(0) 414 | } 415 | 416 | func TestArrf_Set(t *testing.T) { 417 | arr := Zeros(3) 418 | arr.Set(10, 1) 419 | 420 | if arr.Get(1) != 10 { 421 | t.Error("Expected 10, got ", arr.Get(10)) 422 | } 423 | } 424 | 425 | func TestArrf_Values(t *testing.T) { 426 | arr := Array([]float64{1, 2, 3}) 427 | 428 | values := arr.Values() 429 | 430 | if !SameFloat64Slice(values, []float64{1, 2, 3}) { 431 | t.Error("Expected [1.0, 2.0, 3.0], got ", values) 432 | } 433 | values[0] = 100 434 | 435 | if arr.data[0] != 100 { 436 | t.Error("Expected 100, got ", arr.data[0]) 437 | } 438 | } 439 | 440 | func TestLinspace(t *testing.T) { 441 | arr := Linspace(1, 2, 5) 442 | 443 | if !arr.Equal(Array([]float64{1, 1.25, 1.5, 1.75, 2})).AllTrues() { 444 | t.Error("Expected [1, 1.25, 1.5, 1.75, 2], got ", arr) 445 | } 446 | 447 | arr = Linspace(2, 1, 5) 448 | 449 | if !arr.Equal(Array([]float64{2, 1.75, 1.5, 1.25, 1})).AllTrues() { 450 | t.Error("Expected [2, 1.75, 1.5, 1.25, 1], got ", arr) 451 | } 452 | 453 | arr = Linspace(-2, -1, 5) 454 | 455 | if !arr.Equal(Array([]float64{-2, -1.75, -1.5, -1.25, -1})).AllTrues() { 456 | t.Error("Expected [-2, -1.75, -1.5, -1.25, -1], got ", arr) 457 | } 458 | 459 | arr = Linspace(-1, -2, 5) 460 | 461 | if !arr.Equal(Array([]float64{-1, -1.25, -1.5, -1.75, -2})).AllTrues() { 462 | t.Error("Expected [-1, -1.25, -1.5, -1.75, -2], got ", arr) 463 | } 464 | 465 | arr = Linspace(-1, 2, 5) 466 | 467 | if !arr.Equal(Array([]float64{-1, -0.25, 0.5, 1.25, 2})).AllTrues() { 468 | t.Error("Expected [-1, -0.25, 0.5, 1.25, 2], got ", arr) 469 | } 470 | } 471 | 472 | func TestArrf_Copy(t *testing.T) { 473 | arr := Ones(2) 474 | arrCopy := arr.Copy() 475 | arr.Set(10, 0) 476 | 477 | if !arrCopy.Equal(Array([]float64{1, 1})).AllTrues() { 478 | t.Error("Expected [1, 1], got ", arrCopy) 479 | } 480 | } 481 | 482 | func TestArrf_Ndims(t *testing.T) { 483 | arr := Arange(10) 484 | if arr.Ndims() != 1 { 485 | t.Error("Expected 1, got ", arr.Ndims()) 486 | } 487 | 488 | arr.Reshape(2, 5) 489 | if arr.Ndims() != 2 { 490 | t.Error("Expected 2, got ", arr.Ndims()) 491 | } 492 | 493 | arr.Reshape(2, 5, 1) 494 | if arr.Ndims() != 3 { 495 | t.Error("Expected 3, got ", arr.Ndims()) 496 | } 497 | } 498 | 499 | func TestArrf_Transpose(t *testing.T) { 500 | arr := Arange(4).Reshape(2, 2) 501 | 502 | if !arr.Equal(Array([]float64{0, 1, 2, 3}, 2, 2)).AllTrues() { 503 | t.Error("Expected [[0,1],[2,3]], got ", arr) 504 | } 505 | 506 | arrTransposed := arr.Transpose() 507 | if !arrTransposed.Equal(Array([]float64{0, 2, 1, 3}, 2, 2)).AllTrues() { 508 | t.Error("Expected [[0,2,], [1,3]], got ", arrTransposed) 509 | } 510 | 511 | arrTransposed = arr.Transpose(1, 0) 512 | if !arrTransposed.Equal(Array([]float64{0, 2, 1, 3}, 2, 2)).AllTrues() { 513 | t.Error("Expected [[0,2,], [1,3]], got ", arrTransposed) 514 | } 515 | } 516 | 517 | func TestArrf_TransposeException(t *testing.T) { 518 | arr := Arange(4) 519 | 520 | defer func() { 521 | r := recover() 522 | if r != DIMENTION_ERROR { 523 | t.Error("Expected DIMENTION_ERROR, got ", r) 524 | } 525 | }() 526 | arr.Transpose(0, 1) 527 | } 528 | 529 | func TestArrf_String(t *testing.T) { 530 | var arr *Arrf 531 | if arr.String() != "" { 532 | t.Error("Expected , got ", arr.String()) 533 | } 534 | 535 | arr = Zeros(2) 536 | arr.data = nil 537 | if arr.String() != "" { 538 | t.Error("Expected got ", arr.String()) 539 | } 540 | 541 | arr = Array(nil, 1) 542 | arr.strides = make([]int, 2) 543 | if arr.String() != "[]" { 544 | t.Error("Expected [], got ", arr.String()) 545 | } 546 | 547 | arr = Arange(2) 548 | if arr.String() != "[0 1]" { 549 | t.Error("Expected [0 1], got ", arr.String()) 550 | } 551 | 552 | arr = Arange(2).Reshape(2, 1) 553 | if strings.Replace(arr.String(), "\n", ":", -1) != "[[0] : [1]]" { 554 | t.Error("Expected , got ", arr.String()) 555 | } 556 | } 557 | -------------------------------------------------------------------------------- /shape.go: -------------------------------------------------------------------------------- 1 | package arrgo 2 | 3 | import "fmt" 4 | 5 | //改变原始多维数组的形状,并返回改变后的多维数组的指引引用。 6 | //不会创建新的数据副本。 7 | //如果新的shape的大小和原来多维数组的大小不同,则抛出异常。 8 | func (a *Arrf) Reshape(shape ...int) *Arrf { 9 | if a.Length() != ProductIntSlice(shape) { 10 | fmt.Println("new shape length does not equal to original array length.") 11 | panic(SHAPE_ERROR) 12 | } 13 | 14 | internalShape := make([]int, len(shape)) 15 | copy(internalShape, shape) 16 | a.shape = internalShape 17 | 18 | a.strides = make([]int, len(a.shape)+1) 19 | a.strides[len(a.shape)] = 1 20 | for i := len(a.shape) - 1; i >= 0; i-- { 21 | a.strides[i] = a.strides[i+1] * a.shape[i] 22 | } 23 | 24 | return a 25 | } 26 | 27 | //两个多维数组形状相同,则返回true, 否则返回false。 28 | func (a *Arrf) SameShapeTo(b *Arrf) bool { 29 | return SameIntSlice(a.shape, b.shape) 30 | } 31 | 32 | //将多个两维数组在垂直方向上组合起来,形成新的多维数组。 33 | //不影响原多维数组。 34 | func Vstack(arrs ...*Arrf) *Arrf { 35 | for i := range arrs { 36 | if arrs[i].Ndims() > 2 { 37 | fmt.Println("in Vstack function, array dimension cannot bigger than 2.") 38 | panic(SHAPE_ERROR) 39 | } 40 | } 41 | if len(arrs) == 0 { 42 | return nil 43 | } 44 | if len(arrs) == 1 { 45 | return arrs[0].Copy() 46 | } 47 | 48 | return Concat(0, arrs...) 49 | // 50 | //var vlenSum int = 0 51 | // 52 | //var hlen int 53 | //if arrs[0].Ndims() == 1 { 54 | // hlen = arrs[0].shape[0] 55 | // vlenSum += 1 56 | //} else { 57 | // hlen = arrs[0].shape[1] 58 | // vlenSum += arrs[0].shape[0] 59 | //} 60 | //for i := 1; i < len(arrs); i++ { 61 | // var nextHen int 62 | // if arrs[i].Ndims() == 1 { 63 | // nextHen = arrs[i].shape[0] 64 | // vlenSum += 1 65 | // } else { 66 | // nextHen = arrs[i].shape[1] 67 | // vlenSum += arrs[i].shape[0] 68 | // } 69 | // if hlen != nextHen { 70 | // panic(SHAPE_ERROR) 71 | // } 72 | //} 73 | // 74 | //data := make([]float64, vlenSum*hlen) 75 | //var offset = 0 76 | //for i := range arrs { 77 | // copy(data[offset:], arrs[i].data) 78 | // offset += len(arrs[i].data) 79 | //} 80 | // 81 | //return Array(data, vlenSum, hlen) 82 | } 83 | 84 | //将多个两维数组在水平方向上组合起来,形成新的多维数组。 85 | //不影响原多维数组。 86 | func Hstack(arrs ...*Arrf) *Arrf { 87 | for i := range arrs { 88 | if arrs[i].Ndims() > 2 { 89 | panic(SHAPE_ERROR) 90 | } 91 | } 92 | if len(arrs) == 0 { 93 | return nil 94 | } 95 | if len(arrs) == 1 { 96 | return arrs[0].Copy() 97 | } 98 | 99 | return Concat(1, arrs...) 100 | 101 | //var hlenSum int = 0 102 | //var hBlockLens = make([]int, len(arrs)) 103 | //var vlen int 104 | //if arrs[0].Ndims() == 1 { 105 | // vlen = 1 106 | // hlenSum += arrs[0].shape[0] 107 | // hBlockLens[0] = arrs[0].shape[0] 108 | //} else { 109 | // vlen = arrs[0].shape[0] 110 | // hlenSum += arrs[0].shape[1] 111 | // hBlockLens[0] = arrs[0].shape[1] 112 | //} 113 | //for i := 1; i < len(arrs); i++ { 114 | // var nextVlen int 115 | // if arrs[i].Ndims() == 1 { 116 | // nextVlen = 1 117 | // hlenSum += arrs[i].shape[0] 118 | // hBlockLens[i] = arrs[i].shape[0] 119 | // } else { 120 | // nextVlen = arrs[i].shape[0] 121 | // hlenSum += arrs[i].shape[1] 122 | // hBlockLens[i] = arrs[i].shape[1] 123 | // } 124 | // if vlen != nextVlen { 125 | // panic(SHAPE_ERROR) 126 | // } 127 | //} 128 | // 129 | //data := make([]float64, hlenSum*vlen) 130 | //for i := 0; i < vlen; i++ { 131 | // var curPos = 0 132 | // for j := 0; j < len(arrs); j++ { 133 | // copy(data[curPos+i*hlenSum:curPos+i*hlenSum+hBlockLens[j]], arrs[j].data[i*hBlockLens[j]:(i+1)*hBlockLens[j]]) 134 | // curPos += hBlockLens[j] 135 | // } 136 | //} 137 | // 138 | //return Array(data, vlen, hlenSum) 139 | } 140 | 141 | //将多个多维数组在指定的轴上组合起来。 142 | //一维数组默认扩充为2维,参考AtLeast2D函数。 143 | func Concat(axis int, arrs ...*Arrf) *Arrf { 144 | if len(arrs) == 0 { 145 | return nil 146 | } 147 | if len(arrs) == 1 { 148 | return arrs[0].Copy() 149 | } 150 | 151 | for i := range arrs { 152 | AtLeast2D(arrs[i]) 153 | } 154 | 155 | if axis >= arrs[0].Ndims() { 156 | fmt.Println("axis is bigger than dimensions num.") 157 | panic(PARAMETER_ERROR) 158 | } 159 | 160 | var newShape = make([]int, arrs[0].Ndims()) 161 | for index, firstL := range arrs[0].shape { 162 | if index == axis { 163 | newShape[index] += firstL 164 | for j := 1; j < len(arrs); j++ { 165 | newShape[index] += arrs[j].shape[index] 166 | } 167 | } else { 168 | newShape[index] = firstL 169 | for j := 1; j < len(arrs); j++ { 170 | if firstL != arrs[j].shape[index] { 171 | panic(SHAPE_ERROR) 172 | } 173 | } 174 | } 175 | } 176 | 177 | var times = 0 178 | if axis == 0 { 179 | times = 1 180 | } else { 181 | times = ProductIntSlice(arrs[0].shape[0:axis]) 182 | } 183 | 184 | var data = make([]float64, ProductIntSlice(newShape)) 185 | 186 | var curPos = 0 187 | for i := 0; i < times; i++ { 188 | for j := 0; j < len(arrs); j++ { 189 | var l = ProductIntSlice(arrs[j].shape[axis:]) 190 | copy(data[curPos:curPos+l], arrs[j].data[i*l:(i+1)*l]) 191 | curPos += l 192 | } 193 | } 194 | 195 | return Array(data, newShape...) 196 | } 197 | 198 | //将一维数组扩充为二维 199 | func AtLeast2D(a *Arrf) *Arrf { 200 | if a == nil { 201 | return nil 202 | } else if a.Ndims() >= 2 { 203 | return a 204 | } else { 205 | newShpae := make([]int, 2) 206 | newShpae[0] = 1 207 | newShpae[1] = a.shape[0] 208 | a.shape = newShpae 209 | return a 210 | } 211 | } 212 | 213 | //将数组内部的元素铺平返回,创建新的数据副本。 214 | func (a *Arrf) Flatten() *Arrf { 215 | ra := make([]float64, len(a.data)) 216 | copy(ra, a.data) 217 | return Array(ra, len(a.data)) 218 | } 219 | -------------------------------------------------------------------------------- /shape_test.go: -------------------------------------------------------------------------------- 1 | package arrgo 2 | 3 | import "testing" 4 | 5 | func TestArrf_Reshape(t *testing.T) { 6 | arr := Array([]float64{1, 2, 3, 4, 5, 6}, 2, 3) 7 | arr2 := arr.Reshape(3, 2) 8 | 9 | if !SameIntSlice(arr.strides, []int{6, 2, 1}) { 10 | t.Error("Expected [6,2,1], got ", arr2.strides) 11 | } 12 | if !SameIntSlice(arr.shape, []int{3, 2}) { 13 | t.Error("Expected [3, 2], got ", arr.shape) 14 | } 15 | if !SameIntSlice(arr2.shape, []int{3, 2}) { 16 | t.Error("Expected [3, 2], got ", arr2.shape) 17 | } 18 | } 19 | 20 | func TestArrf_ReshapeException(t *testing.T) { 21 | defer func() { 22 | r := recover() 23 | if r != SHAPE_ERROR { 24 | t.Error("Expected shape error, got ", r) 25 | } 26 | }() 27 | 28 | Arange(4).Reshape(5) 29 | } 30 | 31 | func TestArrf_SameShapeTo(t *testing.T) { 32 | a := Arange(4).Reshape(2, 2) 33 | b := Array([]float64{3, 4, 5, 6}, 2, 2) 34 | if a.SameShapeTo(b) != true { 35 | t.Errorf("Expected true, got %t", a.SameShapeTo(b)) 36 | } 37 | } 38 | 39 | func TestVstack(t *testing.T) { 40 | if Vstack() != nil { 41 | t.Errorf("Expected nil, got %s", Vstack()) 42 | } 43 | 44 | a := Arange(3) 45 | stacked := Vstack(a) 46 | if !stacked.Equal(Arange(3)).AllTrues() { 47 | t.Errorf("Expected [0, 1, 2], got %s", stacked) 48 | } 49 | 50 | b := Array([]float64{3, 4, 5}) 51 | stacked = Vstack(a, b) 52 | if !stacked.Equal(Array([]float64{0, 1, 2, 3, 4, 5}, 2, 3)).AllTrues() { 53 | t.Errorf("Expected [[0 1 2] [3 4 5]], got %s", stacked) 54 | } 55 | 56 | a = Arange(2) 57 | b = Arange(4).Reshape(2, 2) 58 | stacked = Vstack(a, b) 59 | if !stacked.Equal(Array([]float64{0, 1, 0, 1, 2, 3}, 3, 2)).AllTrues() { 60 | t.Errorf("Expected [[0,1], [0,1], [2, 3]], got %s", stacked) 61 | } 62 | } 63 | 64 | func TestVstackException(t *testing.T) { 65 | a := Arange(4).Reshape(1, 2, 2) 66 | defer func() { 67 | r := recover() 68 | if r != SHAPE_ERROR { 69 | t.Errorf("Expected shape error, got %s", r) 70 | } 71 | }() 72 | 73 | Vstack(a) 74 | } 75 | 76 | func TestVstackException2(t *testing.T) { 77 | a := Arange(4) 78 | b := Arange(5) 79 | defer func() { 80 | r := recover() 81 | if r != SHAPE_ERROR { 82 | t.Error("Expected shape error, got ", r) 83 | } 84 | }() 85 | 86 | Vstack(a, b) 87 | } 88 | 89 | func TestHstack(t *testing.T) { 90 | if Hstack() != nil { 91 | t.Error("Expected nil, got ", Hstack()) 92 | } 93 | 94 | a := Arange(3) 95 | stacked := Hstack(a) 96 | if !stacked.Equal(Arange(3)).AllTrues() { 97 | t.Error("Expected [0, 1, 2], got ", stacked) 98 | } 99 | a = a.Reshape(3, 1) 100 | b := Array([]float64{3, 4, 5}).Reshape(3, 1) 101 | stacked = Hstack(a, b) 102 | if !stacked.Equal(Array([]float64{0, 3, 1, 4, 2, 5}, 3, 2)).AllTrues() { 103 | t.Error("Expected [[0 3] [1 4], [2 5]], got ", stacked) 104 | } 105 | 106 | a = Arange(2).Reshape(2, 1) 107 | b = Arange(4).Reshape(2, 2) 108 | stacked = Hstack(a, b) 109 | if !stacked.Equal(Array([]float64{0, 0, 1, 1, 2, 3}, 2, 3)).AllTrues() { 110 | t.Error("Expected [[0, 0, 1], [1, 2, 3]], got ", stacked) 111 | } 112 | } 113 | 114 | func TestHstackException(t *testing.T) { 115 | a := Arange(4).Reshape(1, 2, 2) 116 | defer func() { 117 | r := recover() 118 | if r != SHAPE_ERROR { 119 | t.Error("Expected shape error, got ", r) 120 | } 121 | }() 122 | 123 | Hstack(a) 124 | } 125 | 126 | func TestHstackException2(t *testing.T) { 127 | a := Arange(4).Reshape(4, 1) 128 | b := Arange(5).Reshape(5, 1) 129 | defer func() { 130 | r := recover() 131 | if r != SHAPE_ERROR { 132 | t.Error("Expected shape error, got ", r) 133 | } 134 | }() 135 | 136 | Hstack(a, b) 137 | } 138 | 139 | func TestConcat(t *testing.T) { 140 | if Concat(0) != nil { 141 | t.Error("Expected nil, got ", Concat(0)) 142 | } 143 | concated := Concat(0, Arange(2)) 144 | if !concated.Equal(Arange(2)).AllTrues() { 145 | t.Error("Expected [0, 1], got ", concated) 146 | } 147 | 148 | a := Arange(3) 149 | b := Arange(1, 4) 150 | 151 | concated = Concat(0, a, b) 152 | if !concated.Equal(Array([]float64{0, 1, 2, 1, 2, 3}, 2, 3)).AllTrues() { 153 | t.Error("Expected [[0,1,2], [1,2,3]], got ", concated) 154 | } 155 | 156 | a = Arange(3) 157 | b = Arange(1, 4) 158 | 159 | concated = Concat(1, a, b) 160 | t.Log(concated) 161 | if !concated.Equal(Array([]float64{0, 1, 2, 1, 2, 3}, 1, 6)).AllTrues() { 162 | t.Error("Expected [[0,1,2,1,2,3]], got ", concated) 163 | } 164 | 165 | } 166 | 167 | func TestConcatException(t *testing.T) { 168 | a := Arange(4) 169 | b := Arange(1, 4) 170 | 171 | defer func() { 172 | r := recover() 173 | if r != SHAPE_ERROR { 174 | t.Error("Expected shape error, got ", r) 175 | } 176 | }() 177 | 178 | Concat(0, a, b) 179 | } 180 | 181 | func TestConcatException2(t *testing.T) { 182 | a := Arange(4) 183 | b := Arange(1, 4) 184 | 185 | defer func() { 186 | r := recover() 187 | if r != PARAMETER_ERROR { 188 | t.Error("Expected PARAMETER_ERROR, got ", r) 189 | } 190 | }() 191 | 192 | Concat(2, a, b) 193 | } 194 | 195 | func TestAtLeast2D(t *testing.T) { 196 | a := Arange(10) 197 | AtLeast2D(a) 198 | if !SameIntSlice(a.shape, []int{1, 10}) { 199 | t.Error("Expected [1, 10], got ", a.shape) 200 | } 201 | 202 | a.Reshape(1, 1, 10) 203 | AtLeast2D(a) 204 | if !SameIntSlice(a.shape, []int{1, 1, 10}) { 205 | t.Error("Expected [1, 1, 10], got ", a.shape) 206 | } 207 | 208 | if AtLeast2D(nil) != nil { 209 | t.Error("Expected nil, got ", AtLeast2D(nil)) 210 | } 211 | } 212 | 213 | func TestAtLeast2D2(t *testing.T) { 214 | if AtLeast2D(nil) != nil { 215 | t.Error("Expected nil, got ", AtLeast2D(nil)) 216 | } 217 | 218 | arr := Arange(3) 219 | AtLeast2D(arr) 220 | 221 | if !SameIntSlice(arr.shape, []int{1, 3}) { 222 | t.Error("expected true, got false") 223 | } 224 | 225 | arr = Arange(3).Reshape(3, 1) 226 | AtLeast2D(arr) 227 | 228 | if !SameIntSlice(arr.shape, []int{3, 1}) { 229 | t.Error("expected true, got false") 230 | } 231 | } 232 | 233 | func TestArrf_Flatten(t *testing.T) { 234 | arr := Arange(3).Reshape(3, 1) 235 | flattened := arr.Flatten() 236 | 237 | if !flattened.SameShapeTo(Arange(3)) { 238 | t.Error("expected [3], got ", flattened.shape) 239 | } 240 | } 241 | -------------------------------------------------------------------------------- /stats.go: -------------------------------------------------------------------------------- 1 | package arrgo 2 | 3 | import ( 4 | "sort" 5 | 6 | asm "github.com/ledao/arrgo/internal" 7 | ) 8 | 9 | func (a *Arrf) Sum(axis ...int) *Arrf { 10 | if len(axis) == 0 || len(axis) >= a.Ndims() { 11 | tot := float64(0) 12 | for _, v := range a.data { 13 | tot += v 14 | } 15 | return Fill(tot, 1) 16 | } 17 | 18 | //对axis进行排序,按照从大到小的顺序进行规约 19 | sort.IntSlice(axis).Sort() 20 | //规约后的数组的形状 21 | restAxis := make([]int, len(a.shape)-len(axis)) 22 | //对a进行复制,所有的操作都作用于临时变量ta中,最后将ta返回 23 | ta := a.Copy() 24 | 25 | axisR: 26 | for i, t := 0, 0; i < len(ta.shape); i++ { 27 | for _, w := range axis { 28 | if i == w { 29 | continue axisR 30 | } 31 | } 32 | restAxis[t] = ta.shape[i] 33 | t++ 34 | } 35 | 36 | //数组的元素的个数保存到ln中 37 | ln := ta.strides[0] 38 | //对每个指定的轴,顺寻进行规约 39 | for k := 0; k < len(axis); k++ { 40 | //如果轴大小为1,则不需要任何操作 41 | if ta.shape[axis[k]] == 1 { 42 | continue 43 | } 44 | //获取当前轴的大小v,当前轴的跨度wd,以及下一个轴的跨度st 45 | v, wd, st := ta.shape[axis[k]], ta.strides[axis[k]], ta.strides[axis[k]+1] 46 | //如果下一个轴st的跨度为1,则说明当前轴为最后一个轴,只需要每wd个跨度进行一个规约即可 47 | if st == 1 { 48 | //每wd个数据进行一次规约,结果依次放到开始的位置 49 | asm.Hadd(uint64(wd), ta.data) 50 | ln /= v 51 | ta.data = ta.data[:ln] 52 | continue 53 | } 54 | //如果不是最后一个轴,则在该轴上进行规约 55 | for w := 0; w < ln; w += wd { 56 | t := ta.data[w/wd*st : (w/wd+1)*st] 57 | copy(t, ta.data[w:w+st]) 58 | for i := 1; i*st+1 < wd; i++ { 59 | asm.Vadd(t, ta.data[w+(i)*st:w+(i+1)*st]) 60 | } 61 | } 62 | ln /= v 63 | ta.data = ta.data[:ln] 64 | } 65 | ta.shape = restAxis 66 | 67 | tmp := 1 68 | for i := len(restAxis); i > 0; i-- { 69 | ta.strides[i] = tmp 70 | tmp *= restAxis[i-1] 71 | } 72 | ta.strides[0] = tmp 73 | ta.data = ta.data[:tmp] 74 | ta.strides = ta.strides[:len(restAxis)+1] 75 | return ta 76 | } 77 | 78 | func Sum(a *Arrf, axis ...int) *Arrf { 79 | return a.Sum(axis...) 80 | } 81 | 82 | func (a *Arrf) Mean(axis ...int) *Arrf { 83 | if len(axis) == 0 || len(axis) >= a.Ndims() { 84 | tot := float64(0) 85 | for _, v := range a.data { 86 | tot += v 87 | } 88 | return Fill(tot/float64(a.strides[0]), 1) 89 | } 90 | 91 | sort.IntSlice(axis).Sort() 92 | selectShape := make([]int, len(axis)) 93 | for i := range selectShape { 94 | selectShape[i] = a.shape[axis[i]] 95 | } 96 | N := ProductIntSlice(selectShape) 97 | 98 | ta := a.Sum(axis...) 99 | 100 | return ta.DivC(float64(N)) 101 | } 102 | 103 | func Mean(a *Arrf, axis ...int) *Arrf { 104 | return a.Mean(axis...) 105 | } 106 | 107 | func (a *Arrf) Var(axis ...int) *Arrf { 108 | a2 := a.Mul(a).Sum(axis...) 109 | m := a.Mean(axis...) 110 | var N int 111 | if len(axis) == 0 || len(axis) >= a.Ndims() { 112 | N = ProductIntSlice(a.shape) 113 | } else { 114 | selectShape := make([]int, len(axis)) 115 | for i := range selectShape { 116 | selectShape[i] = a.shape[axis[i]] 117 | } 118 | N = ProductIntSlice(selectShape) 119 | } 120 | 121 | m2 := m.Mul(m).MulC(float64(N)) 122 | a_m_2 := a.Sum(axis...).Mul(m).MulC(2) 123 | return a2.Sub(a_m_2).Add(m2).DivC(float64(N)) 124 | } 125 | 126 | func Var(a *Arrf, axis ...int) *Arrf { 127 | return a.Var(axis...) 128 | } 129 | 130 | func (a *Arrf) Std(axis ...int) *Arrf { 131 | return Sqrt(a.Var(axis...)) 132 | } 133 | 134 | func Std(a *Arrf, axis ...int) *Arrf { 135 | return a.Std(axis...) 136 | } 137 | 138 | func (a *Arrf) Min(axis ...int) *Arrf { 139 | if len(axis) == 0 || len(axis) >= a.Ndims() { 140 | minValue := a.data[0] 141 | for _, v := range a.data { 142 | if minValue > v { 143 | minValue = v 144 | } 145 | } 146 | return Fill(minValue, 1) 147 | } 148 | 149 | sort.IntSlice(axis).Sort() 150 | restAxis := make([]int, len(a.shape)-len(axis)) 151 | ta := a.Copy() 152 | axisR: 153 | for i, t := 0, 0; i < len(ta.shape); i++ { 154 | for _, w := range axis { 155 | if i == w { 156 | continue axisR 157 | } 158 | } 159 | restAxis[t] = ta.shape[i] 160 | t++ 161 | } 162 | 163 | //数组的元素的个数保存到ln中 164 | ln := ta.strides[0] 165 | //对每个指定的轴,顺寻进行规约 166 | for k := 0; k < len(axis); k++ { 167 | //如果轴大小为1,则不需要任何操作 168 | if ta.shape[axis[k]] == 1 { 169 | continue 170 | } 171 | //获取当前轴的大小v,当前轴的跨度wd,以及下一个轴的跨度st 172 | v, wd, st := ta.shape[axis[k]], ta.strides[axis[k]], ta.strides[axis[k]+1] 173 | //如果下一个轴st的跨度为1,则说明当前轴为最后一个轴,只需要每wd个跨度进行一个规约即可 174 | if st == 1 { 175 | //每wd个数据进行一次规约,结果依次放到开始的位置 176 | Hmin(wd, ta.data) 177 | ln /= v 178 | ta.data = ta.data[:ln] 179 | continue 180 | } 181 | //如果不是最后一个轴,则在该轴上进行规约 182 | for w := 0; w < ln; w += wd { 183 | t := ta.data[w/wd*st : (w/wd+1)*st] 184 | copy(t, ta.data[w:w+st]) 185 | for i := 1; i*st+1 < wd; i++ { 186 | Vmin(t, ta.data[w+(i)*st:w+(i+1)*st]) 187 | } 188 | } 189 | ln /= v 190 | ta.data = ta.data[:ln] 191 | } 192 | 193 | ta.shape = restAxis 194 | 195 | tmp := 1 196 | for i := len(restAxis); i > 0; i-- { 197 | ta.strides[i] = tmp 198 | tmp *= restAxis[i-1] 199 | } 200 | ta.strides[0] = tmp 201 | ta.strides = ta.strides[:len(restAxis)+1] 202 | return ta 203 | } 204 | 205 | func Min(a *Arrf, axis ...int) *Arrf { 206 | return a.Min(axis...) 207 | } 208 | 209 | func (a *Arrf) Max(axis ...int) *Arrf { 210 | if len(axis) == 0 || len(axis) >= a.Ndims() { 211 | maxValue := a.data[0] 212 | for _, v := range a.data { 213 | if maxValue < v { 214 | maxValue = v 215 | } 216 | } 217 | return Fill(maxValue, 1) 218 | } 219 | 220 | sort.IntSlice(axis).Sort() 221 | restAxis := make([]int, len(a.shape)-len(axis)) 222 | ta := a.Copy() 223 | axisR: 224 | for i, t := 0, 0; i < len(ta.shape); i++ { 225 | for _, w := range axis { 226 | if i == w { 227 | continue axisR 228 | } 229 | } 230 | restAxis[t] = ta.shape[i] 231 | t++ 232 | } 233 | 234 | //数组的元素的个数保存到ln中 235 | ln := ta.strides[0] 236 | //对每个指定的轴,顺寻进行规约 237 | for k := 0; k < len(axis); k++ { 238 | //如果轴大小为1,则不需要任何操作 239 | if ta.shape[axis[k]] == 1 { 240 | continue 241 | } 242 | //获取当前轴的大小v,当前轴的跨度wd,以及下一个轴的跨度st 243 | v, wd, st := ta.shape[axis[k]], ta.strides[axis[k]], ta.strides[axis[k]+1] 244 | //如果下一个轴st的跨度为1,则说明当前轴为最后一个轴,只需要每wd个跨度进行一个规约即可 245 | if st == 1 { 246 | //每wd个数据进行一次规约,结果依次放到开始的位置 247 | Hmax(wd, ta.data) 248 | ln /= v 249 | ta.data = ta.data[:ln] 250 | continue 251 | } 252 | //如果不是最后一个轴,则在该轴上进行规约 253 | for w := 0; w < ln; w += wd { 254 | t := ta.data[w/wd*st : (w/wd+1)*st] 255 | copy(t, ta.data[w:w+st]) 256 | for i := 1; i*st+1 < wd; i++ { 257 | Vmax(t, ta.data[w+(i)*st:w+(i+1)*st]) 258 | } 259 | } 260 | ln /= v 261 | ta.data = ta.data[:ln] 262 | } 263 | 264 | ta.shape = restAxis 265 | 266 | tmp := 1 267 | for i := len(restAxis); i > 0; i-- { 268 | ta.strides[i] = tmp 269 | tmp *= restAxis[i-1] 270 | } 271 | ta.strides[0] = tmp 272 | ta.strides = ta.strides[:len(restAxis)+1] 273 | return ta 274 | } 275 | 276 | func Max(a *Arrf, axis ...int) *Arrf { 277 | return a.Max(axis...) 278 | } 279 | 280 | func (a *Arrf) ArgMax(axis int) *Arrf { 281 | if axis < 0 { 282 | axis = axis + len(a.shape) 283 | } 284 | restAxis := make([]int, len(a.shape)-1) 285 | ta := a.Copy() 286 | for i, t := 0, 0; i < len(ta.shape); i++ { 287 | if i == axis { 288 | continue 289 | } 290 | restAxis[t] = ta.shape[i] 291 | t++ 292 | } 293 | 294 | //数组的元素的个数保存到ln中 295 | ln := ta.strides[0] 296 | 297 | //获取当前轴的大小v,当前轴的跨度wd,以及下一个轴的跨度st 298 | v, wd, st := ta.shape[axis], ta.strides[axis], ta.strides[axis+1] 299 | //如果下一个轴st的跨度为1,则说明当前轴为最后一个轴,只需要每wd个跨度进行一个规约即可 300 | if st == 1 { 301 | //每wd个数据进行一次规约,结果依次放到开始的位置 302 | Hargmax(wd, ta.data) 303 | ln /= v 304 | ta.data = ta.data[:ln] 305 | } else { 306 | //如果不是最后一个轴,则在该轴上进行规约 307 | td := make([]float64, 0, ln/wd) 308 | for w := 0; w < ln; w += wd { 309 | Vargmax(st, ta.data[w:w+wd]) 310 | td = append(td, ta.data[w : w+wd][:st]...) 311 | } 312 | ln /= v 313 | ta.data = td 314 | } 315 | 316 | ta.shape = restAxis 317 | 318 | tmp := 1 319 | for i := len(restAxis); i > 0; i-- { 320 | ta.strides[i] = tmp 321 | tmp *= restAxis[i-1] 322 | } 323 | ta.strides[0] = tmp 324 | ta.strides = ta.strides[:len(restAxis)+1] 325 | return ta 326 | } 327 | 328 | func ArgMax(a *Arrf, axis int) *Arrf { 329 | return a.ArgMax(axis) 330 | } 331 | 332 | //fixme has bug 333 | func (a *Arrf) ArgMin(axis int) *Arrf { 334 | if axis < 0 { 335 | axis = axis + len(a.shape) 336 | } 337 | restAxis := make([]int, len(a.shape)-1) 338 | ta := a.Copy() 339 | for i, t := 0, 0; i < len(ta.shape); i++ { 340 | if i == axis { 341 | continue 342 | } 343 | restAxis[t] = ta.shape[i] 344 | t++ 345 | } 346 | 347 | //数组的元素的个数保存到ln中 348 | ln := ta.strides[0] 349 | 350 | //获取当前轴的大小v,当前轴的跨度wd,以及下一个轴的跨度st 351 | v, wd, st := ta.shape[axis], ta.strides[axis], ta.strides[axis+1] 352 | //如果下一个轴st的跨度为1,则说明当前轴为最后一个轴,只需要每wd个跨度进行一个规约即可 353 | if st == 1 { 354 | //每wd个数据进行一次规约,结果依次放到开始的位置 355 | Hargmin(wd, ta.data) 356 | ln /= v 357 | ta.data = ta.data[:ln] 358 | } else { 359 | //如果不是最后一个轴,则在该轴上进行规约 360 | td := make([]float64, 0, ln/wd) 361 | for w := 0; w < ln; w += wd { 362 | Vargmin(st, ta.data[w:w+wd]) 363 | td = append(td, ta.data[w : w+wd][:st]...) 364 | } 365 | ln /= v 366 | ta.data = td 367 | } 368 | 369 | ta.shape = restAxis 370 | 371 | tmp := 1 372 | for i := len(restAxis); i > 0; i-- { 373 | ta.strides[i] = tmp 374 | tmp *= restAxis[i-1] 375 | } 376 | ta.strides[0] = tmp 377 | ta.strides = ta.strides[:len(restAxis)+1] 378 | return ta 379 | } 380 | 381 | func ArgMin(a *Arrf, axis int) *Arrf { 382 | return a.ArgMin(axis) 383 | } 384 | -------------------------------------------------------------------------------- /stats_test.go: -------------------------------------------------------------------------------- 1 | package arrgo 2 | 3 | import "testing" 4 | 5 | func TestSum(t *testing.T) { 6 | var arr = Arange(100).Reshape(2, 5, 10) 7 | if arr.Sum(0).NotEqual(Array( 8 | []float64{ 9 | 50, 52, 54, 56, 58, 60, 62, 64, 66, 68, 10 | 70, 72, 74, 76, 78, 80, 82, 84, 86, 88, 11 | 90, 92, 94, 96, 98, 100, 102, 104, 106, 108, 12 | 110, 112, 114, 116, 118, 120, 122, 124, 126, 128, 13 | 130, 132, 134, 136, 138, 140, 142, 144, 146, 148}, 14 | 5, 10)).AnyTrue() { 15 | t.Error(`Expected [[ 50, 52, 54, 56, 58, 60, 62, 64, 66, 68], 16 | [ 70, 72, 74, 76, 78, 80, 82, 84, 86, 88], 17 | [ 90, 92, 94, 96, 98, 100, 102, 104, 106, 108], 18 | [110, 112, 114, 116, 118, 120, 122, 124, 126, 128], 19 | [130, 132, 134, 136, 138, 140, 142, 144, 146, 148]], got `, 20 | arr.Sum(0)) 21 | } 22 | 23 | if arr.Sum(1).NotEqual(Array( 24 | []float64{ 25 | 100, 105, 110, 115, 120, 125, 130, 135, 140, 145, 26 | 350, 355, 360, 365, 370, 375, 380, 385, 390, 395}, 27 | 2, 10)).AnyTrue() { 28 | t.Error(`Expected 29 | [[100, 105, 110, 115, 120, 125, 130, 135, 140, 145], 30 | [350, 355, 360, 365, 370, 375, 380, 385, 390, 395]], got `, 31 | arr.Sum(1)) 32 | } 33 | 34 | if arr.Sum(2).NotEqual(Array( 35 | []float64{ 36 | 45, 145, 245, 345, 445, 37 | 545, 645, 745, 845, 945}, 38 | 2, 5)).AnyTrue() { 39 | t.Error(`Expected 40 | [[ 45, 145, 245, 345, 445], 41 | [545, 645, 745, 845, 945]] 42 | , got 43 | `, arr.Sum(2)) 44 | } 45 | 46 | if arr.Sum(0, 1).NotEqual(Array( 47 | []float64{ 48 | 450, 460, 470, 480, 490, 500, 510, 520, 530, 540}, 49 | 10)).AnyTrue() { 50 | t.Error(`Expected 51 | [450, 460, 470, 480, 490, 500, 510, 520, 530, 540] 52 | , got`, 53 | arr.Sum(0, 1)) 54 | } 55 | 56 | if arr.Sum(0, 2).NotEqual(Array([]float64{590, 790, 990, 1190, 1390}, 5)).AnyTrue() { 57 | t.Error(`Expected [ 590, 790, 990, 1190, 1390], got `, arr.Sum(0, 2)) 58 | } 59 | 60 | if arr.Sum(1, 2).NotEqual(Array([]float64{1225, 3725})).AnyTrue() { 61 | t.Error("Expected [1225, 3725], got ", arr.Sum(1, 2)) 62 | } 63 | 64 | if arr.Sum(0, 1, 2).NotEqual(Array([]float64{4950})).AnyTrue() { 65 | t.Error("expected [4950], got ", arr.Sum(0, 1, 2)) 66 | } 67 | 68 | if arr.Sum().NotEqual(Array([]float64{4950})).AnyTrue() { 69 | t.Error("expected [4950], got ", arr.Sum()) 70 | } 71 | } 72 | 73 | 74 | func TestArgMax(t *testing.T) { 75 | arr := Array([]float64{17, 10, 22, 3, 2, 7, 15, 9, 23, 4, 14, 18, 5, 8, 0, 12, 1, 76 | 19, 20, 11, 6, 16, 21, 13}, 2,3,4) 77 | 78 | if arr.ArgMax(0).NotEqual(Array([]float64{0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0}, 3, 4)).AnyTrue() { 79 | t.Error(`Expected 80 | [[0, 0, 0, 1], 81 | [0, 1, 1, 1], 82 | [0, 1, 1, 0]], got `, arr.ArgMax(0)) 83 | } 84 | 85 | if arr.ArgMax(-3).NotEqual(Array([]float64{0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0}, 3, 4)).AnyTrue() { 86 | t.Error(`Expected 87 | [[0, 0, 0, 1], 88 | [0, 1, 1, 1], 89 | [0, 1, 1, 0]], got `, arr.ArgMax(0)) 90 | } 91 | 92 | 93 | if arr.ArgMax(1).NotEqual(Array([]float64{2, 0, 0, 2, 2, 1, 2, 2}, 2, 4)).AnyTrue() { 94 | t.Error(`Expected 95 | [[2, 0, 0, 2], 96 | [2, 1, 2, 2]], got `, arr.ArgMax(1)) 97 | } 98 | 99 | if arr.ArgMax(-2).NotEqual(Array([]float64{2, 0, 0, 2, 2, 1, 2, 2}, 2, 4)).AnyTrue() { 100 | t.Error(`Expected 101 | [[2, 0, 0, 2], 102 | [2, 1, 2, 2]], got `, arr.ArgMax(1)) 103 | } 104 | 105 | 106 | if arr.ArgMax(2).NotEqual(Array([]float64{2, 2, 0, 3, 2, 2}, 2, 3)).AnyTrue() { 107 | t.Error(`Expected 108 | [[2, 2, 0], 109 | [3, 2, 2]], got `, arr.ArgMax(2)) 110 | } 111 | 112 | if arr.ArgMax(-1).NotEqual(Array([]float64{2, 2, 0, 3, 2, 2}, 2, 3)).AnyTrue() { 113 | t.Error(`Expected 114 | [[2, 2, 0], 115 | [3, 2, 2]], got `, arr.ArgMax(2)) 116 | } 117 | 118 | } 119 | 120 | func TestArgMin(t *testing.T) { 121 | arr := Array([]float64{17, 10, 22, 3, 2, 7, 15, 9, 23, 4, 14, 18, 5, 8, 0, 12, 1, 122 | 19, 20, 11, 6, 16, 21, 13}, 2,3,4) 123 | 124 | if arr.ArgMin(0).NotEqual(Array([]float64{1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1}, 3, 4)).AnyTrue() { 125 | t.Error(`Expected 126 | [[1, 1, 1, 0], 127 | [1, 0, 0, 0], 128 | [1, 0, 0, 1]], got `, arr.ArgMin(0)) 129 | } 130 | 131 | if arr.ArgMin(1).NotEqual(Array([]float64{1, 2, 2, 0, 1, 0, 0, 1}, 2, 4)).AnyTrue() { 132 | t.Error(`Expected 133 | [[1, 2, 2, 0], 134 | [1, 0, 0, 1]], got `, arr.ArgMin(1)) 135 | } 136 | 137 | if arr.ArgMin(2).NotEqual(Array([]float64{3, 0, 1, 2, 0, 0}, 2, 3)).AnyTrue() { 138 | t.Error(`Expected 139 | [[3, 0, 1], 140 | [2, 0, 0]], got `, arr.ArgMin(2)) 141 | } 142 | 143 | if arr.ArgMin(-3).NotEqual(Array([]float64{1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1}, 3, 4)).AnyTrue() { 144 | t.Error(`Expected 145 | [[1, 1, 1, 0], 146 | [1, 0, 0, 0], 147 | [1, 0, 0, 1]], got `, arr.ArgMin(0)) 148 | } 149 | 150 | if arr.ArgMin(-2).NotEqual(Array([]float64{1, 2, 2, 0, 1, 0, 0, 1}, 2, 4)).AnyTrue() { 151 | t.Error(`Expected 152 | [[1, 2, 2, 0], 153 | [1, 0, 0, 1]], got `, arr.ArgMin(1)) 154 | } 155 | 156 | if arr.ArgMin(-1).NotEqual(Array([]float64{3, 0, 1, 2, 0, 0}, 2, 3)).AnyTrue() { 157 | t.Error(`Expected 158 | [[3, 0, 1], 159 | [2, 0, 0]], got `, arr.ArgMin(2)) 160 | } 161 | } -------------------------------------------------------------------------------- /utils.go: -------------------------------------------------------------------------------- 1 | package arrgo 2 | 3 | import ( 4 | "math" 5 | "sort" 6 | ) 7 | 8 | func ReverseIntSlice(slice []int) []int { 9 | s := make([]int, len(slice)) 10 | copy(s, slice) 11 | for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { 12 | s[i], s[j] = s[j], s[i] 13 | } 14 | return s 15 | } 16 | 17 | //计算[]int的所有元素的乘积. 18 | func ProductIntSlice(slice []int) int { 19 | var prod = 1 20 | for _, v := range slice { 21 | prod *= v 22 | } 23 | return prod 24 | } 25 | 26 | func Roundf(val float64, places int) float64 { 27 | var t float64 28 | f := math.Pow10(places) 29 | x := val * f 30 | if math.IsInf(x, 0) || math.IsNaN(x) { 31 | return val 32 | } 33 | if x >= 0.0 { 34 | t = math.Ceil(x) 35 | if (t - x) > 0.50000000001 { 36 | t -= 1.0 37 | } 38 | } else { 39 | t = math.Ceil(-x) 40 | if (t + x) > 0.50000000001 { 41 | t -= 1.0 42 | } 43 | t = -t 44 | } 45 | x = t / f 46 | 47 | if !math.IsInf(x, 0) { 48 | return x 49 | } 50 | 51 | return t 52 | } 53 | 54 | func Hmin(ln int, data []float64) { 55 | for i := 0; i*ln < len(data); i++ { 56 | minValue := data[i*ln] 57 | for j := i*ln + 1; j < i*ln+ln; j++ { 58 | if minValue > data[j] { 59 | minValue = data[j] 60 | } 61 | } 62 | data[i] = minValue 63 | } 64 | } 65 | 66 | func Vmin(a, b []float64) { 67 | for i := range a { 68 | if a[i] > b[i] { 69 | a[i] = b[i] 70 | } 71 | } 72 | } 73 | 74 | func Hmax(ln int, data []float64) { 75 | for i := 0; i*ln < len(data); i++ { 76 | maxValue := data[i*ln] 77 | for j := i*ln + 1; j < i*ln+ln; j++ { 78 | if maxValue < data[j] { 79 | maxValue = data[j] 80 | } 81 | } 82 | data[i] = maxValue 83 | } 84 | } 85 | 86 | func Vmax(a, b []float64) { 87 | for i := range a { 88 | if a[i] < b[i] { 89 | a[i] = b[i] 90 | } 91 | } 92 | } 93 | 94 | //在data中计算每ln个数据中,最大值的位置,并将结果依次放到data中。 95 | func Hargmax(ln int, data []float64) { 96 | for i := 0; i*ln < len(data); i += 1 { 97 | maxValue := data[i*ln] 98 | maxIndex := 0.0 99 | for j := i*ln + 1; j < i*ln+ln; j++ { 100 | if maxValue < data[j] { 101 | maxValue = data[j] 102 | maxIndex = float64(j % ln) 103 | } 104 | } 105 | data[i] = maxIndex 106 | } 107 | } 108 | 109 | func Vargmax(ln int, a []float64) { 110 | for i := 0; i < ln; i++ { 111 | maxValue := a[i] 112 | maxIndex := 0.0 113 | for j := i + ln; j < len(a); j += ln { 114 | if maxValue < a[j] { 115 | maxValue = a[j] 116 | maxIndex = float64(int(j / ln)) 117 | } 118 | } 119 | a[i] = maxIndex 120 | } 121 | } 122 | 123 | func Hargmin(ln int, data []float64) { 124 | for i := 0; i*ln < len(data); i++ { 125 | minValue := data[i*ln] 126 | minIndex := 0.0 127 | for j := i*ln + 1; j < i*ln+ln; j++ { 128 | if minValue > data[j] { 129 | minValue = data[j] 130 | minIndex = float64(j % ln) 131 | } 132 | } 133 | data[i] = minIndex 134 | } 135 | } 136 | 137 | func Vargmin(ln int, a []float64) { 138 | for i := 0; i < ln; i++ { 139 | minValue := a[i] 140 | minIndex := 0.0 141 | for j := i + ln; j < len(a); j += ln { 142 | if minValue > a[j] { 143 | minValue = a[j] 144 | minIndex = float64(int(j / ln)) 145 | } 146 | } 147 | a[i] = minIndex 148 | } 149 | } 150 | 151 | func Hsort(ln int, data []float64) { 152 | for i := 0; i*ln < len(data); i++ { 153 | sort.Float64s(data[i*ln : i*ln+ln]) 154 | } 155 | } 156 | 157 | func Vsort(ln int, a []float64) { 158 | for i := 0; i < ln; i++ { 159 | tmpSlice := make([]float64, 0, len(a)/ln) 160 | for j := i; j < len(a); j += ln { 161 | tmpSlice = append(tmpSlice, a[j]) 162 | } 163 | sort.Float64s(tmpSlice) 164 | for j := i; j < len(a); j += ln { 165 | a[j] = tmpSlice[j/ln] 166 | } 167 | } 168 | } 169 | 170 | func ContainsFloat64(s []float64, e float64) bool { 171 | for _, v := range s { 172 | if v == e { 173 | return true 174 | } 175 | } 176 | return false 177 | } 178 | 179 | func ContainsInt(s []int, e int) bool { 180 | for _, v := range s { 181 | if v == e { 182 | return true 183 | } 184 | } 185 | return false 186 | } 187 | 188 | //判断两个[]int是否相等。 189 | //相等是严格的相等,否则为不等。 190 | //如果有一个为nil则为不相等。 191 | func SameIntSlice(a, b []int) bool { 192 | if a == nil || b == nil { 193 | return false 194 | } 195 | if len(a) != len(b) { 196 | return false 197 | } else { 198 | for i := range a { 199 | if a[i] != b[i] { 200 | return false 201 | } 202 | } 203 | return true 204 | } 205 | } 206 | 207 | //判断两个[]float64是否相等。 208 | //相等是严格的相等,否则为不等。 209 | //如果有一个为nil则为不相等。 210 | func SameFloat64Slice(a, b []float64) bool { 211 | if a == nil || b == nil { 212 | return false 213 | } 214 | if len(a) != len(b) { 215 | return false 216 | } else { 217 | for i := range a { 218 | if a[i] != b[i] { 219 | return false 220 | } 221 | } 222 | return true 223 | } 224 | } 225 | 226 | //判断两个[]bool是否相等。 227 | //相等是严格的相等,否则为不等。 228 | //如果有一个为nil则为不相等。 229 | func SameBoolSlice(a, b []bool) bool { 230 | if a == nil || b == nil { 231 | return false 232 | } 233 | if len(a) != len(b) { 234 | return false 235 | } else { 236 | for i := range a { 237 | if a[i] != b[i] { 238 | return false 239 | } 240 | } 241 | return true 242 | } 243 | } 244 | -------------------------------------------------------------------------------- /utils_test.go: -------------------------------------------------------------------------------- 1 | package arrgo 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestSameIntSlice(t *testing.T) { 8 | var s1 []int = nil 9 | var s2 []int = nil 10 | if SameIntSlice(s1, s2) != false { 11 | t.Error("if one of args is nil, the result should be false, but go ", SameIntSlice(s1, s2)) 12 | } 13 | s3 := []int{1, 2, 3} 14 | s4 := []int{1, 2} 15 | if SameIntSlice(s3, s4) != false { 16 | t.Error("different length should get false, got ", SameIntSlice(s3, s4)) 17 | } 18 | s5 := []int{1, 2, 4} 19 | if SameIntSlice(s3, s5) != false { 20 | t.Error("bit wise different should get false, got ", SameIntSlice(s3, s5)) 21 | } 22 | s6 := []int{1, 2, 3} 23 | if SameIntSlice(s3, s6) != true { 24 | t.Error("same int[] should get true, got ", SameIntSlice(s3, s6)) 25 | } 26 | } 27 | --------------------------------------------------------------------------------