diff --git a/NuGet.config b/NuGet.config index 80dd215a4bf01b..29f8225aac34fb 100644 --- a/NuGet.config +++ b/NuGet.config @@ -7,6 +7,11 @@ + + + + + + false + false diff --git a/eng/Subsets.props b/eng/Subsets.props index 77268ffa7b5d09..eb4151f3a185d6 100644 --- a/eng/Subsets.props +++ b/eng/Subsets.props @@ -502,7 +502,7 @@ - + @@ -512,7 +512,7 @@ - + diff --git a/eng/Version.Details.xml b/eng/Version.Details.xml index 83e0c865d5612a..96128eaf65a696 100644 --- a/eng/Version.Details.xml +++ b/eng/Version.Details.xml @@ -1,80 +1,80 @@ - + https://github.com/dotnet/icu - 0bd4ee514dc6abe5cc664282709e70d5e26bb11d + feea7b8dcee39fd35ee6c415197e47d19102bb0b - + https://github.com/dotnet/msquic - 72811ab66f2611ac9f652cbb020dba033fc37401 + bbb1252b31e3a194be3163982d972e4583c75476 https://github.com/dotnet/wcf 7f504aabb1988e9a093c1e74d8040bd52feb2f01 - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd https://github.com/dotnet/command-line-api @@ -85,101 +85,105 @@ 02fe27cd6a9b001c8feb7938e6ef4b3799745759b - + https://github.com/dotnet/cecil - a112f15aa032c029b7d9c77df3427111d93cf407 + 45dd3a73dd5b64b010c4251303b3664bb30df029 - + https://github.com/dotnet/emsdk - 1999c8c8ab7473a7e1c5b7bdf5ba6d9a985a69cc + 51bf18a2e20024dfa89d63e20b0c3b695b5c1eed + + + https://github.com/dotnet/emsdk + 51bf18a2e20024dfa89d63e20b0c3b695b5c1eed - + https://github.com/dotnet/source-build-reference-packages - 3dd2c0ef203db8fe0e849557960b4cd009afbaac + b4fa7f2e1e65ef49881be2ab2df27624280a8c55 - + https://github.com/dotnet/source-build-externals - e9d6489787a5ea5400a31dfa34aa6ad6b590de9b + 3dc05150cf234f76f6936dcb2853d31a0da1f60e - + https://github.com/dotnet/arcade - 1d451c32dda2314c721adbf8829e1c0cd4e681ff + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/xliff-tasks - 194f32828726c3f1f63f79f3dc09b9e99c157b11 + 73f0850939d96131c28cf6ea6ee5aacb4da0083a - + https://github.com/dotnet/arcade - 1d451c32dda2314c721adbf8829e1c0cd4e681ff + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 1d451c32dda2314c721adbf8829e1c0cd4e681ff + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 1d451c32dda2314c721adbf8829e1c0cd4e681ff + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 1d451c32dda2314c721adbf8829e1c0cd4e681ff + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 1d451c32dda2314c721adbf8829e1c0cd4e681ff + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 1d451c32dda2314c721adbf8829e1c0cd4e681ff + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 1d451c32dda2314c721adbf8829e1c0cd4e681ff + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 1d451c32dda2314c721adbf8829e1c0cd4e681ff + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 1d451c32dda2314c721adbf8829e1c0cd4e681ff + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 1d451c32dda2314c721adbf8829e1c0cd4e681ff + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 1d451c32dda2314c721adbf8829e1c0cd4e681ff + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 1d451c32dda2314c721adbf8829e1c0cd4e681ff + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 1d451c32dda2314c721adbf8829e1c0cd4e681ff + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 1d451c32dda2314c721adbf8829e1c0cd4e681ff + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 1d451c32dda2314c721adbf8829e1c0cd4e681ff + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://github.com/dotnet/arcade - 1d451c32dda2314c721adbf8829e1c0cd4e681ff + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 https://github.com/dotnet/runtime-assets @@ -233,61 +237,61 @@ https://github.com/dotnet/runtime-assets 99168dcff56809205e7ef8530d1256f3a07fab1f - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd - + https://github.com/dotnet/llvm-project - 08a449c9a9bf593b29fc05de2f424e6882320e5d + a0c65bc3a652036d21cd2c506a54c4b6cf8c49bd https://github.com/dotnet/runtime @@ -330,67 +334,67 @@ https://github.com/dotnet/xharness 480b9159eb7e69b182a87581d5a336e97e0b6dae - + https://github.com/dotnet/arcade - 1d451c32dda2314c721adbf8829e1c0cd4e681ff + 39042b4048580366d35a7c1c4f4ce8fc0dbea4b4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-optimization - f67c4e806a744f7274aed9493e4c8d39a3c73445 + e5d9d61ccb43b9135a7471429b338aa7332e2eb5 - + https://dev.azure.com/dnceng/internal/_git/dotnet-optimization - f67c4e806a744f7274aed9493e4c8d39a3c73445 + e5d9d61ccb43b9135a7471429b338aa7332e2eb5 - + https://dev.azure.com/dnceng/internal/_git/dotnet-optimization - f67c4e806a744f7274aed9493e4c8d39a3c73445 + e5d9d61ccb43b9135a7471429b338aa7332e2eb5 - + https://dev.azure.com/dnceng/internal/_git/dotnet-optimization - f67c4e806a744f7274aed9493e4c8d39a3c73445 + e5d9d61ccb43b9135a7471429b338aa7332e2eb5 - + https://github.com/dotnet/hotreload-utils - 4b29cfaccdab45442e15f3b84f75bc9c10ee79b3 + 7e01dcd64329d25070ad66af5eddd02410e80111 https://github.com/dotnet/runtime-assets 99168dcff56809205e7ef8530d1256f3a07fab1f - + https://github.com/dotnet/roslyn - 9233e36abc5e2ca263dbd4d1616f35623440a935 + 12b11685a551e0a6a203dcecb584f610f4df1157 - + https://github.com/dotnet/roslyn - 9233e36abc5e2ca263dbd4d1616f35623440a935 + 12b11685a551e0a6a203dcecb584f610f4df1157 - + https://github.com/dotnet/roslyn - 9233e36abc5e2ca263dbd4d1616f35623440a935 + 12b11685a551e0a6a203dcecb584f610f4df1157 - + https://github.com/dotnet/roslyn-analyzers - 7ec4e8924bcbc469e00aa2bda84251c3e90aa96e + 4ff28092cdb2006c30869fb35b2fd6b7b11382b1 - + https://github.com/dotnet/roslyn-analyzers - 7ec4e8924bcbc469e00aa2bda84251c3e90aa96e + 4ff28092cdb2006c30869fb35b2fd6b7b11382b1 - + https://github.com/dotnet/sdk - d10b02ae5cc670609d920a672985ed4456bdd6b6 + 7e33fd449381b337c290a801057fdcd68c4b7220 - + https://dev.azure.com/dnceng/internal/_git/dotnet-optimization - f67c4e806a744f7274aed9493e4c8d39a3c73445 + e5d9d61ccb43b9135a7471429b338aa7332e2eb5 - + https://dev.azure.com/dnceng/internal/_git/dotnet-optimization - f67c4e806a744f7274aed9493e4c8d39a3c73445 + e5d9d61ccb43b9135a7471429b338aa7332e2eb5 @@ -398,5 +402,9 @@ https://github.com/NuGet/NuGet.Client 8fef55f5a55a3b4f2c96cd1a9b5ddc51d4b927f8 + + https://github.com/dotnet/installer + 46a7370763921ded24dcb70c585ee97883c615d4 + diff --git a/eng/Versions.props b/eng/Versions.props index f3a6b00771bc4d..ed16dbfe0a219a 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -7,14 +7,16 @@ 0 0 8.0.100 - 7.0.11 + 7.0.12 6.0.$([MSBuild]::Add($([System.Version]::Parse('$(PackageVersionNet7)').Build),11)) - rc - 2 + rtm + + - false + true release - -$(PreReleaseVersionLabel).$(PreReleaseVersionIteration) + -$(PreReleaseVersionLabel) + -$(PreReleaseVersionLabel).$(PreReleaseVersionIteration) $(SdkBandVersion)$(WorkloadVersionSuffix) false @@ -34,17 +36,17 @@ - 3.11.0-beta1.23464.2 - 8.0.0-preview.23464.2 + 3.11.0-beta1.23516.2 + 8.0.0-preview.23516.2 - 4.8.0-3.23469.1 - 4.8.0-3.23469.1 - 4.8.0-3.23469.1 + 4.8.0-3.23518.7 + 4.8.0-3.23518.7 + 4.8.0-3.23518.7 - 8.0.100-preview.7.23329.3 + 8.0.100-rtm.23520.8 - 8.0.0-beta.23463.1 - 8.0.0-beta.23463.1 - 8.0.0-beta.23463.1 - 8.0.0-beta.23463.1 - 8.0.0-beta.23463.1 - 2.5.1-beta.23463.1 - 8.0.0-beta.23463.1 - 8.0.0-beta.23463.1 - 8.0.0-beta.23463.1 - 8.0.0-beta.23463.1 - 8.0.0-beta.23463.1 - 8.0.0-beta.23463.1 - 8.0.0-beta.23463.1 - 8.0.0-beta.23463.1 - 8.0.0-beta.23463.1 + 8.0.0-beta.23516.4 + 8.0.0-beta.23516.4 + 8.0.0-beta.23516.4 + 8.0.0-beta.23516.4 + 8.0.0-beta.23516.4 + 2.5.1-beta.23516.4 + 8.0.0-beta.23516.4 + 8.0.0-beta.23516.4 + 8.0.0-beta.23516.4 + 8.0.0-beta.23516.4 + 8.0.0-beta.23516.4 + 8.0.0-beta.23516.4 + 8.0.0-beta.23516.4 + 8.0.0-beta.23516.4 + 8.0.0-beta.23516.4 6.0.0-preview.1.102 @@ -108,14 +110,14 @@ 8.0.0-rc.1.23406.6 8.0.0-preview.7.23325.2 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 6.0.0 1.1.1 @@ -156,12 +158,12 @@ 8.0.0-beta.23421.1 8.0.0-beta.23421.1 - 1.0.0-prerelease.23465.3 - 1.0.0-prerelease.23465.3 - 1.0.0-prerelease.23465.3 - 1.0.0-prerelease.23465.3 - 1.0.0-prerelease.23465.3 - 1.0.0-prerelease.23465.3 + 1.0.0-prerelease.23521.3 + 1.0.0-prerelease.23521.3 + 1.0.0-prerelease.23521.3 + 1.0.0-prerelease.23521.3 + 1.0.0-prerelease.23521.3 + 1.0.0-prerelease.23521.3 16.11.29-beta1.23404.4 2.0.0-beta4.23307.1 @@ -184,7 +186,7 @@ 8.0.0-prerelease.23407.2 8.0.0-prerelease.23407.2 8.0.0-prerelease.23407.2 - 8.0.0-alpha.0.23461.1 + 8.0.0-alpha.0.23518.2 2.4.2 1.0.0 2.4.5 @@ -205,57 +207,58 @@ 2.46.3 2.45.0 2.45.0 - - 8.0.100-rc.1.23415.5 1.1.2-beta1.23323.1 8.0.0-preview-20230918.1 8.0.0-rc.1.23406.6 - 0.11.4-alpha.23461.1 + 0.11.4-alpha.23509.2 8.0.0-rc.1.23406.6 - 8.0.0-rc.2.23454.2 + 8.0.0-rtm.23511.1 - 2.2.3-ci.391104 - 8.0.0-alpha.1.23412.1 + 2.2.3 + 8.0.0-alpha.1.23468.1 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 - 8.0.0-rc.2.23463.1 - $(MicrosoftNETWorkloadEmscriptenCurrentManifest80100TransportVersion) + 8.0.0 + $(MicrosoftNETWorkloadEmscriptenCurrentManifest80100Version) 1.1.87-gba258badda 1.0.0-v3.14.0.5722 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 - 16.0.5-alpha.1.23423.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 + 16.0.5-alpha.1.23478.1 3.1.7 1.0.406601 + + 8.0.100-rtm.23506.1 + diff --git a/eng/common/sdk-task.ps1 b/eng/common/sdk-task.ps1 index 6c4ac6fec1a99a..73828dd30d3179 100644 --- a/eng/common/sdk-task.ps1 +++ b/eng/common/sdk-task.ps1 @@ -64,7 +64,7 @@ try { $GlobalJson.tools | Add-Member -Name "vs" -Value (ConvertFrom-Json "{ `"version`": `"16.5`" }") -MemberType NoteProperty } if( -not ($GlobalJson.tools.PSObject.Properties.Name -match "xcopy-msbuild" )) { - $GlobalJson.tools | Add-Member -Name "xcopy-msbuild" -Value "17.6.0-2" -MemberType NoteProperty + $GlobalJson.tools | Add-Member -Name "xcopy-msbuild" -Value "17.8.1-2" -MemberType NoteProperty } if ($GlobalJson.tools."xcopy-msbuild".Trim() -ine "none") { $xcopyMSBuildToolsFolder = InitializeXCopyMSBuild $GlobalJson.tools."xcopy-msbuild" -install $true diff --git a/eng/common/tools.ps1 b/eng/common/tools.ps1 index aa74ab4a81e782..fdd0cbb91f8596 100644 --- a/eng/common/tools.ps1 +++ b/eng/common/tools.ps1 @@ -379,13 +379,13 @@ function InitializeVisualStudioMSBuild([bool]$install, [object]$vsRequirements = } # Minimum VS version to require. - $vsMinVersionReqdStr = '17.6' + $vsMinVersionReqdStr = '17.7' $vsMinVersionReqd = [Version]::new($vsMinVersionReqdStr) # If the version of msbuild is going to be xcopied, # use this version. Version matches a package here: - # https://dev.azure.com/dnceng/public/_artifacts/feed/dotnet-eng/NuGet/RoslynTools.MSBuild/versions/17.6.0-2 - $defaultXCopyMSBuildVersion = '17.6.0-2' + # https://dev.azure.com/dnceng/public/_artifacts/feed/dotnet-eng/NuGet/RoslynTools.MSBuild/versions/17.8.1-2 + $defaultXCopyMSBuildVersion = '17.8.1-2' if (!$vsRequirements) { if (Get-Member -InputObject $GlobalJson.tools -Name 'vs') { diff --git a/eng/native/ijw/IJW.cmake b/eng/native/ijw/IJW.cmake index 9ef90525dda8ba..33f047d54fca74 100644 --- a/eng/native/ijw/IJW.cmake +++ b/eng/native/ijw/IJW.cmake @@ -51,7 +51,7 @@ if (CLR_CMAKE_HOST_WIN32) # 4365 - signed/unsigned mismatch # 4679 - Could not import member. This is an issue with IJW and static abstract methods in interfaces. - add_compile_options(/wd4365 /wd4679) + add_compile_options(/wd4365 /wd4679 /wd5271) # IJW add_compile_options(/clr:netcore) diff --git a/eng/pipelines/common/evaluate-default-paths.yml b/eng/pipelines/common/evaluate-default-paths.yml index 5fb74a3741f413..0e4279a9697b94 100644 --- a/eng/pipelines/common/evaluate-default-paths.yml +++ b/eng/pipelines/common/evaluate-default-paths.yml @@ -241,6 +241,7 @@ jobs: - src/mono/tools/* - src/mono/wasi/* - src/mono/wasm/debugger/* + - src/mono/wasm/host/* - src/mono/wasm/Wasm.Build.Tests/* - ${{ parameters._const_paths._wasm_pipelines }} - ${{ parameters._const_paths._always_exclude }} @@ -258,6 +259,7 @@ jobs: - eng/testing/workloads-testing.targets - src/mono/mono/component/mini-wasm-debugger.c - src/mono/wasm/debugger/* + - src/mono/wasm/host/* - src/mono/wasm/Wasm.Build.Tests/* - src/mono/nuget/Microsoft.NET.Runtime* src/mono/nuget/Microsoft.NET.Sdk.WebAssembly.Pack/* diff --git a/eng/pipelines/common/xplat-setup.yml b/eng/pipelines/common/xplat-setup.yml index 28257b05265ba0..eb19570aeecac2 100644 --- a/eng/pipelines/common/xplat-setup.yml +++ b/eng/pipelines/common/xplat-setup.yml @@ -108,7 +108,7 @@ jobs: - ${{ if eq(parameters.archType, 'wasm') }}: - name: wasmDarcDependenciesChanged value: $[ or( - eq(dependencies.evaluate_paths.outputs['DarcDependenciesChanged.Microsoft_NET_Workload_Emscripten_Current_Manifest-8_0_100_Transport'], true), + eq(dependencies.evaluate_paths.outputs['DarcDependenciesChanged.Microsoft_NET_Workload_Emscripten_Current_Manifest-8_0_100'], true), eq(dependencies.evaluate_paths.outputs['DarcDependenciesChanged.Microsoft_DotNet_Build_Tasks_Workloads'], true), eq(dependencies.evaluate_paths.outputs['DarcDependenciesChanged.System_Runtime_TimeZoneData'], true), eq(dependencies.evaluate_paths.outputs['DarcDependenciesChanged.Microsoft_Net_Compilers_Toolset'], true), diff --git a/eng/pipelines/coreclr/perf.yml b/eng/pipelines/coreclr/perf.yml index edaadde3e511ee..65d29662504364 100644 --- a/eng/pipelines/coreclr/perf.yml +++ b/eng/pipelines/coreclr/perf.yml @@ -3,7 +3,7 @@ trigger: branches: include: - main - - release/8.0-rc1 + - release/8.0 paths: include: - '*' diff --git a/eng/pipelines/runtime-llvm.yml b/eng/pipelines/runtime-llvm.yml index e31e623a0353c8..9d358e5f793086 100644 --- a/eng/pipelines/runtime-llvm.yml +++ b/eng/pipelines/runtime-llvm.yml @@ -119,7 +119,7 @@ extends: testGroup: innerloop nameSuffix: AllSubsets_Mono_LLVMAOT buildArgs: -s mono+libs+host+packs -c $(_BuildConfig) - /p:MonoEnableLLVM=true /p:MonoBundleLLVMOptimizer=true + /p:MonoEnableLLVM=true /p:MonoAOTEnableLLVM=true /p:MonoBundleLLVMOptimizer=true condition: >- or( eq(dependencies.evaluate_paths.outputs['SetPathVars_libraries.containsChange'], true), @@ -138,7 +138,7 @@ extends: testGroup: innerloop nameSuffix: AllSubsets_Mono_LLVMAOT buildArgs: -s mono+libs+host+packs -c $(_BuildConfig) - /p:MonoEnableLLVM=true /p:MonoBundleLLVMOptimizer=true + /p:MonoEnableLLVM=true /p:MonoAOTEnableLLVM=true /p:MonoBundleLLVMOptimizer=true condition: >- or( eq(dependencies.evaluate_paths.outputs['SetPathVars_libraries.containsChange'], true), diff --git a/eng/pipelines/runtime-official.yml b/eng/pipelines/runtime-official.yml index 9c341a04791289..3a9fd8d89ac4b0 100644 --- a/eng/pipelines/runtime-official.yml +++ b/eng/pipelines/runtime-official.yml @@ -334,7 +334,7 @@ extends: runtimeFlavor: mono jobParameters: buildArgs: -s mono+libs+host+packs -c $(_BuildConfig) - /p:MonoEnableLLVM=true /p:MonoBundleLLVMOptimizer=true + /p:MonoEnableLLVM=true /p:MonoAOTEnableLLVM=true /p:MonoBundleLLVMOptimizer=true nameSuffix: AllSubsets_Mono_LLVMAOT runtimeVariant: LLVMAOT isOfficialBuild: ${{ variables.isOfficialBuild }} diff --git a/eng/pipelines/runtime-wasm-perf.yml b/eng/pipelines/runtime-wasm-perf.yml index bd6a6d979e3e40..69039fb3e2a473 100644 --- a/eng/pipelines/runtime-wasm-perf.yml +++ b/eng/pipelines/runtime-wasm-perf.yml @@ -3,6 +3,7 @@ # UI to this, and thus avoid any scheduled triggers trigger: none +pr: none variables: - template: /eng/pipelines/common/variables.yml diff --git a/eng/pipelines/runtime.yml b/eng/pipelines/runtime.yml index 0f1f9610c60349..3aa0b6504819a7 100644 --- a/eng/pipelines/runtime.yml +++ b/eng/pipelines/runtime.yml @@ -556,6 +556,47 @@ extends: extraBuildArgs: /p:AotHostArchitecture=x64 /p:AotHostOS=$(_hostedOS) alwaysRun: ${{ variables.isRollingBuild }} + # + # Android devices + # Build the whole product using Mono and run libraries tests + # + - template: /eng/pipelines/common/platform-matrix.yml + parameters: + jobTemplate: /eng/pipelines/common/global-build-job.yml + helixQueuesTemplate: /eng/pipelines/libraries/helix-queues-setup.yml + buildConfig: Release + runtimeFlavor: mono + platforms: + - android_arm + - android_arm64 + variables: + # map dependencies variables to local variables + - name: librariesContainsChange + value: $[ dependencies.evaluate_paths.outputs['SetPathVars_libraries.containsChange'] ] + - name: monoContainsChange + value: $[ dependencies.evaluate_paths.outputs['SetPathVars_mono_excluding_wasm.containsChange'] ] + jobParameters: + testGroup: innerloop + nameSuffix: AllSubsets_Mono + buildArgs: -s mono+libs+libs.tests+host+packs -c $(_BuildConfig) /p:ArchiveTests=true /p:RunSmokeTestsOnly=true /p:EnableAdditionalTimezoneChecks=true + timeoutInMinutes: 480 + condition: >- + or( + eq(dependencies.evaluate_paths.outputs['SetPathVars_libraries.containsChange'], true), + eq(dependencies.evaluate_paths.outputs['SetPathVars_mono_excluding_wasm.containsChange'], true), + eq(dependencies.evaluate_paths.outputs['SetPathVars_installer.containsChange'], true), + eq(variables['isRollingBuild'], true)) + # extra steps, run tests + extraStepsTemplate: /eng/pipelines/libraries/helix.yml + extraStepsParameters: + creator: dotnet-bot + testRunNamePrefixSuffix: Mono_$(_BuildConfig) + condition: >- + or( + eq(variables['librariesContainsChange'], true), + eq(variables['monoContainsChange'], true), + eq(variables['isRollingBuild'], true)) + # # iOS/tvOS devices - Full AOT + AggressiveTrimming to reduce size # Build the whole product using Mono and run libraries tests @@ -739,7 +780,7 @@ extends: testGroup: innerloop nameSuffix: AllSubsets_Mono_LLVMAOT buildArgs: -s mono+libs+host+packs -c $(_BuildConfig) - /p:MonoEnableLLVM=true /p:MonoBundleLLVMOptimizer=true + /p:MonoEnableLLVM=true /p:MonoAOTEnableLLVM=true /p:MonoBundleLLVMOptimizer=true condition: >- or( eq(dependencies.evaluate_paths.outputs['SetPathVars_libraries.containsChange'], true), @@ -758,7 +799,7 @@ extends: testGroup: innerloop nameSuffix: AllSubsets_Mono_LLVMAOT buildArgs: -s mono+libs+host+packs -c $(_BuildConfig) - /p:MonoEnableLLVM=true /p:MonoBundleLLVMOptimizer=true + /p:MonoEnableLLVM=true /p:MonoAOTEnableLLVM=true /p:MonoBundleLLVMOptimizer=true condition: >- or( eq(dependencies.evaluate_paths.outputs['SetPathVars_libraries.containsChange'], true), @@ -1277,7 +1318,7 @@ extends: testGroup: innerloop nameSuffix: AllSubsets_Mono_LLVMAot_RuntimeTests runtimeVariant: llvmaot - buildArgs: -s mono+libs+clr.hosts+clr.iltools -c Release /p:MonoEnableLLVM=true /p:MonoBundleLLVMOptimizer=true + buildArgs: -s mono+libs+clr.hosts+clr.iltools -c Release /p:MonoEnableLLVM=true /p:MonoAOTEnableLLVM=true /p:MonoBundleLLVMOptimizer=true timeoutInMinutes: 180 condition: >- diff --git a/eng/resolveContract.targets b/eng/resolveContract.targets index 6d414f46f93e6b..3454d7064739a8 100644 --- a/eng/resolveContract.targets +++ b/eng/resolveContract.targets @@ -73,8 +73,9 @@ That is necessary as APICompat is invoked twice, once for the ref <-> src comparision and then again for the package validation (which doesn't include reference assemblies). As both operations don't have all the inputs available, some suppressions might only apply to one or the other and hence unnecessary - suppressions can't be determined. --> - + suppressions can't be determined. + Disable the validation under source build as that might use an out-of-date SDK and not the ApiCompat.Task package. --> + true true diff --git a/eng/testing/performance/performance-setup.ps1 b/eng/testing/performance/performance-setup.ps1 index 8caea345a893dc..8a8cd269dbe454 100644 --- a/eng/testing/performance/performance-setup.ps1 +++ b/eng/testing/performance/performance-setup.ps1 @@ -101,7 +101,7 @@ if ($iOSNativeAOT) { } # FIX ME: This is a workaround until we get this from the actual pipeline -$CleanedBranchName = "main" +$CleanedBranchName = "release/8.0" if($Branch.Contains("refs/heads/release")) { $CleanedBranchName = $Branch.replace('refs/heads/', '') diff --git a/eng/testing/performance/performance-setup.sh b/eng/testing/performance/performance-setup.sh index 9a1c95ec730820..c53ca6924b97b4 100755 --- a/eng/testing/performance/performance-setup.sh +++ b/eng/testing/performance/performance-setup.sh @@ -358,9 +358,7 @@ if [[ "$physicalpromotion" == "true" ]]; then configurations="$configurations PhysicalPromotionType=physicalpromotion" fi - - -cleaned_branch_name="main" +cleaned_branch_name="release/8.0" if [[ $branch == *"refs/heads/release"* ]]; then cleaned_branch_name=${branch/refs\/heads\//} fi @@ -404,15 +402,14 @@ if [[ -n "$wasm_bundle_directory" ]]; then using_wasm=true wasm_bundle_directory_path=$payload_directory mv $wasm_bundle_directory/* $wasm_bundle_directory_path - find $wasm_bundle_directory_path -type d - wasm_args="--experimental-wasm-eh --expose_wasm" + wasm_args="--expose_wasm" if [ "$javascript_engine" == "v8" ]; then # for es6 module support wasm_args="$wasm_args --module" fi # Workaround: escaping the quotes around `--wasmArgs=..` so they get retained for the actual command line - extra_benchmark_dotnet_arguments="$extra_benchmark_dotnet_arguments --wasmEngine /home/helixbot/.jsvu/bin/$javascript_engine --wasmArgs \\\"$wasm_args\\\" --cli \$HELIX_CORRELATION_PAYLOAD/dotnet/dotnet --wasmDataDir \$HELIX_CORRELATION_PAYLOAD/wasm-data" + extra_benchmark_dotnet_arguments="$extra_benchmark_dotnet_arguments --wasmEngine /home/helixbot/.jsvu/bin/$javascript_engine \\\"--wasmArgs=$wasm_args\\\" --cli \$HELIX_CORRELATION_PAYLOAD/dotnet/dotnet --wasmDataDir \$HELIX_CORRELATION_PAYLOAD/wasm-data" if [[ "$wasmaot" == "true" ]]; then extra_benchmark_dotnet_arguments="$extra_benchmark_dotnet_arguments --aotcompilermode wasm --buildTimeout 3600" fi diff --git a/global.json b/global.json index 2b41b42c3256e6..38ec23b6193a2f 100644 --- a/global.json +++ b/global.json @@ -1,16 +1,16 @@ { "sdk": { - "version": "8.0.100-preview.7.23376.3", + "version": "8.0.100-rtm.23506.1", "allowPrerelease": true, "rollForward": "major" }, "tools": { - "dotnet": "8.0.100-preview.7.23376.3" + "dotnet": "8.0.100-rtm.23506.1" }, "msbuild-sdks": { - "Microsoft.DotNet.Arcade.Sdk": "8.0.0-beta.23463.1", - "Microsoft.DotNet.Helix.Sdk": "8.0.0-beta.23463.1", - "Microsoft.DotNet.SharedFramework.Sdk": "8.0.0-beta.23463.1", + "Microsoft.DotNet.Arcade.Sdk": "8.0.0-beta.23516.4", + "Microsoft.DotNet.Helix.Sdk": "8.0.0-beta.23516.4", + "Microsoft.DotNet.SharedFramework.Sdk": "8.0.0-beta.23516.4", "Microsoft.Build.NoTargets": "3.7.0", "Microsoft.Build.Traversal": "3.4.0", "Microsoft.NET.Sdk.IL": "8.0.0-rc.1.23406.6" diff --git a/src/coreclr/debug/daccess/dacdbiimpl.cpp b/src/coreclr/debug/daccess/dacdbiimpl.cpp index 67d5b1e60d948c..07208001b0c3b8 100644 --- a/src/coreclr/debug/daccess/dacdbiimpl.cpp +++ b/src/coreclr/debug/daccess/dacdbiimpl.cpp @@ -7788,8 +7788,9 @@ HRESULT DacStackReferenceWalker::Next(ULONG count, DacGcReference stackRefs[], U stackRefs[i].i64ExtraData = 0; const SOSStackRefData &sosStackRef = mList.Get(i); - if (sosStackRef.Flags & GC_CALL_INTERIOR) + if (sosStackRef.Flags & GC_CALL_INTERIOR || sosStackRef.Address == 0) { + // Direct pointer case - interior pointer, Frame ref, or enregistered var. stackRefs[i].pObject = CLRDATA_ADDRESS_TO_TADDR(sosStackRef.Object) | 1; } else diff --git a/src/coreclr/debug/daccess/dacimpl.h b/src/coreclr/debug/daccess/dacimpl.h index ddf61370f416e9..8de684a5dae9a2 100644 --- a/src/coreclr/debug/daccess/dacimpl.h +++ b/src/coreclr/debug/daccess/dacimpl.h @@ -1253,17 +1253,6 @@ class ClrDataAccess /* [out] */ union STUB_BUF* outBuffer, /* [out] */ ULONG32* outFlags); - DebuggerJitInfo* GetDebuggerJitInfo(MethodDesc* methodDesc, - TADDR addr) - { - if (g_pDebugger) - { - return g_pDebugger->GetJitInfo(methodDesc, (PBYTE)addr, NULL); - } - - return NULL; - } - HRESULT GetMethodExtents(MethodDesc* methodDesc, METH_EXTENTS** extents); HRESULT GetMethodVarInfo(MethodDesc* methodDesc, diff --git a/src/coreclr/debug/daccess/task.cpp b/src/coreclr/debug/daccess/task.cpp index 9e428e81adef2b..ddbf251b7b9825 100644 --- a/src/coreclr/debug/daccess/task.cpp +++ b/src/coreclr/debug/daccess/task.cpp @@ -5225,7 +5225,7 @@ EnumMethodInstances::Next(ClrDataAccess* dac, } } - if (!m_methodIter.Current()->HasNativeCodeReJITAware()) + if (!m_methodIter.Current()->HasNativeCodeAnyVersion()) { goto NextMethod; } @@ -5243,7 +5243,7 @@ EnumMethodInstances::CdStart(MethodDesc* methodDesc, CLRDATA_ENUM* handle) { if (!methodDesc->HasClassOrMethodInstantiation() && - !methodDesc->HasNativeCodeReJITAware()) + !(methodDesc->HasNativeCodeAnyVersion())) { *handle = 0; return S_FALSE; diff --git a/src/coreclr/debug/di/breakpoint.cpp b/src/coreclr/debug/di/breakpoint.cpp index ad45df5c618ace..568d7fc9fc66ad 100644 --- a/src/coreclr/debug/di/breakpoint.cpp +++ b/src/coreclr/debug/di/breakpoint.cpp @@ -211,11 +211,13 @@ HRESULT CordbFunctionBreakpoint::Activate(BOOL fActivate) if (codeIsIL) { pEvent->BreakpointData.nativeCodeMethodDescToken = pEvent->BreakpointData.nativeCodeMethodDescToken.NullPtr(); + pEvent->BreakpointData.codeStartAddress = 0; } else { pEvent->BreakpointData.nativeCodeMethodDescToken = (m_code.GetValue()->AsNativeCode())->GetVMNativeCodeMethodDescToken().ToLsPtr(); + pEvent->BreakpointData.codeStartAddress = (m_code.GetValue()->AsNativeCode())->GetAddress(); } // Note: we're sending a two-way event, so it blocks here diff --git a/src/coreclr/debug/di/rsclass.cpp b/src/coreclr/debug/di/rsclass.cpp index ec52823c07af5f..55f83b48a6d211 100644 --- a/src/coreclr/debug/di/rsclass.cpp +++ b/src/coreclr/debug/di/rsclass.cpp @@ -132,6 +132,7 @@ HRESULT CordbClass::GetStaticFieldValue(mdFieldDef fieldDef, IMetaDataImport * pImport = NULL; EX_TRY { + RSLockHolder lockHolder(GetProcess()->GetProcessLock()); pImport = GetModule()->GetMetaDataImporter(); // throws // Validate the token. @@ -1191,4 +1192,3 @@ HRESULT CordbClass::SearchFieldInfo( // Well, the field doesn't even belong to this class... ThrowHR(E_INVALIDARG); } - diff --git a/src/coreclr/debug/ee/controller.cpp b/src/coreclr/debug/ee/controller.cpp index 7dd186b4113d44..58e63ab399db2c 100644 --- a/src/coreclr/debug/ee/controller.cpp +++ b/src/coreclr/debug/ee/controller.cpp @@ -1247,26 +1247,8 @@ bool DebuggerController::BindPatch(DebuggerControllerPatch *patch, startAddr = (CORDB_ADDRESS_TYPE *) CORDB_ADDRESS_TO_PTR(patch->GetDJI()->m_addrOfCode); _ASSERTE(startAddr != NULL); } - if (startAddr == NULL) - { - // Should not be trying to place patches on MethodDecs's for stubs. - // These stubs will never get jitted. - CONSISTENCY_CHECK_MSGF(!pMD->IsWrapperStub(), ("Can't place patch at stub md %p, %s::%s", - pMD, pMD->m_pszDebugClassName, pMD->m_pszDebugMethodName)); - - startAddr = (CORDB_ADDRESS_TYPE *)g_pEEInterface->GetFunctionAddress(pMD); - // - // Code is not available yet to patch. The prestub should - // notify us when it is executed. - // - if (startAddr == NULL) - { - LOG((LF_CORDB, LL_INFO10000, - "DC::BP: Patch at 0x%zx not bindable yet.\n", patch->offset)); - - return false; - } - } + //We should never be calling this function with both a NULL startAddr and a DJI that doesn't have code. + _ASSERTE(startAddr != NULL); } _ASSERTE(!g_pEEInterface->IsStub((const BYTE *)startAddr)); @@ -8656,7 +8638,7 @@ bool DebuggerFuncEvalComplete::SendEvent(Thread *thread, bool fIpChanged) // DebuggerEnCBreakpoint constructor - creates and activates a new EnC breakpoint // // Arguments: -// offset - native offset in the function to place the patch +// offset - IL offset in the function to place the patch // jitInfo - identifies the function in which the breakpoint is being placed // fTriggerType - breakpoint type: either REMAP_PENDING or REMAP_COMPLETE // pAppDomain - the breakpoint applies to the specified AppDomain only diff --git a/src/coreclr/debug/ee/debugger.cpp b/src/coreclr/debug/ee/debugger.cpp index a44a9e235f36cd..d58a987244a8ea 100644 --- a/src/coreclr/debug/ee/debugger.cpp +++ b/src/coreclr/debug/ee/debugger.cpp @@ -2841,6 +2841,8 @@ HRESULT Debugger::GetILToNativeMapping(PCODE pNativeCodeStartAddress, ULONG32 cM } CONTRACTL_END; + _ASSERTE(pNativeCodeStartAddress != NULL); + #ifdef PROFILING_SUPPORTED // At this point, we're pulling in the debugger. if (!HasLazyData()) @@ -3007,6 +3009,7 @@ HRESULT Debugger::GetILToNativeMappingIntoArrays( _ASSERTE(pcMap != NULL); _ASSERTE(prguiILOffset != NULL); _ASSERTE(prguiNativeOffset != NULL); + _ASSERTE(pNativeCodeStartAddress != NULL); // Any caller of GetILToNativeMappingIntoArrays had better call // InitializeLazyDataIfNecessary first! @@ -5411,28 +5414,6 @@ void Debugger::ReleaseAllRuntimeThreads(AppDomain *pAppDomain) g_pEEInterface->ResumeFromDebug(pAppDomain); } -// Given a method, get's its EnC version number. 1 if the method is not EnCed. -// Note that MethodDescs are reused between versions so this will give us -// the most recent EnC number. -int Debugger::GetMethodEncNumber(MethodDesc * pMethod) -{ - CONTRACTL - { - THROWS; - GC_NOTRIGGER; - } - CONTRACTL_END; - - DebuggerJitInfo * dji = GetLatestJitInfoFromMethodDesc(pMethod); - if (dji == NULL) - { - // If there's no DJI, couldn't have been EnCed. - return 1; - } - return (int) dji->m_encVersion; -} - - bool Debugger::IsJMCMethod(Module* pModule, mdMethodDef tkMethod) { CONTRACTL @@ -6219,25 +6200,6 @@ void Debugger::LockAndSendEnCRemapCompleteEvent(MethodDesc *pMD) Thread *thread = g_pEEInterface->GetThread(); // Note that the debugger lock is reentrant, so we may or may not hold it already. SENDIPCEVENT_BEGIN(this, thread); - - EX_TRY - { - // Ensure the DJI for the latest version of this method has been pre-created. - // It's not clear whether this is necessary or not, but it shouldn't hurt since - // we're going to need to create it anyway since we'll be debugging inside it. - DebuggerJitInfo *dji = g_pDebugger->GetLatestJitInfoFromMethodDesc(pMD); - (void)dji; //prevent "unused variable" error from GCC - _ASSERTE( dji != NULL ); - } - EX_CATCH - { - // GetLatestJitInfo could throw on OOM, but the debugger isn't resiliant to OOM. - // I'm not aware of any other legitimate reason why it may throw, so we'll ASSERT - // if it fails. - _ASSERTE(!"Unexpected exception from Debugger::GetLatestJitInfoFromMethodDesc on EnC remap complete"); - } - EX_END_CATCH(RethrowTerminalExceptions); - // Send an EnC remap complete event to the Right Side. DebuggerIPCEvent* ipce = m_pRCThread->GetIPCEventSendBuffer(); InitIPCEvent(ipce, @@ -7865,6 +7827,7 @@ void Debugger::FirstChanceManagedExceptionCatcherFound(Thread *pThread, // Implements DebugInterface // Call by EE/exception. Must be on managed thread _ASSERTE(GetThreadNULLOk() != NULL); + _ASSERTE(pMethodAddr != NULL); // Quick check. if (!CORDebuggerAttached()) @@ -10498,7 +10461,7 @@ bool Debugger::HandleIPCEvent(DebuggerIPCEvent * pEvent) DebuggerJitInfo * pDJI = NULL; if ((pMethodDesc != NULL) && (pDMI != NULL)) { - pDJI = pDMI->FindOrCreateInitAndAddJitInfo(pMethodDesc, NULL /* startAddr */); + pDJI = pDMI->FindOrCreateInitAndAddJitInfo(pMethodDesc, PINSTRToPCODE(dac_cast(pEvent->BreakpointData.codeStartAddress))); } { @@ -12625,7 +12588,7 @@ DWORD Debugger::GetThreadIdHelper(Thread *pThread) // does not own the memory provided via vars outparameter. //----------------------------------------------------------------------------- void Debugger::GetVarInfo(MethodDesc * fd, // [IN] method of interest - void *DebuggerVersionToken, // [IN] which edit version + CORDB_ADDRESS nativeCodeAddress, // [IN] which edit version SIZE_T * cVars, // [OUT] size of 'vars' const ICorDebugInfo::NativeVarInfo **vars // [OUT] map telling where local vars are stored ) @@ -12637,7 +12600,7 @@ void Debugger::GetVarInfo(MethodDesc * fd, // [IN] method of interest } CONTRACTL_END; - DebuggerJitInfo * ji = (DebuggerJitInfo *)DebuggerVersionToken; + DebuggerJitInfo * ji = g_pDebugger->GetJitInfo(fd, (const BYTE *)nativeCodeAddress); // If we didn't supply a DJI, then we're asking for the most recent version. if (ji == NULL) @@ -12961,6 +12924,11 @@ HRESULT Debugger::UpdateFunction(MethodDesc* pMD, SIZE_T encVersion) // For each offset in the IL->Native map, set a new EnC breakpoint on the // ones that we know could be remap points. + + // Depending on which DJI was picked, the code might compute different IL offsets. The JIT may not guarantee it produces + // the same set of sequence points for every generic instantiation. + // Inside ENCSequencePointHelper there is logic that skips IL offsets that map to the same native offset. + // Its possible that one version of the code maps two IL offsets to the same native offset but another version of the code maps them to different offsets. PTR_DebuggerILToNativeMap seqMap = pJitInfo->GetSequenceMap(); for (unsigned int i = 0; i < pJitInfo->GetSequenceMapCount(); i++) { diff --git a/src/coreclr/debug/ee/debugger.h b/src/coreclr/debug/ee/debugger.h index 26edd26a96140b..2c2440ddaf6977 100644 --- a/src/coreclr/debug/ee/debugger.h +++ b/src/coreclr/debug/ee/debugger.h @@ -1933,8 +1933,6 @@ class Debugger : public DebugInterface bool IsJMCMethod(Module* pModule, mdMethodDef tkMethod); - int GetMethodEncNumber(MethodDesc * pMethod); - bool FirstChanceManagedException(Thread *pThread, SIZE_T currentIP, SIZE_T currentSP); @@ -1980,7 +1978,7 @@ class Debugger : public DebugInterface #endif // EnC_SUPPORTED void GetVarInfo(MethodDesc * fd, // [IN] method of interest - void *DebuggerVersionToken, // [IN] which edit version + CORDB_ADDRESS nativeCodeAddress, // [IN] which edit version SIZE_T * cVars, // [OUT] size of 'vars' const ICorDebugInfo::NativeVarInfo **vars // [OUT] map telling where local vars are stored ); diff --git a/src/coreclr/debug/ee/functioninfo.cpp b/src/coreclr/debug/ee/functioninfo.cpp index 76d4be3ab232f2..6eaa02d2c6de6f 100644 --- a/src/coreclr/debug/ee/functioninfo.cpp +++ b/src/coreclr/debug/ee/functioninfo.cpp @@ -1565,9 +1565,7 @@ DebuggerJitInfo *DebuggerMethodInfo::FindOrCreateInitAndAddJitInfo(MethodDesc* f GC_NOTRIGGER; } CONTRACTL_END; - _ASSERTE(fd != NULL); - // The debugger doesn't track Lightweight-codegen methods b/c they have no metadata. if (fd->IsDynamicMethod()) { @@ -1576,15 +1574,11 @@ DebuggerJitInfo *DebuggerMethodInfo::FindOrCreateInitAndAddJitInfo(MethodDesc* f if (startAddr == NULL) { - // This will grab the start address for the current code version. startAddr = g_pEEInterface->GetFunctionAddress(fd); if (startAddr == NULL) { - startAddr = fd->GetNativeCodeReJITAware(); - if (startAddr == NULL) - { - return NULL; - } + //The only case this should happen is if we are trying to get the DJI for a method that has not been jitted yet. + return NULL; } } else diff --git a/src/coreclr/debug/inc/dbgipcevents.h b/src/coreclr/debug/inc/dbgipcevents.h index 9fe1afd31a54ba..e9643e50f480a2 100644 --- a/src/coreclr/debug/inc/dbgipcevents.h +++ b/src/coreclr/debug/inc/dbgipcevents.h @@ -2011,6 +2011,7 @@ struct MSLAYOUT DebuggerIPCEvent SIZE_T offset; SIZE_T encVersion; LSPTR_METHODDESC nativeCodeMethodDescToken; // points to the MethodDesc if !isIL + CORDB_ADDRESS codeStartAddress; } BreakpointData; struct MSLAYOUT diff --git a/src/coreclr/gc/gc.cpp b/src/coreclr/gc/gc.cpp index daeadfe9821b8c..7351954070725e 100644 --- a/src/coreclr/gc/gc.cpp +++ b/src/coreclr/gc/gc.cpp @@ -823,6 +823,11 @@ class t_join join_struct.r_join_lock = n_th; } + int get_num_threads() + { + return join_struct.n_threads; + } + void destroy () { dprintf (JOIN_LOG, ("Destroying join structure")); @@ -887,6 +892,8 @@ class t_join // avoid race due to the thread about to reset the event (occasionally) being preempted before ResetEvent() if (color == join_struct.lock_color.LoadWithoutBarrier()) { + dprintf (9999, ("---h%d %d j%d %d - respin!!! (c:%d-%d)", + gch->heap_number, join_id, join_struct.n_threads, color, join_struct.lock_color.LoadWithoutBarrier())); goto respin; } @@ -1117,6 +1124,25 @@ t_join bgc_t_join; } \ } +#define spin_and_wait(count_to_spin, expr) \ +{ \ + while (!expr) \ + { \ + for (int j = 0; j < count_to_spin; j++) \ + { \ + if (expr) \ + { \ + break; \ + } \ + YieldProcessor (); \ + } \ + if (!(expr)) \ + { \ + GCToOSInterface::YieldThread (0); \ + } \ + } \ +} + #ifdef BACKGROUND_GC #define max_pending_allocs 64 @@ -1429,8 +1455,6 @@ enter_msl_status gc_heap::enter_spin_lock_msl_helper (GCSpinLock* msl) { #ifdef DYNAMIC_HEAP_COUNT uint64_t start = GetHighPrecisionTimeStamp(); - - msl->msl_wait_count++; #endif //DYNAMIC_HEAP_COUNT unsigned int i = 0; @@ -1485,7 +1509,7 @@ enter_msl_status gc_heap::enter_spin_lock_msl_helper (GCSpinLock* msl) #ifdef DYNAMIC_HEAP_COUNT uint64_t end = GetHighPrecisionTimeStamp(); Interlocked::ExchangeAdd64 (&msl->msl_wait_time, end - start); - dprintf (6666, ("wait for msl lock total time: %zd, total count: %zd, this time: %zd, this count: %u", msl->msl_wait_time, msl->msl_wait_count, end - start, i)); + dprintf (3, ("h%d wait for msl lock wait time %zd, total wait time: %zd", heap_number, (end - start), msl->msl_wait_time)); #endif //DYNAMIC_HEAP_COUNT } while (Interlocked::CompareExchange (&msl->lock, lock_taken, lock_free) != lock_free); @@ -2318,9 +2342,6 @@ sorted_table* gc_heap::seg_table; #ifdef MULTIPLE_HEAPS GCEvent gc_heap::ee_suspend_event; -#ifdef DYNAMIC_HEAP_COUNT -GCEvent gc_heap::gc_idle_thread_event; -#endif //DYNAMIC_HEAP_COUNT size_t gc_heap::min_gen0_balance_delta = 0; size_t gc_heap::min_balance_threshold = 0; #endif //MULTIPLE_HEAPS @@ -2919,6 +2940,12 @@ BOOL gc_heap::should_expand_in_full_gc = FALSE; #ifdef DYNAMIC_HEAP_COUNT int gc_heap::dynamic_adaptation_mode = dynamic_adaptation_default; gc_heap::dynamic_heap_count_data_t SVR::gc_heap::dynamic_heap_count_data; +uint64_t gc_heap::last_suspended_end_time = 0; +size_t gc_heap::gc_index_full_gc_end = 0; + +#ifdef STRESS_DYNAMIC_HEAP_COUNT +int gc_heap::heaps_in_this_gc = 0; +#endif //STRESS_DYNAMIC_HEAP_COUNT #endif // DYNAMIC_HEAP_COUNT // Provisional mode related stuff. @@ -6967,12 +6994,6 @@ BOOL gc_heap::create_thread_support (int number_of_heaps) { goto cleanup; } -#ifdef DYNAMIC_HEAP_COUNT - if (!gc_idle_thread_event.CreateOSManualEventNoThrow (FALSE)) - { - goto cleanup; - } -#endif //DYNAMIC_HEAP_COUNT if (!ee_suspend_event.CreateOSAutoEventNoThrow (FALSE)) { goto cleanup; @@ -7020,10 +7041,6 @@ bool gc_heap::create_gc_thread () return GCToEEInterface::CreateThread(gc_thread_stub, this, false, ".NET Server GC"); } -#ifdef DYNAMIC_HEAP_COUNT -static size_t prev_change_heap_count_gc_index; -#endif //DYNAMIC_HEAP_COUNT - #ifdef _MSC_VER #pragma warning(disable:4715) //IA64 xcompiler recognizes that without the 'break;' the while(1) will never end and therefore not return a value for that code path #endif //_MSC_VER @@ -7042,18 +7059,87 @@ void gc_heap::gc_thread_function () if (heap_number == 0) { - uint32_t wait_result = gc_heap::ee_suspend_event.Wait(gradual_decommit_in_progress_p ? DECOMMIT_TIME_STEP_MILLISECONDS : INFINITE, FALSE); + bool wait_on_time_out_p = gradual_decommit_in_progress_p; + uint32_t wait_time = DECOMMIT_TIME_STEP_MILLISECONDS; +#ifdef DYNAMIC_HEAP_COUNT + // background_running_p can only change from false to true during suspension. + if (!gc_heap::background_running_p () && dynamic_heap_count_data.should_change_heap_count) + { + assert (dynamic_adaptation_mode == dynamic_adaptation_to_application_sizes); + + dynamic_heap_count_data_t::sample& sample = dynamic_heap_count_data.samples[dynamic_heap_count_data.sample_index]; + wait_time = min (wait_time, (uint32_t)(sample.elapsed_between_gcs / 1000 / 3)); + wait_time = max (wait_time, 1); + + dprintf (6666, ("gc#0 thread waiting for %d ms (betwen GCs %I64d)", wait_time, sample.elapsed_between_gcs)); + } +#endif //DYNAMIC_HEAP_COUNT + uint32_t wait_result = gc_heap::ee_suspend_event.Wait(wait_on_time_out_p ? wait_time : INFINITE, FALSE); + dprintf (9999, ("waiting for ee done res %d (timeout %d, %I64d ms since last suspend end)(should_change_heap_count is %d) (gradual_decommit_in_progress_p %d)", + wait_result, wait_time, ((GetHighPrecisionTimeStamp() - last_suspended_end_time) / 1000), + dynamic_heap_count_data.should_change_heap_count, gradual_decommit_in_progress_p)); if (wait_result == WAIT_TIMEOUT) { - decommit_lock.Enter(); - gradual_decommit_in_progress_p = decommit_step (DECOMMIT_TIME_STEP_MILLISECONDS); - decommit_lock.Leave(); +#ifdef DYNAMIC_HEAP_COUNT + if (dynamic_heap_count_data.should_change_heap_count) + { +#ifdef BACKGROUND_GC + if (!gc_heap::background_running_p ()) +#endif //BACKGROUND_GC + { + dprintf (6666, ("changing heap count due to timeout")); + check_heap_count(); + } + } +#endif //DYNAMIC_HEAP_COUNT + + if (gradual_decommit_in_progress_p) + { + decommit_lock.Enter (); + gradual_decommit_in_progress_p = decommit_step (DECOMMIT_TIME_STEP_MILLISECONDS); + decommit_lock.Leave (); + } continue; } +#ifdef DYNAMIC_HEAP_COUNT + // We might want to consider also doing this when a BGC finishes. + if (dynamic_heap_count_data.should_change_heap_count) + { +#ifdef BACKGROUND_GC + if (!gc_heap::background_running_p ()) +#endif //BACKGROUND_GC + { + // this was a request to do a GC so make sure we follow through with one. + dprintf (6666, ("changing heap count at a GC start")); + check_heap_count (); + } + } + + // wait till the threads that should have gone idle at least reached the place where they are about to wait on the idle event. + if ((gc_heap::dynamic_adaptation_mode == dynamic_adaptation_to_application_sizes) && + (n_heaps != dynamic_heap_count_data.last_n_heaps)) + { + int spin_count = 1024; + int idle_thread_count = n_max_heaps - n_heaps; + dprintf (9999, ("heap count changed %d->%d, idle should be %d and is %d", dynamic_heap_count_data.last_n_heaps, n_heaps, + idle_thread_count, VolatileLoadWithoutBarrier (&dynamic_heap_count_data.idle_thread_count))); + if (idle_thread_count != dynamic_heap_count_data.idle_thread_count) + { + spin_and_wait (spin_count, (idle_thread_count == dynamic_heap_count_data.idle_thread_count)); + dprintf (9999, ("heap count changed %d->%d, now idle is %d", dynamic_heap_count_data.last_n_heaps, n_heaps, + VolatileLoadWithoutBarrier (&dynamic_heap_count_data.idle_thread_count))); + } + + dynamic_heap_count_data.last_n_heaps = n_heaps; + } +#endif //DYNAMIC_HEAP_COUNT + suspended_start_time = GetHighPrecisionTimeStamp(); BEGIN_TIMING(suspend_ee_during_log); + dprintf (9999, ("h0 suspending EE in GC!")); GCToEEInterface::SuspendEE(SUSPEND_FOR_GC); + dprintf (9999, ("h0 suspended EE in GC!")); END_TIMING(suspend_ee_during_log); proceed_with_gc_p = TRUE; @@ -7067,46 +7153,74 @@ void gc_heap::gc_thread_function () { settings.init_mechanisms(); #ifdef DYNAMIC_HEAP_COUNT - // make sure the other gc threads cannot see this as a request to change heap count - // see explanation below about the cases when we return from gc_start_event.Wait - assert (dynamic_heap_count_data.new_n_heaps == n_heaps); + if (gc_heap::dynamic_adaptation_mode == dynamic_adaptation_to_application_sizes) + { + // make sure the other gc threads cannot see this as a request to change heap count + // see explanation below about the cases when we return from gc_start_event.Wait + assert (dynamic_heap_count_data.new_n_heaps == n_heaps); + } #endif //DYNAMIC_HEAP_COUNT + dprintf (9999, ("GC thread %d setting_gc_start_in_gc(h%d)", heap_number, n_heaps)); gc_start_event.Set(); } dprintf (3, (ThreadStressLog::gcServerThread0StartMsg(), heap_number)); } else { + dprintf (9999, ("GC thread %d waiting_for_gc_start(%d)(gc%Id)", heap_number, n_heaps, VolatileLoadWithoutBarrier(&settings.gc_index))); gc_start_event.Wait(INFINITE, FALSE); #ifdef DYNAMIC_HEAP_COUNT - // we have a couple different cases to handle here when we come back from the wait: - // 1. We are starting a GC. Signaled by dynamic_heap_count_data.new_n_heaps == n_heaps - // a) We are starting a GC, but this thread is idle. Signaled by n_heaps <= heap_number - // b) We are starting a GC, and this thread is participating. Signaled by heap_number < n_heaps - // 2. We are changing heap count. Signaled by dynamic_heap_count_data.new_n_heaps != n_heaps - // a) We are changing heap count, but this thread is idle. Signaled by n_heaps <= heap_number. - // b) We are changing heap count, and this thread is participating. Signaled by heap_number < n_heaps. - - // check for 1.a) and 2.a) cases above - if (n_heaps <= heap_number) - { - dprintf (2, ("GC thread %d idle", heap_number)); - - // make sure GC is complete so we know the gc_idle_thread_event has been reset - g_theGCHeap->WaitUntilGCComplete(); + dprintf (9999, ("GC thread %d waiting_done_gc_start(%d-%d)(i: %d)(gc%Id)", + heap_number, n_heaps, dynamic_heap_count_data.new_n_heaps, dynamic_heap_count_data.init_only_p, VolatileLoadWithoutBarrier (&settings.gc_index))); + + if ((gc_heap::dynamic_adaptation_mode == dynamic_adaptation_to_application_sizes) && + (dynamic_heap_count_data.new_n_heaps != n_heaps)) + { + // The reason why we need to do this is - + // + for threads that were participating, we need them to do work for change_heap_count + // + for threads that were not participating but will need to participate, we need to make sure they are woken now instead of + // randomly sometime later. + int old_n_heaps = n_heaps; + int new_n_heaps = dynamic_heap_count_data.new_n_heaps; + int num_threads_to_wake = max (new_n_heaps, old_n_heaps); + if (heap_number < num_threads_to_wake) + { + dprintf (9999, ("h%d < %d, calling change", heap_number, num_threads_to_wake)); + change_heap_count (dynamic_heap_count_data.new_n_heaps); + if (new_n_heaps < old_n_heaps) + { + dprintf (9999, ("h%d after change", heap_number)); + // at the end of change_heap_count we've changed join's heap count to the new one if it's smaller. So we need to make sure + // only that many threads will participate in the following GCs. + if (heap_number < new_n_heaps) + { + dprintf (9999, ("h%d < %d participating (dec)", heap_number, new_n_heaps)); + } + else + { + Interlocked::Increment (&dynamic_heap_count_data.idle_thread_count); + dprintf (9999, ("GC thread %d wait_on_idle(%d < %d)(gc%Id), total idle %d", heap_number, old_n_heaps, new_n_heaps, + VolatileLoadWithoutBarrier (&settings.gc_index), VolatileLoadWithoutBarrier (&dynamic_heap_count_data.idle_thread_count))); + gc_idle_thread_event.Wait (INFINITE, FALSE); + dprintf (9999, ("GC thread %d waking_from_idle(%d)(gc%Id) after doing change", heap_number, n_heaps, VolatileLoadWithoutBarrier (&settings.gc_index))); + } + } + else + { + dprintf (9999, ("h%d < %d participating (inc)", heap_number, new_n_heaps)); + } + } + else + { + Interlocked::Increment (&dynamic_heap_count_data.idle_thread_count); + dprintf (9999, ("GC thread %d wait_on_idle(< max %d)(gc%Id), total idle %d", heap_number, num_threads_to_wake, + VolatileLoadWithoutBarrier (&settings.gc_index), VolatileLoadWithoutBarrier (&dynamic_heap_count_data.idle_thread_count))); + gc_idle_thread_event.Wait (INFINITE, FALSE); + dprintf (9999, ("GC thread %d waking_from_idle(%d)(gc%Id)", heap_number, n_heaps, VolatileLoadWithoutBarrier (&settings.gc_index))); + } - // now wait on the gc_idle_thread_event - gc_idle_thread_event.Wait(INFINITE, FALSE); - dprintf (2, ("GC thread %d waking from idle", heap_number)); - continue; - } - // case 2.b) above: is this a request to change heap count? - if (dynamic_heap_count_data.new_n_heaps != n_heaps) - { - change_heap_count (dynamic_heap_count_data.new_n_heaps); continue; } - // case 1.b) above: we're starting a GC. #endif //DYNAMIC_HEAP_COUNT dprintf (3, (ThreadStressLog::gcServerThreadNStartMsg(), heap_number)); } @@ -7191,10 +7305,6 @@ void gc_heap::gc_thread_function () { gradual_decommit_in_progress_p = decommit_step (DECOMMIT_TIME_STEP_MILLISECONDS); } -#ifdef DYNAMIC_HEAP_COUNT - // check if we should adjust the number of heaps - check_heap_count(); -#endif //DYNAMIC_HEAP_COUNT } else { @@ -12527,6 +12637,16 @@ void gc_heap::rearrange_uoh_segments() freeable_uoh_segment = 0; } +void gc_heap::delay_free_segments() +{ + rearrange_uoh_segments(); +#ifdef BACKGROUND_GC + background_delay_delete_uoh_segments(); + if (!gc_heap::background_running_p()) + rearrange_small_heap_segments(); +#endif //BACKGROUND_GC +} + #ifndef USE_REGIONS void gc_heap::rearrange_heap_segments(BOOL compacting) { @@ -14860,6 +14980,25 @@ gc_heap::init_gc_heap (int h_number) gc_done_event_lock = -1; gc_done_event_set = false; +#ifdef DYNAMIC_HEAP_COUNT + if (h_number != 0) + { + if (!gc_idle_thread_event.CreateAutoEventNoThrow (FALSE)) + { + return 0; + } + +#ifdef BACKGROUND_GC + if (!bgc_idle_thread_event.CreateAutoEventNoThrow (FALSE)) + { + return 0; + } +#endif //BACKGROUND_GC + + dprintf (9999, ("creating idle events for h%d", h_number)); + } +#endif //DYNAMIC_HEAP_COUNT + if (!init_dynamic_data()) { return 0; @@ -16038,7 +16177,6 @@ void min_fl_list_info::thread_item_no_prev (uint8_t* item) tail = item; } -// This is only implemented for gen2 right now!!!! // the min_fl_list array is arranged as chunks of n_heaps min_fl_list_info, the 1st chunk corresponds to the 1st bucket, // and so on. void allocator::rethread_items (size_t* num_total_fl_items, size_t* num_total_fl_items_rethreaded, gc_heap* current_heap, @@ -17406,6 +17544,7 @@ BOOL gc_heap::a_fit_free_list_uoh_p (size_t size, gen_number, align_const); dd_new_allocation (dynamic_data_of (gen_number)) -= limit; + size_t saved_free_list_size = free_list_size; #ifdef FEATURE_LOH_COMPACTION if (loh_pad) { @@ -17434,7 +17573,7 @@ BOOL gc_heap::a_fit_free_list_uoh_p (size_t size, { generation_free_obj_space (gen) += remain_size; } - generation_free_list_space (gen) -= free_list_size; + generation_free_list_space (gen) -= saved_free_list_size; assert ((ptrdiff_t)generation_free_list_space (gen) >= 0); generation_free_list_allocated (gen) += limit; @@ -22000,11 +22139,70 @@ BOOL gc_heap::should_proceed_with_gc() void gc_heap::update_end_gc_time_per_heap() { +#ifdef DYNAMIC_HEAP_COUNT + size_t prev_gen2_end_time = 0; + if ((heap_number == 0) && (dynamic_adaptation_mode == dynamic_adaptation_to_application_sizes) && (settings.condemned_generation == max_generation)) + { + dynamic_data* dd = dynamic_data_of (max_generation); + prev_gen2_end_time = dd_previous_time_clock (dd) + dd_gc_elapsed_time (dd);; + } +#endif //DYNAMIC_HEAP_COUNT + for (int gen_number = 0; gen_number <= settings.condemned_generation; gen_number++) { dynamic_data* dd = dynamic_data_of (gen_number); + + if (heap_number == 0) + { + dprintf (6666, ("prev gen%d GC end time: prev start %I64d + prev gc elapsed %Id = %I64d", + gen_number, dd_previous_time_clock (dd), dd_gc_elapsed_time (dd), (dd_previous_time_clock (dd) + dd_gc_elapsed_time (dd)))); + } + dd_gc_elapsed_time (dd) = (size_t)(end_gc_time - dd_time_clock (dd)); + + if (heap_number == 0) + { + dprintf (6666, ("updated NGC%d %Id elapsed time to %I64d - %I64d = %I64d", gen_number, dd_gc_clock (dd), end_gc_time, dd_time_clock (dd), dd_gc_elapsed_time (dd))); + } + } + +#ifdef DYNAMIC_HEAP_COUNT + if ((heap_number == 0) && (dynamic_adaptation_mode == dynamic_adaptation_to_application_sizes)) + { + dynamic_heap_count_data_t::sample& sample = dynamic_heap_count_data.samples[dynamic_heap_count_data.sample_index]; + sample.elapsed_between_gcs = end_gc_time - last_suspended_end_time; + sample.gc_pause_time = dd_gc_elapsed_time (dynamic_data_of (0)); + sample.msl_wait_time = get_msl_wait_time(); + + dprintf (6666, ("sample#%d: this GC end %I64d - last sus end %I64d = %I64d, this GC pause %I64d, msl wait %I64d", + dynamic_heap_count_data.sample_index, end_gc_time, last_suspended_end_time, sample.elapsed_between_gcs, sample.gc_pause_time, sample.msl_wait_time)); + + last_suspended_end_time = end_gc_time; + + GCEventFireHeapCountSample_V1 ( + (uint64_t)VolatileLoadWithoutBarrier (&settings.gc_index), + sample.elapsed_between_gcs, + sample.gc_pause_time, + sample.msl_wait_time); + + dynamic_heap_count_data.sample_index = (dynamic_heap_count_data.sample_index + 1) % dynamic_heap_count_data_t::sample_size; + + if (settings.condemned_generation == max_generation) + { + gc_index_full_gc_end = dd_gc_clock (dynamic_data_of (0)); + size_t elapsed_between_gen2_gcs = end_gc_time - prev_gen2_end_time; + size_t gen2_elapsed_time = sample.gc_pause_time; + dynamic_heap_count_data.gen2_gc_percents[dynamic_heap_count_data.gen2_sample_index] = (float)gen2_elapsed_time * 100.0f / elapsed_between_gen2_gcs; + + dprintf (6666, ("gen2 sample#%d: this GC end %I64d - last gen2 end %I64d = %I64d, GC elapsed %I64d, percent %.3f", + dynamic_heap_count_data.gen2_sample_index, end_gc_time, prev_gen2_end_time, elapsed_between_gen2_gcs, + gen2_elapsed_time, dynamic_heap_count_data.gen2_gc_percents[dynamic_heap_count_data.gen2_sample_index])); + dynamic_heap_count_data.gen2_sample_index = (dynamic_heap_count_data.gen2_sample_index + 1) % dynamic_heap_count_data_t::sample_size; + } + + calculate_new_heap_count (); } +#endif //DYNAMIC_HEAP_COUNT } void gc_heap::update_end_ngc_time() @@ -22151,7 +22349,31 @@ void gc_heap::gc1() { dynamic_data* dd = dynamic_data_of (n); end_gc_time = GetHighPrecisionTimeStamp(); + size_t time_since_last_gen2 = 0; + +#ifdef DYNAMIC_HEAP_COUNT + if ((heap_number == 0) && (dynamic_adaptation_mode == dynamic_adaptation_to_application_sizes)) + { + time_since_last_gen2 = (size_t)(end_gc_time - (dd_previous_time_clock (dd) + dd_gc_elapsed_time (dd))); + dprintf (6666, ("BGC %Id end %I64d - (prev gen2 start %I64d + elapsed %Id = %I64d) = time inbewteen gen2 %Id", + dd_gc_clock (dd), end_gc_time, dd_previous_time_clock (dd), dd_gc_elapsed_time (dd), (dd_previous_time_clock (dd) + dd_gc_elapsed_time (dd)), time_since_last_gen2)); + } +#endif //DYNAMIC_HEAP_COUNT + dd_gc_elapsed_time (dd) = (size_t)(end_gc_time - dd_time_clock (dd)); +#ifdef DYNAMIC_HEAP_COUNT + if ((heap_number == 0) && (dynamic_adaptation_mode == dynamic_adaptation_to_application_sizes)) + { + dprintf (6666, ("updating BGC %Id elapsed time to %I64d - %I64d = %I64d", dd_gc_clock (dd), end_gc_time, dd_time_clock (dd), dd_gc_elapsed_time (dd))); + + float bgc_percent = (float)dd_gc_elapsed_time (dd) * 100.0f / (float)time_since_last_gen2; + dynamic_heap_count_data.gen2_gc_percents[dynamic_heap_count_data.gen2_sample_index] = bgc_percent; + dprintf (6666, ("gen2 sample %d elapsed %Id * 100 / time inbetween gen2 %Id = %.3f", + dynamic_heap_count_data.gen2_sample_index, dd_gc_elapsed_time (dd), time_since_last_gen2, bgc_percent)); + dynamic_heap_count_data.gen2_sample_index = (dynamic_heap_count_data.gen2_sample_index + 1) % dynamic_heap_count_data_t::sample_size; + gc_index_full_gc_end = dd_gc_clock (dynamic_data_of (0)); + } +#endif //DYNAMIC_HEAP_COUNT #ifdef HEAP_BALANCE_INSTRUMENTATION if (heap_number == 0) @@ -22758,7 +22980,12 @@ void gc_heap::merge_fl_from_other_heaps (int gen_idx, int to_n_heaps, int from_n assert (free_list_space_decrease <= generation_free_list_space (gen)); generation_free_list_space (gen) -= free_list_space_decrease; - assert (free_list_space_decrease <= dd_fragmentation (dd)); + // TODO - I'm seeing for gen2 this is free_list_space_decrease can be a bit larger than frag. + // Need to fix this later. + if (gen_idx != max_generation) + { + assert (free_list_space_decrease <= dd_fragmentation (dd)); + } size_t free_list_space_increase = 0; for (int from_hn = 0; from_hn < from_n_heaps; from_hn++) @@ -23733,9 +23960,6 @@ void gc_heap::garbage_collect (int n) #ifdef MULTIPLE_HEAPS gc_start_event.Reset(); -#ifdef DYNAMIC_HEAP_COUNT - gc_idle_thread_event.Reset(); -#endif //DYNAMIC_HEAP_COUNT gc_t_join.restart(); #endif //MULTIPLE_HEAPS } @@ -23757,6 +23981,9 @@ void gc_heap::garbage_collect (int n) #endif // STRESS_HEAP #ifdef MULTIPLE_HEAPS +#ifdef STRESS_DYNAMIC_HEAP_COUNT + Interlocked::Increment (&heaps_in_this_gc); +#endif //STRESS_DYNAMIC_HEAP_COUNT //align all heaps on the max generation to condemn dprintf (3, ("Joining for max generation to condemn")); condemned_generation_num = generation_to_condemn (n, @@ -23772,30 +23999,31 @@ void gc_heap::garbage_collect (int n) #endif //FEATURE_BASICFREEZE #ifdef MULTIPLE_HEAPS +#ifdef STRESS_DYNAMIC_HEAP_COUNT + dprintf (9999, ("%d heaps, join sees %d, actually joined %d, %d idle threads (%d)", + n_heaps, gc_t_join.get_num_threads (), heaps_in_this_gc, + VolatileLoadWithoutBarrier(&dynamic_heap_count_data.idle_thread_count), (n_max_heaps - n_heaps))); + if (heaps_in_this_gc != n_heaps) + { + dprintf (9999, ("should have %d heaps but actually have %d!!", n_heaps, heaps_in_this_gc)); + GCToOSInterface::DebugBreak (); + } + + heaps_in_this_gc = 0; +#endif //STRESS_DYNAMIC_HEAP_COUNT + for (int i = 0; i < n_heaps; i++) { gc_heap* hp = g_heaps[i]; // check for card table growth if (g_gc_card_table != hp->card_table) hp->copy_brick_card_table(); - - hp->rearrange_uoh_segments(); -#ifdef BACKGROUND_GC - hp->background_delay_delete_uoh_segments(); - if (!gc_heap::background_running_p()) - hp->rearrange_small_heap_segments(); -#endif //BACKGROUND_GC + hp->delay_free_segments(); } #else //MULTIPLE_HEAPS if (g_gc_card_table != card_table) copy_brick_card_table(); - - rearrange_uoh_segments(); -#ifdef BACKGROUND_GC - background_delay_delete_uoh_segments(); - if (!gc_heap::background_running_p()) - rearrange_small_heap_segments(); -#endif //BACKGROUND_GC + delay_free_segments(); #endif //MULTIPLE_HEAPS BOOL should_evaluate_elevation = TRUE; @@ -23882,10 +24110,8 @@ void gc_heap::garbage_collect (int n) do_pre_gc(); #ifdef MULTIPLE_HEAPS + dprintf (9999, ("in GC, resetting gc_start")); gc_start_event.Reset(); -#ifdef DYNAMIC_HEAP_COUNT - gc_idle_thread_event.Reset(); -#endif //DYNAMIC_HEAP_COUNT dprintf(3, ("Starting all gc threads for gc")); gc_t_join.restart(); #endif //MULTIPLE_HEAPS @@ -24341,7 +24567,7 @@ void gc_heap::equalize_promoted_bytes(int condemned_gen_number) // hope is to achieve better work balancing in relocate and compact phases // this is also used when the heap count changes to balance regions between heaps int highest_gen_number = ((condemned_gen_number == max_generation) ? - (total_generation_count - 1) : condemned_gen_number); + (total_generation_count - 1) : condemned_gen_number); int stop_gen_idx = get_stop_generation_index (condemned_gen_number); for (int gen_idx = highest_gen_number; gen_idx >= stop_gen_idx; gen_idx--) @@ -25050,285 +25276,332 @@ void gc_heap::recommission_heap() #endif //RECORD_LOH_STATE } -void gc_heap::check_heap_count () +float median_of_3 (float a, float b, float c) { - dynamic_heap_count_data.new_n_heaps = n_heaps; +#define compare_and_swap(i, j) \ + { \ + if (i < j) \ + { \ + float t = i; \ + i = j; \ + j = t; \ + } \ + } + compare_and_swap (b, a); + compare_and_swap (c, a); + compare_and_swap (c, b); +#undef compare_and_swap + return b; +} - if (dynamic_adaptation_mode != dynamic_adaptation_to_application_sizes) +size_t gc_heap::get_num_completed_gcs () +{ + size_t num_completed_gcs = settings.gc_index; +#ifdef BACKGROUND_GC + if (g_heaps[0]->is_bgc_in_progress ()) { - return; + num_completed_gcs--; + dprintf (6666, ("BGC in prog, completed GCs -> %Id", num_completed_gcs)); } +#endif //BACKGROUND_GC - // we should be calling this only on the main GC thread - assert (heap_number == 0); + return num_completed_gcs; +} - // acquire data for the current sample - uint64_t soh_msl_wait_time = 0; - uint64_t uoh_msl_wait_time = 0; - size_t allocating_thread_count = 0; - size_t heap_size = 0; - for (int i = 0; i < n_heaps; i++) +int gc_heap::calculate_new_heap_count () +{ + assert (dynamic_adaptation_mode == dynamic_adaptation_to_application_sizes); + + size_t num_completed_gcs = get_num_completed_gcs (); + + dprintf (6666, ("current GC %Id(completed: %Id), prev completed GCs %Id, last full GC happened at index %Id", + VolatileLoadWithoutBarrier (&settings.gc_index), num_completed_gcs, dynamic_heap_count_data.prev_num_completed_gcs, gc_index_full_gc_end)); + + if (num_completed_gcs < (dynamic_heap_count_data.prev_num_completed_gcs + dynamic_heap_count_data_t::sample_size)) { - gc_heap* hp = g_heaps[i]; + dprintf (6666, ("not enough GCs, skipping")); + return n_heaps; + } - allocating_thread_count += hp->alloc_contexts_used; + float median_gen2_tcp_percent = 0.0f; + if (gc_index_full_gc_end >= (settings.gc_index - dynamic_heap_count_data_t::sample_size)) + { + median_gen2_tcp_percent = dynamic_heap_count_data.get_median_gen2_gc_percent (); + } - soh_msl_wait_time += hp->more_space_lock_soh.msl_wait_time; - hp->more_space_lock_soh.msl_wait_time = 0; - hp->more_space_lock_soh.msl_wait_count = 0; + // If there was a blocking gen2 GC, the overhead would be very large and most likely we would not pick it. So we + // rely on the gen2 sample's overhead calculated above. + float throughput_cost_percents[dynamic_heap_count_data_t::sample_size]; + for (int i = 0; i < dynamic_heap_count_data_t::sample_size; i++) + { + dynamic_heap_count_data_t::sample& sample = dynamic_heap_count_data.samples[i]; + throughput_cost_percents[i] = (sample.elapsed_between_gcs ? (((float)sample.msl_wait_time / n_heaps + sample.gc_pause_time) * 100.0f / (float)sample.elapsed_between_gcs) : 0.0f); + assert (throughput_cost_percents[i] >= 0.0); + if (throughput_cost_percents[i] > 100.0) + throughput_cost_percents[i] = 100.0; + dprintf (6666, ("sample %d: msl %I64d / %d + pause %I64d / elapsed %I64d = throughput_cost_percent: %.3f", i, + sample.msl_wait_time, n_heaps, sample.gc_pause_time, sample.elapsed_between_gcs, throughput_cost_percents[i])); + } - uoh_msl_wait_time += hp->more_space_lock_uoh.msl_wait_time; - hp->more_space_lock_uoh.msl_wait_time = 0; - hp->more_space_lock_uoh.msl_wait_count = 0; + float median_throughput_cost_percent = median_of_3 (throughput_cost_percents[0], throughput_cost_percents[1], throughput_cost_percents[2]); + + // apply exponential smoothing and use 1/3 for the smoothing factor + const float smoothing = 3; + float smoothed_median_throughput_cost_percent = dynamic_heap_count_data.smoothed_median_throughput_cost_percent; + if (smoothed_median_throughput_cost_percent != 0.0f) + { + // average it with the previous value + smoothed_median_throughput_cost_percent = median_throughput_cost_percent / smoothing + (smoothed_median_throughput_cost_percent / smoothing) * (smoothing - 1); + } + else + { + smoothed_median_throughput_cost_percent = median_throughput_cost_percent; + } + + dprintf (6666, ("median tcp: %.3f, smoothed tcp: %.3f, gen2 tcp %.3f(%.3f, %.3f, %.3f)", + median_throughput_cost_percent, smoothed_median_throughput_cost_percent, median_gen2_tcp_percent, + dynamic_heap_count_data.gen2_gc_percents[0], dynamic_heap_count_data.gen2_gc_percents[1], dynamic_heap_count_data.gen2_gc_percents[2])); + + size_t heap_size = 0; + for (int i = 0; i < n_heaps; i++) + { + gc_heap* hp = g_heaps[i]; for (int gen_idx = 0; gen_idx < total_generation_count; gen_idx++) { dynamic_data* dd = hp->dynamic_data_of (gen_idx); // estimate the size of each generation as the live data size plus the budget - heap_size += dd_promoted_size (dd) + dd_desired_allocation (dd); - dprintf (6666, ("h%d g%d promoted: %zd desired allocation: %zd", i, gen_idx, dd_promoted_size (dd), dd_desired_allocation (dd))); + heap_size += dd_current_size (dd) + dd_desired_allocation (dd); + dprintf (3, ("h%d g%d current: %zd desired allocation: %zd", i, gen_idx, dd_promoted_size (dd), dd_desired_allocation (dd))); } } - dynamic_data* hp0_dd0 = g_heaps[0]->dynamic_data_of (0); + // estimate the space cost of adding a heap as the min gen0 budget + size_t heap_space_cost_per_heap = dd_min_size (g_heaps[0]->dynamic_data_of (0)); - // persist data for the current sample - dynamic_heap_count_data_t::sample& sample = dynamic_heap_count_data.samples[dynamic_heap_count_data.sample_index]; + // compute the % space cost of adding a heap + float percent_heap_space_cost_per_heap = heap_space_cost_per_heap * 100.0f / heap_size; - sample.soh_msl_wait_time = soh_msl_wait_time / n_heaps; - sample.uoh_msl_wait_time = uoh_msl_wait_time / n_heaps; - sample.elapsed_between_gcs = dd_time_clock (hp0_dd0) - dd_previous_time_clock (hp0_dd0); - sample.gc_elapsed_time = dd_gc_elapsed_time (hp0_dd0); - sample.allocating_thread_count = allocating_thread_count; - sample.heap_size = heap_size; + // compute reasonable step sizes for the heap count + // + // on the way up, we essentially multiply the heap count by 1.5, so we go 1, 2, 3, 5, 8 ... + // we don't go all the way to the number of CPUs, but stay 1 or 2 short + int step_up = (n_heaps + 1) / 2; + int extra_heaps = 1 + (n_max_heaps >= 32); + step_up = min (step_up, n_max_heaps - extra_heaps - n_heaps); - dprintf (6666, ("sample %d: soh_msl_wait_time: %zd, uoh_msl_wait_time: %zd, elapsed_between_gcs: %zd, gc_elapsed_time: %d, heap_size: %zd MB", - dynamic_heap_count_data.sample_index, - sample.soh_msl_wait_time, - sample.uoh_msl_wait_time, - sample.elapsed_between_gcs, - sample.gc_elapsed_time, - sample.heap_size/(1024*1024))); + // on the way down, we essentially divide the heap count by 1.5 + int step_down = (n_heaps + 1) / 3; - dynamic_heap_count_data.sample_index = (dynamic_heap_count_data.sample_index + 1) % dynamic_heap_count_data_t::sample_size; + // estimate the potential time benefit of going up a step + float tcp_reduction_per_step_up = smoothed_median_throughput_cost_percent * step_up / (n_heaps + step_up); - GCEventFireHeapCountSample_V1( - sample.gc_elapsed_time, - sample.soh_msl_wait_time, - sample.uoh_msl_wait_time, - sample.elapsed_between_gcs - ); + // estimate the potential time cost of going down a step + float tcp_increase_per_step_down = smoothed_median_throughput_cost_percent * step_down / (n_heaps - step_down); + + // estimate the potential space cost of going up a step + float scp_increase_per_step_up = percent_heap_space_cost_per_heap * step_up; - if (settings.gc_index < prev_change_heap_count_gc_index + 3) + // estimate the potential space saving of going down a step + float scp_decrease_per_step_down = percent_heap_space_cost_per_heap * step_down; + + dprintf (6666, ("[CHP] u %d, d %d | space cost %Id / heap %Id(%.2fmb) = scp %.3f (u: %.3f, d: %.3f) | stcp %.3f, u * %.1f = %.3f, d * %.1f = %.3f", + step_up, step_down, + heap_space_cost_per_heap, heap_size, ((float)heap_size / (float)1000 / (float)1000), percent_heap_space_cost_per_heap, + scp_increase_per_step_up, scp_decrease_per_step_down, + smoothed_median_throughput_cost_percent, + ((float)step_up / (float)(n_heaps + step_up)), tcp_reduction_per_step_up, + ((float)step_down / (float)(n_heaps - step_down)), tcp_increase_per_step_down)); + +#ifdef STRESS_DYNAMIC_HEAP_COUNT + // quick hack for initial testing + int new_n_heaps = (int)gc_rand::get_rand (n_max_heaps - 1) + 1; + + // if we are adjusting down, make sure we adjust lower than the lowest uoh msl heap + if ((new_n_heaps < n_heaps) && (dynamic_heap_count_data.lowest_heap_with_msl_uoh != -1)) { - // reconsider the decision every few gcs - return; + new_n_heaps = min (dynamic_heap_count_data.lowest_heap_with_msl_uoh, new_n_heaps); + new_n_heaps = max (new_n_heaps, 1); } - - if (gc_heap::background_running_p()) + dprintf (6666, ("stress %d -> %d", n_heaps, new_n_heaps)); +#else //STRESS_DYNAMIC_HEAP_COUNT + int new_n_heaps = n_heaps; + if (median_throughput_cost_percent > 10.0f) { - // can't have background gc running while we change the number of heaps - // so it's useless to compute a new number of heaps here + // ramp up more agressively - use as many heaps as it would take to bring + // the tcp down to 5% + new_n_heaps = (int)(n_heaps * (median_throughput_cost_percent / 5.0)); + dprintf (6666, ("[CHP0] tcp %.3f -> %d * %.3f = %d", median_throughput_cost_percent, n_heaps, (median_throughput_cost_percent / 5.0), new_n_heaps)); + new_n_heaps = min (new_n_heaps, n_max_heaps - extra_heaps); } - else + // if the median tcp is 10% or less, react slower + else if ((smoothed_median_throughput_cost_percent > 5.0f) || (median_gen2_tcp_percent > 10.0f)) { - // compute the % overhead from msl waiting time and gc time for each of the samples - float percent_overhead[dynamic_heap_count_data_t::sample_size]; - for (int i = 0; i < dynamic_heap_count_data_t::sample_size; i++) - { - dynamic_heap_count_data_t::sample& sample = dynamic_heap_count_data.samples[i]; - uint64_t overhead_time = sample.soh_msl_wait_time + sample.uoh_msl_wait_time + sample.gc_elapsed_time; - percent_overhead[i] = overhead_time * 100.0f / sample.elapsed_between_gcs; - if (percent_overhead[i] < 0) - percent_overhead[i] = 0; - else if (percent_overhead[i] > 100) - percent_overhead[i] = 100; - dprintf (6666, ("sample %d: percent_overhead: %d%%", i, (int)percent_overhead[i])); - } - // compute the median of the percent overhead samples - #define compare_and_swap(i, j) \ - { \ - if (percent_overhead[i] < percent_overhead[j]) \ - { \ - float t = percent_overhead[i]; \ - percent_overhead[i] = percent_overhead[j]; \ - percent_overhead[j] = t; \ - } \ - } - compare_and_swap (1, 0); - compare_and_swap (2, 0); - compare_and_swap (2, 1); - #undef compare_and_swap - - // the middle element is the median overhead percentage - float median_percent_overhead = percent_overhead[1]; - - // apply exponential smoothing and use 1/3 for the smoothing factor - const float smoothing = 3; - float smoothed_median_percent_overhead = dynamic_heap_count_data.smoothed_median_percent_overhead; - if (smoothed_median_percent_overhead != 0.0f) - { - // average it with the previous value - smoothed_median_percent_overhead = median_percent_overhead / smoothing + (smoothed_median_percent_overhead / smoothing) * (smoothing - 1); + if (smoothed_median_throughput_cost_percent > 5.0f) + { + dprintf (6666, ("[CHP1] stcp %.3f > 5, %d + %d = %d", smoothed_median_throughput_cost_percent, n_heaps, step_up, (n_heaps + step_up))); } else { - // first time? initialize to the median - smoothed_median_percent_overhead = median_percent_overhead; + dprintf (6666, ("[CHP2] tcp %.3f > 10, %d + %d = %d", median_gen2_tcp_percent, n_heaps, step_up, (n_heaps + step_up))); } + new_n_heaps += step_up; + } + // if we can save at least 1% more in time than we spend in space, increase number of heaps + else if ((tcp_reduction_per_step_up - scp_increase_per_step_up) >= 1.0f) + { + dprintf (6666, ("[CHP3] % .3f - % .3f = % .3f, % d + % d = % d", + tcp_reduction_per_step_up, scp_increase_per_step_up, (tcp_reduction_per_step_up - scp_increase_per_step_up), + n_heaps, step_up, (n_heaps + step_up))); + new_n_heaps += step_up; + } + // if we can save at least 1% more in space than we spend in time, decrease number of heaps + else if ((smoothed_median_throughput_cost_percent < 1.0f) && + (median_gen2_tcp_percent < 5.0f) && + ((scp_decrease_per_step_down - tcp_increase_per_step_down) >= 1.0f)) + { + dprintf (6666, ("[CHP4] stcp %.3f tcp %.3f, %.3f - %.3f = %.3f, %d + %d = %d", + smoothed_median_throughput_cost_percent, median_gen2_tcp_percent, + scp_decrease_per_step_down, tcp_increase_per_step_down, (scp_decrease_per_step_down - tcp_increase_per_step_down), + n_heaps, step_up, (n_heaps + step_up))); + new_n_heaps -= step_down; + } - dprintf (6666, ("median overhead: %d%% smoothed median overhead: %d%%", (int)(median_percent_overhead*1000), (int)(smoothed_median_percent_overhead*1000))); - - // estimate the space cost of adding a heap as the min gen0 size - size_t heap_space_cost_per_heap = dd_min_size (hp0_dd0); - - // compute the % space cost of adding a heap - float percent_heap_space_cost_per_heap = heap_space_cost_per_heap * 100.0f / heap_size; - - // compute reasonable step sizes for the heap count + assert (new_n_heaps >= 1); + assert (new_n_heaps <= n_max_heaps); +#endif //STRESS_DYNAMIC_HEAP_COUNT - // on the way up, we essentially multiply the heap count by 1.5, so we go 1, 2, 3, 5, 8 ... - // we don't go all the way to the number of CPUs, but stay 1 or 2 short - int step_up = (n_heaps + 1) / 2; - int extra_heaps = 1 + (n_max_heaps >= 32); - step_up = min (step_up, n_max_heaps - extra_heaps - n_heaps); + // store data used for decision to emit in ETW event + dynamic_heap_count_data.median_throughput_cost_percent = median_throughput_cost_percent; + dynamic_heap_count_data.smoothed_median_throughput_cost_percent = smoothed_median_throughput_cost_percent; + dynamic_heap_count_data.percent_heap_space_cost_per_heap = percent_heap_space_cost_per_heap; + dynamic_heap_count_data.tcp_reduction_per_step_up = tcp_reduction_per_step_up; + dynamic_heap_count_data.tcp_increase_per_step_down = tcp_increase_per_step_down; + dynamic_heap_count_data.scp_increase_per_step_up = scp_increase_per_step_up; + dynamic_heap_count_data.scp_decrease_per_step_down = scp_decrease_per_step_down; + + GCEventFireHeapCountTuning_V1 ( + (uint16_t)dynamic_heap_count_data.new_n_heaps, + (uint64_t)VolatileLoadWithoutBarrier (&settings.gc_index), + dynamic_heap_count_data.median_throughput_cost_percent, + dynamic_heap_count_data.smoothed_median_throughput_cost_percent, + dynamic_heap_count_data.tcp_reduction_per_step_up, + dynamic_heap_count_data.tcp_increase_per_step_down, + dynamic_heap_count_data.scp_increase_per_step_up, + dynamic_heap_count_data.scp_decrease_per_step_down + ); - // on the way down, we essentially divide the heap count by 1.5 - int step_down = (n_heaps + 1) / 3; + dynamic_heap_count_data.prev_num_completed_gcs = num_completed_gcs; - // estimate the potential time benefit of going up a step - float overhead_reduction_per_step_up = smoothed_median_percent_overhead * step_up / (n_heaps + step_up); + if (new_n_heaps != n_heaps) + { + dprintf (6666, ("should change! %d->%d", n_heaps, new_n_heaps)); + dynamic_heap_count_data.heap_count_to_change_to = new_n_heaps; + dynamic_heap_count_data.should_change_heap_count = true; + } - // estimate the potential time cost of going down a step - float overhead_increase_per_step_down = smoothed_median_percent_overhead * step_down / (n_heaps - step_down); + return new_n_heaps; +} - // estimate the potential space cost of going up a step - float space_cost_increase_per_step_up = percent_heap_space_cost_per_heap * step_up; +void gc_heap::check_heap_count () +{ + dynamic_heap_count_data.new_n_heaps = dynamic_heap_count_data.heap_count_to_change_to; - // estimate the potential space saving of going down a step - float space_cost_decrease_per_step_down = percent_heap_space_cost_per_heap * step_down; + assert (dynamic_heap_count_data.new_n_heaps != n_heaps); -#ifdef STRESS_DYNAMIC_HEAP_COUNT - // quick hack for initial testing - int new_n_heaps = (int)gc_rand::get_rand (n_max_heaps - 1) + 1; + if (dynamic_heap_count_data.new_n_heaps != n_heaps) + { + dprintf (9999, ("h0 suspending EE in check")); + // can't have threads allocating while we change the number of heaps + GCToEEInterface::SuspendEE(SUSPEND_FOR_GC_PREP); + dprintf (9999, ("h0 suspended EE in check")); - // if we are adjusting down, make sure we adjust lower than the lowest uoh msl heap - if ((new_n_heaps < n_heaps) && (dynamic_heap_count_data.lowest_heap_with_msl_uoh != -1)) +#ifdef BACKGROUND_GC + if (gc_heap::background_running_p()) { - new_n_heaps = min (dynamic_heap_count_data.lowest_heap_with_msl_uoh, new_n_heaps); + // background GC is running - reset the new heap count + dynamic_heap_count_data.new_n_heaps = n_heaps; + dprintf (6666, ("can't change heap count! BGC in progress")); - // but not down to zero, obviously... - new_n_heaps = max (new_n_heaps, 1); - } -#else //STRESS_DYNAMIC_HEAP_COUNT - int new_n_heaps = n_heaps; - if (median_percent_overhead > 10.0f) - { - // ramp up more agressively - use as many heaps as it would take to bring - // the overhead down to 5% - new_n_heaps = (int)(n_heaps * (median_percent_overhead / 5.0)); - new_n_heaps = min (new_n_heaps, n_max_heaps - extra_heaps); - } - // if the median overhead is 10% or less, react slower - else if (smoothed_median_percent_overhead > 5.0f) - { - new_n_heaps += step_up; - } - // if we can save at least 1% more in time than we spend in space, increase number of heaps - else if (overhead_reduction_per_step_up - space_cost_increase_per_step_up >= 1.0f) - { - new_n_heaps += step_up; - } - // if we can save at least 1% more in space than we spend in time, decrease number of heaps - else if (smoothed_median_percent_overhead < 1.0f && space_cost_decrease_per_step_down - overhead_increase_per_step_down >= 1.0f) - { - new_n_heaps -= step_down; + GCToEEInterface::RestartEE(TRUE); } +#endif //BACKGROUND_GC + } - dprintf (6666, ("or: %d, si: %d, sd: %d, oi: %d => %d -> %d", - (int)overhead_reduction_per_step_up, - (int)space_cost_increase_per_step_up, - (int)space_cost_decrease_per_step_down, - (int)overhead_increase_per_step_down, - n_heaps, - new_n_heaps)); - - assert (1 <= new_n_heaps); - assert (new_n_heaps <= n_max_heaps); -#endif //STRESS_DYNAMIC_HEAP_COUNT - - dynamic_heap_count_data.new_n_heaps = new_n_heaps; - - // store data used for decision to emit in ETW event - dynamic_heap_count_data.median_percent_overhead = median_percent_overhead; - dynamic_heap_count_data.smoothed_median_percent_overhead = smoothed_median_percent_overhead; - dynamic_heap_count_data.percent_heap_space_cost_per_heap = percent_heap_space_cost_per_heap; - dynamic_heap_count_data.overhead_reduction_per_step_up = overhead_reduction_per_step_up; - dynamic_heap_count_data.overhead_increase_per_step_down = overhead_increase_per_step_down; - dynamic_heap_count_data.space_cost_increase_per_step_up = space_cost_increase_per_step_up; - dynamic_heap_count_data.space_cost_decrease_per_step_down = space_cost_decrease_per_step_down; - - GCEventFireHeapCountTuning_V1( - (uint16_t)dynamic_heap_count_data.new_n_heaps, - (uint64_t)VolatileLoad(&settings.gc_index), - dynamic_heap_count_data.median_percent_overhead, - dynamic_heap_count_data.smoothed_median_percent_overhead, - dynamic_heap_count_data.overhead_reduction_per_step_up, - dynamic_heap_count_data.overhead_increase_per_step_down, - dynamic_heap_count_data.space_cost_increase_per_step_up, - dynamic_heap_count_data.space_cost_decrease_per_step_down - ); - - if (new_n_heaps != n_heaps) + if (dynamic_heap_count_data.new_n_heaps != n_heaps) + { + dprintf (6666, ("prep to change from %d to %d", n_heaps, dynamic_heap_count_data.new_n_heaps)); + if (!prepare_to_change_heap_count (dynamic_heap_count_data.new_n_heaps)) { - // can't have threads allocating while we change the number of heaps - GCToEEInterface::SuspendEE(SUSPEND_FOR_GC_PREP); - - if (gc_heap::background_running_p()) - { - // background GC is running - reset the new heap count - dynamic_heap_count_data.new_n_heaps = n_heaps; - - GCToEEInterface::RestartEE(TRUE); - } + // we don't have sufficient resources - reset the new heap count + dynamic_heap_count_data.new_n_heaps = n_heaps; } } if (dynamic_heap_count_data.new_n_heaps == n_heaps) { // heap count stays the same, no work to do - dprintf (6666, ("heap count stays the same, no work to do %d == %d", dynamic_heap_count_data.new_n_heaps, n_heaps)); + dynamic_heap_count_data.prev_num_completed_gcs = get_num_completed_gcs (); + dynamic_heap_count_data.should_change_heap_count = false; - // come back after 3 GCs to reconsider - prev_change_heap_count_gc_index = settings.gc_index; + dprintf (6666, ("heap count stays the same %d, no work to do, set prev completed to %Id", dynamic_heap_count_data.new_n_heaps, dynamic_heap_count_data.prev_num_completed_gcs)); return; } - if (GCScan::GetGcRuntimeStructuresValid()) + int new_n_heaps = dynamic_heap_count_data.new_n_heaps; + + assert (!(dynamic_heap_count_data.init_only_p)); + { + // At this point we are guaranteed to be able to change the heap count to the new one. + // Change the heap count for joins here because we will need to join new_n_heaps threads together. + dprintf (9999, ("changing join hp %d->%d", n_heaps, new_n_heaps)); + int max_threads_to_wake = max (n_heaps, new_n_heaps); + gc_t_join.update_n_threads (max_threads_to_wake); + // make sure the other gc threads cannot see this as a request to GC assert (dynamic_heap_count_data.new_n_heaps != n_heaps); + + if (n_heaps < new_n_heaps) + { + int saved_idle_thread_count = dynamic_heap_count_data.idle_thread_count; + Interlocked::ExchangeAdd (&dynamic_heap_count_data.idle_thread_count, (n_heaps - new_n_heaps)); + dprintf (9999, ("GC thread %d setting idle events for h%d-h%d, total idle %d -> %d", heap_number, n_heaps, (new_n_heaps - 1), + saved_idle_thread_count, VolatileLoadWithoutBarrier (&dynamic_heap_count_data.idle_thread_count))); + + for (int heap_idx = n_heaps; heap_idx < new_n_heaps; heap_idx++) + { + g_heaps[heap_idx]->gc_idle_thread_event.Set(); +#ifdef BACKGROUND_GC + g_heaps[heap_idx]->bgc_idle_thread_event.Set(); +#endif //BACKGROUND_GC + } + } + gc_start_event.Set(); } int old_n_heaps = n_heaps; + (dynamic_heap_count_data.heap_count_change_count)++; change_heap_count (dynamic_heap_count_data.new_n_heaps); GCToEEInterface::RestartEE(TRUE); - prev_change_heap_count_gc_index = settings.gc_index; + dprintf (9999, ("h0 restarted EE")); // we made changes to the heap count that will change the overhead, // so change the smoothed overhead to reflect that - int new_n_heaps = n_heaps; - dynamic_heap_count_data.smoothed_median_percent_overhead = dynamic_heap_count_data.smoothed_median_percent_overhead/new_n_heaps*old_n_heaps; + dynamic_heap_count_data.smoothed_median_throughput_cost_percent = dynamic_heap_count_data.smoothed_median_throughput_cost_percent / n_heaps * old_n_heaps; + + dprintf (6666, ("h0 finished changing, set should change to false!")); + dynamic_heap_count_data.should_change_heap_count = false; } bool gc_heap::prepare_to_change_heap_count (int new_n_heaps) { - dprintf (6666, ("trying to change heap count %d -> %d", n_heaps, new_n_heaps)); + dprintf (9999, ("trying to change heap count %d -> %d", n_heaps, new_n_heaps)); // use this variable for clarity - n_heaps will change during the transition int old_n_heaps = n_heaps; @@ -25371,6 +25644,17 @@ bool gc_heap::prepare_to_change_heap_count (int new_n_heaps) } } + // Before we look at whether we have sufficient regions we should return regions that should be deleted to free + // so we don't lose them when we decommission heaps. We could do this for only heaps that we are about + // to decomission. But it's better to do this for all heaps because we don't need to worry about adding them to the + // heaps remain (freeable uoh/soh regions) and we get rid of regions with the heap_segment_flags_uoh_delete flag + // because background_delay_delete_uoh_segments makes the assumption it can't be the start region. + for (int i = 0; i < old_n_heaps; i++) + { + gc_heap* hp = g_heaps[i]; + hp->delay_free_segments (); + } + // if we want to increase the number of heaps, we have to make sure we can give // each heap a region for each generation. If we cannot do that, we have to give up ptrdiff_t region_count_in_gen[total_generation_count]; @@ -25451,39 +25735,34 @@ bool gc_heap::prepare_to_change_heap_count (int new_n_heaps) bool gc_heap::change_heap_count (int new_n_heaps) { + dprintf (9999, ("BEG heap%d changing %d->%d", heap_number, n_heaps, new_n_heaps)); + // use this variable for clarity - n_heaps will change during the transition int old_n_heaps = n_heaps; + bool init_only_p = dynamic_heap_count_data.init_only_p; - if (heap_number == 0) { - if (!prepare_to_change_heap_count (new_n_heaps)) - { - // we don't have sufficient resources - reset the new heap count - dynamic_heap_count_data.new_n_heaps = n_heaps; - } - } - - if (GCScan::GetGcRuntimeStructuresValid()) - { - // join for sufficient resources decision gc_t_join.join (this, gc_join_merge_temp_fl); if (gc_t_join.joined ()) { + // BGC is not running, we can safely change its join's heap count. +#ifdef BACKGROUND_GC + bgc_t_join.update_n_threads (new_n_heaps); +#endif //BACKGROUND_GC + + dynamic_heap_count_data.init_only_p = false; + dprintf (9999, ("in change h%d resetting gc_start, update bgc join to %d heaps", heap_number, new_n_heaps)); gc_start_event.Reset(); gc_t_join.restart (); } } - // gc_heap::n_heaps may have changed by now, compare to the snapshot *before* the join - if (dynamic_heap_count_data.new_n_heaps == old_n_heaps) - { - dprintf (6666, ("failed to change heap count, no work to do %d == %d", dynamic_heap_count_data.new_n_heaps, old_n_heaps)); - return false; - } + assert (dynamic_heap_count_data.new_n_heaps != old_n_heaps); + + dprintf (9999, ("Waiting h0 heap%d changing %d->%d", heap_number, n_heaps, new_n_heaps)); if (heap_number == 0) { - // after having checked for sufficient resources, we are now committed to actually change the heap count dprintf (3, ("switching heap count from %d to %d heaps", old_n_heaps, new_n_heaps)); // spread finalization data out to heaps coming into service @@ -25504,17 +25783,23 @@ bool gc_heap::change_heap_count (int new_n_heaps) from_heap_number = (from_heap_number + 1) % old_n_heaps; } - // prepare for the switch by fixing the allocation contexts on the old heaps, + // prepare for the switch by fixing the allocation contexts on the old heaps, unify the gen0_bricks_cleared flag, // and setting the survived size for the existing regions to their allocated size + BOOL unified_gen0_bricks_cleared = TRUE; for (int i = 0; i < old_n_heaps; i++) { gc_heap* hp = g_heaps[i]; - if (GCScan::GetGcRuntimeStructuresValid()) + if (!init_only_p) { hp->fix_allocation_contexts (TRUE); } + if (unified_gen0_bricks_cleared && (hp->gen0_bricks_cleared == FALSE)) + { + unified_gen0_bricks_cleared = FALSE; + } + for (int gen_idx = 0; gen_idx < total_generation_count; gen_idx++) { generation* gen = hp->generation_of (gen_idx); @@ -25614,7 +25899,7 @@ bool gc_heap::change_heap_count (int new_n_heaps) hpd->free_regions[kind].transfer_regions(&hp->free_regions[kind]); } } - // update number of heaps + dprintf (9999, ("h%d changing %d->%d", heap_number, n_heaps, new_n_heaps)); n_heaps = new_n_heaps; // even out the regions over the current number of heaps @@ -25625,6 +25910,8 @@ bool gc_heap::change_heap_count (int new_n_heaps) { gc_heap* hp = g_heaps[i]; + hp->gen0_bricks_cleared = unified_gen0_bricks_cleared; + // establish invariants regarding the ephemeral segment generation* gen0 = hp->generation_of (0); if ((hp->ephemeral_heap_segment == nullptr) || @@ -25653,7 +25940,9 @@ bool gc_heap::change_heap_count (int new_n_heaps) } } - if (GCScan::GetGcRuntimeStructuresValid()) + dprintf (3, ("individual heap%d changing %d->%d", heap_number, n_heaps, new_n_heaps)); + + if (!init_only_p) { // join for rethreading the free lists gc_t_join.join (this, gc_join_merge_temp_fl); @@ -25665,7 +25954,11 @@ bool gc_heap::change_heap_count (int new_n_heaps) // rethread the free lists for (int gen_idx = 0; gen_idx < total_generation_count; gen_idx++) { - rethread_fl_items (gen_idx); + if (heap_number < old_n_heaps) + { + dprintf (3, ("h%d calling per heap work!", heap_number)); + rethread_fl_items (gen_idx); + } // join for merging the free lists gc_t_join.join (this, gc_join_merge_temp_fl); @@ -25676,18 +25969,14 @@ bool gc_heap::change_heap_count (int new_n_heaps) gc_t_join.restart (); } } +#ifdef BACKGROUND_GC // there should be no items in the bgc_alloc_lock bgc_alloc_lock->check(); +#endif //BACKGROUND_GC } if (heap_number == 0) { - // udate the number of heaps in the joins - gc_t_join.update_n_threads(new_n_heaps); - #ifdef BACKGROUND_GC - bgc_t_join.update_n_threads(new_n_heaps); - #endif //BACKGROUND_GC - // compute the total budget per generation over the old heaps // and figure out what the new budget per heap is ptrdiff_t budget_per_heap[total_generation_count]; @@ -25747,21 +26036,50 @@ bool gc_heap::change_heap_count (int new_n_heaps) hp->decommission_heap(); } - if (GCScan::GetGcRuntimeStructuresValid()) + if (!init_only_p) { // make sure no allocation contexts point to idle heaps fix_allocation_contexts_heaps(); } - if (old_n_heaps < new_n_heaps) + dynamic_heap_count_data.last_n_heaps = old_n_heaps; + } + + // join the last time to change the heap count again if needed. + if (new_n_heaps < old_n_heaps) + { + gc_t_join.join (this, gc_join_merge_temp_fl); + if (gc_t_join.joined ()) { - // wake up threads for the new heaps - gc_idle_thread_event.Set(); + dprintf (9999, ("now changing the join heap count to the smaller one %d", new_n_heaps)); + gc_t_join.update_n_threads (new_n_heaps); + + gc_t_join.restart (); } } return true; } + +size_t gc_heap::get_msl_wait_time() +{ + assert (dynamic_adaptation_mode == dynamic_adaptation_to_application_sizes); + + size_t msl_wait_since_pause = 0; + + for (int i = 0; i < n_heaps; i++) + { + gc_heap* hp = g_heaps[i]; + + msl_wait_since_pause += hp->more_space_lock_soh.msl_wait_time; + hp->more_space_lock_soh.msl_wait_time = 0; + + msl_wait_since_pause += hp->more_space_lock_uoh.msl_wait_time; + hp->more_space_lock_uoh.msl_wait_time = 0; + } + + return msl_wait_since_pause; +} #endif //DYNAMIC_HEAP_COUNT #endif //USE_REGIONS @@ -32805,17 +33123,17 @@ void gc_heap::plan_phase (int condemned_gen_number) } else { - dprintf (2, ("gen2 didn't grow (end seg alloc: %zd, , condemned alloc: %zd, gen1 c alloc: %zd", + dprintf (1, ("gen2 didn't grow (end seg alloc: %zd, , condemned alloc: %zd, gen1 c alloc: %zd", end_seg_allocated, condemned_allocated, generation_condemned_allocated (generation_of (max_generation - 1)))); } - dprintf (1, ("older gen's free alloc: %zd->%zd, seg alloc: %zd->%zd, condemned alloc: %zd->%zd", + dprintf (2, ("older gen's free alloc: %zd->%zd, seg alloc: %zd->%zd, condemned alloc: %zd->%zd", r_older_gen_free_list_allocated, generation_free_list_allocated (older_gen), r_older_gen_end_seg_allocated, generation_end_seg_allocated (older_gen), r_older_gen_condemned_allocated, generation_condemned_allocated (older_gen))); - dprintf (1, ("this GC did %zd free list alloc(%zd bytes free space rejected)", + dprintf (2, ("this GC did %zd free list alloc(%zd bytes free space rejected)", free_list_allocated, rejected_free_space)); maxgen_size_increase* maxgen_size_info = &(get_gc_data_per_heap()->maxgen_size_info); @@ -38908,9 +39226,9 @@ void gc_heap::bgc_thread_function() { // this is the case where we have more background GC threads than heaps // - wait until we're told to continue... - dprintf (3, ("BGC thread %d idle", heap_number)); - gc_idle_thread_event.Wait(INFINITE, FALSE); - dprintf (3, ("BGC thread %d waking from idle", heap_number)); + dprintf (9999, ("BGC thread %d idle (%d heaps) (gc%Id)", heap_number, n_heaps, VolatileLoadWithoutBarrier (&settings.gc_index))); + bgc_idle_thread_event.Wait(INFINITE, FALSE); + dprintf (9999, ("BGC thread %d waking from idle (%d heaps) (gc%Id)", heap_number, n_heaps, VolatileLoadWithoutBarrier (&settings.gc_index))); continue; } #endif //DYNAMIC_HEAP_COUNT @@ -38982,7 +39300,7 @@ void gc_heap::bgc_thread_function() dprintf (SPINLOCK_LOG, ("bgc Lgc")); leave_spin_lock (&gc_lock); #ifdef MULTIPLE_HEAPS - dprintf(1, ("End of BGC - starting all BGC threads")); + dprintf(1, ("End of BGC")); bgc_t_join.restart(); #endif //MULTIPLE_HEAPS } @@ -42859,6 +43177,9 @@ bool gc_heap::init_dynamic_data() { process_start_time = now; smoothed_desired_total[0] = dynamic_data_of (0)->min_size * n_heaps; +#ifdef DYNAMIC_HEAP_COUNT + last_suspended_end_time = now; +#endif //DYNAMIC_HEAP_COUNT #ifdef HEAP_BALANCE_INSTRUMENTATION last_gc_end_time_us = now; dprintf (HEAP_BALANCE_LOG, ("qpf=%zd, start: %zd(%d)", qpf, start_raw_ts, now)); @@ -47957,6 +48278,7 @@ HRESULT GCHeap::Initialize() uint32_t nhp = 1; uint32_t nhp_from_config = 0; + uint32_t max_nhp_from_config = (uint32_t)GCConfig::GetMaxHeapCount(); #ifndef MULTIPLE_HEAPS GCConfig::SetServerGC(false); @@ -48151,6 +48473,10 @@ HRESULT GCHeap::Initialize() #ifdef MULTIPLE_HEAPS assert (nhp <= g_num_processors); + if (max_nhp_from_config) + { + nhp = min (nhp, max_nhp_from_config); + } gc_heap::n_max_heaps = nhp; gc_heap::n_heaps = nhp; hr = gc_heap::initialize_gc (seg_size, large_seg_size, pin_seg_size, nhp); @@ -48301,9 +48627,32 @@ HRESULT GCHeap::Initialize() { // start with only 1 heap gc_heap::smoothed_desired_total[0] /= gc_heap::n_heaps; - gc_heap::g_heaps[0]->change_heap_count (1); + int initial_n_heaps = 1; + dprintf (9999, ("gc_heap::n_heaps is %d, initial %d", gc_heap::n_heaps, initial_n_heaps)); + + { + if (!gc_heap::prepare_to_change_heap_count (initial_n_heaps)) + { + // we don't have sufficient resources. + return E_FAIL; + } + + gc_heap::dynamic_heap_count_data.new_n_heaps = initial_n_heaps; + gc_heap::dynamic_heap_count_data.idle_thread_count = 0; + gc_heap::dynamic_heap_count_data.init_only_p = true; + + int max_threads_to_wake = max (gc_heap::n_heaps, initial_n_heaps); + gc_t_join.update_n_threads (max_threads_to_wake); + gc_heap::gc_start_event.Set (); + } + + gc_heap::g_heaps[0]->change_heap_count (initial_n_heaps); + gc_heap::gc_start_event.Reset (); + + // This needs to be different from our initial heap count so we can make sure we wait for + // the idle threads correctly in gc_thread_function. + gc_heap::dynamic_heap_count_data.last_n_heaps = 0; } - gc_heap::dynamic_heap_count_data.new_n_heaps = gc_heap::n_heaps; #endif //DYNAMIC_HEAP_COUNT GCScan::GcRuntimeStructuresValid (TRUE); @@ -49875,10 +50224,16 @@ void gc_heap::do_post_gc() } #endif //BGC_SERVO_TUNING +#ifdef BACKGROUND_GC + const char* str_gc_type = (settings.concurrent ? "BGC" : (gc_heap::background_running_p () ? "FGC" : "NGC")); +#else + const char* str_gc_type = "NGC"; +#endif //BACKGROUND_GC + dprintf (1, (ThreadStressLog::gcDetailedEndMsg(), - VolatileLoad(&settings.gc_index), - dd_collection_count(hp->dynamic_data_of(0)), - (size_t)(GetHighPrecisionTimeStamp() / 1000), + VolatileLoad (&settings.gc_index), + dd_collection_count (hp->dynamic_data_of (0)), + (size_t)(GetHighPrecisionTimeStamp () / 1000), settings.condemned_generation, (settings.concurrent ? "BGC" : (gc_heap::background_running_p() ? "FGC" : "NGC")), (settings.compaction ? "C" : "S"), diff --git a/src/coreclr/gc/gcconfig.h b/src/coreclr/gc/gcconfig.h index 72786778d5a978..aeded6bc97f17f 100644 --- a/src/coreclr/gc/gcconfig.h +++ b/src/coreclr/gc/gcconfig.h @@ -83,6 +83,7 @@ class GCConfigStringHolder INT_CONFIG (BGCSpinCount, "BGCSpinCount", NULL, 140, "Specifies the bgc spin count") \ INT_CONFIG (BGCSpin, "BGCSpin", NULL, 2, "Specifies the bgc spin time") \ INT_CONFIG (HeapCount, "GCHeapCount", "System.GC.HeapCount", 0, "Specifies the number of server GC heaps") \ + INT_CONFIG (MaxHeapCount, "GCMaxHeapCount", "System.GC.MaxHeapCount", 0, "Specifies the max number of server GC heaps to adjust to") \ INT_CONFIG (Gen0Size, "GCgen0size", NULL, 0, "Specifies the smallest gen0 budget") \ INT_CONFIG (SegmentSize, "GCSegmentSize", NULL, 0, "Specifies the managed heap segment size") \ INT_CONFIG (LatencyMode, "GCLatencyMode", NULL, -1, "Specifies the GC latency mode - batch, interactive or low latency (note that the same " \ diff --git a/src/coreclr/gc/gcpriv.h b/src/coreclr/gc/gcpriv.h index 1a73add83b429f..cce6c5ee28adf0 100644 --- a/src/coreclr/gc/gcpriv.h +++ b/src/coreclr/gc/gcpriv.h @@ -402,8 +402,6 @@ struct GCDebugSpinLock { #if defined(DYNAMIC_HEAP_COUNT) // time in microseconds we wait for the more space lock uint64_t msl_wait_time; - // number of times we wait for the more space lock - uint64_t msl_wait_count; #endif //DYNAMIC_HEAP_COUNT GCDebugSpinLock() @@ -415,7 +413,7 @@ struct GCDebugSpinLock { , num_switch_thread(0), num_wait_longer(0), num_switch_thread_w(0), num_disable_preemptive_w(0) #endif #if defined(DYNAMIC_HEAP_COUNT) - , msl_wait_time(0), msl_wait_count(0) + , msl_wait_time(0) #endif //DYNAMIC_HEAP_COUNT { } @@ -1148,15 +1146,12 @@ class dynamic_data // // The following 3 fields are updated at the beginning of each GC, if that GC condemns this generation. // - // The number of GC that condemned this generation. The only difference between this - // and collection_count is just that collection_count is maintained for all physical generations - // (currently there are 5) whereas this is only updated for logical generations (there are 3). - size_t gc_clock; - uint64_t time_clock; //time when this gc started + size_t gc_clock; // the gc index + uint64_t time_clock; // time when this gc started uint64_t previous_time_clock; // time when previous gc started // Updated at the end of a GC, if that GC condemns this generation. - size_t gc_elapsed_time; // Time it took for the gc to complete + size_t gc_elapsed_time; // time it took for the gc to complete // // The following fields (and fields in sdata) are initialized during GC init time and do not change. @@ -1495,6 +1490,8 @@ class mark_queue_t void verify_empty(); }; +float median_of_3 (float a, float b, float c); + //class definition of the internal class class gc_heap { @@ -2422,6 +2419,7 @@ class gc_heap #ifndef USE_REGIONS PER_HEAP_METHOD void rearrange_heap_segments(BOOL compacting); #endif //!USE_REGIONS + PER_HEAP_METHOD void delay_free_segments(); PER_HEAP_ISOLATED_METHOD void distribute_free_regions(); #ifdef BACKGROUND_GC PER_HEAP_ISOLATED_METHOD void reset_write_watch_for_gc_heap(void* base_address, size_t region_size); @@ -2597,11 +2595,17 @@ class gc_heap // re-initialize a heap in preparation to putting it back into service PER_HEAP_METHOD void recommission_heap(); + PER_HEAP_ISOLATED_METHOD size_t get_num_completed_gcs(); + + PER_HEAP_ISOLATED_METHOD int calculate_new_heap_count(); + // check if we should change the heap count PER_HEAP_METHOD void check_heap_count(); - PER_HEAP_METHOD bool prepare_to_change_heap_count (int new_n_heaps); + PER_HEAP_ISOLATED_METHOD bool prepare_to_change_heap_count (int new_n_heaps); PER_HEAP_METHOD bool change_heap_count (int new_n_heaps); + + PER_HEAP_ISOLATED_METHOD size_t get_msl_wait_time(); #endif //DYNAMIC_HEAP_COUNT #endif //USE_REGIONS @@ -3778,6 +3782,13 @@ class gc_heap PER_HEAP_FIELD_MAINTAINED mark* loh_pinned_queue; #endif //FEATURE_LOH_COMPACTION +#ifdef DYNAMIC_HEAP_COUNT + PER_HEAP_FIELD_MAINTAINED GCEvent gc_idle_thread_event; +#ifdef BACKGROUND_GC + PER_HEAP_FIELD_MAINTAINED GCEvent bgc_idle_thread_event; +#endif //BACKGROUND_GC +#endif //DYNAMIC_HEAP_COUNT + /******************************************/ // PER_HEAP_FIELD_MAINTAINED_ALLOC fields // /******************************************/ @@ -4084,7 +4095,6 @@ class gc_heap // These 2 fields' values do not change but are set/unset per GC PER_HEAP_ISOLATED_FIELD_SINGLE_GC GCEvent gc_start_event; PER_HEAP_ISOLATED_FIELD_SINGLE_GC GCEvent ee_suspend_event; - PER_HEAP_ISOLATED_FIELD_SINGLE_GC GCEvent gc_idle_thread_event; // Also updated on the heap#0 GC thread because that's where we are actually doing the decommit. PER_HEAP_ISOLATED_FIELD_SINGLE_GC BOOL gradual_decommit_in_progress_p; @@ -4163,6 +4173,10 @@ class gc_heap PER_HEAP_ISOLATED_FIELD_SINGLE_GC uint8_t* gc_high; // high end of the highest region being condemned #endif //USE_REGIONS +#ifdef STRESS_DYNAMIC_HEAP_COUNT + PER_HEAP_ISOLATED_FIELD_SINGLE_GC int heaps_in_this_gc; +#endif //STRESS_DYNAMIC_HEAP_COUNT + /**************************************************/ // PER_HEAP_ISOLATED_FIELD_SINGLE_GC_ALLOC fields // /**************************************************/ @@ -4261,37 +4275,65 @@ class gc_heap #endif //USE_REGIONS #ifdef DYNAMIC_HEAP_COUNT + // Sample collection - + // + // For every GC, we collect the msl wait time + GC pause duration info and use both to calculate the + // throughput cost percentage. We will also be using the wait time and the GC pause duration separately + // for other purposes in the future. + // + // For all gen2 GCs we also keep a separate array currently just for the GC cost. This serves as a backstop + // to smooth out the situation when we rarely pick the gen2 GCs in the first array. struct dynamic_heap_count_data_t { static const int sample_size = 3; struct sample { - uint64_t elapsed_between_gcs; // time between gcs in microseconds - uint64_t gc_elapsed_time; // time the gc took - uint64_t soh_msl_wait_time; // time the allocator spent waiting for the soh msl lock - uint64_t uoh_msl_wait_time; // time the allocator spent waiting for the uoh msl lock - size_t allocating_thread_count;// number of allocating threads - size_t heap_size; + uint64_t elapsed_between_gcs; // time between gcs in microseconds (this should really be between_pauses) + uint64_t gc_pause_time; // pause time for this GC + uint64_t msl_wait_time; }; - unsigned sample_index; + uint32_t sample_index; sample samples[sample_size]; + size_t prev_num_completed_gcs; + + uint32_t gen2_sample_index; + // This is (gc_elapsed_time / time inbetween this and the last gen2 GC) + float gen2_gc_percents[sample_size]; - float median_percent_overhead; // estimated overhead of allocator + gc - float smoothed_median_percent_overhead; // exponentially smoothed version - float percent_heap_space_cost_per_heap; // percent space cost of adding a heap - float overhead_reduction_per_step_up; // percentage effect on overhead of increasing heap count - float overhead_increase_per_step_down; // percentage effect on overhead of decreasing heap count - float space_cost_increase_per_step_up; // percentage effect on space of increasing heap count - float space_cost_decrease_per_step_down;// percentage effect on space of decreasing heap count + float median_throughput_cost_percent; // estimated overhead of allocator + gc + float smoothed_median_throughput_cost_percent; // exponentially smoothed version + float percent_heap_space_cost_per_heap; // percent space cost of adding a heap + float tcp_reduction_per_step_up; // throughput cost percent effect of increasing heap count + float tcp_increase_per_step_down; // throughput cost percent effect of decreasing heap count + float scp_increase_per_step_up; // space cost percent effect of increasing heap count + float scp_decrease_per_step_down; // space cost percent effect of decreasing heap count int new_n_heaps; + // the heap count we changed from + int last_n_heaps; + // don't start a GC till we see (n_max_heaps - new_n_heaps) number of threads idling + VOLATILE(int32_t) idle_thread_count; + bool init_only_p; + + bool should_change_heap_count; + int heap_count_to_change_to; + int heap_count_change_count; #ifdef STRESS_DYNAMIC_HEAP_COUNT int lowest_heap_with_msl_uoh; #endif //STRESS_DYNAMIC_HEAP_COUNT + + float get_median_gen2_gc_percent() + { + return median_of_3 (gen2_gc_percents[0], gen2_gc_percents[1], gen2_gc_percents[2]); + } }; PER_HEAP_ISOLATED_FIELD_MAINTAINED dynamic_heap_count_data_t dynamic_heap_count_data; + PER_HEAP_ISOLATED_FIELD_MAINTAINED uint64_t last_suspended_end_time; + // If the last full GC is blocking, this is that GC's index; for BGC, this is the settings.gc_index + // when the BGC ended. + PER_HEAP_ISOLATED_FIELD_MAINTAINED size_t gc_index_full_gc_end; #endif //DYNAMIC_HEAP_COUNT /****************************************************/ @@ -4867,7 +4909,6 @@ uint64_t& dd_previous_time_clock (dynamic_data* inst) return inst->previous_time_clock; } - inline size_t& dd_gc_clock_interval (dynamic_data* inst) { diff --git a/src/coreclr/gc/unix/gcenv.unix.cpp b/src/coreclr/gc/unix/gcenv.unix.cpp index 285b783485802a..b45cd40d8073fe 100644 --- a/src/coreclr/gc/unix/gcenv.unix.cpp +++ b/src/coreclr/gc/unix/gcenv.unix.cpp @@ -168,6 +168,17 @@ enum membarrier_cmd bool CanFlushUsingMembarrier() { + +#ifdef TARGET_ANDROID + // Avoid calling membarrier on older Android versions where membarrier + // may be barred by seccomp causing the process to be killed. + int apiLevel = android_get_device_api_level(); + if (apiLevel < __ANDROID_API_Q__) + { + return false; + } +#endif + // Starting with Linux kernel 4.14, process memory barriers can be generated // using MEMBARRIER_CMD_PRIVATE_EXPEDITED. diff --git a/src/coreclr/inc/safemath.h b/src/coreclr/inc/safemath.h index 3f6d5c5716bdb4..fcd51af3de8cb0 100644 --- a/src/coreclr/inc/safemath.h +++ b/src/coreclr/inc/safemath.h @@ -688,6 +688,10 @@ template class ClrSafeInt INDEBUG( mutable bool m_checkedOverflow; ) }; +#if defined(_MSC_VER) && defined(HOST_ARM64) // Workaround for https://github.com/dotnet/runtime/issues/93442 +#pragma optimize("", off) +#endif + template <> inline bool ClrSafeInt::multiply(int64_t lhs, int64_t rhs, int64_t &result) { @@ -874,6 +878,10 @@ inline bool ClrSafeInt::multiply(uint8_t lhs, uint8_t rhs, uint8_t &res return true; } +#if defined(_MSC_VER) && defined(HOST_ARM64) // Workaround for https://github.com/dotnet/runtime/issues/93442 +#pragma optimize("", on) +#endif + // Allows creation of a ClrSafeInt corresponding to the type of the argument. template ClrSafeInt AsClrSafeInt(T t) diff --git a/src/coreclr/jit/emitxarch.cpp b/src/coreclr/jit/emitxarch.cpp index 980d40a47ac318..1c48d1c52f0bb2 100644 --- a/src/coreclr/jit/emitxarch.cpp +++ b/src/coreclr/jit/emitxarch.cpp @@ -5485,6 +5485,13 @@ void emitter::emitInsRMW(instruction ins, emitAttr attr, GenTreeStoreInd* storeI { assert(!src->isContained()); // there must be one non-contained src + if (addr->isContained() && addr->OperIs(GT_LCL_ADDR)) + { + GenTreeLclVarCommon* lclVar = addr->AsLclVarCommon(); + emitIns_S_R(ins, attr, src->GetRegNum(), lclVar->GetLclNum(), lclVar->GetLclOffs()); + return; + } + // ind, reg id = emitNewInstrAmd(attr, offset); emitHandleMemOp(storeInd, id, emitInsModeFormat(ins, IF_ARD_RRD), ins); diff --git a/src/coreclr/jit/gentree.cpp b/src/coreclr/jit/gentree.cpp index a30c3794efdd29..7c34c51571e66d 100644 --- a/src/coreclr/jit/gentree.cpp +++ b/src/coreclr/jit/gentree.cpp @@ -18534,10 +18534,12 @@ CORINFO_CLASS_HANDLE Compiler::gtGetFieldClassHandle(CORINFO_FIELD_HANDLE fieldH { JITDUMP("Field's current class not available\n"); } + + return fieldClass; } } - return fieldClass; + return NO_CLASS_HANDLE; } //------------------------------------------------------------------------ @@ -19639,8 +19641,8 @@ GenTree* Compiler::gtNewSimdBinOpNode( } else { - assert(op2->TypeIs(type, simdBaseType, genActualType(simdBaseType)) || - (op2->TypeIs(TYP_SIMD12) && type == TYP_SIMD16)); + assert((genActualType(op2) == genActualType(type)) || (genActualType(op2) == genActualType(simdBaseType)) || + (op2->TypeIs(TYP_SIMD12) && (type == TYP_SIMD16))); } NamedIntrinsic intrinsic = NI_Illegal; diff --git a/src/coreclr/jit/lowerxarch.cpp b/src/coreclr/jit/lowerxarch.cpp index d1c3e317d83b87..150ad04a55d99f 100644 --- a/src/coreclr/jit/lowerxarch.cpp +++ b/src/coreclr/jit/lowerxarch.cpp @@ -6504,7 +6504,7 @@ void Lowering::ContainCheckStoreIndir(GenTreeStoreInd* node) case NI_AVX2_ConvertToUInt32: { // These intrinsics are "ins reg/mem, xmm" - isContainable = varTypeIsIntegral(simdBaseType); + isContainable = varTypeIsIntegral(simdBaseType) && (genTypeSize(src) == genTypeSize(node)); break; } @@ -6568,7 +6568,8 @@ void Lowering::ContainCheckStoreIndir(GenTreeStoreInd* node) size_t numArgs = hwintrinsic->GetOperandCount(); GenTree* lastOp = hwintrinsic->Op(numArgs); - isContainable = HWIntrinsicInfo::isImmOp(intrinsicId, lastOp) && lastOp->IsCnsIntOrI(); + isContainable = HWIntrinsicInfo::isImmOp(intrinsicId, lastOp) && lastOp->IsCnsIntOrI() && + (genTypeSize(simdBaseType) == genTypeSize(node)); if (isContainable && (intrinsicId == NI_SSE2_Extract)) { @@ -7956,6 +7957,9 @@ bool Lowering::IsContainableHWIntrinsicOp(GenTreeHWIntrinsic* parentNode, GenTre // The memory form of this already takes a pointer and should be treated like a MemoryLoad supportsGeneralLoads = !childNode->OperIsHWIntrinsic(); } + + supportsGeneralLoads = + supportsGeneralLoads && (genTypeSize(childNode) >= genTypeSize(parentNode->GetSimdBaseType())); break; } @@ -8199,10 +8203,12 @@ bool Lowering::IsContainableHWIntrinsicOp(GenTreeHWIntrinsic* parentNode, GenTre case NI_AVX2_BroadcastScalarToVector256: case NI_AVX512F_BroadcastScalarToVector512: { - var_types baseType = hwintrinsic->GetSimdBaseType(); - if (varTypeIsSmall(baseType)) + var_types parentBaseType = parentNode->GetSimdBaseType(); + var_types childBaseType = hwintrinsic->GetSimdBaseType(); + + if (varTypeIsSmall(parentBaseType) || (genTypeSize(parentBaseType) != genTypeSize(childBaseType))) { - // early return if the base type is not embedded broadcast compatible. + // early return if either base type is not embedded broadcast compatible. return false; } @@ -8210,7 +8216,7 @@ bool Lowering::IsContainableHWIntrinsicOp(GenTreeHWIntrinsic* parentNode, GenTre if (intrinsicId == NI_SSE3_MoveAndDuplicate) { // NI_SSE3_MoveAndDuplicate is for Vector128 only. - assert(baseType == TYP_DOUBLE); + assert(childBaseType == TYP_DOUBLE); } if (comp->compOpportunisticallyDependsOn(InstructionSet_AVX512F_VL) && @@ -8243,6 +8249,15 @@ bool Lowering::IsContainableHWIntrinsicOp(GenTreeHWIntrinsic* parentNode, GenTre case NI_AVX_BroadcastScalarToVector128: case NI_AVX_BroadcastScalarToVector256: { + var_types parentBaseType = parentNode->GetSimdBaseType(); + var_types childBaseType = hwintrinsic->GetSimdBaseType(); + + if (varTypeIsSmall(parentBaseType) || (genTypeSize(parentBaseType) != genTypeSize(childBaseType))) + { + // early return if either base type is not embedded broadcast compatible. + return false; + } + return parentNode->OperIsEmbBroadcastCompatible(); } diff --git a/src/coreclr/jit/morph.cpp b/src/coreclr/jit/morph.cpp index 153f9b8bba8a82..3deada8eec085b 100644 --- a/src/coreclr/jit/morph.cpp +++ b/src/coreclr/jit/morph.cpp @@ -10807,8 +10807,6 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node) } #if defined(TARGET_XARCH) - case NI_AVX512F_Add: - case NI_AVX512BW_Add: case NI_AVX512F_And: case NI_AVX512DQ_And: case NI_AVX512F_AndNot: @@ -10850,13 +10848,6 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node) switch (intrinsicId) { - case NI_AVX512F_Add: - case NI_AVX512BW_Add: - { - maskIntrinsicId = NI_AVX512F_AddMask; - break; - } - case NI_AVX512F_And: case NI_AVX512DQ_And: { diff --git a/src/coreclr/jit/valuenum.cpp b/src/coreclr/jit/valuenum.cpp index 2ea2ee7845b8b2..6943c3c5e07e26 100644 --- a/src/coreclr/jit/valuenum.cpp +++ b/src/coreclr/jit/valuenum.cpp @@ -2860,6 +2860,81 @@ ValueNum ValueNumStore::VNForMapPhysicalSelect( return result; } +typedef JitHashTable, bool> ValueNumSet; + +class SmallValueNumSet +{ + union { + ValueNum m_inlineElements[4]; + ValueNumSet* m_set; + }; + unsigned m_numElements = 0; + +public: + unsigned Count() + { + return m_numElements; + } + + template + void ForEach(Func func) + { + if (m_numElements <= ArrLen(m_inlineElements)) + { + for (unsigned i = 0; i < m_numElements; i++) + { + func(m_inlineElements[i]); + } + } + else + { + for (ValueNum vn : ValueNumSet::KeyIteration(m_set)) + { + func(vn); + } + } + } + + void Add(Compiler* comp, ValueNum vn) + { + if (m_numElements <= ArrLen(m_inlineElements)) + { + for (unsigned i = 0; i < m_numElements; i++) + { + if (m_inlineElements[i] == vn) + { + return; + } + } + + if (m_numElements < ArrLen(m_inlineElements)) + { + m_inlineElements[m_numElements] = vn; + m_numElements++; + } + else + { + ValueNumSet* set = new (comp, CMK_ValueNumber) ValueNumSet(comp->getAllocator(CMK_ValueNumber)); + for (ValueNum oldVn : m_inlineElements) + { + set->Set(oldVn, true); + } + + set->Set(vn, true); + + m_set = set; + m_numElements++; + assert(m_numElements == set->GetCount()); + } + } + else + { + m_set->Set(vn, true, ValueNumSet::SetKind::Overwrite); + m_numElements = m_set->GetCount(); + } + } +}; + //------------------------------------------------------------------------------ // VNForMapSelectInner: Select value from a map and record loop memory dependencies. // @@ -2874,10 +2949,10 @@ ValueNum ValueNumStore::VNForMapPhysicalSelect( // ValueNum ValueNumStore::VNForMapSelectInner(ValueNumKind vnk, var_types type, ValueNum map, ValueNum index) { - int budget = m_mapSelectBudget; - bool usedRecursiveVN = false; - ArrayStack memoryDependencies(m_alloc); - ValueNum result = VNForMapSelectWork(vnk, type, map, index, &budget, &usedRecursiveVN, &memoryDependencies); + int budget = m_mapSelectBudget; + bool usedRecursiveVN = false; + SmallValueNumSet memoryDependencies; + ValueNum result = VNForMapSelectWork(vnk, type, map, index, &budget, &usedRecursiveVN, memoryDependencies); // The remaining budget should always be between [0..m_mapSelectBudget] assert((budget >= 0) && (budget <= m_mapSelectBudget)); @@ -2888,11 +2963,9 @@ ValueNum ValueNumStore::VNForMapSelectInner(ValueNumKind vnk, var_types type, Va if ((m_pComp->compCurBB != nullptr) && (m_pComp->compCurTree != nullptr) && m_pComp->compCurBB->bbNatLoopNum != BasicBlock::NOT_IN_LOOP) { - for (int i = 0; i < memoryDependencies.Height(); i++) - { - m_pComp->optRecordLoopMemoryDependence(m_pComp->compCurTree, m_pComp->compCurBB, - memoryDependencies.Bottom(i)); - } + memoryDependencies.ForEach([this](ValueNum vn) { + m_pComp->optRecordLoopMemoryDependence(m_pComp->compCurTree, m_pComp->compCurBB, vn); + }); } return result; @@ -2903,19 +2976,16 @@ ValueNum ValueNumStore::VNForMapSelectInner(ValueNumKind vnk, var_types type, Va // cache entry. // // Arguments: -// alloc - Allocator to use if memory is required. -// deps - Array stack containing the memory dependencies. -// startIndex - Start index into 'deps' of memory dependencies. +// comp - Compiler instance +// set - Set of memory dependencies to store in the entry. // -void ValueNumStore::MapSelectWorkCacheEntry::SetMemoryDependencies(CompAllocator alloc, - ArrayStack& deps, - unsigned startIndex) +void ValueNumStore::MapSelectWorkCacheEntry::SetMemoryDependencies(Compiler* comp, SmallValueNumSet& set) { - m_numMemoryDependencies = deps.Height() - startIndex; + m_numMemoryDependencies = set.Count(); ValueNum* arr; if (m_numMemoryDependencies > ArrLen(m_inlineMemoryDependencies)) { - m_memoryDependencies = new (alloc) ValueNum[m_numMemoryDependencies]; + m_memoryDependencies = new (comp, CMK_ValueNumber) ValueNum[m_numMemoryDependencies]; arr = m_memoryDependencies; } @@ -2924,27 +2994,29 @@ void ValueNumStore::MapSelectWorkCacheEntry::SetMemoryDependencies(CompAllocator arr = m_inlineMemoryDependencies; } - for (unsigned i = 0; i < m_numMemoryDependencies; i++) - { - arr[i] = deps.Bottom(startIndex + i); - } + size_t i = 0; + set.ForEach([&i, arr](ValueNum vn) { + arr[i] = vn; + i++; + }); } //------------------------------------------------------------------------------ // GetMemoryDependencies: Push all of the memory dependencies cached in this -// entry into the specified array stack. +// entry into the specified set. // // Arguments: -// result - Array stack to push memory dependencies into. +// comp - Compiler instance +// result - Set to add memory dependencies to. // -void ValueNumStore::MapSelectWorkCacheEntry::GetMemoryDependencies(ArrayStack& result) +void ValueNumStore::MapSelectWorkCacheEntry::GetMemoryDependencies(Compiler* comp, SmallValueNumSet& result) { ValueNum* arr = m_numMemoryDependencies <= ArrLen(m_inlineMemoryDependencies) ? m_inlineMemoryDependencies : m_memoryDependencies; for (unsigned i = 0; i < m_numMemoryDependencies; i++) { - result.Push(arr[i]); + result.Add(comp, arr[i]); } } @@ -2959,7 +3031,7 @@ void ValueNumStore::MapSelectWorkCacheEntry::GetMemoryDependencies(ArrayStack* memoryDependencies) +ValueNum ValueNumStore::VNForMapSelectWork(ValueNumKind vnk, + var_types type, + ValueNum map, + ValueNum index, + int* pBudget, + bool* pUsedRecursiveVN, + SmallValueNumSet& memoryDependencies) { TailCall: // This label allows us to directly implement a tail call by setting up the arguments, and doing a goto to here. @@ -2997,13 +3069,12 @@ ValueNum ValueNumStore::VNForMapSelectWork(ValueNumKind vnk, assert(selLim == 0 || m_numMapSels < selLim); #endif - int firstMemoryDependency = memoryDependencies->Height(); MapSelectWorkCacheEntry entry; VNDefFuncApp<2> fstruct(VNF_MapSelect, map, index); if (GetMapSelectWorkCache()->Lookup(fstruct, &entry)) { - entry.GetMemoryDependencies(*memoryDependencies); + entry.GetMemoryDependencies(m_pComp, memoryDependencies); return entry.Result; } @@ -3029,6 +3100,8 @@ ValueNum ValueNumStore::VNForMapSelectWork(ValueNumKind vnk, return RecursiveVN; } + SmallValueNumSet recMemoryDependencies; + VNFuncApp funcApp; if (GetVNFunc(map, &funcApp)) { @@ -3047,7 +3120,7 @@ ValueNum ValueNumStore::VNForMapSelectWork(ValueNumKind vnk, funcApp.m_args[0], map, funcApp.m_args[1], funcApp.m_args[2], index, funcApp.m_args[2]); #endif - memoryDependencies->Push(funcApp.m_args[0]); + memoryDependencies.Add(m_pComp, funcApp.m_args[0]); return funcApp.m_args[2]; } @@ -3191,7 +3264,7 @@ ValueNum ValueNumStore::VNForMapSelectWork(ValueNumKind vnk, bool allSame = true; ValueNum argRest = phiFuncApp.m_args[1]; ValueNum sameSelResult = VNForMapSelectWork(vnk, type, phiArgVN, index, pBudget, - pUsedRecursiveVN, memoryDependencies); + pUsedRecursiveVN, recMemoryDependencies); // It is possible that we just now exceeded our budget, if so we need to force an early exit // and stop calling VNForMapSelectWork @@ -3233,7 +3306,7 @@ ValueNum ValueNumStore::VNForMapSelectWork(ValueNumKind vnk, { bool usedRecursiveVN = false; ValueNum curResult = VNForMapSelectWork(vnk, type, phiArgVN, index, pBudget, - &usedRecursiveVN, memoryDependencies); + &usedRecursiveVN, recMemoryDependencies); *pUsedRecursiveVN |= usedRecursiveVN; if (sameSelResult == ValueNumStore::RecursiveVN) @@ -3261,11 +3334,14 @@ ValueNum ValueNumStore::VNForMapSelectWork(ValueNumKind vnk, if (!*pUsedRecursiveVN) { entry.Result = sameSelResult; - entry.SetMemoryDependencies(m_alloc, *memoryDependencies, firstMemoryDependency); + entry.SetMemoryDependencies(m_pComp, recMemoryDependencies); GetMapSelectWorkCache()->Set(fstruct, entry); } + recMemoryDependencies.ForEach( + [this, &memoryDependencies](ValueNum vn) { memoryDependencies.Add(m_pComp, vn); }); + return sameSelResult; } // Otherwise, fall through to creating the select(phi(m1, m2), x) function application. @@ -3294,11 +3370,13 @@ ValueNum ValueNumStore::VNForMapSelectWork(ValueNumKind vnk, fapp->m_args[1] = fstruct.m_args[1]; entry.Result = c->m_baseVN + offsetWithinChunk; - entry.SetMemoryDependencies(m_alloc, *memoryDependencies, firstMemoryDependency); + entry.SetMemoryDependencies(m_pComp, recMemoryDependencies); GetMapSelectWorkCache()->Set(fstruct, entry); } + recMemoryDependencies.ForEach([this, &memoryDependencies](ValueNum vn) { memoryDependencies.Add(m_pComp, vn); }); + return entry.Result; } @@ -7891,7 +7969,7 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunBinary(var_types type, #endif { // Handle `x ^ x == 0` - return arg0VN; + return VNZeroForType(type); } default: diff --git a/src/coreclr/jit/valuenum.h b/src/coreclr/jit/valuenum.h index 8417579ea2ff7a..04fed7bfbc1f67 100644 --- a/src/coreclr/jit/valuenum.h +++ b/src/coreclr/jit/valuenum.h @@ -684,13 +684,13 @@ class ValueNumStore ValueNum VNForMapSelectInner(ValueNumKind vnk, var_types type, ValueNum map, ValueNum index); // A method that does the work for VNForMapSelect and may call itself recursively. - ValueNum VNForMapSelectWork(ValueNumKind vnk, - var_types type, - ValueNum map, - ValueNum index, - int* pBudget, - bool* pUsedRecursiveVN, - ArrayStack* loopMemoryDependencies); + ValueNum VNForMapSelectWork(ValueNumKind vnk, + var_types type, + ValueNum map, + ValueNum index, + int* pBudget, + bool* pUsedRecursiveVN, + class SmallValueNumSet& loopMemoryDependencies); // A specialized version of VNForFunc that is used for VNF_MapStore and provides some logging when verbose is set ValueNum VNForMapStore(ValueNum map, ValueNum index, ValueNum value); @@ -1821,8 +1821,8 @@ class ValueNumStore public: ValueNum Result; - void SetMemoryDependencies(CompAllocator alloc, ArrayStack& deps, unsigned startIndex); - void GetMemoryDependencies(ArrayStack& deps); + void SetMemoryDependencies(Compiler* comp, class SmallValueNumSet& deps); + void GetMemoryDependencies(Compiler* comp, class SmallValueNumSet& deps); }; typedef JitHashTable, VNDefFuncAppKeyFuncs<2>, MapSelectWorkCacheEntry> MapSelectWorkCache; diff --git a/src/coreclr/nativeaot/BuildIntegration/Microsoft.NETCore.Native.Publish.targets b/src/coreclr/nativeaot/BuildIntegration/Microsoft.NETCore.Native.Publish.targets index caee0777b30ddd..da6c90642f6f13 100644 --- a/src/coreclr/nativeaot/BuildIntegration/Microsoft.NETCore.Native.Publish.targets +++ b/src/coreclr/nativeaot/BuildIntegration/Microsoft.NETCore.Native.Publish.targets @@ -95,7 +95,10 @@ + + diff --git a/src/coreclr/nativeaot/Directory.Build.props b/src/coreclr/nativeaot/Directory.Build.props index ebfa725e4efd2c..005d6ae997adab 100644 --- a/src/coreclr/nativeaot/Directory.Build.props +++ b/src/coreclr/nativeaot/Directory.Build.props @@ -25,6 +25,9 @@ false v4.0.30319 + + $(ProductVersion) + $(ProductVersion) $(NoWarn),0419,0649,CA2249,CA1830 diff --git a/src/coreclr/nativeaot/Runtime.Base/src/System/Runtime/InternalCalls.cs b/src/coreclr/nativeaot/Runtime.Base/src/System/Runtime/InternalCalls.cs index 3c9d6c86ffc323..5c11243bbad99b 100644 --- a/src/coreclr/nativeaot/Runtime.Base/src/System/Runtime/InternalCalls.cs +++ b/src/coreclr/nativeaot/Runtime.Base/src/System/Runtime/InternalCalls.cs @@ -62,12 +62,12 @@ internal static class InternalCalls [RuntimeExport("RhCollect")] internal static void RhCollect(int generation, InternalGCCollectionMode mode, bool lowMemoryP = false) { - RhpCollect(generation, mode, lowMemoryP); + RhpCollect(generation, mode, lowMemoryP ? Interop.BOOL.TRUE : Interop.BOOL.FALSE); } [DllImport(Redhawk.BaseName)] [UnmanagedCallConv(CallConvs = new Type[] { typeof(CallConvCdecl) })] - private static extern void RhpCollect(int generation, InternalGCCollectionMode mode, bool lowMemoryP); + private static extern void RhpCollect(int generation, InternalGCCollectionMode mode, Interop.BOOL lowMemoryP); [RuntimeExport("RhGetGcTotalMemory")] internal static long RhGetGcTotalMemory() diff --git a/src/coreclr/nativeaot/Runtime/arm64/AllocFast.S b/src/coreclr/nativeaot/Runtime/arm64/AllocFast.S index 2bab323e65abca..79ffed2b05210d 100644 --- a/src/coreclr/nativeaot/Runtime/arm64/AllocFast.S +++ b/src/coreclr/nativeaot/Runtime/arm64/AllocFast.S @@ -46,7 +46,7 @@ OFFSETOF__Thread__m_alloc_context__alloc_limit = OFFSETOF__Thread__m_rgbAll add x2, x2, x12 ldr x13, [x1, #OFFSETOF__Thread__m_alloc_context__alloc_limit] cmp x2, x13 - bhi RhpNewFast_RarePath + bhi LOCAL_LABEL(RhpNewFast_RarePath) // Update the alloc pointer to account for the allocation. str x2, [x1, #OFFSETOF__Thread__m_alloc_context__alloc_ptr] @@ -57,7 +57,7 @@ OFFSETOF__Thread__m_alloc_context__alloc_limit = OFFSETOF__Thread__m_rgbAll mov x0, x12 ret -RhpNewFast_RarePath: +LOCAL_LABEL(RhpNewFast_RarePath): mov x1, #0 b C_FUNC(RhpNewObject) LEAF_END RhpNewFast, _TEXT @@ -88,12 +88,12 @@ RhpNewFast_RarePath: bl C_FUNC(RhpGcAlloc) // Set the new objects MethodTable pointer on success. - cbz x0, NewOutOfMemory + cbz x0, LOCAL_LABEL(NewOutOfMemory) POP_COOP_PINVOKE_FRAME EPILOG_RETURN -NewOutOfMemory: +LOCAL_LABEL(NewOutOfMemory): // This is the OOM failure path. We are going to tail-call to a managed helper that will throw // an out of memory exception that the caller of this allocator understands. @@ -113,7 +113,7 @@ NewOutOfMemory: movz x2, MAX_STRING_LENGTH & 0xFFFF movk x2, MAX_STRING_LENGTH >> 16, lsl 16 cmp x1, x2 - bhi StringSizeOverflow + bhi LOCAL_LABEL(StringSizeOverflow) // Compute overall allocation size (align(base size + (element size * elements), 8)). mov w2, #STRING_COMPONENT_SIZE @@ -139,7 +139,7 @@ NewOutOfMemory: add x2, x2, x12 ldr x12, [x3, #OFFSETOF__Thread__m_alloc_context__alloc_limit] cmp x2, x12 - bhi C_FUNC(RhpNewArrayRare) + bhi LOCAL_LABEL(RhNewString_Rare) // Reload new object address into r12. ldr x12, [x3, #OFFSETOF__Thread__m_alloc_context__alloc_ptr] @@ -156,7 +156,7 @@ NewOutOfMemory: ret -StringSizeOverflow: +LOCAL_LABEL(StringSizeOverflow): // We get here if the length of the final string object can not be represented as an unsigned // 32-bit value. We are going to tail-call to a managed helper that will throw // an OOM exception that the caller of this allocator understands. @@ -164,6 +164,9 @@ StringSizeOverflow: // x0 holds MethodTable pointer already mov x1, #1 // Indicate that we should throw OverflowException b C_FUNC(RhExceptionHandling_FailedAllocation) + +LOCAL_LABEL(RhNewString_Rare): + b C_FUNC(RhpNewArrayRare) LEAF_END RhNewString, _Text // Allocate one dimensional, zero based array (SZARRAY). @@ -177,7 +180,7 @@ StringSizeOverflow: // case (32 dimensional MdArray) is less than 0xffff, and thus the product fits in 64 bits. mov x2, #0x7FFFFFFF cmp x1, x2 - bhi ArraySizeOverflow + bhi LOCAL_LABEL(ArraySizeOverflow) ldrh w2, [x0, #OFFSETOF__MethodTable__m_usComponentSize] umull x2, w1, w2 @@ -204,7 +207,7 @@ StringSizeOverflow: add x2, x2, x12 ldr x12, [x3, #OFFSETOF__Thread__m_alloc_context__alloc_limit] cmp x2, x12 - bhi C_FUNC(RhpNewArrayRare) + bhi LOCAL_LABEL(RhpNewArray_Rare) // Reload new object address into x12. ldr x12, [x3, #OFFSETOF__Thread__m_alloc_context__alloc_ptr] @@ -221,7 +224,7 @@ StringSizeOverflow: ret -ArraySizeOverflow: +LOCAL_LABEL(ArraySizeOverflow): // We get here if the size of the final array object can not be represented as an unsigned // 32-bit value. We are going to tail-call to a managed helper that will throw // an overflow exception that the caller of this allocator understands. @@ -229,6 +232,9 @@ ArraySizeOverflow: // x0 holds MethodTable pointer already mov x1, #1 // Indicate that we should throw OverflowException b C_FUNC(RhExceptionHandling_FailedAllocation) + +LOCAL_LABEL(RhpNewArray_Rare): + b C_FUNC(RhpNewArrayRare) LEAF_END RhpNewArray, _TEXT // Allocate one dimensional, zero based array (SZARRAY) using the slow path that calls a runtime helper. @@ -254,12 +260,12 @@ ArraySizeOverflow: bl C_FUNC(RhpGcAlloc) // Set the new objects MethodTable pointer and length on success. - cbz x0, ArrayOutOfMemory + cbz x0, LOCAL_LABEL(ArrayOutOfMemory) POP_COOP_PINVOKE_FRAME EPILOG_RETURN -ArrayOutOfMemory: +LOCAL_LABEL(ArrayOutOfMemory): // This is the OOM failure path. We are going to tail-call to a managed helper that will throw // an out of memory exception that the caller of this allocator understands. diff --git a/src/coreclr/nativeaot/Runtime/arm64/ExceptionHandling.S b/src/coreclr/nativeaot/Runtime/arm64/ExceptionHandling.S index d0425171e1d191..7c04f15ad3b858 100644 --- a/src/coreclr/nativeaot/Runtime/arm64/ExceptionHandling.S +++ b/src/coreclr/nativeaot/Runtime/arm64/ExceptionHandling.S @@ -275,7 +275,7 @@ // where the tail-calling thread had saved LR, which may not match where we have saved LR. ldr x1, [x2, #OFFSETOF__Thread__m_pvHijackedReturnAddress] - cbz x1, NotHijacked + cbz x1, LOCAL_LABEL(NotHijacked) ldr x3, [x2, #OFFSETOF__Thread__m_ppvHijackedReturnAddressLocation] @@ -286,13 +286,13 @@ add x12, sp, #(STACKSIZEOF_ExInfo + SIZEOF__PAL_LIMITED_CONTEXT) // re-compute SP at callsite cmp x3, x12 // if (m_ppvHijackedReturnAddressLocation < SP at callsite) - blo TailCallWasHijacked + blo LOCAL_LABEL(TailCallWasHijacked) // normal case where a valid return address location is hijacked str x1, [x3] - b ClearThreadState + b LOCAL_LABEL(ClearThreadState) -TailCallWasHijacked: +LOCAL_LABEL(TailCallWasHijacked): // Abnormal case where the return address location is now invalid because we ended up here via a tail // call. In this case, our hijacked return address should be the correct caller of this method. @@ -302,13 +302,13 @@ TailCallWasHijacked: str lr, [sp, #(rsp_offsetof_Context + OFFSETOF__PAL_LIMITED_CONTEXT__LR)] str lr, [sp, #(rsp_offsetof_Context + OFFSETOF__PAL_LIMITED_CONTEXT__IP)] -ClearThreadState: +LOCAL_LABEL(ClearThreadState): // clear the Thread's hijack state str xzr, [x2, #OFFSETOF__Thread__m_ppvHijackedReturnAddressLocation] str xzr, [x2, #OFFSETOF__Thread__m_pvHijackedReturnAddress] -NotHijacked: +LOCAL_LABEL(NotHijacked): add x1, sp, #rsp_offsetof_ExInfo // x1 <- ExInfo* str xzr, [x1, #OFFSETOF__ExInfo__m_exception] // pExInfo->m_exception = null @@ -429,13 +429,13 @@ NotHijacked: add x12, x5, #OFFSETOF__Thread__m_ThreadStateFlags -ClearRetry_Catch: +LOCAL_LABEL(ClearRetry_Catch): ldxr w4, [x12] bic w4, w4, #TSF_DoNotTriggerGc stxr w6, w4, [x12] - cbz w6, ClearSuccess_Catch - b ClearRetry_Catch -ClearSuccess_Catch: + cbz w6, LOCAL_LABEL(ClearSuccess_Catch) + b LOCAL_LABEL(ClearRetry_Catch) +LOCAL_LABEL(ClearSuccess_Catch): // // set preserved regs to the values expected by the funclet @@ -487,21 +487,21 @@ ClearSuccess_Catch: ldr x3, [sp, #rsp_offset_x3] // x3 <- current ExInfo* ldr x2, [x2, #OFFSETOF__REGDISPLAY__SP] // x2 <- resume SP value -PopExInfoLoop: +LOCAL_LABEL(PopExInfoLoop): ldr x3, [x3, #OFFSETOF__ExInfo__m_pPrevExInfo] // x3 <- next ExInfo - cbz x3, DonePopping // if (pExInfo == null) { we're done } + cbz x3, LOCAL_LABEL(DonePopping) // if (pExInfo == null) { we're done } cmp x3, x2 - blt PopExInfoLoop // if (pExInfo < resume SP} { keep going } + blt LOCAL_LABEL(PopExInfoLoop) // if (pExInfo < resume SP} { keep going } -DonePopping: +LOCAL_LABEL(DonePopping): str x3, [x1, #OFFSETOF__Thread__m_pExInfoStackHead] // store the new head on the Thread PREPARE_EXTERNAL_VAR_INDIRECT_W RhpTrapThreads, 3 - tbz x3, #TrapThreadsFlags_AbortInProgress_Bit, NoAbort + tbz x3, #TrapThreadsFlags_AbortInProgress_Bit, LOCAL_LABEL(NoAbort) ldr x3, [sp, #rsp_offset_is_not_handling_thread_abort] - cbnz x3, NoAbort + cbnz x3, LOCAL_LABEL(NoAbort) // It was the ThreadAbortException, so rethrow it // reset SP @@ -510,7 +510,7 @@ DonePopping: mov sp, x2 b C_FUNC(RhpThrowHwEx) -NoAbort: +LOCAL_LABEL(NoAbort): // reset SP and jump to continuation address mov sp, x2 br x0 @@ -564,13 +564,13 @@ NoAbort: add x12, x2, #OFFSETOF__Thread__m_ThreadStateFlags -ClearRetry: +LOCAL_LABEL(ClearRetry): ldxr w4, [x12] bic w4, w4, #TSF_DoNotTriggerGc stxr w3, w4, [x12] - cbz w3, ClearSuccess - b ClearRetry -ClearSuccess: + cbz w3, LOCAL_LABEL(ClearSuccess) + b LOCAL_LABEL(ClearRetry) +LOCAL_LABEL(ClearSuccess): // // set preserved regs to the values expected by the funclet @@ -602,13 +602,13 @@ ClearSuccess: ldr x2, [sp, rsp_FinallyFunclet_offset_thread] add x12, x2, #OFFSETOF__Thread__m_ThreadStateFlags -SetRetry: +LOCAL_LABEL(SetRetry): ldxr w1, [x12] orr w1, w1, #TSF_DoNotTriggerGc stxr w3, w1, [x12] - cbz w3, SetSuccess - b SetRetry -SetSuccess: + cbz w3, LOCAL_LABEL(SetSuccess) + b LOCAL_LABEL(SetRetry) +LOCAL_LABEL(SetSuccess): ldp d8, d9, [sp, #0x00] ldp d10, d11, [sp, #0x10] @@ -707,13 +707,13 @@ SetSuccess: add x12, x5, #OFFSETOF__Thread__m_ThreadStateFlags -ClearRetry_Propagate: +LOCAL_LABEL(ClearRetry_Propagate): ldxr w4, [x12] bic w4, w4, #TSF_DoNotTriggerGc stxr w6, w4, [x12] - cbz w6, ClearSuccess_Propagate - b ClearRetry_Propagate -ClearSuccess_Propagate: + cbz w6, LOCAL_LABEL(ClearSuccess_Propagate) + b LOCAL_LABEL(ClearRetry_Propagate) +LOCAL_LABEL(ClearSuccess_Propagate): // // set preserved regs to the values expected by the funclet @@ -749,13 +749,13 @@ ClearSuccess_Propagate: ldr x3, [sp, #rsp_offset_x3] // x3 <- current ExInfo* ldr x2, [x2, #OFFSETOF__REGDISPLAY__SP] // x2 <- resume SP value -Propagate_PopExInfoLoop: +LOCAL_LABEL(Propagate_PopExInfoLoop): ldr x3, [x3, #OFFSETOF__ExInfo__m_pPrevExInfo] // x3 <- next ExInfo - cbz x3, Propagate_DonePopping // if (pExInfo == null) { we're done } + cbz x3, LOCAL_LABEL(Propagate_DonePopping) // if (pExInfo == null) { we're done } cmp x3, x2 - blt Propagate_PopExInfoLoop // if (pExInfo < resume SP} { keep going } + blt LOCAL_LABEL(Propagate_PopExInfoLoop) // if (pExInfo < resume SP} { keep going } -Propagate_DonePopping: +LOCAL_LABEL(Propagate_DonePopping): str x3, [x1, #OFFSETOF__Thread__m_pExInfoStackHead] // store the new head on the Thread // restore preemptive mode diff --git a/src/coreclr/nativeaot/Runtime/arm64/GcProbe.S b/src/coreclr/nativeaot/Runtime/arm64/GcProbe.S index e27834bae6fedd..abe7555b761134 100644 --- a/src/coreclr/nativeaot/Runtime/arm64/GcProbe.S +++ b/src/coreclr/nativeaot/Runtime/arm64/GcProbe.S @@ -127,10 +127,10 @@ NESTED_ENTRY RhpGcProbeHijack, _TEXT, NoHandler FixupHijackedCallstack PREPARE_EXTERNAL_VAR_INDIRECT_W RhpTrapThreads, 3 - tbnz x3, #TrapThreadsFlags_TrapThreads_Bit, WaitForGC + tbnz x3, #TrapThreadsFlags_TrapThreads_Bit, LOCAL_LABEL(WaitForGC) ret -WaitForGC: +LOCAL_LABEL(WaitForGC): orr x12, x12, DEFAULT_FRAME_SAVE_FLAGS + PTFF_SAVE_X0 + PTFF_SAVE_X1 b C_FUNC(RhpWaitForGC) NESTED_END RhpGcProbeHijack @@ -144,11 +144,11 @@ NESTED_ENTRY RhpWaitForGC, _TEXT, NoHandler bl C_FUNC(RhpWaitForGC2) ldr x2, [sp, #OFFSETOF__PInvokeTransitionFrame__m_Flags] - tbnz x2, #PTFF_THREAD_ABORT_BIT, ThrowThreadAbort + tbnz x2, #PTFF_THREAD_ABORT_BIT, LOCAL_LABEL(ThrowThreadAbort) POP_PROBE_FRAME EPILOG_RETURN -ThrowThreadAbort: +LOCAL_LABEL(ThrowThreadAbort): POP_PROBE_FRAME mov w0, #STATUS_REDHAWK_THREAD_ABORT mov x1, lr // return address as exception PC @@ -159,8 +159,10 @@ NESTED_END RhpWaitForGC LEAF_ENTRY RhpGcPoll PREPARE_EXTERNAL_VAR_INDIRECT_W RhpTrapThreads, 0 - cbnz w0, C_FUNC(RhpGcPollRare) // TrapThreadsFlags_None = 0 + cbnz w0, LOCAL_LABEL(RhpGcPoll_Rare) // TrapThreadsFlags_None = 0 ret +LOCAL_LABEL(RhpGcPoll_Rare): + b C_FUNC(RhpGcPollRare) LEAF_END RhpGcPoll NESTED_ENTRY RhpGcPollRare, _TEXT, NoHandler diff --git a/src/coreclr/nativeaot/Runtime/arm64/WriteBarriers.S b/src/coreclr/nativeaot/Runtime/arm64/WriteBarriers.S index d00ffb3a4a9978..835466c3b9e7e4 100644 --- a/src/coreclr/nativeaot/Runtime/arm64/WriteBarriers.S +++ b/src/coreclr/nativeaot/Runtime/arm64/WriteBarriers.S @@ -224,9 +224,11 @@ LEAF_END RhpByRefAssignRefArm64, _TEXT PREPARE_EXTERNAL_VAR_INDIRECT g_highest_address, x12 ccmp x14, x12, #0x2, hs - blo C_FUNC(RhpAssignRefArm64) + bhs LOCAL_LABEL(NotInHeap) -NotInHeap: + b C_FUNC(RhpAssignRefArm64) + +LOCAL_LABEL(NotInHeap): ALTERNATE_ENTRY RhpCheckedAssignRefAVLocation str x15, [x14], 8 ret @@ -293,44 +295,44 @@ LEAF_END RhpAssignRef, _TEXT #ifndef LSE_INSTRUCTIONS_ENABLED_BY_DEFAULT PREPARE_EXTERNAL_VAR_INDIRECT_W g_cpuFeatures, 16 - tbz w16, #ARM64_ATOMICS_FEATURE_FLAG_BIT, CmpXchgRetry + tbz w16, #ARM64_ATOMICS_FEATURE_FLAG_BIT, LOCAL_LABEL(CmpXchgRetry) #endif mov x10, x2 ALTERNATE_ENTRY RhpCheckedLockCmpXchgAVLocation casal x10, x1, [x0] // exchange cmp x2, x10 - bne CmpXchgNoUpdate + bne LOCAL_LABEL(CmpXchgNoUpdate) #ifndef LSE_INSTRUCTIONS_ENABLED_BY_DEFAULT - b DoCardsCmpXchg -CmpXchgRetry: + b LOCAL_LABEL(DoCardsCmpXchg) +LOCAL_LABEL(CmpXchgRetry): // Check location value is what we expect. ALTERNATE_ENTRY RhpCheckedLockCmpXchgAVLocation2 ldaxr x10, [x0] cmp x10, x2 - bne CmpXchgNoUpdate + bne LOCAL_LABEL(CmpXchgNoUpdate) // Current value matches comparand, attempt to update with the new value. stlxr w12, x1, [x0] - cbnz w12, CmpXchgRetry + cbnz w12, LOCAL_LABEL(CmpXchgRetry) #endif -DoCardsCmpXchg: +LOCAL_LABEL(DoCardsCmpXchg): // We have successfully updated the value of the objectref so now we need a GC write barrier. // The following barrier code takes the destination in x0 and the value in x1 so the arguments are // already correctly set up. INSERT_CHECKED_WRITE_BARRIER_CORE x0, x1 -CmpXchgNoUpdate: +LOCAL_LABEL(CmpXchgNoUpdate): // x10 still contains the original value. mov x0, x10 #ifndef LSE_INSTRUCTIONS_ENABLED_BY_DEFAULT - tbnz w16, #ARM64_ATOMICS_FEATURE_FLAG_BIT, NoBarrierCmpXchg + tbnz w16, #ARM64_ATOMICS_FEATURE_FLAG_BIT, LOCAL_LABEL(NoBarrierCmpXchg) InterlockedOperationBarrier -NoBarrierCmpXchg: +LOCAL_LABEL(NoBarrierCmpXchg): #endif ret lr @@ -357,25 +359,25 @@ NoBarrierCmpXchg: #ifndef LSE_INSTRUCTIONS_ENABLED_BY_DEFAULT PREPARE_EXTERNAL_VAR_INDIRECT_W g_cpuFeatures, 16 - tbz w16, #ARM64_ATOMICS_FEATURE_FLAG_BIT, ExchangeRetry + tbz w16, #ARM64_ATOMICS_FEATURE_FLAG_BIT, LOCAL_LABEL(ExchangeRetry) #endif ALTERNATE_ENTRY RhpCheckedXchgAVLocation swpal x1, x10, [x0] // exchange #ifndef LSE_INSTRUCTIONS_ENABLED_BY_DEFAULT - b DoCardsXchg -ExchangeRetry: + b LOCAL_LABEL(DoCardsXchg) +LOCAL_LABEL(ExchangeRetry): // Read the existing memory location. ALTERNATE_ENTRY RhpCheckedXchgAVLocation2 ldaxr x10, [x0] // Attempt to update with the new value. stlxr w12, x1, [x0] - cbnz w12, ExchangeRetry + cbnz w12, LOCAL_LABEL(ExchangeRetry) #endif -DoCardsXchg: +LOCAL_LABEL(DoCardsXchg): // We have successfully updated the value of the objectref so now we need a GC write barrier. // The following barrier code takes the destination in x0 and the value in x1 so the arguments are // already correctly set up. @@ -386,9 +388,9 @@ DoCardsXchg: mov x0, x10 #ifndef LSE_INSTRUCTIONS_ENABLED_BY_DEFAULT - tbnz w16, #ARM64_ATOMICS_FEATURE_FLAG_BIT, NoBarrierXchg + tbnz w16, #ARM64_ATOMICS_FEATURE_FLAG_BIT, LOCAL_LABEL(NoBarrierXchg) InterlockedOperationBarrier -NoBarrierXchg: +LOCAL_LABEL(NoBarrierXchg): #endif ret diff --git a/src/coreclr/nativeaot/Runtime/unix/unixasmmacros.inc b/src/coreclr/nativeaot/Runtime/unix/unixasmmacros.inc index ef6d393fd248b1..bde1d517b7e823 100644 --- a/src/coreclr/nativeaot/Runtime/unix/unixasmmacros.inc +++ b/src/coreclr/nativeaot/Runtime/unix/unixasmmacros.inc @@ -3,6 +3,11 @@ #define INVALIDGCVALUE 0xCCCCCCCD +// Enforce subsections via symbols to workaround bugs in Xcode 15 linker. +#if defined(__APPLE__) +.subsections_via_symbols +#endif + #if defined(__APPLE__) #define C_FUNC(name) _##name #define EXTERNAL_C_FUNC(name) C_FUNC(name) diff --git a/src/coreclr/nativeaot/System.Private.CoreLib/src/CompatibilitySuppressions.xml b/src/coreclr/nativeaot/System.Private.CoreLib/src/CompatibilitySuppressions.xml index d5fbde8e348dc3..229085a10afaa0 100644 --- a/src/coreclr/nativeaot/System.Private.CoreLib/src/CompatibilitySuppressions.xml +++ b/src/coreclr/nativeaot/System.Private.CoreLib/src/CompatibilitySuppressions.xml @@ -800,6 +800,10 @@ CP0001 T:System.Diagnostics.DebugAnnotations + + CP0001 + T:System.Diagnostics.DebuggerGuidedStepThroughAttribute + CP0001 T:System.MDArray @@ -864,6 +868,10 @@ CP0001 T:System.Reflection.RuntimeAssemblyName + + CP0001 + T:System.Runtime.CompilerServices.EagerStaticClassConstructionAttribute + CP0001 T:System.Runtime.CompilerServices.ForceDictionaryLookupsAttribute diff --git a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/CrashInfo.cs b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/CrashInfo.cs index 8a2f33c93f1f8c..c2421fc6b4ceb9 100644 --- a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/CrashInfo.cs +++ b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/CrashInfo.cs @@ -103,12 +103,17 @@ private bool WriteHeader(RhFailFastReason reason, ulong crashingThreadId, string if (!WriteValue("version"u8, "1.0.0"u8)) return false; - if (!WriteValue("runtime"u8, new ReadOnlySpan(RuntimeImports.RhGetRuntimeVersion(out int cbLength), cbLength))) + static void Dummy() { } + + if (!WriteHexValue("runtime_base"u8, (ulong)RuntimeImports.RhGetOSModuleFromPointer((nint)(void*)(delegate*)&Dummy))) return false; if (!WriteIntValue("runtime_type"u8, (int)RuntimeType.NativeAOT)) return false; + if (!WriteValue("runtime_version"u8, new ReadOnlySpan(RuntimeImports.RhGetRuntimeVersion(out int cbLength), cbLength))) + return false; + CrashReason crashReason = reason switch { RhFailFastReason.EnvironmentFailFast => CrashReason.EnvironmentFailFast, diff --git a/src/coreclr/pal/inc/unixasmmacros.inc b/src/coreclr/pal/inc/unixasmmacros.inc index 658a65bb4b35aa..120b26543e3faa 100644 --- a/src/coreclr/pal/inc/unixasmmacros.inc +++ b/src/coreclr/pal/inc/unixasmmacros.inc @@ -3,6 +3,11 @@ #define INVALIDGCVALUE 0xCCCCCCCD +// Enforce subsections via symbols to workaround bugs in Xcode 15 linker. +#if defined(__APPLE__) +.subsections_via_symbols +#endif + #if defined(__APPLE__) #define C_FUNC(name) _##name #define EXTERNAL_C_FUNC(name) C_FUNC(name) diff --git a/src/coreclr/tools/Common/TypeSystem/Ecma/SymbolReader/UnmanagedPdbSymbolReader.cs b/src/coreclr/tools/Common/TypeSystem/Ecma/SymbolReader/UnmanagedPdbSymbolReader.cs index c791a509f02a99..107170e743fbc7 100644 --- a/src/coreclr/tools/Common/TypeSystem/Ecma/SymbolReader/UnmanagedPdbSymbolReader.cs +++ b/src/coreclr/tools/Common/TypeSystem/Ecma/SymbolReader/UnmanagedPdbSymbolReader.cs @@ -68,7 +68,7 @@ protected override object CreateObject(IntPtr externalComObject, CreateObjectFla Debug.Assert(flags == CreateObjectFlags.UniqueInstance); var iid = ICLRMetaHost.IID; - if (Marshal.QueryInterface(externalComObject, ref iid, out IntPtr hostPtr) != 0) + if (Marshal.QueryInterface(externalComObject, in iid, out IntPtr hostPtr) != 0) { throw new ArgumentException("Expected ICLRMetaHost COM interface"); } @@ -284,7 +284,7 @@ private CoCreateWrapperCache() { } Debug.Assert(flags == CreateObjectFlags.UniqueInstance); var iid = new Guid("AA544D42-28CB-11d3-BD22-0000F80849BD"); - if (Marshal.QueryInterface(externalComObject, ref iid, out IntPtr ppv) != 0) + if (Marshal.QueryInterface(externalComObject, in iid, out IntPtr ppv) != 0) { return null; } diff --git a/src/coreclr/tools/aot/Directory.Build.props b/src/coreclr/tools/aot/Directory.Build.props deleted file mode 100644 index 5a5e0e9914b730..00000000000000 --- a/src/coreclr/tools/aot/Directory.Build.props +++ /dev/null @@ -1,6 +0,0 @@ - - - - true - - diff --git a/src/coreclr/tools/aot/Directory.Build.targets b/src/coreclr/tools/aot/Directory.Build.targets new file mode 100644 index 00000000000000..4f855d71288f72 --- /dev/null +++ b/src/coreclr/tools/aot/Directory.Build.targets @@ -0,0 +1,6 @@ + + + + true + + diff --git a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/DependencyAnalysis/EETypeNode.cs b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/DependencyAnalysis/EETypeNode.cs index 61520b4bfadaff..eb64be16023150 100644 --- a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/DependencyAnalysis/EETypeNode.cs +++ b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/DependencyAnalysis/EETypeNode.cs @@ -437,10 +437,14 @@ public sealed override IEnumerable GetConditionalSt // Add conditional dependencies for interface methods the type implements. For example, if the type T implements // interface IFoo which has a method M1, add a dependency on T.M1 dependent on IFoo.M1 being called, since it's // possible for any IFoo object to actually be an instance of T. + DefType defTypeDefinition = (DefType)defType.GetTypeDefinition(); DefType[] defTypeRuntimeInterfaces = defType.RuntimeInterfaces; + DefType[] defTypeDefinitionRuntimeInterfaces = defTypeDefinition.RuntimeInterfaces; + Debug.Assert(defTypeDefinitionRuntimeInterfaces.Length == defTypeRuntimeInterfaces.Length); for (int interfaceIndex = 0; interfaceIndex < defTypeRuntimeInterfaces.Length; interfaceIndex++) { DefType interfaceType = defTypeRuntimeInterfaces[interfaceIndex]; + DefType interfaceDefinitionType = defTypeDefinitionRuntimeInterfaces[interfaceIndex]; Debug.Assert(interfaceType.IsInterface); @@ -457,11 +461,22 @@ public sealed override IEnumerable GetConditionalSt if (!isStaticInterfaceMethod && !needsDependenciesForInstanceInterfaceMethodImpls) continue; + MethodDesc interfaceMethodDefinition = interfaceMethod; + if (interfaceType != interfaceDefinitionType) + interfaceMethodDefinition = factory.TypeSystemContext.GetMethodForInstantiatedType(interfaceMethodDefinition.GetTypicalMethodDefinition(), (InstantiatedType)interfaceDefinitionType); + MethodDesc implMethod = isStaticInterfaceMethod ? - defType.ResolveInterfaceMethodToStaticVirtualMethodOnType(interfaceMethod) : - defType.ResolveInterfaceMethodToVirtualMethodOnType(interfaceMethod); + defTypeDefinition.ResolveInterfaceMethodToStaticVirtualMethodOnType(interfaceMethodDefinition) : + defTypeDefinition.ResolveInterfaceMethodToVirtualMethodOnType(interfaceMethodDefinition); if (implMethod != null) { + TypeDesc implType = defType; + while (!implType.HasSameTypeDefinition(implMethod.OwningType)) + implType = implType.BaseType; + + if (!implType.IsTypeDefinition) + implMethod = factory.TypeSystemContext.GetMethodForInstantiatedType(implMethod.GetTypicalMethodDefinition(), (InstantiatedType)implType); + if (isStaticInterfaceMethod) { Debug.Assert(!implMethod.IsVirtual); @@ -500,12 +515,7 @@ public sealed override IEnumerable GetConditionalSt // Is the implementation provided by a default interface method? // If so, add a dependency on the entrypoint directly since nobody else is going to do that // (interface types have an empty vtable, modulo their generic dictionary). - TypeDesc interfaceOnDefinition = defType.GetTypeDefinition().RuntimeInterfaces[interfaceIndex]; - MethodDesc interfaceMethodDefinition = interfaceMethod; - if (!interfaceType.IsTypeDefinition) - interfaceMethodDefinition = factory.TypeSystemContext.GetMethodForInstantiatedType(interfaceMethod.GetTypicalMethodDefinition(), (InstantiatedType)interfaceOnDefinition); - - var resolution = defType.GetTypeDefinition().ResolveInterfaceMethodToDefaultImplementationOnType(interfaceMethodDefinition, out implMethod); + var resolution = defTypeDefinition.ResolveInterfaceMethodToDefaultImplementationOnType(interfaceMethodDefinition, out implMethod); if (resolution == DefaultInterfaceMethodResolution.DefaultImplementation) { DefType providingInterfaceDefinitionType = (DefType)implMethod.OwningType; diff --git a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/RootingHelpers.cs b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/RootingHelpers.cs index 038c1ee1f38dc0..ffd286cce759f2 100644 --- a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/RootingHelpers.cs +++ b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/RootingHelpers.cs @@ -11,11 +11,11 @@ namespace ILCompiler { public class RootingHelpers { - public static bool TryRootType(IRootingServiceProvider rootProvider, TypeDesc type, string reason) + public static bool TryRootType(IRootingServiceProvider rootProvider, TypeDesc type, bool rootBaseTypes, string reason) { try { - RootType(rootProvider, type, reason); + RootType(rootProvider, type, rootBaseTypes, reason); return true; } catch (TypeSystemException) @@ -24,7 +24,7 @@ public static bool TryRootType(IRootingServiceProvider rootProvider, TypeDesc ty } } - public static void RootType(IRootingServiceProvider rootProvider, TypeDesc type, string reason) + public static void RootType(IRootingServiceProvider rootProvider, TypeDesc type, bool rootBaseTypes, string reason) { rootProvider.AddReflectionRoot(type, reason); @@ -40,13 +40,13 @@ public static void RootType(IRootingServiceProvider rootProvider, TypeDesc type, rootProvider.AddReflectionRoot(type, reason); } - // Also root base types. This is so that we make methods on the base types callable. - // This helps in cases like "class Foo : Bar { }" where we discover new - // generic instantiations. - TypeDesc baseType = type.BaseType; - if (baseType != null) + if (rootBaseTypes) { - RootType(rootProvider, baseType.NormalizeInstantiation(), reason); + TypeDesc baseType = type.BaseType; + if (baseType != null) + { + RootType(rootProvider, baseType.NormalizeInstantiation(), rootBaseTypes, reason); + } } if (type.IsDefType) diff --git a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/UsageBasedMetadataManager.cs b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/UsageBasedMetadataManager.cs index 0d4a855e736ed0..bd3e1069881361 100644 --- a/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/UsageBasedMetadataManager.cs +++ b/src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/UsageBasedMetadataManager.cs @@ -351,7 +351,7 @@ protected override void GetMetadataDependenciesDueToReflectability(ref Dependenc var rootProvider = new RootingServiceProvider(factory, dependencies.Add); foreach (TypeDesc t in mdType.Module.GetAllTypes()) { - RootingHelpers.TryRootType(rootProvider, t, reason); + RootingHelpers.TryRootType(rootProvider, t, rootBaseTypes: false, reason); } } } diff --git a/src/coreclr/tools/aot/ILCompiler/RdXmlRootProvider.cs b/src/coreclr/tools/aot/ILCompiler/RdXmlRootProvider.cs index 2c6b01849db031..6d263eddc1eb5f 100644 --- a/src/coreclr/tools/aot/ILCompiler/RdXmlRootProvider.cs +++ b/src/coreclr/tools/aot/ILCompiler/RdXmlRootProvider.cs @@ -71,7 +71,7 @@ private void ProcessAssemblyDirective(IRootingServiceProvider rootProvider, XEle foreach (TypeDesc type in ((EcmaModule)assembly).GetAllTypes()) { - RootingHelpers.TryRootType(rootProvider, type, "RD.XML root"); + RootingHelpers.TryRootType(rootProvider, type, rootBaseTypes: true, "RD.XML root"); } } @@ -103,7 +103,7 @@ private static void ProcessTypeDirective(IRootingServiceProvider rootProvider, M if (dynamicDegreeAttribute.Value != "Required All") throw new NotSupportedException($"\"{dynamicDegreeAttribute.Value}\" is not a supported value for the \"Dynamic\" attribute of the \"Type\" Runtime Directive. Supported values are \"Required All\"."); - RootingHelpers.RootType(rootProvider, type, "RD.XML root"); + RootingHelpers.RootType(rootProvider, type, rootBaseTypes: true, "RD.XML root"); } var marshalStructureDegreeAttribute = typeElement.Attribute("MarshalStructure"); diff --git a/src/coreclr/tools/aot/Mono.Linker.Tests/TestCases/TestDatabase.cs b/src/coreclr/tools/aot/Mono.Linker.Tests/TestCases/TestDatabase.cs index 31cd4455663e2e..57c3de36dd7574 100644 --- a/src/coreclr/tools/aot/Mono.Linker.Tests/TestCases/TestDatabase.cs +++ b/src/coreclr/tools/aot/Mono.Linker.Tests/TestCases/TestDatabase.cs @@ -34,7 +34,12 @@ public static IEnumerable InlineArrays () return TestNamesBySuiteName(); } - public static IEnumerable LinkXml() + public static IEnumerable Libraries() + { + return TestNamesBySuiteName(); + } + + public static IEnumerable LinkXml() { return TestNamesBySuiteName(); } diff --git a/src/coreclr/tools/aot/Mono.Linker.Tests/TestCases/TestSuites.cs b/src/coreclr/tools/aot/Mono.Linker.Tests/TestCases/TestSuites.cs index 744fb23e416f98..f8d7fb5c69d02b 100644 --- a/src/coreclr/tools/aot/Mono.Linker.Tests/TestCases/TestSuites.cs +++ b/src/coreclr/tools/aot/Mono.Linker.Tests/TestCases/TestSuites.cs @@ -37,7 +37,14 @@ public void InlineArrays(string t) Run(t); } - [Theory] + [Theory] + [MemberData(nameof(TestDatabase.Libraries), MemberType = typeof(TestDatabase))] + public void Libraries(string t) + { + Run(t); + } + + [Theory] [MemberData (nameof (TestDatabase.LinkXml), MemberType = typeof (TestDatabase))] public void LinkXml (string t) { diff --git a/src/coreclr/tools/aot/Mono.Linker.Tests/TestCasesRunner/AssemblyChecker.cs b/src/coreclr/tools/aot/Mono.Linker.Tests/TestCasesRunner/AssemblyChecker.cs index 7280796f5abc4e..393837de75f076 100644 --- a/src/coreclr/tools/aot/Mono.Linker.Tests/TestCasesRunner/AssemblyChecker.cs +++ b/src/coreclr/tools/aot/Mono.Linker.Tests/TestCasesRunner/AssemblyChecker.cs @@ -123,9 +123,12 @@ public void Verify () throw new NotImplementedException ($"Don't know how to check member of type {originalMember.GetType ()}"); } - // Filter out all members which are not from the main assembly - // The Kept attributes are "optional" for non-main assemblies - string mainModuleName = originalAssembly.Name.Name; + // Verify anything not in the main assembly + VerifyLinkingOfOtherAssemblies(this.originalAssembly); + + // Filter out all members which are not from the main assembly + // The Kept attributes are "optional" for non-main assemblies + string mainModuleName = originalAssembly.Name.Name; List externalMembers = linkedMembers.Where (m => GetModuleName (m.Value.Entity) != mainModuleName).Select (m => m.Key).ToList (); foreach (var externalMember in externalMembers) { linkedMembers.Remove (externalMember); @@ -136,7 +139,7 @@ public void Verify () false, "Linked output includes unexpected member:\n " + string.Join ("\n ", linkedMembers.Values.Select (e => e.Entity.GetDisplayName ()))); - } + } static bool IsCompilerGeneratedMemberName (string memberName) { @@ -304,12 +307,23 @@ static bool ShouldIncludeType (TypeDesc type) static bool ShouldIncludeMethod (MethodDesc method) => ShouldIncludeType (method.OwningType) && ShouldIncludeEntityByDisplayName (method); } + private static MetadataType? GetOwningType (TypeSystemEntity? entity) + { + return entity switch + { + MetadataType type => type.ContainingType as MetadataType, + MethodDesc method => method.OwningType as MetadataType, + PropertyPseudoDesc prop => prop.OwningType, + EventPseudoDesc e => e.OwningType, + _ => null + }; + } + private static string? GetModuleName (TypeSystemEntity entity) { return entity switch { MetadataType type => type.Module.ToString (), - MethodDesc { OwningType: MetadataType owningType } => owningType.Module.ToString (), - _ => null + _ => GetOwningType(entity)?.Module.ToString() }; } @@ -1338,38 +1352,38 @@ private static bool HasActiveKeptDerivedAttribute (ICustomAttributeProvider prov return GetActiveKeptDerivedAttributes (provider).Any (); } - private void VerifyLinkingOfOtherAssemblies (AssemblyDefinition original) + internal void VerifyLinkingOfOtherAssemblies (AssemblyDefinition original) { var checks = BuildOtherAssemblyCheckTable (original); - // TODO - // For now disable the code below by simply removing all checks - checks.Clear (); - try { foreach (var assemblyName in checks.Keys) { - var linkedAssembly = ResolveLinkedAssembly (assemblyName); + var linkedMembersInAssembly = ResolveLinkedMembersForAssembly (assemblyName); + var originalTargetAssembly = ResolveOriginalsAssembly(assemblyName); foreach (var checkAttrInAssembly in checks[assemblyName]) { var attributeTypeName = checkAttrInAssembly.AttributeType.Name; switch (attributeTypeName) { case nameof (KeptAllTypesAndMembersInAssemblyAttribute): - VerifyKeptAllTypesAndMembersInAssembly (linkedAssembly); + VerifyKeptAllTypesAndMembersInAssembly (assemblyName, linkedMembersInAssembly); continue; case nameof (KeptAttributeInAssemblyAttribute): - VerifyKeptAttributeInAssembly (checkAttrInAssembly, linkedAssembly); + // VerifyKeptAttributeInAssembly (checkAttrInAssembly, linkedAssembly); continue; case nameof (RemovedAttributeInAssembly): - VerifyRemovedAttributeInAssembly (checkAttrInAssembly, linkedAssembly); + // VerifyRemovedAttributeInAssembly (checkAttrInAssembly, linkedAssembly); continue; default: break; } - var expectedTypeName = checkAttrInAssembly.ConstructorArguments[1].Value.ToString ()!; - TypeDefinition? linkedType = linkedAssembly.MainModule.GetType (expectedTypeName); + var expectedTypeName = checkAttrInAssembly.ConstructorArguments[1].Value.ToString ()!; + var expectedType = originalTargetAssembly.MainModule.GetType(expectedTypeName); + linkedMembersInAssembly.TryGetValue(new AssemblyQualifiedToken(expectedType), out LinkedEntity? linkedTypeEntity); + MetadataType? linkedType = linkedTypeEntity?.Entity as MetadataType; - if (linkedType == null && linkedAssembly.MainModule.HasExportedTypes) { +#if false + if (linkedType == null && linkedAssembly.MainModule.HasExportedTypes) { ExportedType? exportedType = linkedAssembly.MainModule.ExportedTypes .FirstOrDefault (exported => exported.FullName == expectedTypeName); @@ -1381,6 +1395,7 @@ private void VerifyLinkingOfOtherAssemblies (AssemblyDefinition original) linkedType = exportedType?.Resolve (); } +#endif switch (attributeTypeName) { case nameof (RemovedTypeInAssemblyAttribute): @@ -1392,6 +1407,7 @@ private void VerifyLinkingOfOtherAssemblies (AssemblyDefinition original) if (linkedType == null) Assert.Fail ($"Type `{expectedTypeName}' should have been kept in assembly {assemblyName}"); break; +#if false case nameof (RemovedInterfaceOnTypeInAssemblyAttribute): if (linkedType == null) Assert.Fail ($"Type `{expectedTypeName}' should have been kept in assembly {assemblyName}"); @@ -1444,11 +1460,15 @@ private void VerifyLinkingOfOtherAssemblies (AssemblyDefinition original) Assert.Fail ($"Type `{expectedTypeName}` should have been kept in assembly {assemblyName}"); VerifyExpectedInstructionSequenceOnMemberInAssembly (checkAttrInAssembly, linkedType); break; - default: + default: UnhandledOtherAssemblyAssertion (expectedTypeName, checkAttrInAssembly, linkedType); break; - } - } +#else + default: + break; +#endif + } + } } } catch (AssemblyResolutionException e) { Assert.Fail ($"Failed to resolve linked assembly `{e.AssemblyReference.Name}`. It must not exist in the output."); @@ -1740,54 +1760,62 @@ protected virtual bool TryVerifyKeptMemberInAssemblyAsMethod (string memberName, private void VerifyKeptReferencesInAssembly (CustomAttribute inAssemblyAttribute) { - var assembly = ResolveLinkedAssembly (inAssemblyAttribute.ConstructorArguments[0].Value.ToString ()!); +#if false + var assembly = ResolveLinkedAssembly (inAssemblyAttribute.ConstructorArguments[0].Value.ToString ()!); var expectedReferenceNames = ((CustomAttributeArgument[]) inAssemblyAttribute.ConstructorArguments[1].Value).Select (attr => (string) attr.Value).ToList (); for (int i = 0; i < expectedReferenceNames.Count; i++) if (expectedReferenceNames[i].EndsWith (".dll")) expectedReferenceNames[i] = expectedReferenceNames[i].Substring (0, expectedReferenceNames[i].LastIndexOf (".")); Assert.Equal (assembly.MainModule.AssemblyReferences.Select (asm => asm.Name), expectedReferenceNames); +#endif } private void VerifyKeptResourceInAssembly (CustomAttribute inAssemblyAttribute) { - var assembly = ResolveLinkedAssembly (inAssemblyAttribute.ConstructorArguments[0].Value.ToString ()!); +#if false + var assembly = ResolveLinkedAssembly (inAssemblyAttribute.ConstructorArguments[0].Value.ToString ()!); var resourceName = inAssemblyAttribute.ConstructorArguments[1].Value.ToString (); Assert.Contains (resourceName, assembly.MainModule.Resources.Select (r => r.Name)); +#endif } private void VerifyRemovedResourceInAssembly (CustomAttribute inAssemblyAttribute) { - var assembly = ResolveLinkedAssembly (inAssemblyAttribute.ConstructorArguments[0].Value.ToString ()!); +#if false + var assembly = ResolveLinkedAssembly (inAssemblyAttribute.ConstructorArguments[0].Value.ToString ()!); var resourceName = inAssemblyAttribute.ConstructorArguments[1].Value.ToString (); Assert.DoesNotContain (resourceName, assembly.MainModule.Resources.Select (r => r.Name)); +#endif } - private void VerifyKeptAllTypesAndMembersInAssembly (AssemblyDefinition linked) + private void VerifyKeptAllTypesAndMembersInAssembly (string assemblyName, Dictionary linkedMembers) { - var original = ResolveOriginalsAssembly (linked.MainModule.Assembly.Name.Name); + var original = ResolveOriginalsAssembly (assemblyName); if (original == null) - Assert.Fail ($"Failed to resolve original assembly {linked.MainModule.Assembly.Name.Name}"); + Assert.Fail ($"Failed to resolve original assembly {assemblyName}"); - var originalTypes = original.AllDefinedTypes ().ToDictionary (t => t.FullName); - var linkedTypes = linked.AllDefinedTypes ().ToDictionary (t => t.FullName); + var originalTypes = original.AllDefinedTypes ().ToDictionary (t => new AssemblyQualifiedToken(t)); + var linkedTypes = linkedMembers.Where(t => t.Value.Entity is TypeDesc).ToDictionary(); var missingInLinked = originalTypes.Keys.Except (linkedTypes.Keys); - Assert.True (missingInLinked.Any (), $"Expected all types to exist in the linked assembly, but one or more were missing"); + Assert.False (missingInLinked.Any (), $"Expected all types to exist in the linked assembly {assemblyName}, but one or more were missing"); foreach (var originalKvp in originalTypes) { var linkedType = linkedTypes[originalKvp.Key]; + TypeDesc linkedTypeDesc = (TypeDesc)linkedType.Entity; - var originalMembers = originalKvp.Value.AllMembers ().Select (m => m.FullName); - var linkedMembers = linkedType.AllMembers ().Select (m => m.FullName); + // NativeAOT field trimming is very different (it basically doesn't trim fields, not in the same way trimmer does) + var originalMembers = originalKvp.Value.AllMembers ().Where(m => m is not FieldDefinition).Select (m => new AssemblyQualifiedToken(m)); + var linkedMembersOnType = linkedMembers.Where(t => GetOwningType(t.Value.Entity) == linkedTypeDesc).Select(t => t.Key); - var missingMembersInLinked = originalMembers.Except (linkedMembers); + var missingMembersInLinked = originalMembers.Except (linkedMembersOnType); - Assert.True (missingMembersInLinked.Any (), $"Expected all members of `{originalKvp.Key}`to exist in the linked assembly, but one or more were missing"); + Assert.False (missingMembersInLinked.Any (), $"Expected all members of `{linkedTypeDesc.GetDisplayName()}`to exist in the linked assembly, but one or more were missing"); } } @@ -1823,6 +1851,11 @@ private static Dictionary> BuildOtherAssemblyCheck foreach (var typeWithRemoveInAssembly in original.AllDefinedTypes ()) { foreach (var attr in typeWithRemoveInAssembly.CustomAttributes.Where (IsTypeInOtherAssemblyAssertion)) { var assemblyName = (string) attr.ConstructorArguments[0].Value; + + Tool? toolTarget = (Tool?)(int?)attr.GetPropertyValue("Tool"); + if (toolTarget is not null && !toolTarget.Value.HasFlag(Tool.NativeAot)) + continue; + if (!checks.TryGetValue (assemblyName, out List? checksForAssembly)) checks[assemblyName] = checksForAssembly = new List (); @@ -1833,14 +1866,13 @@ private static Dictionary> BuildOtherAssemblyCheck return checks; } - protected AssemblyDefinition ResolveLinkedAssembly (string assemblyName) + private Dictionary ResolveLinkedMembersForAssembly (string assemblyName) { - //var cleanAssemblyName = assemblyName; - //if (assemblyName.EndsWith (".exe") || assemblyName.EndsWith (".dll")) - //cleanAssemblyName = System.IO.Path.GetFileNameWithoutExtension (assemblyName); - //return _linkedResolver.Resolve (new AssemblyNameReference (cleanAssemblyName, null), _linkedReaderParameters); - // TODO - adapt to Native AOT - return ResolveOriginalsAssembly (assemblyName); + var cleanAssemblyName = assemblyName; + if (assemblyName.EndsWith(".exe") || assemblyName.EndsWith(".dll")) + cleanAssemblyName = System.IO.Path.GetFileNameWithoutExtension(assemblyName); + + return this.linkedMembers.Where(e => GetModuleName(e.Value.Entity) == cleanAssemblyName).ToDictionary(); } protected AssemblyDefinition ResolveOriginalsAssembly (string assemblyName) diff --git a/src/coreclr/tools/aot/Mono.Linker.Tests/TestCasesRunner/ILCompilerDriver.cs b/src/coreclr/tools/aot/Mono.Linker.Tests/TestCasesRunner/ILCompilerDriver.cs index 275d035c66843e..5fe3c9adf30d40 100644 --- a/src/coreclr/tools/aot/Mono.Linker.Tests/TestCasesRunner/ILCompilerDriver.cs +++ b/src/coreclr/tools/aot/Mono.Linker.Tests/TestCasesRunner/ILCompilerDriver.cs @@ -103,7 +103,7 @@ public ILScanResults Trim (ILCompilerOptions options, ILogWriter logWriter) new ManifestResourceBlockingPolicy (logger, options.FeatureSwitches, new Dictionary>()), logFile: null, new NoStackTraceEmissionPolicy (), - new NoDynamicInvokeThunkGenerationPolicy (), + new DefaultDynamicInvokeThunkGenerationPolicy (), new FlowAnnotations (logger, ilProvider, compilerGeneratedState), UsageBasedMetadataGenerationOptions.ReflectionILScanning, options: default, diff --git a/src/coreclr/vm/arm64/asmhelpers.S b/src/coreclr/vm/arm64/asmhelpers.S index cdbe24ec427a98..89dab80461c356 100644 --- a/src/coreclr/vm/arm64/asmhelpers.S +++ b/src/coreclr/vm/arm64/asmhelpers.S @@ -329,7 +329,9 @@ WRITE_BARRIER_ENTRY JIT_CheckedWriteBarrier // branch below is not taken. ccmp x14, x12, #0x2, hs - blo C_FUNC(JIT_WriteBarrier) + bhs LOCAL_LABEL(NotInHeap) + + b C_FUNC(JIT_WriteBarrier) LOCAL_LABEL(NotInHeap): str x15, [x14], 8 diff --git a/src/coreclr/vm/callsiteinspect.cpp b/src/coreclr/vm/callsiteinspect.cpp index dabbe89a497720..8209e41e6a7d44 100644 --- a/src/coreclr/vm/callsiteinspect.cpp +++ b/src/coreclr/vm/callsiteinspect.cpp @@ -433,7 +433,8 @@ void CallsiteInspect::PropagateOutParametersBackToCallsite( *(ARG_SLOT *)(frame->GetReturnValuePtr()) = retVal; } #ifdef ENREGISTERED_RETURNTYPE_MAXSIZE - else if (argit.HasNonStandardByvalReturn()) + else if (argit.HasNonStandardByvalReturn() + && !(flags & CallsiteDetails::HResultReturn)) { // In these cases, put the pointer to the return buffer into // the frame's return value slot. diff --git a/src/coreclr/vm/callsiteinspect.h b/src/coreclr/vm/callsiteinspect.h index 373b9347dfd9c9..4ca66eca9feba2 100644 --- a/src/coreclr/vm/callsiteinspect.h +++ b/src/coreclr/vm/callsiteinspect.h @@ -25,6 +25,7 @@ struct CallsiteDetails BeginInvoke = 0x01, EndInvoke = 0x02, Ctor = 0x04, + HResultReturn = 0x08, }; INT32 Flags; }; diff --git a/src/coreclr/vm/clrtocomcall.cpp b/src/coreclr/vm/clrtocomcall.cpp index 06d28f507249b4..c604a6c8a90116 100644 --- a/src/coreclr/vm/clrtocomcall.cpp +++ b/src/coreclr/vm/clrtocomcall.cpp @@ -364,7 +364,7 @@ UINT32 CLRToCOMEventCallWorker(ComPlusMethodFrame* pFrame, ComPlusCallMethodDesc return 0; } -CallsiteDetails CreateCallsiteDetails(_In_ FramedMethodFrame *pFrame) +static CallsiteDetails CreateCallsiteDetails(_In_ FramedMethodFrame *pFrame) { CONTRACTL { @@ -442,10 +442,20 @@ CallsiteDetails CreateCallsiteDetails(_In_ FramedMethodFrame *pFrame) SigTypeContext::InitTypeContext(pMD, actualType, &typeContext); } + // If the signature is marked preserve sig, then the return + // is required to be an HRESULT, per COM rules. We set a flag to + // indicate this state to avoid issues when a C# developer defines + // an HRESULT in C# as a ValueClass with a single int field. This + // is convenient but does violate the COM ABI. Setting the flag + // lets us permit this convention and allow either a 4 byte primitive + // or the commonly used C# type "struct HResult { int Value; }". + if (IsMiPreserveSig(pMD->GetImplAttrs())) + callsiteFlags |= CallsiteDetails::HResultReturn; + _ASSERTE(!signature.IsEmpty() && pModule != nullptr); // Create details - return CallsiteDetails{ { signature, pModule, &typeContext }, pFrame, pMD, fIsDelegate }; + return CallsiteDetails{ { signature, pModule, &typeContext }, pFrame, pMD, fIsDelegate, callsiteFlags }; } UINT32 CLRToCOMLateBoundWorker( diff --git a/src/coreclr/vm/dbginterface.h b/src/coreclr/vm/dbginterface.h index daa57d25c86cf3..85b9785bccbb9b 100644 --- a/src/coreclr/vm/dbginterface.h +++ b/src/coreclr/vm/dbginterface.h @@ -203,7 +203,7 @@ class DebugInterface // Get debugger variable information for a specific version of a method virtual void GetVarInfo(MethodDesc * fd, // [IN] method of interest - void *DebuggerVersionToken, // [IN] which edit version + CORDB_ADDRESS nativeCodeAddress, // [IN] which edit version SIZE_T * cVars, // [OUT] size of 'vars' const ICorDebugInfo::NativeVarInfo **vars // [OUT] map telling where local vars are stored ) = 0; @@ -262,11 +262,6 @@ class DebugInterface virtual bool IsJMCMethod(Module* pModule, mdMethodDef tkMethod) = 0; - // Given a method, get's its EnC version number. 1 if the method is not EnCed. - // Note that MethodDescs are reused between versions so this will give us - // the most recent EnC number. - virtual int GetMethodEncNumber(MethodDesc * pMethod) = 0; - virtual void SendLogSwitchSetting (int iLevel, int iReason, _In_z_ LPCWSTR pLogSwitchName, diff --git a/src/coreclr/vm/eedbginterfaceimpl.cpp b/src/coreclr/vm/eedbginterfaceimpl.cpp index 792c608918a61d..352a534d5c1a88 100644 --- a/src/coreclr/vm/eedbginterfaceimpl.cpp +++ b/src/coreclr/vm/eedbginterfaceimpl.cpp @@ -630,7 +630,6 @@ PCODE EEDbgInterfaceImpl::GetFunctionAddress(MethodDesc *pFD) SUPPORTS_DAC; } CONTRACTL_END; - return pFD->GetNativeCode(); } diff --git a/src/coreclr/vm/encee.cpp b/src/coreclr/vm/encee.cpp index 1dcfb8bf091f4c..3339462ad7fe77 100644 --- a/src/coreclr/vm/encee.cpp +++ b/src/coreclr/vm/encee.cpp @@ -806,8 +806,8 @@ NOINLINE void EditAndContinueModule::FixContextAndResume( // Get the var info which the codemanager will use for updating // enregistered variables correctly, or variables whose lifetimes differ // at the update point - g_pDebugInterface->GetVarInfo(pMD, oldDebuggerFuncHandle, &oldVarInfoCount, &pOldVarInfo); - g_pDebugInterface->GetVarInfo(pMD, NULL, &newVarInfoCount, &pNewVarInfo); + g_pDebugInterface->GetVarInfo(pMD, oldCodeInfo.GetCodeAddress(), &oldVarInfoCount, &pOldVarInfo); + g_pDebugInterface->GetVarInfo(pMD, newCodeInfo.GetCodeAddress(), &newVarInfoCount, &pNewVarInfo); #ifdef TARGET_X86 // save the frame pointer as FixContextForEnC might step on it. diff --git a/src/coreclr/vm/ilmarshalers.h b/src/coreclr/vm/ilmarshalers.h index f3c9f31628f156..61ff10ac2b2b86 100644 --- a/src/coreclr/vm/ilmarshalers.h +++ b/src/coreclr/vm/ilmarshalers.h @@ -3138,39 +3138,13 @@ class ILMngdMarshaler : public ILMarshaler void EmitClearNative(ILCodeStream* pslILEmit) override { WRAPPER_NO_CONTRACT; - ILCodeLabel* pNoManagedValueLabel = nullptr; - if (IsFieldMarshal(m_dwMarshalFlags)) - { - pNoManagedValueLabel = pslILEmit->NewCodeLabel(); - pslILEmit->EmitLDARG(StructMarshalStubs::MANAGED_STRUCT_ARGIDX); - pslILEmit->EmitBRFALSE(pNoManagedValueLabel); - } - EmitCallMngdMarshalerMethod(pslILEmit, GetClearNativeMethod()); - - if (IsFieldMarshal(m_dwMarshalFlags)) - { - pslILEmit->EmitLabel(pNoManagedValueLabel); - } } void EmitClearNativeContents(ILCodeStream* pslILEmit) override { WRAPPER_NO_CONTRACT; - ILCodeLabel* pNoManagedValueLabel = nullptr; - if (IsFieldMarshal(m_dwMarshalFlags)) - { - pNoManagedValueLabel = pslILEmit->NewCodeLabel(); - pslILEmit->EmitLDARG(StructMarshalStubs::MANAGED_STRUCT_ARGIDX); - pslILEmit->EmitBRFALSE(pNoManagedValueLabel); - } - EmitCallMngdMarshalerMethod(pslILEmit, GetClearNativeContentsMethod()); - - if (IsFieldMarshal(m_dwMarshalFlags)) - { - pslILEmit->EmitLabel(pNoManagedValueLabel); - } } bool NeedsClearCLR() override diff --git a/src/coreclr/vm/method.cpp b/src/coreclr/vm/method.cpp index 62b24e3dc091c6..29910d6cb4c1c3 100644 --- a/src/coreclr/vm/method.cpp +++ b/src/coreclr/vm/method.cpp @@ -913,7 +913,6 @@ PCODE MethodDesc::GetNativeCode() WRAPPER_NO_CONTRACT; SUPPORTS_DAC; _ASSERTE(!IsDefaultInterfaceMethod() || HasNativeCodeSlot()); - if (HasNativeCodeSlot()) { // When profiler is enabled, profiler may ask to rejit a code even though we @@ -935,7 +934,7 @@ PCODE MethodDesc::GetNativeCode() return GetStableEntryPoint(); } -PCODE MethodDesc::GetNativeCodeReJITAware() +PCODE MethodDesc::GetNativeCodeAnyVersion() { WRAPPER_NO_CONTRACT; SUPPORTS_DAC; @@ -946,19 +945,23 @@ PCODE MethodDesc::GetNativeCodeReJITAware() return pDefaultCode; } + else { CodeVersionManager *pCodeVersionManager = GetCodeVersionManager(); CodeVersionManager::LockHolder codeVersioningLockHolder; - ILCodeVersion ilVersion = pCodeVersionManager->GetActiveILCodeVersion(PTR_MethodDesc(this)); - if (!ilVersion.IsDefaultVersion()) + ILCodeVersionCollection ilVersionCollection = pCodeVersionManager->GetILCodeVersions(PTR_MethodDesc(this)); + for (ILCodeVersionIterator curIL = ilVersionCollection.Begin(), endIL = ilVersionCollection.End(); curIL != endIL; curIL++) { - NativeCodeVersion activeNativeCodeVersion = ilVersion.GetActiveNativeCodeVersion(PTR_MethodDesc(this)); - if (!activeNativeCodeVersion.IsNull()) + NativeCodeVersionCollection nativeCollection = curIL->GetNativeCodeVersions(PTR_MethodDesc(this)); + for (NativeCodeVersionIterator curNative = nativeCollection.Begin(), endNative = nativeCollection.End(); curNative != endNative; curNative++) { - return activeNativeCodeVersion.GetNativeCode(); + PCODE native = curNative->GetNativeCode(); + if(native != NULL) + { + return native; + } } } - return NULL; } } diff --git a/src/coreclr/vm/method.hpp b/src/coreclr/vm/method.hpp index e51d9f7453d35e..12b3c86e0f7414 100644 --- a/src/coreclr/vm/method.hpp +++ b/src/coreclr/vm/method.hpp @@ -1373,11 +1373,11 @@ class MethodDesc } // Perf warning: takes the CodeVersionManagerLock on every call - BOOL HasNativeCodeReJITAware() + BOOL HasNativeCodeAnyVersion() { LIMITED_METHOD_DAC_CONTRACT; - return GetNativeCodeReJITAware() != NULL; + return GetNativeCodeAnyVersion() != NULL; } BOOL SetNativeCodeInterlocked(PCODE addr, PCODE pExpected = NULL); @@ -1437,9 +1437,9 @@ class MethodDesc PCODE GetNativeCode(); // Returns GetNativeCode() if it exists, but also checks to see if there - // is a non-default IL code version and returns that. + // is a non-default code version that is populated with a code body and returns that. // Perf warning: takes the CodeVersionManagerLock on every call - PCODE GetNativeCodeReJITAware(); + PCODE GetNativeCodeAnyVersion(); #if defined(FEATURE_JIT_PITCHING) bool IsPitchable(); diff --git a/src/coreclr/vm/perfinfo.cpp b/src/coreclr/vm/perfinfo.cpp index 0be2e519936fbe..98fc667661a504 100644 --- a/src/coreclr/vm/perfinfo.cpp +++ b/src/coreclr/vm/perfinfo.cpp @@ -32,8 +32,8 @@ void PerfInfo::LogImage(PEAssembly* pPEAssembly, CHAR* guid) PRECONDITION(guid != nullptr); } CONTRACTL_END; - SString value; - const SString& path = pPEAssembly->GetPath(); + // Nothing to log if the assembly path isn't present. + SString path{ pPEAssembly->GetPath() }; if (path.IsEmpty()) { return; @@ -49,12 +49,11 @@ void PerfInfo::LogImage(PEAssembly* pPEAssembly, CHAR* guid) } } + SString value; value.Printf("%s%c%s%c%p", path.GetUTF8(), sDelimiter, guid, sDelimiter, baseAddr); - SString command; - command.Printf("%s", "ImageLoad"); + SString command{ SString::Literal, "ImageLoad" }; WriteLine(command, value); - } // Writes a command line, with "type" being the type of command, with "value" as the command's corresponding instructions/values. This is to be used to log specific information, e.g. LogImage diff --git a/src/installer/managed/Microsoft.NET.HostModel/AppHost/PEUtils.cs b/src/installer/managed/Microsoft.NET.HostModel/AppHost/PEUtils.cs index 0d0b33ed55569e..1bfc80fcfca492 100644 --- a/src/installer/managed/Microsoft.NET.HostModel/AppHost/PEUtils.cs +++ b/src/installer/managed/Microsoft.NET.HostModel/AppHost/PEUtils.cs @@ -1,8 +1,11 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; +using System.Buffers.Binary; using System.IO; using System.IO.MemoryMappedFiles; +using System.Reflection.PortableExecutable; namespace Microsoft.NET.HostModel.AppHost { @@ -15,29 +18,13 @@ public static class PEUtils /// true if the accessor represents a PE image, false otherwise. internal static unsafe bool IsPEImage(MemoryMappedViewAccessor accessor) { - byte* pointer = null; + if (accessor.Capacity < PEOffsets.DosStub.PESignatureOffset + sizeof(uint)) + return false; - try - { - accessor.SafeMemoryMappedViewHandle.AcquirePointer(ref pointer); - byte* bytes = pointer + accessor.PointerOffset; - - // https://en.wikipedia.org/wiki/Portable_Executable - // Validate that we're looking at Windows PE file - if (((ushort*)bytes)[0] != PEOffsets.DosImageSignature - || accessor.Capacity < PEOffsets.DosStub.PESignatureOffset + sizeof(uint)) - { - return false; - } - return true; - } - finally - { - if (pointer != null) - { - accessor.SafeMemoryMappedViewHandle.ReleasePointer(); - } - } + // https://en.wikipedia.org/wiki/Portable_Executable + // Validate that we're looking at Windows PE file + ushort signature = AsLittleEndian(accessor.ReadUInt16(0)); + return signature == PEOffsets.DosImageSignature; } public static bool IsPEImage(string filePath) @@ -60,40 +47,15 @@ public static bool IsPEImage(string filePath) /// The memory accessor which has the apphost file opened. internal static unsafe void SetWindowsGraphicalUserInterfaceBit(MemoryMappedViewAccessor accessor) { - byte* pointer = null; - - try - { - accessor.SafeMemoryMappedViewHandle.AcquirePointer(ref pointer); - byte* bytes = pointer + accessor.PointerOffset; - - // https://en.wikipedia.org/wiki/Portable_Executable - uint peHeaderOffset = ((uint*)(bytes + PEOffsets.DosStub.PESignatureOffset))[0]; - - if (accessor.Capacity < peHeaderOffset + PEOffsets.PEHeader.Subsystem + sizeof(ushort)) - { - throw new AppHostNotPEFileException("Subsystem offset out of file range."); - } - - ushort* subsystem = ((ushort*)(bytes + peHeaderOffset + PEOffsets.PEHeader.Subsystem)); - - // https://docs.microsoft.com/en-us/windows/desktop/Debug/pe-format#windows-subsystem - // The subsystem of the prebuilt apphost should be set to CUI - if (subsystem[0] != (ushort)PEOffsets.Subsystem.WindowsCui) - { - throw new AppHostNotCUIException(subsystem[0]); - } - - // Set the subsystem to GUI - subsystem[0] = (ushort)PEOffsets.Subsystem.WindowsGui; - } - finally - { - if (pointer != null) - { - accessor.SafeMemoryMappedViewHandle.ReleasePointer(); - } - } + // https://learn.microsoft.com/windows/win32/debug/pe-format#windows-subsystem + // The subsystem of the prebuilt apphost should be set to CUI + uint peHeaderOffset; + ushort subsystem = GetWindowsSubsystem(accessor, out peHeaderOffset); + if (subsystem != (ushort)Subsystem.WindowsCui) + throw new AppHostNotCUIException(subsystem); + + // Set the subsystem to GUI + accessor.Write(peHeaderOffset + PEOffsets.PEHeader.Subsystem, AsLittleEndian((ushort)Subsystem.WindowsGui)); } public static unsafe void SetWindowsGraphicalUserInterfaceBit(string filePath) @@ -113,32 +75,7 @@ public static unsafe void SetWindowsGraphicalUserInterfaceBit(string filePath) /// The memory accessor which has the apphost file opened. internal static unsafe ushort GetWindowsGraphicalUserInterfaceBit(MemoryMappedViewAccessor accessor) { - byte* pointer = null; - - try - { - accessor.SafeMemoryMappedViewHandle.AcquirePointer(ref pointer); - byte* bytes = pointer + accessor.PointerOffset; - - // https://en.wikipedia.org/wiki/Portable_Executable - uint peHeaderOffset = ((uint*)(bytes + PEOffsets.DosStub.PESignatureOffset))[0]; - - if (accessor.Capacity < peHeaderOffset + PEOffsets.PEHeader.Subsystem + sizeof(ushort)) - { - throw new AppHostNotPEFileException("Subsystem offset out of file range."); - } - - ushort* subsystem = ((ushort*)(bytes + peHeaderOffset + PEOffsets.PEHeader.Subsystem)); - - return subsystem[0]; - } - finally - { - if (pointer != null) - { - accessor.SafeMemoryMappedViewHandle.ReleasePointer(); - } - } + return GetWindowsSubsystem(accessor, out _); } public static unsafe ushort GetWindowsGraphicalUserInterfaceBit(string filePath) @@ -151,5 +88,25 @@ public static unsafe ushort GetWindowsGraphicalUserInterfaceBit(string filePath) } } } + + private static ushort GetWindowsSubsystem(MemoryMappedViewAccessor accessor, out uint peHeaderOffset) + { + // https://en.wikipedia.org/wiki/Portable_Executable + if (accessor.Capacity < PEOffsets.DosStub.PESignatureOffset + sizeof(uint)) + throw new AppHostNotPEFileException("PESignature offset out of file range."); + + peHeaderOffset = AsLittleEndian(accessor.ReadUInt32(PEOffsets.DosStub.PESignatureOffset)); + if (accessor.Capacity < peHeaderOffset + PEOffsets.PEHeader.Subsystem + sizeof(ushort)) + throw new AppHostNotPEFileException("Subsystem offset out of file range."); + + // https://learn.microsoft.com/windows/win32/debug/pe-format#windows-subsystem + return AsLittleEndian(accessor.ReadUInt16(peHeaderOffset + PEOffsets.PEHeader.Subsystem)); + } + + private static ushort AsLittleEndian(ushort value) + => BitConverter.IsLittleEndian ? value : BinaryPrimitives.ReverseEndianness(value); + + private static uint AsLittleEndian(uint value) + => BitConverter.IsLittleEndian ? value : BinaryPrimitives.ReverseEndianness(value); } } diff --git a/src/installer/pkg/sfx/Microsoft.NETCore.App/Microsoft.NETCore.App.Runtime.sfxproj b/src/installer/pkg/sfx/Microsoft.NETCore.App/Microsoft.NETCore.App.Runtime.sfxproj index ba597425bcfc1d..d0e9b6e16ff1a1 100644 --- a/src/installer/pkg/sfx/Microsoft.NETCore.App/Microsoft.NETCore.App.Runtime.sfxproj +++ b/src/installer/pkg/sfx/Microsoft.NETCore.App/Microsoft.NETCore.App.Runtime.sfxproj @@ -7,11 +7,14 @@ dotnet-runtime-internal dotnet-runtime dotnet-runtime-internal - $(SharedFrameworkName).PGO true + false dotnet-runtime-symbols NetCore.SharedFramework true + + true diff --git a/src/installer/tests/Microsoft.NET.HostModel.Tests/Microsoft.NET.HostModel.AppHost.Tests/AppHostUpdateTests.cs b/src/installer/tests/Microsoft.NET.HostModel.Tests/Microsoft.NET.HostModel.AppHost.Tests/AppHostUpdateTests.cs index 232410be8f2596..b4d038b99b5ce3 100644 --- a/src/installer/tests/Microsoft.NET.HostModel.Tests/Microsoft.NET.HostModel.AppHost.Tests/AppHostUpdateTests.cs +++ b/src/installer/tests/Microsoft.NET.HostModel.Tests/Microsoft.NET.HostModel.AppHost.Tests/AppHostUpdateTests.cs @@ -11,6 +11,7 @@ using Microsoft.NET.HostModel.AppHost; using Microsoft.DotNet.CoreSetup.Test; using System.Diagnostics; +using System.Reflection.PortableExecutable; namespace Microsoft.NET.HostModel.Tests { @@ -111,7 +112,9 @@ public void ItCanSetWindowsGUISubsystem() BitConverter .ToUInt16(File.ReadAllBytes(destinationFilePath), SubsystemOffset) .Should() - .Be(2); + .Be((ushort)Subsystem.WindowsGui); + + Assert.Equal((ushort)Subsystem.WindowsGui, PEUtils.GetWindowsGraphicalUserInterfaceBit(destinationFilePath)); } } @@ -153,6 +156,7 @@ public void ItFailsToSetGUISubsystemWithWrongDefault() string destinationFilePath = Path.Combine(testDirectory.Path, "DestinationAppHost.exe.mock"); string appBinaryFilePath = "Test/App/Binary/Path.dll"; + Assert.Equal(42, PEUtils.GetWindowsGraphicalUserInterfaceBit(sourceAppHostMock)); Assert.Throws(() => HostWriter.CreateAppHost( sourceAppHostMock, diff --git a/src/libraries/Common/src/Interop/Interop.Locale.iOS.cs b/src/libraries/Common/src/Interop/Interop.Locale.iOS.cs index 0f7b1763d58509..6725436b4f0945 100644 --- a/src/libraries/Common/src/Interop/Interop.Locale.iOS.cs +++ b/src/libraries/Common/src/Interop/Interop.Locale.iOS.cs @@ -24,5 +24,8 @@ internal static partial class Globalization [LibraryImport(Libraries.GlobalizationNative, EntryPoint = "GlobalizationNative_GetLocaleTimeFormatNative", StringMarshalling = StringMarshalling.Utf8)] internal static partial string GetLocaleTimeFormatNative(string localeName, [MarshalAs(UnmanagedType.Bool)] bool shortFormat); + + [LibraryImport(Libraries.GlobalizationNative, EntryPoint = "GlobalizationNative_GetLocalesNative", StringMarshalling = StringMarshalling.Utf16)] + internal static partial int GetLocalesNative([Out] char[]? value, int valueLength); } } diff --git a/src/libraries/Common/src/SourceGenerators/DiagnosticInfo.cs b/src/libraries/Common/src/SourceGenerators/DiagnosticInfo.cs new file mode 100644 index 00000000000000..74f44f99c62baa --- /dev/null +++ b/src/libraries/Common/src/SourceGenerators/DiagnosticInfo.cs @@ -0,0 +1,60 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Linq; +using System.Numerics.Hashing; +using Microsoft.CodeAnalysis; + +namespace SourceGenerators; + +/// +/// Descriptor for diagnostic instances using structural equality comparison. +/// Provides a work-around for https://github.com/dotnet/roslyn/issues/68291. +/// +internal readonly struct DiagnosticInfo : IEquatable +{ + public DiagnosticDescriptor Descriptor { get; private init; } + public object?[] MessageArgs { get; private init; } + public Location? Location { get; private init; } + + public static DiagnosticInfo Create(DiagnosticDescriptor descriptor, Location? location, object?[]? messageArgs) + { + Location? trimmedLocation = location is null ? null : GetTrimmedLocation(location); + + return new DiagnosticInfo + { + Descriptor = descriptor, + Location = trimmedLocation, + MessageArgs = messageArgs ?? Array.Empty() + }; + + // Creates a copy of the Location instance that does not capture a reference to Compilation. + static Location GetTrimmedLocation(Location location) + => Location.Create(location.SourceTree?.FilePath ?? "", location.SourceSpan, location.GetLineSpan().Span); + } + + public Diagnostic CreateDiagnostic() + => Diagnostic.Create(Descriptor, Location, MessageArgs); + + public override readonly bool Equals(object? obj) => obj is DiagnosticInfo info && Equals(info); + + public readonly bool Equals(DiagnosticInfo other) + { + return Descriptor.Equals(other.Descriptor) && + MessageArgs.SequenceEqual(other.MessageArgs) && + Location == other.Location; + } + + public override readonly int GetHashCode() + { + int hashCode = Descriptor.GetHashCode(); + foreach (object? messageArg in MessageArgs) + { + hashCode = HashHelpers.Combine(hashCode, messageArg?.GetHashCode() ?? 0); + } + + hashCode = HashHelpers.Combine(hashCode, Location?.GetHashCode() ?? 0); + return hashCode; + } +} diff --git a/src/libraries/System.Text.Json/gen/Helpers/ImmutableEquatableArray.cs b/src/libraries/Common/src/SourceGenerators/ImmutableEquatableArray.cs similarity index 85% rename from src/libraries/System.Text.Json/gen/Helpers/ImmutableEquatableArray.cs rename to src/libraries/Common/src/SourceGenerators/ImmutableEquatableArray.cs index ac3aa804fdd9dc..47fdde1751882a 100644 --- a/src/libraries/System.Text.Json/gen/Helpers/ImmutableEquatableArray.cs +++ b/src/libraries/Common/src/SourceGenerators/ImmutableEquatableArray.cs @@ -1,12 +1,13 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Collections; using System.Collections.Generic; using System.Linq; using System.Numerics.Hashing; -namespace System.Text.Json.SourceGeneration +namespace SourceGenerators { /// /// Provides an immutable list implementation which implements sequence equality. @@ -72,15 +73,9 @@ public bool MoveNext() } } - public static class ImmutableEquatableArray + internal static class ImmutableEquatableArray { - public static ImmutableEquatableArray Empty() where T : IEquatable - => ImmutableEquatableArray.Empty; - public static ImmutableEquatableArray ToImmutableEquatableArray(this IEnumerable values) where T : IEquatable => new(values); - - public static ImmutableEquatableArray Create(params T[] values) where T : IEquatable - => values is { Length: > 0 } ? new(values) : ImmutableEquatableArray.Empty; } } diff --git a/src/libraries/Common/src/SourceGenerators/TypeModelHelper.cs b/src/libraries/Common/src/SourceGenerators/TypeModelHelper.cs index 73c19d61ca1225..7a3a3e98fd7fde 100644 --- a/src/libraries/Common/src/SourceGenerators/TypeModelHelper.cs +++ b/src/libraries/Common/src/SourceGenerators/TypeModelHelper.cs @@ -3,6 +3,8 @@ using Microsoft.CodeAnalysis; using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; namespace SourceGenerators { @@ -32,5 +34,7 @@ void TraverseContainingTypes(INamedTypeSymbol current) } } } + + public static string GetFullyQualifiedName(this ITypeSymbol type) => type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); } } diff --git a/src/libraries/System.Text.Json/gen/Model/TypeRef.cs b/src/libraries/Common/src/SourceGenerators/TypeRef.cs similarity index 96% rename from src/libraries/System.Text.Json/gen/Model/TypeRef.cs rename to src/libraries/Common/src/SourceGenerators/TypeRef.cs index 050aba0cda658c..cfbf33ed741366 100644 --- a/src/libraries/System.Text.Json/gen/Model/TypeRef.cs +++ b/src/libraries/Common/src/SourceGenerators/TypeRef.cs @@ -1,10 +1,11 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Diagnostics; using Microsoft.CodeAnalysis; -namespace System.Text.Json.SourceGeneration +namespace SourceGenerators { /// /// An equatable value representing type identity. diff --git a/src/libraries/Common/tests/SourceGenerators/GeneratorTestHelpers.cs b/src/libraries/Common/tests/SourceGenerators/GeneratorTestHelpers.cs new file mode 100644 index 00000000000000..d62a3c788e73dc --- /dev/null +++ b/src/libraries/Common/tests/SourceGenerators/GeneratorTestHelpers.cs @@ -0,0 +1,84 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using Xunit; + +namespace SourceGenerators.Tests +{ + internal static class GeneratorTestHelpers + { + /// + /// Asserts for structural equality, returning a path to the mismatching data when not equal. + /// + public static void AssertStructurallyEqual(T expected, T actual) + { + CheckAreEqualCore(expected, actual, new()); + static void CheckAreEqualCore(object expected, object actual, Stack path) + { + if (expected is null || actual is null) + { + if (expected is not null || actual is not null) + { + FailNotEqual(); + } + + return; + } + + Type type = expected.GetType(); + if (type != actual.GetType()) + { + FailNotEqual(); + return; + } + + if (expected is IEnumerable leftCollection) + { + if (actual is not IEnumerable rightCollection) + { + FailNotEqual(); + return; + } + + object?[] expectedValues = leftCollection.Cast().ToArray(); + object?[] actualValues = rightCollection.Cast().ToArray(); + + for (int i = 0; i < Math.Max(expectedValues.Length, actualValues.Length); i++) + { + object? expectedElement = i < expectedValues.Length ? expectedValues[i] : ""; + object? actualElement = i < actualValues.Length ? actualValues[i] : ""; + + path.Push($"[{i}]"); + CheckAreEqualCore(expectedElement, actualElement, path); + path.Pop(); + } + } + + if (type.GetProperty("EqualityContract", BindingFlags.Instance | BindingFlags.NonPublic, null, returnType: typeof(Type), types: Array.Empty(), null) != null) + { + // Type is a C# record, run pointwise equality comparison. + foreach (PropertyInfo property in type.GetProperties(BindingFlags.Public | BindingFlags.Instance)) + { + path.Push("." + property.Name); + CheckAreEqualCore(property.GetValue(expected), property.GetValue(actual), path); + path.Pop(); + } + + return; + } + + if (!expected.Equals(actual)) + { + FailNotEqual(); + } + + void FailNotEqual() => Assert.Fail($"Value not equal in ${string.Join("", path.Reverse())}: expected {expected}, but was {actual}."); + } + } + } +} diff --git a/src/libraries/Microsoft.Bcl.Cryptography/src/PACKAGE.md b/src/libraries/Microsoft.Bcl.Cryptography/src/PACKAGE.md new file mode 100644 index 00000000000000..215d29e162c4b6 --- /dev/null +++ b/src/libraries/Microsoft.Bcl.Cryptography/src/PACKAGE.md @@ -0,0 +1,39 @@ +## About + +This library provides some cryptographic types and functionality for .NET Standard and .NET Framework. This library is not necessary nor recommended when targeting versions of .NET that include the relevant support. + +## Key Features + +* Enables the use of some cryptographic functionality on older .NET platforms. + +## How to Use + +This package should only be used by platforms where the desired functionality is not built-in. + +```C# +using System.Security.Cryptography; + +internal static class Program +{ + private static void Main() + { + byte[] key = LoadKey(); + SP800108HmacCounterKdf kbkdf = new(key, HashAlgorithmName.SHA256); + byte[] derivedKey = kbkdf.DeriveKey("label"u8, "context"u8, derivedKeyLengthInBytes: 32); + } +} +``` + +## Main Types + +The main types provided by this library are: + +* `System.Security.Cryptography.SP800108HmacCounterKdf` + +## Additional Documentation + +* [API documentation](https://learn.microsoft.com/dotnet/api/System.Security.Cryptography) + +## Feedback & Contributing + +Microsoft.Bcl.Cryptography is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/Microsoft.Extensions.Caching.Abstractions/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Caching.Abstractions/src/PACKAGE.md new file mode 100644 index 00000000000000..eb8a9beacbc44b --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Caching.Abstractions/src/PACKAGE.md @@ -0,0 +1,53 @@ +## About + + + +Provides the abstractions to create and use in-memory and distributed caching in your applications. + +This library defines how in-memory and distributed caches should be implemented; it doesn’t contain any cache implementation. +With the abstractions provided in this library, various types of caches can be built and used interchangeably, whether the data is kept in memory, in files, or even across a network. + +## Key Features + + + +* Interfaces for building and using in-memory and distributed caches. + +## How to Use + + + +This package is typically used with an implementation of the caching abstractions, such as `Microsoft.Extensions.Caching.Memory` or `Microsoft.Extensions.Caching.SqlServer`. + +## Main Types + + + +The main types provided by this library are: + +* `Microsoft.Extensions.Caching.Abstractions.ICacheEntry` +* `Microsoft.Extensions.Caching.Abstractions.IMemoryCache` +* `Microsoft.Extensions.Caching.Abstractions.IDistributedCache` + +## Additional Documentation + + + +* [Conceptual documentation](https://learn.microsoft.com/dotnet/core/extensions/caching) +* API documentation + * [Microsoft.Extensions.Caching.Memory](https://learn.microsoft.com/dotnet/api/microsoft.extensions.caching.memory) + * [Microsoft.Extensions.Caching.Distributed](https://learn.microsoft.com/dotnet/api/microsoft.extensions.caching.distributed) + +## Related Packages + + + +* In-memory caching: [Microsoft.Extensions.Caching.Memory](https://www.nuget.org/packages/Microsoft.Extensions.Caching.Memory/) +* SQL Server caching: [Microsoft.Extensions.Caching.SqlServer](https://www.nuget.org/packages/Microsoft.Extensions.Caching.SqlServer/) +* Redis caching: [Microsoft.Extensions.Caching.StackExchangeRedis](https://www.nuget.org/packages/Microsoft.Extensions.Caching.StackExchangeRedis/) + +## Feedback & Contributing + + + +Microsoft.Extensions.Caching.Abstractions is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.Emitter.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.Emitter.cs index 7206d549041147..1721a124dead95 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.Emitter.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.Emitter.cs @@ -1,7 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Collections.Immutable; using Microsoft.CodeAnalysis; using SourceGenerators; @@ -11,19 +10,22 @@ public sealed partial class ConfigurationBindingGenerator : IIncrementalGenerato { private sealed partial class Emitter { - private readonly SourceProductionContext _context; - private readonly SourceGenerationSpec _sourceGenSpec; + private readonly InterceptorInfo _interceptorInfo; + private readonly BindingHelperInfo _bindingHelperInfo; + private readonly TypeIndex _typeIndex; + private readonly SourceWriter _writer = new(); - public Emitter(SourceProductionContext context, SourceGenerationSpec sourceGenSpec) + public Emitter(SourceGenerationSpec sourceGenSpec) { - _context = context; - _sourceGenSpec = sourceGenSpec; + _interceptorInfo = sourceGenSpec.InterceptorInfo; + _bindingHelperInfo = sourceGenSpec.BindingHelperInfo; + _typeIndex = new TypeIndex(sourceGenSpec.ConfigTypes); } - public void Emit() + public void Emit(SourceProductionContext context) { - if (!ShouldEmitBindingExtensions()) + if (!ShouldEmitMethods(MethodsToGen.Any)) { return; } @@ -52,7 +54,7 @@ file static class {{Identifier.BindingExtensions}} EmitEndBlock(); // Binding namespace. - _context.AddSource($"{Identifier.BindingExtensions}.g.cs", _writer.ToSourceText()); + context.AddSource($"{Identifier.BindingExtensions}.g.cs", _writer.ToSourceText()); } private void EmitInterceptsLocationAttrDecl() @@ -79,7 +81,7 @@ public InterceptsLocationAttribute(string filePath, int line, int column) private void EmitUsingStatements() { - foreach (string @namespace in _sourceGenSpec.Namespaces.ToImmutableSortedSet()) + foreach (string @namespace in _bindingHelperInfo.Namespaces) { _writer.WriteLine($"using {@namespace};"); } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.Parser.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.Parser.cs index 2a6f5d2126e8c8..d01c5dbae13f3c 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.Parser.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.Parser.cs @@ -7,45 +7,73 @@ using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Linq; +using System.Threading; using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.Operations; using SourceGenerators; namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { public sealed partial class ConfigurationBindingGenerator : IIncrementalGenerator { - private sealed partial class Parser + internal sealed partial class Parser(CompilationData compilationData) { - private record struct InvocationDiagnosticInfo(DiagnosticDescriptor Descriptor, object[]? MessageArgs); + private readonly KnownTypeSymbols _typeSymbols = compilationData.TypeSymbols!; + private readonly bool _langVersionIsSupported = compilationData.LanguageVersionIsSupported; - private readonly SourceProductionContext _context; - private readonly SourceGenerationSpec _sourceGenSpec = new(); - private readonly KnownTypeSymbols _typeSymbols; - private readonly ImmutableArray _invocations; + private readonly List _invocationTypeParseInfo = new(); + private readonly Queue _typesToParse = new(); + private readonly Dictionary _createdTypeSpecs = new(SymbolEqualityComparer.Default); - private readonly Dictionary _createdSpecs = new(SymbolEqualityComparer.Default); - private readonly HashSet _unsupportedTypes = new(SymbolEqualityComparer.Default); + private readonly InterceptorInfo.Builder _interceptorInfoBuilder = new(); + private BindingHelperInfo.Builder? _helperInfoBuilder; // Init'ed with type index when registering interceptors, after creating type specs. - private readonly List _invocationTargetTypeDiags = new(); - private readonly Dictionary> _typeDiagnostics = new(SymbolEqualityComparer.Default); + public List? Diagnostics { get; private set; } - public Parser(SourceProductionContext context, KnownTypeSymbols typeSymbols, ImmutableArray invocations) + public SourceGenerationSpec? GetSourceGenerationSpec(ImmutableArray invocations, CancellationToken cancellationToken) { - _context = context; - _typeSymbols = typeSymbols; - _invocations = invocations; - } + if (!_langVersionIsSupported) + { + RecordDiagnostic(DiagnosticDescriptors.LanguageVersionNotSupported, trimmedLocation: Location.None); + return null; + } - public SourceGenerationSpec? GetSourceGenerationSpec() - { if (_typeSymbols is not { IConfiguration: { }, ConfigurationBinder: { } }) { return null; } - foreach (BinderInvocation invocation in _invocations) + ParseInvocations(invocations); + CreateTypeSpecs(cancellationToken); + RegisterInterceptors(); + + return new SourceGenerationSpec { + InterceptorInfo = _interceptorInfoBuilder.ToIncrementalValue(), + BindingHelperInfo = _helperInfoBuilder!.ToIncrementalValue(), + ConfigTypes = _createdTypeSpecs.Values.OrderBy(s => s.TypeRef.FullyQualifiedName).ToImmutableEquatableArray(), + }; + } + + private bool IsValidRootConfigType([NotNullWhen(true)] ITypeSymbol? type) + { + if (type is null || + type.SpecialType is SpecialType.System_Object or SpecialType.System_Void || + !_typeSymbols.Compilation.IsSymbolAccessibleWithin(type, _typeSymbols.Compilation.Assembly) || + type.TypeKind is TypeKind.TypeParameter or TypeKind.Pointer or TypeKind.Error || + type.IsRefLikeType || + ContainsGenericParameters(type)) + { + return false; + } + + return true; + } + + private void ParseInvocations(ImmutableArray invocations) + { + foreach (BinderInvocation? invocation in invocations) + { + Debug.Assert(invocation is not null); IMethodSymbol targetMethod = invocation.Operation.TargetMethod; INamedTypeSymbol? candidateBinderType = targetMethod.ContainingType; Debug.Assert(targetMethod.IsExtensionMethod); @@ -63,174 +91,124 @@ public Parser(SourceProductionContext context, KnownTypeSymbols typeSymbols, Imm ParseInvocation_ServiceCollectionExt(invocation); } } - - return _sourceGenSpec; } - private bool IsValidRootConfigType(ITypeSymbol? type) + private void CreateTypeSpecs(CancellationToken cancellationToken) { - if (type is null || - type.SpecialType is SpecialType.System_Object or SpecialType.System_Void || - !_typeSymbols.Compilation.IsSymbolAccessibleWithin(type, _typeSymbols.Compilation.Assembly) || - type.TypeKind is TypeKind.TypeParameter or TypeKind.Pointer or TypeKind.Error || - type.IsRefLikeType || - ContainsGenericParameters(type)) + while (_typesToParse.Count > 0) { - return false; - } + cancellationToken.ThrowIfCancellationRequested(); - return true; + TypeParseInfo typeParseInfo = _typesToParse.Dequeue(); + ITypeSymbol typeSymbol = typeParseInfo.TypeSymbol; + + if (!_createdTypeSpecs.ContainsKey(typeSymbol)) + { + _createdTypeSpecs.Add(typeSymbol, CreateTypeSpec(typeParseInfo)); + } + } } - private TypeSpec? GetTargetTypeForRootInvocation(ITypeSymbol? type, Location? invocationLocation) + private void RegisterInterceptors() { - if (!IsValidRootConfigType(type)) + TypeIndex typeIndex = new(_createdTypeSpecs.Values); + _helperInfoBuilder = new(typeIndex); + + foreach (TypeParseInfo typeParseInfo in _invocationTypeParseInfo) { - _context.ReportDiagnostic(Diagnostic.Create(Diagnostics.CouldNotDetermineTypeInfo, invocationLocation)); - return null; + TypeSpec typeSpec = _createdTypeSpecs[typeParseInfo.TypeSymbol]; + MethodsToGen overload = typeParseInfo.BindingOverload; + + if ((MethodsToGen.ConfigBinder_Any & overload) is not 0) + { + RegisterInterceptor_ConfigurationBinder(typeParseInfo, typeSpec); + } + else if ((MethodsToGen.OptionsBuilderExt_Any & overload) is not 0) + { + RegisterInterceptor_OptionsBuilderExt(typeParseInfo, typeSpec); + } + else + { + Debug.Assert((MethodsToGen.ServiceCollectionExt_Any & overload) is not 0); + RegisterInterceptor_ServiceCollectionExt(typeParseInfo, typeSpec); + } } + } - return GetTargetTypeForRootInvocationCore(type, invocationLocation); + private void EnqueueTargetTypeForRootInvocation(ITypeSymbol? typeSymbol, MethodsToGen overload, BinderInvocation invocation) + { + if (!IsValidRootConfigType(typeSymbol)) + { + RecordDiagnostic(DiagnosticDescriptors.CouldNotDetermineTypeInfo, invocation.Location); + } + else + { + TypeParseInfo typeParseInfo = TypeParseInfo.Create(typeSymbol, overload, invocation, containingTypeDiagInfo: null); + _typesToParse.Enqueue(typeParseInfo); + _invocationTypeParseInfo.Add(typeParseInfo); + } } - public TypeSpec? GetTargetTypeForRootInvocationCore(ITypeSymbol type, Location? invocationLocation) + private TypeRef EnqueueTransitiveType(TypeParseInfo containingTypeParseInfo, ITypeSymbol memberTypeSymbol, DiagnosticDescriptor diagDescriptor, string? memberName = null) { - TypeSpec? spec = GetOrCreateTypeSpec(type); + TypeParseInfo memberTypeParseInfo = containingTypeParseInfo.ToTransitiveTypeParseInfo(memberTypeSymbol, diagDescriptor, memberName); - foreach (InvocationDiagnosticInfo diag in _invocationTargetTypeDiags) + if (_createdTypeSpecs.TryGetValue(memberTypeSymbol, out TypeSpec? memberTypeSpec)) { - _context.ReportDiagnostic(Diagnostic.Create(diag.Descriptor, invocationLocation, diag.MessageArgs)); + RecordTypeDiagnosticIfRequired(memberTypeParseInfo, memberTypeSpec); + return memberTypeSpec.TypeRef; } - _invocationTargetTypeDiags.Clear(); - return spec; + _typesToParse.Enqueue(memberTypeParseInfo); + return new TypeRef(memberTypeSymbol); } - private TypeSpec? GetOrCreateTypeSpec(ITypeSymbol type) + private TypeSpec CreateTypeSpec(TypeParseInfo typeParseInfo) { - if (_createdSpecs.TryGetValue(type, out TypeSpec? spec)) - { - if (_typeDiagnostics.TryGetValue(type, out HashSet? typeDiags)) - { - _invocationTargetTypeDiags.AddRange(typeDiags); - } - - return spec; - } + ITypeSymbol type = typeParseInfo.TypeSymbol; + TypeSpec spec; if (IsNullable(type, out ITypeSymbol? underlyingType)) { - spec = MemberTypeIsBindable(type, underlyingType, Diagnostics.NullableUnderlyingTypeNotSupported, out TypeSpec? underlyingTypeSpec) - ? new NullableSpec(type, underlyingTypeSpec) - : null; + TypeRef underlyingTypeRef = EnqueueTransitiveType( + typeParseInfo, + underlyingType, + DiagnosticDescriptors.NullableUnderlyingTypeNotSupported); + + spec = new NullableSpec(type, underlyingTypeRef); } else if (IsParsableFromString(type, out StringParsableTypeKind specialTypeKind)) { ParsableFromStringSpec stringParsableSpec = new(type) { StringParsableTypeKind = specialTypeKind }; - - if (stringParsableSpec.StringParsableTypeKind is not StringParsableTypeKind.AssignFromSectionValue) - { - _sourceGenSpec.PrimitivesForHelperGen.Add(stringParsableSpec); - } - spec = stringParsableSpec; } - else if (IsSupportedArrayType(type)) + else if (type.TypeKind is TypeKind.Array) { - spec = CreateArraySpec((type as IArrayTypeSymbol)); + spec = CreateArraySpec(typeParseInfo); + Debug.Assert(spec is ArraySpec or UnsupportedTypeSpec); } else if (IsCollection(type)) { - spec = CreateCollectionSpec((INamedTypeSymbol)type); + spec = CreateCollectionSpec(typeParseInfo); } else if (SymbolEqualityComparer.Default.Equals(type, _typeSymbols.IConfigurationSection)) { spec = new ConfigurationSectionSpec(type); } - else if (type is INamedTypeSymbol namedType) + else if (type is INamedTypeSymbol) { - // List is used in generated code as a temp holder for formatting - // an error for config properties that don't map to object properties. - _sourceGenSpec.Namespaces.Add("System.Collections.Generic"); - - spec = CreateObjectSpec(namedType); + spec = CreateObjectSpec(typeParseInfo); } else { - RegisterUnsupportedType(type, Diagnostics.TypeNotSupported); + spec = CreateUnsupportedTypeSpec(typeParseInfo, NotSupportedReason.UnknownType); } - foreach (InvocationDiagnosticInfo diag in _invocationTargetTypeDiags) - { - RegisterTypeDiagnostic(type, diag); - } - - if (spec is { Namespace: string @namespace } && @namespace is not "") - { - _sourceGenSpec.Namespaces.Add(@namespace); - } - - return _createdSpecs[type] = spec; - } - - private bool TryRegisterTypeForBindCoreMainGen(ComplexTypeSpec type) - { - if (type.HasBindableMembers) - { - bool registeredForBindCoreGen = TryRegisterTypeForBindCoreGen(type); - Debug.Assert(registeredForBindCoreGen); - - RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper.BindCoreMain, type); - Register_AsConfigWithChildren_HelperForGen_IfRequired(type); - return true; - } - - return false; - } + RecordTypeDiagnosticIfRequired(typeParseInfo, spec); - private bool TryRegisterTypeForBindCoreGen(ComplexTypeSpec type) - { - if (type.HasBindableMembers) - { - RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper.BindCore, type); - return true; - } - - return false; - } - - private void RegisterTypeForGetCoreGen(TypeSpec typeSpec) - { - RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper.GetCore, typeSpec); - Register_AsConfigWithChildren_HelperForGen_IfRequired(typeSpec); - } - - private void RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper method, TypeSpec type) - { - if (!_sourceGenSpec.TypesForGen_CoreBindingHelper_Methods.TryGetValue(method, out HashSet? types)) - { - _sourceGenSpec.TypesForGen_CoreBindingHelper_Methods[method] = types = new HashSet(); - } - - types.Add(type); - _sourceGenSpec.MethodsToGen_CoreBindingHelper |= method; - } - - private void Register_AsConfigWithChildren_HelperForGen_IfRequired(TypeSpec possibleComplexType) - { - if (possibleComplexType is ComplexTypeSpec) - { - _sourceGenSpec.MethodsToGen_CoreBindingHelper |= MethodsToGen_CoreBindingHelper.AsConfigWithChildren; - } + return spec; } - /// - /// Registers interceptors for root binding methods, except for ConfigurationBinder.Bind, - /// which is handled by - /// - private void RegisterInterceptor(Enum method, IInvocationOperation operation) => - _sourceGenSpec.InterceptionInfo.RegisterCacheEntry(method, new InterceptorLocationInfo(operation)); - private static bool IsNullable(ITypeSymbol type, [NotNullWhen(true)] out ITypeSymbol? underlyingType) { if (type is INamedTypeSymbol { IsGenericType: true } genericType && @@ -349,232 +327,197 @@ private bool IsParsableFromString(ITypeSymbol type, out StringParsableTypeKind t } } - private EnumerableSpec? CreateArraySpec(IArrayTypeSymbol arrayTypeSymbol) + private TypeSpec CreateArraySpec(TypeParseInfo typeParseInfo) { - ITypeSymbol elementTypeSymbol = arrayTypeSymbol.ElementType; + IArrayTypeSymbol typeSymbol = (IArrayTypeSymbol)typeParseInfo.TypeSymbol; - if (!MemberTypeIsBindable(arrayTypeSymbol, elementTypeSymbol, Diagnostics.ElementTypeNotSupported, out TypeSpec elementTypeSpec)) + if (typeSymbol.Rank > 1) { - return null; + return CreateUnsupportedTypeSpec(typeParseInfo, NotSupportedReason.MultiDimArraysNotSupported); } - // We want a BindCore method for List as a temp holder for the array values. - // Since the element type is supported, we can certainly a list of elements. - EnumerableSpec listTypeSpec = (EnumerableSpec)GetOrCreateTypeSpec(_typeSymbols.List.Construct(elementTypeSymbol)); + TypeRef elementTypeRef = EnqueueTransitiveType( + typeParseInfo, + typeSymbol.ElementType, + DiagnosticDescriptors.ElementTypeNotSupported); - EnumerableSpec spec = new EnumerableSpec(arrayTypeSymbol) + return new ArraySpec(typeSymbol) { - ElementType = elementTypeSpec, - InstantiationStrategy = InstantiationStrategy.Array, - PopulationStrategy = CollectionPopulationStrategy.Cast_Then_Add, // Using the concrete list type as a temp holder. - TypeToInstantiate = listTypeSpec, - PopulationCastType = null, + ElementTypeRef = elementTypeRef, }; - - bool registeredForBindCore = TryRegisterTypeForBindCoreGen(listTypeSpec) && TryRegisterTypeForBindCoreGen(spec); - Debug.Assert(registeredForBindCore); - return spec; } - private CollectionSpec? CreateCollectionSpec(INamedTypeSymbol type) + private TypeSpec CreateCollectionSpec(TypeParseInfo typeParseInfo) { - CollectionSpec? spec; - if (IsCandidateDictionary(type, out ITypeSymbol keyType, out ITypeSymbol elementType)) + INamedTypeSymbol type = (INamedTypeSymbol)typeParseInfo.TypeSymbol; + + TypeSpec spec; + if (IsCandidateDictionary(type, out ITypeSymbol? keyType, out ITypeSymbol? elementType)) { - spec = CreateDictionarySpec(type, keyType, elementType); - Debug.Assert(spec is null or DictionarySpec { KeyType: null or ParsableFromStringSpec }); + spec = CreateDictionarySpec(typeParseInfo, keyType, elementType); + Debug.Assert(spec is DictionarySpec or UnsupportedTypeSpec); } else { - spec = CreateEnumerableSpec(type); + spec = CreateEnumerableSpec(typeParseInfo); + Debug.Assert(spec is EnumerableSpec or UnsupportedTypeSpec); } - if (spec is null) - { - return null; - } - - bool registerForBindCoreGen = TryRegisterTypeForBindCoreGen(spec); - Debug.Assert(registerForBindCoreGen); return spec; } - private DictionarySpec CreateDictionarySpec(INamedTypeSymbol type, ITypeSymbol keyType, ITypeSymbol elementType) + private TypeSpec CreateDictionarySpec(TypeParseInfo typeParseInfo, ITypeSymbol keyTypeSymbol, ITypeSymbol elementTypeSymbol) { - if (!MemberTypeIsBindable(type, keyType, Diagnostics.DictionaryKeyNotSupported, out TypeSpec keySpec) || - !MemberTypeIsBindable(type, elementType, Diagnostics.ElementTypeNotSupported, out TypeSpec elementSpec)) - { - return null; - } + INamedTypeSymbol type = (INamedTypeSymbol)typeParseInfo.TypeSymbol; - if (keySpec.SpecKind is not TypeSpecKind.ParsableFromString) - { - RegisterUnsupportedType(type, Diagnostics.DictionaryKeyNotSupported); - return null; - } - - InstantiationStrategy constructionStrategy; - CollectionPopulationStrategy populationStrategy; - INamedTypeSymbol? typeToInstantiate = null; - INamedTypeSymbol? populationCastType = null; + CollectionInstantiationStrategy instantiationStrategy; + CollectionInstantiationConcreteType instantiationConcreteType; + CollectionPopulationCastType populationCastType; if (HasPublicParameterLessCtor(type)) { - constructionStrategy = InstantiationStrategy.ParameterlessConstructor; + instantiationStrategy = CollectionInstantiationStrategy.ParameterlessConstructor; + instantiationConcreteType = CollectionInstantiationConcreteType.Self; - if (HasAddMethod(type, keyType, elementType)) + if (HasAddMethod(type, keyTypeSymbol, elementTypeSymbol)) { - populationStrategy = CollectionPopulationStrategy.Add; + populationCastType = CollectionPopulationCastType.NotApplicable; } - else if (GetInterface(type, _typeSymbols.GenericIDictionary_Unbound) is not null) + else if (_typeSymbols.GenericIDictionary is not null && GetInterface(type, _typeSymbols.GenericIDictionary_Unbound) is not null) { - populationCastType = _typeSymbols.GenericIDictionary; - populationStrategy = CollectionPopulationStrategy.Cast_Then_Add; + populationCastType = CollectionPopulationCastType.IDictionary; } else { - RegisterUnsupportedType(type, Diagnostics.CollectionNotSupported); - return null; + return CreateUnsupportedCollectionSpec(typeParseInfo); } } - else if (IsInterfaceMatch(type, _typeSymbols.GenericIDictionary_Unbound) || IsInterfaceMatch(type, _typeSymbols.IDictionary)) + else if (_typeSymbols.Dictionary is not null && + (IsInterfaceMatch(type, _typeSymbols.GenericIDictionary_Unbound) || IsInterfaceMatch(type, _typeSymbols.IDictionary))) { - typeToInstantiate = _typeSymbols.Dictionary; - constructionStrategy = InstantiationStrategy.ParameterlessConstructor; - populationStrategy = CollectionPopulationStrategy.Add; + instantiationStrategy = CollectionInstantiationStrategy.ParameterlessConstructor; + instantiationConcreteType = CollectionInstantiationConcreteType.Dictionary; + populationCastType = CollectionPopulationCastType.NotApplicable; } - else if (IsInterfaceMatch(type, _typeSymbols.IReadOnlyDictionary_Unbound)) + else if (_typeSymbols.Dictionary is not null && IsInterfaceMatch(type, _typeSymbols.IReadOnlyDictionary_Unbound)) { - typeToInstantiate = _typeSymbols.Dictionary; - populationCastType = _typeSymbols.GenericIDictionary; - constructionStrategy = InstantiationStrategy.ToEnumerableMethod; - populationStrategy = CollectionPopulationStrategy.Cast_Then_Add; - _sourceGenSpec.Namespaces.Add("System.Linq"); + instantiationStrategy = CollectionInstantiationStrategy.LinqToDictionary; + instantiationConcreteType = CollectionInstantiationConcreteType.Dictionary; + populationCastType = CollectionPopulationCastType.IDictionary; } else { - RegisterUnsupportedType(type, Diagnostics.CollectionNotSupported); - return null; + return CreateUnsupportedCollectionSpec(typeParseInfo); } - Debug.Assert(!(populationStrategy is CollectionPopulationStrategy.Cast_Then_Add && populationCastType is null)); + TypeRef keyTypeRef = EnqueueTransitiveType(typeParseInfo, keyTypeSymbol, DiagnosticDescriptors.DictionaryKeyNotSupported); + TypeRef elementTypeRef = EnqueueTransitiveType(typeParseInfo, elementTypeSymbol, DiagnosticDescriptors.ElementTypeNotSupported); - DictionarySpec spec = new(type) + return new DictionarySpec(type) { - KeyType = (ParsableFromStringSpec)keySpec, - ElementType = elementSpec, - InstantiationStrategy = constructionStrategy, - PopulationStrategy = populationStrategy, - TypeToInstantiate = ConstructGenericCollectionSpecIfRequired(typeToInstantiate, keyType, elementType) as DictionarySpec, - PopulationCastType = ConstructGenericCollectionSpecIfRequired(populationCastType, keyType, elementType) as DictionarySpec, + KeyTypeRef = keyTypeRef, + ElementTypeRef = elementTypeRef, + InstantiationStrategy = instantiationStrategy, + InstantiationConcreteType = instantiationConcreteType, + PopulationCastType = populationCastType, }; - - return spec; } - private EnumerableSpec? CreateEnumerableSpec(INamedTypeSymbol type) + private TypeSpec CreateEnumerableSpec(TypeParseInfo typeParseInfo) { - if (!TryGetElementType(type, out ITypeSymbol? elementType) || - !MemberTypeIsBindable(type, elementType, Diagnostics.ElementTypeNotSupported, out TypeSpec elementSpec)) + INamedTypeSymbol type = (INamedTypeSymbol)typeParseInfo.TypeSymbol; + + if (!TryGetElementType(type, out ITypeSymbol? elementType)) { - return null; + return CreateUnsupportedCollectionSpec(typeParseInfo); } - InstantiationStrategy instantiationStrategy; - CollectionPopulationStrategy populationStrategy; - INamedTypeSymbol? typeToInstantiate = null; - INamedTypeSymbol? populationCastType = null; + CollectionInstantiationStrategy instantiationStrategy; + CollectionInstantiationConcreteType instantiationConcreteType; + CollectionPopulationCastType populationCastType; if (HasPublicParameterLessCtor(type)) { - instantiationStrategy = InstantiationStrategy.ParameterlessConstructor; + instantiationStrategy = CollectionInstantiationStrategy.ParameterlessConstructor; + instantiationConcreteType = CollectionInstantiationConcreteType.Self; if (HasAddMethod(type, elementType)) { - populationStrategy = CollectionPopulationStrategy.Add; + populationCastType = CollectionPopulationCastType.NotApplicable; } - else if (GetInterface(type, _typeSymbols.GenericICollection_Unbound) is not null) + else if (_typeSymbols.GenericICollection is not null && GetInterface(type, _typeSymbols.GenericICollection_Unbound) is not null) { - populationCastType = _typeSymbols.GenericICollection; - populationStrategy = CollectionPopulationStrategy.Cast_Then_Add; + populationCastType = CollectionPopulationCastType.ICollection; } else { - RegisterUnsupportedType(type, Diagnostics.CollectionNotSupported); - return null; + return CreateUnsupportedCollectionSpec(typeParseInfo); } } - else if (IsInterfaceMatch(type, _typeSymbols.GenericICollection_Unbound) || - IsInterfaceMatch(type, _typeSymbols.GenericIList_Unbound)) + else if ((IsInterfaceMatch(type, _typeSymbols.GenericICollection_Unbound) || IsInterfaceMatch(type, _typeSymbols.GenericIList_Unbound))) { - typeToInstantiate = _typeSymbols.List; - instantiationStrategy = InstantiationStrategy.ParameterlessConstructor; - populationStrategy = CollectionPopulationStrategy.Add; + instantiationStrategy = CollectionInstantiationStrategy.ParameterlessConstructor; + instantiationConcreteType = CollectionInstantiationConcreteType.List; + populationCastType = CollectionPopulationCastType.NotApplicable; } else if (IsInterfaceMatch(type, _typeSymbols.GenericIEnumerable_Unbound)) { - typeToInstantiate = _typeSymbols.List; - populationCastType = _typeSymbols.GenericICollection; - instantiationStrategy = InstantiationStrategy.ParameterizedConstructor; - populationStrategy = CollectionPopulationStrategy.Cast_Then_Add; + instantiationStrategy = CollectionInstantiationStrategy.CopyConstructor; + instantiationConcreteType = CollectionInstantiationConcreteType.List; + populationCastType = CollectionPopulationCastType.ICollection; } else if (IsInterfaceMatch(type, _typeSymbols.ISet_Unbound)) { - typeToInstantiate = _typeSymbols.HashSet; - instantiationStrategy = InstantiationStrategy.ParameterlessConstructor; - populationStrategy = CollectionPopulationStrategy.Add; + instantiationStrategy = CollectionInstantiationStrategy.ParameterlessConstructor; + instantiationConcreteType = CollectionInstantiationConcreteType.HashSet; + populationCastType = CollectionPopulationCastType.NotApplicable; } else if (IsInterfaceMatch(type, _typeSymbols.IReadOnlySet_Unbound)) { - typeToInstantiate = _typeSymbols.HashSet; - populationCastType = _typeSymbols.ISet; - instantiationStrategy = InstantiationStrategy.ParameterizedConstructor; - populationStrategy = CollectionPopulationStrategy.Cast_Then_Add; + instantiationStrategy = CollectionInstantiationStrategy.CopyConstructor; + instantiationConcreteType = CollectionInstantiationConcreteType.HashSet; + populationCastType = CollectionPopulationCastType.ISet; } else if (IsInterfaceMatch(type, _typeSymbols.IReadOnlyList_Unbound) || IsInterfaceMatch(type, _typeSymbols.IReadOnlyCollection_Unbound)) { - typeToInstantiate = _typeSymbols.List; - populationCastType = _typeSymbols.GenericICollection; - instantiationStrategy = InstantiationStrategy.ParameterizedConstructor; - populationStrategy = CollectionPopulationStrategy.Cast_Then_Add; + instantiationStrategy = CollectionInstantiationStrategy.CopyConstructor; + instantiationConcreteType = CollectionInstantiationConcreteType.List; + populationCastType = CollectionPopulationCastType.ICollection; } else { - RegisterUnsupportedType(type, Diagnostics.CollectionNotSupported); - return null; + return CreateUnsupportedCollectionSpec(typeParseInfo); } - Debug.Assert(!(populationStrategy is CollectionPopulationStrategy.Cast_Then_Add && populationCastType is null)); + TypeRef elementTypeRef = EnqueueTransitiveType(typeParseInfo, elementType, DiagnosticDescriptors.ElementTypeNotSupported); - EnumerableSpec spec = new(type) + return new EnumerableSpec(type) { - ElementType = elementSpec, + ElementTypeRef = elementTypeRef, InstantiationStrategy = instantiationStrategy, - PopulationStrategy = populationStrategy, - TypeToInstantiate = ConstructGenericCollectionSpecIfRequired(typeToInstantiate, elementType) as EnumerableSpec, - PopulationCastType = ConstructGenericCollectionSpecIfRequired(populationCastType, elementType) as EnumerableSpec, + InstantiationConcreteType = instantiationConcreteType, + PopulationCastType = populationCastType, }; - - return spec; } - private ObjectSpec? CreateObjectSpec(INamedTypeSymbol objectSymbol) + private ObjectSpec CreateObjectSpec(TypeParseInfo typeParseInfo) { - // Add spec to cache before traversing properties to avoid stack overflow. - ObjectSpec objectSpec = new(objectSymbol); - _createdSpecs.Add(objectSymbol, objectSpec); + INamedTypeSymbol typeSymbol = (INamedTypeSymbol)typeParseInfo.TypeSymbol; + string typeName = typeSymbol.GetTypeName().Name; - string typeName = objectSpec.Name; - IMethodSymbol? ctor = null; + ObjectInstantiationStrategy initializationStrategy = ObjectInstantiationStrategy.None; DiagnosticDescriptor? initDiagDescriptor = null; + string? initExceptionMessage = null; + + IMethodSymbol? ctor = null; - if (!(objectSymbol.IsAbstract || objectSymbol.TypeKind is TypeKind.Interface)) + if (!(typeSymbol.IsAbstract || typeSymbol.TypeKind is TypeKind.Interface)) { IMethodSymbol? parameterlessCtor = null; IMethodSymbol? parameterizedCtor = null; bool hasMultipleParameterizedCtors = false; - foreach (IMethodSymbol candidate in objectSymbol.InstanceConstructors) + foreach (IMethodSymbol candidate in typeSymbol.InstanceConstructors) { if (candidate.DeclaredAccessibility is not Accessibility.Public) { @@ -595,14 +538,14 @@ private DictionarySpec CreateDictionarySpec(INamedTypeSymbol type, ITypeSymbol k } } - bool hasPublicParameterlessCtor = objectSymbol.IsValueType || parameterlessCtor is not null; + bool hasPublicParameterlessCtor = typeSymbol.IsValueType || parameterlessCtor is not null; if (!hasPublicParameterlessCtor && hasMultipleParameterizedCtors) { - initDiagDescriptor = Diagnostics.MultipleParameterizedConstructors; - objectSpec.InitExceptionMessage = string.Format(Emitter.ExceptionMessages.MultipleParameterizedConstructors, typeName); + initDiagDescriptor = DiagnosticDescriptors.MultipleParameterizedConstructors; + initExceptionMessage = string.Format(Emitter.ExceptionMessages.MultipleParameterizedConstructors, typeName); } - ctor = objectSymbol.IsValueType + ctor = typeSymbol.IsValueType // Roslyn ctor fetching APIs include paramerterless ctors for structs, unlike System.Reflection. ? parameterizedCtor ?? parameterlessCtor : parameterlessCtor ?? parameterizedCtor; @@ -610,21 +553,23 @@ private DictionarySpec CreateDictionarySpec(INamedTypeSymbol type, ITypeSymbol k if (ctor is null) { - initDiagDescriptor = Diagnostics.MissingPublicInstanceConstructor; - objectSpec.InitExceptionMessage = string.Format(Emitter.ExceptionMessages.MissingPublicInstanceConstructor, typeName); + initDiagDescriptor = DiagnosticDescriptors.MissingPublicInstanceConstructor; + initExceptionMessage = string.Format(Emitter.ExceptionMessages.MissingPublicInstanceConstructor, typeName); } else { - objectSpec.InstantiationStrategy = ctor.Parameters.Length is 0 ? InstantiationStrategy.ParameterlessConstructor : InstantiationStrategy.ParameterizedConstructor; + initializationStrategy = ctor.Parameters.Length is 0 ? ObjectInstantiationStrategy.ParameterlessConstructor : ObjectInstantiationStrategy.ParameterizedConstructor; } if (initDiagDescriptor is not null) { - Debug.Assert(objectSpec.InitExceptionMessage is not null); - RegisterUnsupportedType(objectSymbol, initDiagDescriptor); + Debug.Assert(initExceptionMessage is not null); + RecordTypeDiagnostic(typeParseInfo, initDiagDescriptor); } - INamedTypeSymbol current = objectSymbol; + Dictionary? properties = null; + + INamedTypeSymbol? current = typeSymbol; while (current is not null) { ImmutableArray members = current.GetMembers(); @@ -633,105 +578,90 @@ private DictionarySpec CreateDictionarySpec(INamedTypeSymbol type, ITypeSymbol k if (member is IPropertySymbol { IsIndexer: false, IsImplicitlyDeclared: false } property) { string propertyName = property.Name; - TypeSpec propertyTypeSpec = GetOrCreateTypeSpec(property.Type); + TypeRef propertyTypeRef = EnqueueTransitiveType(typeParseInfo, property.Type, DiagnosticDescriptors.PropertyNotSupported, propertyName); - if (propertyTypeSpec?.CanBindTo is not true) - { - InvocationDiagnosticInfo propertyDiagnostic = new InvocationDiagnosticInfo(Diagnostics.PropertyNotSupported, new string[] { propertyName, objectSymbol.ToDisplayString() }); - RegisterTypeDiagnostic(causingType: objectSymbol, propertyDiagnostic); - _invocationTargetTypeDiags.Add(propertyDiagnostic); - } + AttributeData? attributeData = property.GetAttributes().FirstOrDefault(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, _typeSymbols.ConfigurationKeyNameAttribute)); + string configKeyName = attributeData?.ConstructorArguments.FirstOrDefault().Value as string ?? propertyName; - if (propertyTypeSpec is not null) + PropertySpec spec = new(property) { - AttributeData? attributeData = property.GetAttributes().FirstOrDefault(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, _typeSymbols.ConfigurationKeyNameAttribute)); - string configKeyName = attributeData?.ConstructorArguments.FirstOrDefault().Value as string ?? propertyName; - PropertySpec spec = new(property) { Type = propertyTypeSpec, ConfigurationKeyName = configKeyName }; + TypeRef = propertyTypeRef, + ConfigurationKeyName = configKeyName + }; - objectSpec.Properties[propertyName] = spec; - Register_AsConfigWithChildren_HelperForGen_IfRequired(propertyTypeSpec); - } + (properties ??= new(StringComparer.OrdinalIgnoreCase))[propertyName] = spec; } } current = current.BaseType; } - if (objectSpec.InstantiationStrategy is InstantiationStrategy.ParameterizedConstructor) + List? ctorParams = null; + + if (initializationStrategy is ObjectInstantiationStrategy.ParameterizedConstructor) { - List missingParameters = new(); - List invalidParameters = new(); + Debug.Assert(ctor is not null); + List? missingParameters = null; + List? invalidParameters = null; foreach (IParameterSymbol parameter in ctor.Parameters) { string parameterName = parameter.Name; - if (!objectSpec.Properties.TryGetValue(parameterName, out PropertySpec? propertySpec)) + if (properties?.TryGetValue(parameterName, out PropertySpec? propertySpec) is not true) { - missingParameters.Add(parameterName); + (missingParameters ??= new()).Add(parameterName); } else if (parameter.RefKind is not RefKind.None) { - invalidParameters.Add(parameterName); + (invalidParameters ??= new()).Add(parameterName); } else { ParameterSpec paramSpec = new ParameterSpec(parameter) { - Type = propertySpec.Type, + TypeRef = propertySpec.TypeRef, ConfigurationKeyName = propertySpec.ConfigurationKeyName, }; propertySpec.MatchingCtorParam = paramSpec; - objectSpec.ConstructorParameters.Add(paramSpec); + (ctorParams ??= new()).Add(paramSpec); } } - if (invalidParameters.Count > 0) + if (invalidParameters?.Count > 0) { - objectSpec.InitExceptionMessage = string.Format(Emitter.ExceptionMessages.CannotBindToConstructorParameter, typeName, FormatParams(invalidParameters)); + initExceptionMessage = string.Format(Emitter.ExceptionMessages.CannotBindToConstructorParameter, typeName, FormatParams(invalidParameters)); } - else if (missingParameters.Count > 0) + else if (missingParameters?.Count > 0) { - if (objectSymbol.IsValueType) + if (typeSymbol.IsValueType) { - objectSpec.InstantiationStrategy = InstantiationStrategy.ParameterlessConstructor; + initializationStrategy = ObjectInstantiationStrategy.ParameterlessConstructor; } else { - objectSpec.InitExceptionMessage = string.Format(Emitter.ExceptionMessages.ConstructorParametersDoNotMatchProperties, typeName, FormatParams(missingParameters)); + initExceptionMessage = string.Format(Emitter.ExceptionMessages.ConstructorParametersDoNotMatchProperties, typeName, FormatParams(missingParameters)); } } - if (objectSpec.CanInstantiate) - { - RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper.Initialize, objectSpec); - } - static string FormatParams(List names) => string.Join(",", names); } - Debug.Assert((objectSpec.CanInstantiate && objectSpec.InitExceptionMessage is null) || - (!objectSpec.CanInstantiate && objectSpec.InitExceptionMessage is not null) || - (!objectSpec.CanInstantiate && (objectSymbol.IsAbstract || objectSymbol.TypeKind is TypeKind.Interface))); - - TryRegisterTypeForBindCoreGen(objectSpec); - return objectSpec; + return new ObjectSpec( + typeSymbol, + initializationStrategy, + properties: properties?.Values.ToImmutableEquatableArray(), + constructorParameters: ctorParams?.ToImmutableEquatableArray(), + initExceptionMessage); } - private bool MemberTypeIsBindable(ITypeSymbol containingTypeSymbol, ITypeSymbol memberTypeSymbol, DiagnosticDescriptor containingTypeDiagDescriptor, out TypeSpec? memberTypeSpec) - { - if (GetOrCreateTypeSpec(memberTypeSymbol) is TypeSpec { CanBindTo: true } spec) - { - memberTypeSpec = spec; - return true; - } + private static UnsupportedTypeSpec CreateUnsupportedCollectionSpec(TypeParseInfo typeParseInfo) + => CreateUnsupportedTypeSpec(typeParseInfo, NotSupportedReason.CollectionNotSupported); - RegisterUnsupportedType(containingTypeSymbol, containingTypeDiagDescriptor); - memberTypeSpec = null; - return false; - } + private static UnsupportedTypeSpec CreateUnsupportedTypeSpec(TypeParseInfo typeParseInfo, NotSupportedReason reason) => + new(typeParseInfo.TypeSymbol) { NotSupportedReason = reason }; - private bool TryGetElementType(INamedTypeSymbol type, out ITypeSymbol? elementType) + private bool TryGetElementType(INamedTypeSymbol type, [NotNullWhen(true)] out ITypeSymbol? elementType) { INamedTypeSymbol? candidate = GetInterface(type, _typeSymbols.GenericIEnumerable_Unbound); @@ -745,7 +675,7 @@ private bool TryGetElementType(INamedTypeSymbol type, out ITypeSymbol? elementTy return false; } - private bool IsCandidateDictionary(INamedTypeSymbol type, out ITypeSymbol? keyType, out ITypeSymbol? elementType) + private bool IsCandidateDictionary(INamedTypeSymbol type, [NotNullWhen(true)] out ITypeSymbol? keyType, [NotNullWhen(true)] out ITypeSymbol? elementType) { INamedTypeSymbol? candidate = GetInterface(type, _typeSymbols.GenericIDictionary_Unbound) ?? GetInterface(type, _typeSymbols.IReadOnlyDictionary_Unbound); @@ -771,24 +701,13 @@ private bool IsCandidateDictionary(INamedTypeSymbol type, out ITypeSymbol? keyTy private bool IsCollection(ITypeSymbol type) => type is INamedTypeSymbol namedType && GetInterface(namedType, _typeSymbols.IEnumerable) is not null; - private bool IsSupportedArrayType(ITypeSymbol type) + private static INamedTypeSymbol? GetInterface(INamedTypeSymbol type, INamedTypeSymbol? @interface) { - if (type is not IArrayTypeSymbol arrayType) + if (@interface is null) { - return false; - } - - if (arrayType.Rank > 1) - { - RegisterUnsupportedType(arrayType, Diagnostics.MultiDimArraysNotSupported); - return false; + return null; } - return true; - } - - private static INamedTypeSymbol? GetInterface(INamedTypeSymbol type, INamedTypeSymbol @interface) - { if (IsInterfaceMatch(type, @interface)) { return type; @@ -805,8 +724,13 @@ private bool IsSupportedArrayType(ITypeSymbol type) return type.AllInterfaces.FirstOrDefault(candidate => SymbolEqualityComparer.Default.Equals(candidate, @interface)); } - private static bool IsInterfaceMatch(INamedTypeSymbol type, INamedTypeSymbol @interface) + private static bool IsInterfaceMatch(INamedTypeSymbol type, INamedTypeSymbol? @interface) { + if (@interface is null) + { + return false; + } + if (type.IsGenericType) { INamedTypeSymbol unbound = type.ConstructUnboundGenericType(); @@ -840,8 +764,8 @@ private static bool HasPublicParameterLessCtor(INamedTypeSymbol type) => private static bool HasAddMethod(INamedTypeSymbol type, ITypeSymbol element) { - INamedTypeSymbol current = type; - while (current != null) + INamedTypeSymbol? current = type; + while (current is not null) { if (current.GetMembers("Add").Any(member => member is IMethodSymbol { Parameters.Length: 1 } method && @@ -856,8 +780,8 @@ private static bool HasAddMethod(INamedTypeSymbol type, ITypeSymbol element) private static bool HasAddMethod(INamedTypeSymbol type, ITypeSymbol key, ITypeSymbol element) { - INamedTypeSymbol current = type; - while (current != null) + INamedTypeSymbol? current = type; + while (current is not null) { if (current.GetMembers("Add").Any(member => member is IMethodSymbol { Parameters.Length: 2 } method && @@ -873,40 +797,51 @@ private static bool HasAddMethod(INamedTypeSymbol type, ITypeSymbol key, ITypeSy private static bool IsEnum(ITypeSymbol type) => type is INamedTypeSymbol { EnumUnderlyingType: INamedTypeSymbol { } }; - private CollectionSpec? ConstructGenericCollectionSpecIfRequired(INamedTypeSymbol? collectionType, params ITypeSymbol[] parameters) => - (collectionType is not null ? ConstructGenericCollectionSpec(collectionType, parameters) : null); - - private CollectionSpec? ConstructGenericCollectionSpec(INamedTypeSymbol type, params ITypeSymbol[] parameters) - { - Debug.Assert(type.IsGenericType); - INamedTypeSymbol constructedType = type.Construct(parameters); - return CreateCollectionSpec(constructedType); - } - - private void RegisterUnsupportedType(ITypeSymbol type, DiagnosticDescriptor descriptor = null) + private void RecordTypeDiagnosticIfRequired(TypeParseInfo typeParseInfo, TypeSpec typeSpec) { - InvocationDiagnosticInfo diagInfo = new(descriptor, new string[] { type.ToDisplayString() }); + ContainingTypeDiagnosticInfo? containingTypeDiagInfo = typeParseInfo.ContainingTypeDiagnosticInfo; - if (!_unsupportedTypes.Contains(type)) + if (typeSpec is UnsupportedTypeSpec unsupportedTypeSpec) + { + DiagnosticDescriptor descriptor = DiagnosticDescriptors.GetNotSupportedDescriptor(unsupportedTypeSpec.NotSupportedReason); + RecordTypeDiagnostic(typeParseInfo, descriptor); + } + else if (containingTypeDiagInfo?.Descriptor == DiagnosticDescriptors.DictionaryKeyNotSupported && + typeSpec is not ParsableFromStringSpec) { - RegisterTypeDiagnostic(type, diagInfo); - _unsupportedTypes.Add(type); + ReportContainingTypeDiagnosticIfRequired(typeParseInfo); } + } - _invocationTargetTypeDiags.Add(diagInfo); + private void RecordTypeDiagnostic(TypeParseInfo typeParseInfo, DiagnosticDescriptor descriptor) + { + RecordDiagnostic(descriptor, typeParseInfo.BinderInvocation.Location, new object?[] { typeParseInfo.TypeName }); + ReportContainingTypeDiagnosticIfRequired(typeParseInfo); } - private void RegisterTypeDiagnostic(ITypeSymbol causingType, InvocationDiagnosticInfo info) + private void ReportContainingTypeDiagnosticIfRequired(TypeParseInfo typeParseInfo) { - bool typeHadDiags = _typeDiagnostics.TryGetValue(causingType, out HashSet? typeDiags); - typeDiags ??= new HashSet(); - typeDiags.Add(info); + ContainingTypeDiagnosticInfo? containingTypeDiagInfo = typeParseInfo.ContainingTypeDiagnosticInfo; - if (!typeHadDiags) + while (containingTypeDiagInfo is not null) { - _typeDiagnostics[causingType] = typeDiags; + string containingTypeName = containingTypeDiagInfo.TypeName; + + object[] messageArgs = containingTypeDiagInfo.MemberName is string memberName + ? new[] { memberName, containingTypeName } + : new[] { containingTypeName }; + + RecordDiagnostic(containingTypeDiagInfo.Descriptor, typeParseInfo.BinderInvocation.Location, messageArgs); + + containingTypeDiagInfo = containingTypeDiagInfo.ContainingTypeInfo; } } + + private void RecordDiagnostic(DiagnosticDescriptor descriptor, Location trimmedLocation, params object?[]? messageArgs) + { + Diagnostics ??= new List(); + Diagnostics.Add(DiagnosticInfo.Create(descriptor, trimmedLocation, messageArgs)); + } } } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.cs index fbca2dd3cfc507..ec4b234a61045c 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/ConfigurationBindingGenerator.cs @@ -2,9 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. //#define LAUNCH_DEBUGGER -using System.Collections.Immutable; +using System; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; +using SourceGenerators; namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { @@ -14,7 +15,9 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration [Generator] public sealed partial class ConfigurationBindingGenerator : IIncrementalGenerator { - private static readonly string ProjectName = Emitter.s_assemblyName.Name; + private static readonly string ProjectName = Emitter.s_assemblyName.Name!; + + public const string GenSpecTrackingName = nameof(SourceGenerationSpec); public void Initialize(IncrementalGeneratorInitializationContext context) { @@ -30,39 +33,61 @@ public void Initialize(IncrementalGeneratorInitializationContext context) ? new CompilationData((CSharpCompilation)compilation) : null); - IncrementalValuesProvider inputCalls = context.SyntaxProvider + IncrementalValueProvider<(SourceGenerationSpec?, ImmutableEquatableArray?)> genSpec = context.SyntaxProvider .CreateSyntaxProvider( (node, _) => BinderInvocation.IsCandidateSyntaxNode(node), BinderInvocation.Create) - .Where(invocation => invocation is not null); + .Where(invocation => invocation is not null) + .Collect() + .Combine(compilationData) + .Select((tuple, cancellationToken) => + { + if (tuple.Right is not CompilationData compilationData) + { + return (null, null); + } - IncrementalValueProvider<(CompilationData?, ImmutableArray)> inputData = compilationData.Combine(inputCalls.Collect()); + try + { + Parser parser = new(compilationData); + SourceGenerationSpec? spec = parser.GetSourceGenerationSpec(tuple.Left, cancellationToken); + ImmutableEquatableArray? diagnostics = parser.Diagnostics?.ToImmutableEquatableArray(); + return (spec, diagnostics); + } + catch (Exception ex) + { + throw ex; + } + }) + .WithTrackingName(GenSpecTrackingName); - context.RegisterSourceOutput(inputData, (spc, source) => Execute(source.Item1, source.Item2, spc)); + context.RegisterSourceOutput(genSpec, ReportDiagnosticsAndEmitSource); } - private static void Execute(CompilationData compilationData, ImmutableArray inputCalls, SourceProductionContext context) - { - if (inputCalls.IsDefaultOrEmpty) - { - return; - } + /// + /// Instrumentation helper for unit tests. + /// + public Action? OnSourceEmitting { get; init; } - if (compilationData?.LanguageVersionIsSupported is not true) + private void ReportDiagnosticsAndEmitSource(SourceProductionContext sourceProductionContext, (SourceGenerationSpec? SourceGenerationSpec, ImmutableEquatableArray? Diagnostics) input) + { + if (input.Diagnostics is ImmutableEquatableArray diagnostics) { - context.ReportDiagnostic(Diagnostic.Create(Parser.Diagnostics.LanguageVersionNotSupported, location: null)); - return; + foreach (DiagnosticInfo diagnostic in diagnostics) + { + sourceProductionContext.ReportDiagnostic(diagnostic.CreateDiagnostic()); + } } - Parser parser = new(context, compilationData.TypeSymbols!, inputCalls); - if (parser.GetSourceGenerationSpec() is SourceGenerationSpec spec) + if (input.SourceGenerationSpec is SourceGenerationSpec spec) { - Emitter emitter = new(context, spec); - emitter.Emit(); + OnSourceEmitting?.Invoke(spec); + Emitter emitter = new(spec); + emitter.Emit(sourceProductionContext); } } - private sealed record CompilationData + internal sealed class CompilationData { public bool LanguageVersionIsSupported { get; } public KnownTypeSymbols? TypeSymbols { get; } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/ConfigurationBinder.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/ConfigurationBinder.cs index f1c7d5f7ff2150..7d723139bde3e4 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/ConfigurationBinder.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/ConfigurationBinder.cs @@ -1,8 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Collections.Generic; using System.Diagnostics; +using SourceGenerators; namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { @@ -10,11 +10,9 @@ public sealed partial class ConfigurationBindingGenerator { private sealed partial class Emitter { - private bool ShouldEmitMethods(MethodsToGen_ConfigurationBinder methods) => (_sourceGenSpec.MethodsToGen_ConfigurationBinder & methods) != 0; - private void EmitBindingExtensions_IConfiguration() { - if (!ShouldEmitMethods(MethodsToGen_ConfigurationBinder.Any)) + if (!ShouldEmitMethods(MethodsToGen.ConfigBinder_Any)) { return; } @@ -31,30 +29,30 @@ private void EmitGetMethods() const string expressionForGetCore = nameof(MethodsToGen_CoreBindingHelper.GetCore); const string documentation = "Attempts to bind the configuration instance to a new instance of type T."; - if (ShouldEmitMethods(MethodsToGen_ConfigurationBinder.Get_T)) + if (ShouldEmitMethods(MethodsToGen.ConfigBinder_Get_T)) { - StartMethodDefinition(MethodsToGen_ConfigurationBinder.Get_T, documentation); + EmitStartDefinition_Get_Or_GetValue_Overload(MethodsToGen.ConfigBinder_Get_T, documentation); _writer.WriteLine($"public static T? {Identifier.Get}(this {Identifier.IConfiguration} {Identifier.configuration}) => " + $"(T?)({expressionForGetCore}({Identifier.configuration}, typeof(T), {Identifier.configureOptions}: null) ?? default(T));"); } - if (ShouldEmitMethods(MethodsToGen_ConfigurationBinder.Get_T_BinderOptions)) + if (ShouldEmitMethods(MethodsToGen.ConfigBinder_Get_T_BinderOptions)) { - StartMethodDefinition(MethodsToGen_ConfigurationBinder.Get_T_BinderOptions, documentation); + EmitStartDefinition_Get_Or_GetValue_Overload(MethodsToGen.ConfigBinder_Get_T_BinderOptions, documentation); _writer.WriteLine($"public static T? {Identifier.Get}(this {Identifier.IConfiguration} {Identifier.configuration}, {TypeDisplayString.NullableActionOfBinderOptions} {Identifier.configureOptions}) => " + $"(T?)({expressionForGetCore}({Identifier.configuration}, typeof(T), {Identifier.configureOptions}) ?? default(T));"); } - if (ShouldEmitMethods(MethodsToGen_ConfigurationBinder.Get_TypeOf)) + if (ShouldEmitMethods(MethodsToGen.ConfigBinder_Get_TypeOf)) { - StartMethodDefinition(MethodsToGen_ConfigurationBinder.Get_TypeOf, documentation); + EmitStartDefinition_Get_Or_GetValue_Overload(MethodsToGen.ConfigBinder_Get_TypeOf, documentation); _writer.WriteLine($"public static object? {Identifier.Get}(this {Identifier.IConfiguration} {Identifier.configuration}, Type {Identifier.type}) => " + $"{expressionForGetCore}({Identifier.configuration}, {Identifier.type}, {Identifier.configureOptions}: null);"); } - if (ShouldEmitMethods(MethodsToGen_ConfigurationBinder.Get_TypeOf_BinderOptions)) + if (ShouldEmitMethods(MethodsToGen.ConfigBinder_Get_TypeOf_BinderOptions)) { - StartMethodDefinition(MethodsToGen_ConfigurationBinder.Get_TypeOf_BinderOptions, documentation); + EmitStartDefinition_Get_Or_GetValue_Overload(MethodsToGen.ConfigBinder_Get_TypeOf_BinderOptions, documentation); _writer.WriteLine($"public static object? {Identifier.Get}(this {Identifier.IConfiguration} {Identifier.configuration}, Type {Identifier.type}, {TypeDisplayString.NullableActionOfBinderOptions} {Identifier.configureOptions}) => " + $"{expressionForGetCore}({Identifier.configuration}, {Identifier.type}, {Identifier.configureOptions});"); } @@ -65,30 +63,30 @@ private void EmitGetValueMethods() const string expressionForGetValueCore = $"{Identifier.BindingExtensions}.{nameof(MethodsToGen_CoreBindingHelper.GetValueCore)}"; const string documentation = "Extracts the value with the specified key and converts it to the specified type."; - if (ShouldEmitMethods(MethodsToGen_ConfigurationBinder.GetValue_T_key)) + if (ShouldEmitMethods(MethodsToGen.ConfigBinder_GetValue_T_key)) { - StartMethodDefinition(MethodsToGen_ConfigurationBinder.GetValue_T_key, documentation); + EmitStartDefinition_Get_Or_GetValue_Overload(MethodsToGen.ConfigBinder_GetValue_T_key, documentation); _writer.WriteLine($"public static T? {Identifier.GetValue}(this {Identifier.IConfiguration} {Identifier.configuration}, string {Identifier.key}) => " + $"(T?)({expressionForGetValueCore}({Identifier.configuration}, typeof(T), {Identifier.key}) ?? default(T));"); } - if (ShouldEmitMethods(MethodsToGen_ConfigurationBinder.GetValue_T_key_defaultValue)) + if (ShouldEmitMethods(MethodsToGen.ConfigBinder_GetValue_T_key_defaultValue)) { - StartMethodDefinition(MethodsToGen_ConfigurationBinder.GetValue_T_key_defaultValue, documentation); + EmitStartDefinition_Get_Or_GetValue_Overload(MethodsToGen.ConfigBinder_GetValue_T_key_defaultValue, documentation); _writer.WriteLine($"public static T? {Identifier.GetValue}(this {Identifier.IConfiguration} {Identifier.configuration}, string {Identifier.key}, T {Identifier.defaultValue}) => " + $"(T?)({expressionForGetValueCore}({Identifier.configuration}, typeof(T), {Identifier.key}) ?? {Identifier.defaultValue});"); } - if (ShouldEmitMethods(MethodsToGen_ConfigurationBinder.GetValue_TypeOf_key)) + if (ShouldEmitMethods(MethodsToGen.ConfigBinder_GetValue_TypeOf_key)) { - StartMethodDefinition(MethodsToGen_ConfigurationBinder.GetValue_TypeOf_key, documentation); + EmitStartDefinition_Get_Or_GetValue_Overload(MethodsToGen.ConfigBinder_GetValue_TypeOf_key, documentation); _writer.WriteLine($"public static object? {Identifier.GetValue}(this {Identifier.IConfiguration} {Identifier.configuration}, Type {Identifier.type}, string {Identifier.key}) => " + $"{expressionForGetValueCore}({Identifier.configuration}, {Identifier.type}, {Identifier.key});"); } - if (ShouldEmitMethods(MethodsToGen_ConfigurationBinder.GetValue_TypeOf_key_defaultValue)) + if (ShouldEmitMethods(MethodsToGen.ConfigBinder_GetValue_TypeOf_key_defaultValue)) { - StartMethodDefinition(MethodsToGen_ConfigurationBinder.GetValue_TypeOf_key_defaultValue, documentation); + EmitStartDefinition_Get_Or_GetValue_Overload(MethodsToGen.ConfigBinder_GetValue_TypeOf_key_defaultValue, documentation); _writer.WriteLine($"public static object? {Identifier.GetValue}(this {Identifier.IConfiguration} {Identifier.configuration}, Type {Identifier.type}, string {Identifier.key}, object? {Identifier.defaultValue}) => " + $"{expressionForGetValueCore}({Identifier.configuration}, {Identifier.type}, {Identifier.key}) ?? {Identifier.defaultValue};"); } @@ -96,50 +94,52 @@ private void EmitGetValueMethods() private void EmitBindMethods_ConfigurationBinder() { - if (!ShouldEmitMethods(MethodsToGen_ConfigurationBinder.Bind)) + if (!ShouldEmitMethods(MethodsToGen.ConfigBinder_Bind)) { return; } string instanceParamExpr = $"object? {Identifier.instance}"; - if (ShouldEmitMethods(MethodsToGen_ConfigurationBinder.Bind_instance)) + if (ShouldEmitMethods(MethodsToGen.ConfigBinder_Bind_instance)) { EmitMethods( - MethodsToGen_ConfigurationBinder.Bind_instance, + _interceptorInfo.ConfigBinder_Bind_instance, additionalParams: instanceParamExpr, configExpression: Identifier.configuration, configureOptions: false); } - if (ShouldEmitMethods(MethodsToGen_ConfigurationBinder.Bind_instance_BinderOptions)) + if (ShouldEmitMethods(MethodsToGen.ConfigBinder_Bind_instance_BinderOptions)) { EmitMethods( - MethodsToGen_ConfigurationBinder.Bind_instance_BinderOptions, + _interceptorInfo.ConfigBinder_Bind_instance_BinderOptions, additionalParams: $"{instanceParamExpr}, {TypeDisplayString.NullableActionOfBinderOptions} {Identifier.configureOptions}", configExpression: Identifier.configuration, configureOptions: true); } - if (ShouldEmitMethods(MethodsToGen_ConfigurationBinder.Bind_key_instance)) + if (ShouldEmitMethods(MethodsToGen.ConfigBinder_Bind_key_instance)) { EmitMethods( - MethodsToGen_ConfigurationBinder.Bind_key_instance, + _interceptorInfo.ConfigBinder_Bind_key_instance, additionalParams: $"string {Identifier.key}, {instanceParamExpr}", configExpression: $"{Expression.configurationGetSection}({Identifier.key})", configureOptions: false); } - void EmitMethods(MethodsToGen_ConfigurationBinder method, string additionalParams, string configExpression, bool configureOptions) + void EmitMethods(ImmutableEquatableArray? interceptorInfo, string additionalParams, string configExpression, bool configureOptions) { - foreach ((ComplexTypeSpec type, List interceptorInfoList) in _sourceGenSpec.InterceptionInfo_ConfigBinder.GetOverloadInfo(method)) + Debug.Assert(interceptorInfo is not null); + + foreach ((ComplexTypeSpec type, ImmutableEquatableArray locations) in interceptorInfo) { EmitBlankLineIfRequired(); _writer.WriteLine($"/// Attempts to bind the given object instance to configuration values by matching property names against configuration keys recursively."); - EmitInterceptsLocationAnnotations(interceptorInfoList); + EmitInterceptsLocationAnnotations(locations); EmitStartBlock($"public static void {Identifier.Bind}_{type.IdentifierCompatibleSubstring}(this {Identifier.IConfiguration} {Identifier.configuration}, {additionalParams})"); - if (type.HasBindableMembers) + if (_typeIndex.HasBindableMembers(type)) { Debug.Assert(!type.IsValueType); string binderOptionsArg = configureOptions ? $"{Identifier.GetBinderOptions}({Identifier.configureOptions})" : $"{Identifier.binderOptions}: null"; @@ -147,7 +147,7 @@ void EmitMethods(MethodsToGen_ConfigurationBinder method, string additionalParam EmitCheckForNullArgument_WithBlankLine(Identifier.configuration); EmitCheckForNullArgument_WithBlankLine(Identifier.instance, voidReturn: true); _writer.WriteLine($$""" - var {{Identifier.typedObj}} = ({{type.EffectiveType.DisplayString}}){{Identifier.instance}}; + var {{Identifier.typedObj}} = ({{type.DisplayString}}){{Identifier.instance}}; {{nameof(MethodsToGen_CoreBindingHelper.BindCore)}}({{configExpression}}, ref {{Identifier.typedObj}}, defaultValueIfNotFound: false, {{binderOptionsArg}}); """); } @@ -157,11 +157,11 @@ void EmitMethods(MethodsToGen_ConfigurationBinder method, string additionalParam } } - private void StartMethodDefinition(MethodsToGen_ConfigurationBinder method, string documentation) + private void EmitStartDefinition_Get_Or_GetValue_Overload(MethodsToGen overload, string documentation) { EmitBlankLineIfRequired(); _writer.WriteLine($"/// {documentation}"); - EmitInterceptsLocationAnnotations(method); + EmitInterceptsLocationAnnotations(overload); } } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/CoreBindingHelpers.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/CoreBindingHelpers.cs index 90531efe1b0c10..499d4085bbd362 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/CoreBindingHelpers.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/CoreBindingHelpers.cs @@ -7,6 +7,7 @@ using System.Linq; using System.Text.RegularExpressions; using Microsoft.CodeAnalysis; +using SourceGenerators; namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { @@ -18,7 +19,7 @@ private sealed partial class Emitter private bool _emitBlankLineBeforeNextStatement; private static readonly Regex s_arrayBracketsRegex = new(Regex.Escape("[]")); - private bool ShouldEmitMethods(MethodsToGen_CoreBindingHelper methods) => (_sourceGenSpec.MethodsToGen_CoreBindingHelper & methods) != 0; + private bool ShouldEmitMethods(MethodsToGen_CoreBindingHelper methods) => (_bindingHelperInfo.MethodsToGen & methods) != 0; private void EmitCoreBindingHelpers() { @@ -36,33 +37,54 @@ private void EmitCoreBindingHelpers() private void EmitConfigurationKeyCaches() { - if (!_sourceGenSpec.TypesForGen_CoreBindingHelper_Methods.TryGetValue(MethodsToGen_CoreBindingHelper.BindCore, out HashSet targetTypes)) + if (_bindingHelperInfo.TypesForGen_BindCore is not { Count: not 0 } types) { return; } EmitBlankLineIfRequired(); - foreach (TypeSpec type in targetTypes) + foreach (TypeSpec type in types) { if (type is not ObjectSpec objectType) { continue; } - HashSet keys = new(objectType.ConstructorParameters.Select(m => GetCacheElement(m))); - keys.UnionWith(objectType.Properties.Values.Select(m => GetCacheElement(m))); + Debug.Assert(_typeIndex.HasBindableMembers(objectType)); + + HashSet? keys = null; static string GetCacheElement(MemberSpec member) => $@"""{member.ConfigurationKeyName}"""; + if (objectType.ConstructorParameters?.Select(m => GetCacheElement(m)) is IEnumerable paramNames) + { + keys = new(paramNames); + } + + if (objectType.Properties?.Select(m => GetCacheElement(m)) is IEnumerable propNames) + { + if (keys is null) + { + keys = new(propNames); + } + else + { + keys.UnionWith(propNames); + } + } + + // Type has bindable members. + Debug.Assert(keys is not null); + string configKeysSource = string.Join(", ", keys); - string fieldName = GetConfigKeyCacheFieldName(objectType); + string fieldName = TypeIndex.GetConfigKeyCacheFieldName(objectType); _writer.WriteLine($@"private readonly static Lazy<{TypeDisplayString.HashSetOfString}> {fieldName} = new(() => new {TypeDisplayString.HashSetOfString}(StringComparer.OrdinalIgnoreCase) {{ {configKeysSource} }});"); } } private void EmitGetCoreMethod() { - if (!_sourceGenSpec.TypesForGen_CoreBindingHelper_Methods.TryGetValue(MethodsToGen_CoreBindingHelper.GetCore, out HashSet? types)) + if (_bindingHelperInfo.TypesForGen_GetCore is not { Count: not 0 } targetTypes) { return; } @@ -78,10 +100,11 @@ private void EmitGetCoreMethod() EmitIConfigurationHasValueOrChildrenCheck(voidReturn: false); bool isFirstType = true; - foreach (TypeSpec type in types) + foreach (TypeSpec type in targetTypes) { - TypeSpec effectiveType = type.EffectiveType; - TypeSpecKind kind = effectiveType.SpecKind; + Debug.Assert(_typeIndex.CanBindTo(type.TypeRef)); + + TypeSpec effectiveType = _typeIndex.GetEffectiveTypeSpec(type); string conditionKindExpr = GetConditionKindExpr(ref isFirstType); EmitStartBlock($"{conditionKindExpr} ({Identifier.type} == typeof({type.DisplayString}))"); @@ -101,7 +124,7 @@ private void EmitGetCoreMethod() useIncrementalStringValueIdentifier: false); } break; - case ConfigurationSectionSpec configurationSectionSpec: + case ConfigurationSectionSpec: { EmitCastToIConfigurationSection(); _writer.WriteLine($"return {Identifier.section};"); @@ -109,7 +132,7 @@ private void EmitGetCoreMethod() break; case ComplexTypeSpec complexType: { - if (complexType.CanInstantiate) + if (_typeIndex.CanInstantiate(complexType)) { EmitBindingLogic(complexType, Identifier.instance, Identifier.configuration, InitializationKind.Declaration, ValueDefaulting.CallSetter); _writer.WriteLine($"return {Identifier.instance};"); @@ -118,6 +141,12 @@ private void EmitGetCoreMethod() { _writer.WriteLine($@"throw new {Identifier.InvalidOperationException}(""{exMsg}"");"); } +#if DEBUG + else + { + Debug.Fail($"Complex should not be included for GetCore gen: {complexType.DisplayString}"); + } +#endif } break; } @@ -141,7 +170,7 @@ void EmitCastToIConfigurationSection() => private void EmitGetValueCoreMethod() { - if (!_sourceGenSpec.TypesForGen_CoreBindingHelper_Methods.TryGetValue(MethodsToGen_CoreBindingHelper.GetValueCore, out HashSet? targetTypes)) + if (_bindingHelperInfo.TypesForGen_GetValueCore is not { Count: not 0 } targetTypes) { return; } @@ -169,7 +198,7 @@ private void EmitGetValueCoreMethod() EmitStartBlock($"{conditionKindExpr} ({Identifier.type} == typeof({type.DisplayString}))"); EmitBindingLogic( - (ParsableFromStringSpec)type.EffectiveType, + (ParsableFromStringSpec)_typeIndex.GetEffectiveTypeSpec(type), Identifier.value, Expression.sectionPath, writeOnSuccess: (parsedValueExpr) => _writer.WriteLine($"return {parsedValueExpr};"), @@ -188,7 +217,7 @@ private void EmitGetValueCoreMethod() private void EmitBindCoreMainMethod() { - if (!_sourceGenSpec.TypesForGen_CoreBindingHelper_Methods.TryGetValue(MethodsToGen_CoreBindingHelper.BindCoreMain, out HashSet? targetTypes)) + if (_bindingHelperInfo.TypesForGen_BindCoreMain is not { Count: not 0 } targetTypes) { return; } @@ -203,8 +232,8 @@ private void EmitBindCoreMainMethod() bool isFirstType = true; foreach (ComplexTypeSpec type in targetTypes) { - ComplexTypeSpec effectiveType = (ComplexTypeSpec)type.EffectiveType; - Debug.Assert(effectiveType.HasBindableMembers); + ComplexTypeSpec effectiveType = (ComplexTypeSpec)_typeIndex.GetEffectiveTypeSpec(type); + Debug.Assert(_typeIndex.HasBindableMembers(effectiveType)); string conditionKindExpr = GetConditionKindExpr(ref isFirstType); EmitStartBlock($"{conditionKindExpr} ({Identifier.type} == typeof({type.DisplayString}))"); @@ -221,14 +250,14 @@ private void EmitBindCoreMainMethod() private void EmitBindCoreMethods() { - if (!_sourceGenSpec.TypesForGen_CoreBindingHelper_Methods.TryGetValue(MethodsToGen_CoreBindingHelper.BindCore, out HashSet? targetTypes)) + if (_bindingHelperInfo.TypesForGen_BindCore is not ImmutableEquatableArray types) { return; } - foreach (ComplexTypeSpec type in targetTypes) + foreach (ComplexTypeSpec type in types) { - Debug.Assert(type.HasBindableMembers); + Debug.Assert(_typeIndex.HasBindableMembers(type)); EmitBlankLineIfRequired(); EmitBindCoreMethod(type); } @@ -239,26 +268,35 @@ private void EmitBindCoreMethod(ComplexTypeSpec type) string objParameterExpression = $"ref {type.DisplayString} {Identifier.instance}"; EmitStartBlock(@$"public static void {nameof(MethodsToGen_CoreBindingHelper.BindCore)}({Identifier.IConfiguration} {Identifier.configuration}, {objParameterExpression}, bool defaultValueIfNotFound, {Identifier.BinderOptions}? {Identifier.binderOptions})"); - ComplexTypeSpec effectiveType = (ComplexTypeSpec)type.EffectiveType; - if (effectiveType is EnumerableSpec enumerable) - { - if (effectiveType.InstantiationStrategy is InstantiationStrategy.Array) - { - Debug.Assert(type == effectiveType); - EmitPopulationImplForArray((EnumerableSpec)type); - } - else - { - EmitPopulationImplForEnumerableWithAdd(enumerable); - } - } - else if (effectiveType is DictionarySpec dictionary) - { - EmitBindCoreImplForDictionary(dictionary); - } - else + ComplexTypeSpec effectiveType = (ComplexTypeSpec)_typeIndex.GetEffectiveTypeSpec(type); + + switch (effectiveType) { - EmitBindCoreImplForObject((ObjectSpec)effectiveType); + case ArraySpec arrayType: + { + EmitBindCoreImplForArray(arrayType); + } + break; + case EnumerableSpec enumerableType: + { + EmitBindCoreImplForEnumerableWithAdd(enumerableType); + } + break; + case DictionarySpec dictionaryType: + { + EmitBindCoreImplForDictionary(dictionaryType); + } + break; + case ObjectSpec objectType: + { + EmitBindCoreImplForObject(objectType); + } + break; + default: + { + Debug.Fail($"Unsupported spec for bind core gen: {effectiveType.GetType()}"); + } + break; } EmitEndBlock(); @@ -266,12 +304,12 @@ private void EmitBindCoreMethod(ComplexTypeSpec type) private void EmitInitializeMethods() { - if (!_sourceGenSpec.TypesForGen_CoreBindingHelper_Methods.TryGetValue(MethodsToGen_CoreBindingHelper.Initialize, out HashSet? targetTypes)) + if (_bindingHelperInfo.TypesForGen_Initialize is not ImmutableEquatableArray types) { return; } - foreach (ObjectSpec type in targetTypes) + foreach (ObjectSpec type in types) { EmitBlankLineIfRequired(); EmitInitializeMethod(type); @@ -280,16 +318,20 @@ private void EmitInitializeMethods() private void EmitInitializeMethod(ObjectSpec type) { - Debug.Assert(type.CanInstantiate); - List ctorParams = type.ConstructorParameters; - IEnumerable initOnlyProps = type.Properties.Values.Where(prop => prop is { SetOnInit: true }); + Debug.Assert(type.InstantiationStrategy is ObjectInstantiationStrategy.ParameterizedConstructor); + Debug.Assert(_typeIndex.CanInstantiate(type)); + Debug.Assert( + type is { Properties: not null, ConstructorParameters: not null }, + $"Expecting type for init method, {type.DisplayString}, to have both properties and ctor params."); + + IEnumerable initOnlyProps = type.Properties.Where(prop => prop is { SetOnInit: true }); List ctorArgList = new(); string displayString = type.DisplayString; EmitStartBlock($"public static {type.DisplayString} {GetInitalizeMethodDisplayString(type)}({Identifier.IConfiguration} {Identifier.configuration}, {Identifier.BinderOptions}? {Identifier.binderOptions})"); _emitBlankLineBeforeNextStatement = false; - foreach (ParameterSpec parameter in ctorParams) + foreach (ParameterSpec parameter in type.ConstructorParameters) { string name = parameter.Name; string argExpr = parameter.RefKind switch @@ -307,7 +349,7 @@ private void EmitInitializeMethod(ObjectSpec type) foreach (PropertySpec property in initOnlyProps) { - if (property.ShouldBindTo && property.MatchingCtorParam is null) + if (_typeIndex.ShouldBindTo(property) && property.MatchingCtorParam is null) { EmitBindImplForMember(property); } @@ -335,7 +377,7 @@ private void EmitInitializeMethod(ObjectSpec type) void EmitBindImplForMember(MemberSpec member) { - TypeSpec memberType = member.Type; + TypeSpec memberType = _typeIndex.GetTypeSpec(member.TypeRef); string parsedMemberDeclarationLhs = $"{memberType.DisplayString} {member.Name}"; string configKeyName = member.ConfigurationKeyName; string parsedMemberAssignmentLhsExpr; @@ -427,29 +469,32 @@ private void EmitHelperMethods() } if (ShouldEmitMethods(MethodsToGen_CoreBindingHelper.BindCoreMain | MethodsToGen_CoreBindingHelper.GetCore) || - ShouldEmitMethods(MethodsToGen_ConfigurationBinder.Bind_instance_BinderOptions)) + ShouldEmitMethods(MethodsToGen.ConfigBinder_Bind_instance_BinderOptions)) { EmitBlankLineIfRequired(); EmitGetBinderOptionsHelper(); } - bool enumTypeExists = false; - - foreach (ParsableFromStringSpec type in _sourceGenSpec.PrimitivesForHelperGen) + if (_bindingHelperInfo.TypesForGen_ParsePrimitive is { Count: not 0 } stringParsableTypes) { - EmitBlankLineIfRequired(); + bool enumTypeExists = false; - if (type.StringParsableTypeKind == StringParsableTypeKind.Enum) + foreach (ParsableFromStringSpec type in stringParsableTypes) { - if (!enumTypeExists) + EmitBlankLineIfRequired(); + + if (type.StringParsableTypeKind == StringParsableTypeKind.Enum) { - EmitEnumParseMethod(); - enumTypeExists = true; + if (!enumTypeExists) + { + EmitEnumParseMethod(); + enumTypeExists = true; + } + } + else + { + EmitPrimitiveParseMethod(type); } - } - else - { - EmitPrimitiveParseMethod(type); } } } @@ -615,7 +660,7 @@ private void EmitPrimitiveParseMethod(ParsableFromStringSpec type) string exceptionArg1 = string.Format(ExceptionMessages.FailedBinding, $"{{{Identifier.getPath}()}}", $"{{typeof({typeDisplayString})}}"); - EmitStartBlock($"public static {typeDisplayString} {type.ParseMethodName}(string {Identifier.value}, Func {Identifier.getPath})"); + EmitStartBlock($"public static {typeDisplayString} {TypeIndex.GetParseMethodName(type)}(string {Identifier.value}, Func {Identifier.getPath})"); EmitEndBlock($$""" try { @@ -628,13 +673,19 @@ private void EmitPrimitiveParseMethod(ParsableFromStringSpec type) """); } - private void EmitPopulationImplForArray(EnumerableSpec type) + private void EmitBindCoreImplForArray(ArraySpec type) { - EnumerableSpec typeToInstantiate = (EnumerableSpec)type.TypeToInstantiate; - - // Create list and bind elements. + TypeRef elementTypeRef = type.ElementTypeRef; + string elementTypeDisplayString = _typeIndex.GetTypeSpec(elementTypeRef).DisplayString; string tempIdentifier = GetIncrementalIdentifier(Identifier.temp); - EmitBindingLogic(typeToInstantiate, tempIdentifier, Identifier.configuration, InitializationKind.Declaration, ValueDefaulting.None); + + // Create temp list. + _writer.WriteLine($"var {tempIdentifier} = new List<{elementTypeDisplayString}>();"); + _writer.WriteLine(); + + // Bind elements to temp list. + EmitBindingLogicForEnumerableWithAdd(elementTypeRef, tempIdentifier); + _writer.WriteLine(); // Resize array and add binded elements. _writer.WriteLine($$""" @@ -644,15 +695,19 @@ private void EmitPopulationImplForArray(EnumerableSpec type) """); } - private void EmitPopulationImplForEnumerableWithAdd(EnumerableSpec type) + private void EmitBindCoreImplForEnumerableWithAdd(EnumerableSpec type) { EmitCollectionCastIfRequired(type, out string instanceIdentifier); + EmitBindingLogicForEnumerableWithAdd(type.ElementTypeRef, instanceIdentifier); + } + private void EmitBindingLogicForEnumerableWithAdd(TypeRef elementTypeRef, string enumerableIdentifier) + { Emit_Foreach_Section_In_ConfigChildren_StartBlock(); - string addExpr = $"{instanceIdentifier}.{Identifier.Add}"; + string addExpr = $"{enumerableIdentifier}.{Identifier.Add}"; - switch (type.ElementType) + switch (_typeIndex.GetEffectiveTypeSpec(elementTypeRef)) { case ParsableFromStringSpec stringParsableType: { @@ -666,12 +721,12 @@ private void EmitPopulationImplForEnumerableWithAdd(EnumerableSpec type) useIncrementalStringValueIdentifier: false); } break; - case ConfigurationSectionSpec configurationSection: + case ConfigurationSectionSpec: { _writer.WriteLine($"{addExpr}({Identifier.section});"); } break; - case ComplexTypeSpec { CanInstantiate: true } complexType: + case ComplexTypeSpec complexType when _typeIndex.CanInstantiate(complexType): { EmitBindingLogic(complexType, Identifier.value, Identifier.section, InitializationKind.Declaration, ValueDefaulting.None); _writer.WriteLine($"{addExpr}({Identifier.value});"); @@ -688,8 +743,8 @@ private void EmitBindCoreImplForDictionary(DictionarySpec type) Emit_Foreach_Section_In_ConfigChildren_StartBlock(); - ParsableFromStringSpec keyType = type.KeyType; - TypeSpec elementType = type.ElementType; + ParsableFromStringSpec keyType = (ParsableFromStringSpec)_typeIndex.GetEffectiveTypeSpec(type.KeyTypeRef); + TypeSpec elementType = _typeIndex.GetTypeSpec(type.ElementTypeRef); // Parse key EmitBindingLogic( @@ -717,15 +772,13 @@ void Emit_BindAndAddLogic_ForElement(string parsedKeyExpr) useIncrementalStringValueIdentifier: false); } break; - case ConfigurationSectionSpec configurationSection: + case ConfigurationSectionSpec: { _writer.WriteLine($"{instanceIdentifier}[{parsedKeyExpr}] = {Identifier.section};"); } break; case ComplexTypeSpec complexElementType: { - Debug.Assert(complexElementType.CanInstantiate); - if (keyType.StringParsableTypeKind is not StringParsableTypeKind.AssignFromSectionValue) { // Save value to local to avoid parsing twice - during look-up and during add. @@ -746,12 +799,32 @@ void Emit_BindAndAddLogic_ForElement(string parsedKeyExpr) conditionToUseExistingElement += $" && {expressionForElementIsNotNull}"; } - EmitStartBlock($"if (!({conditionToUseExistingElement}))"); - EmitObjectInit(complexElementType, Identifier.element, InitializationKind.SimpleAssignment, Identifier.section); - EmitEndBlock(); + if (_typeIndex.CanInstantiate(complexElementType)) + { + EmitStartBlock($"if (!({conditionToUseExistingElement}))"); + EmitObjectInit(complexElementType, Identifier.element, InitializationKind.SimpleAssignment, Identifier.section); + EmitEndBlock(); + + EmitBindingLogic(); + } + else + { + EmitStartBlock($"if ({conditionToUseExistingElement})"); + EmitBindingLogic(); + EmitEndBlock(); + } - EmitBindingLogic(complexElementType, Identifier.element, Identifier.section, InitializationKind.None, ValueDefaulting.None); - _writer.WriteLine($"{instanceIdentifier}[{parsedKeyExpr}] = {Identifier.element};"); + void EmitBindingLogic() + { + this.EmitBindingLogic( + complexElementType, + Identifier.element, + Identifier.section, + InitializationKind.None, + ValueDefaulting.None); + + _writer.WriteLine($"{instanceIdentifier}[{parsedKeyExpr}] = {Identifier.element};"); + } } break; } @@ -762,16 +835,15 @@ void Emit_BindAndAddLogic_ForElement(string parsedKeyExpr) private void EmitBindCoreImplForObject(ObjectSpec type) { - Debug.Assert(type.HasBindableMembers); + Debug.Assert(_typeIndex.HasBindableMembers(type)); - string keyCacheFieldName = GetConfigKeyCacheFieldName(type); + string keyCacheFieldName = TypeIndex.GetConfigKeyCacheFieldName(type); string validateMethodCallExpr = $"{Identifier.ValidateConfigurationKeys}(typeof({type.DisplayString}), {keyCacheFieldName}, {Identifier.configuration}, {Identifier.binderOptions});"; _writer.WriteLine(validateMethodCallExpr); - foreach (PropertySpec property in type.Properties.Values) + foreach (PropertySpec property in type.Properties!) { - bool noSetter_And_IsReadonly = !property.CanSet && property.Type is CollectionSpec { InstantiationStrategy: InstantiationStrategy.ParameterizedConstructor }; - if (property.ShouldBindTo && !noSetter_And_IsReadonly) + if (_typeIndex.ShouldBindTo(property)) { string containingTypeRef = property.IsStatic ? type.DisplayString : Identifier.instance; EmitBindImplForMember( @@ -791,11 +863,9 @@ private bool EmitBindImplForMember( bool canSet, InitializationKind initializationKind) { - TypeSpec effectiveMemberType = member.Type.EffectiveType; - string sectionParseExpr = GetSectionFromConfigurationExpression(member.ConfigurationKeyName); - switch (effectiveMemberType) + switch (_typeIndex.GetEffectiveTypeSpec(member.TypeRef)) { case ParsableFromStringSpec stringParsableType: { @@ -804,8 +874,8 @@ private bool EmitBindImplForMember( bool useDefaultValueIfSectionValueIsNull = initializationKind == InitializationKind.Declaration && member is PropertySpec && - member.Type.IsValueType && - member.Type.SpecKind is not TypeSpecKind.Nullable; + member.TypeRef.IsValueType && + _typeIndex.GetTypeSpec(member.TypeRef) is not NullableSpec; EmitBlankLineIfRequired(); EmitBindingLogic( @@ -840,7 +910,7 @@ member is PropertySpec && EmitBindingLogicForComplexMember(member, memberAccessExpr, sectionIdentifier, canSet); EmitEndBlock(); - return complexType.CanInstantiate; + return _typeIndex.CanInstantiate(complexType); } default: return false; @@ -854,8 +924,8 @@ private void EmitBindingLogicForComplexMember( bool canSet) { - TypeSpec memberType = member.Type; - ComplexTypeSpec effectiveMemberType = (ComplexTypeSpec)memberType.EffectiveType; + TypeSpec memberType = _typeIndex.GetTypeSpec(member.TypeRef); + ComplexTypeSpec effectiveMemberType = (ComplexTypeSpec)_typeIndex.GetEffectiveTypeSpec(memberType); string tempIdentifier = GetIncrementalIdentifier(Identifier.temp); InitializationKind initKind; @@ -872,7 +942,7 @@ private void EmitBindingLogicForComplexMember( string effectiveMemberTypeDisplayString = effectiveMemberType.DisplayString; initKind = InitializationKind.None; - if (memberType.SpecKind is TypeSpecKind.Nullable) + if (memberType is NullableSpec) { string nullableTempIdentifier = GetIncrementalIdentifier(Identifier.temp); @@ -902,12 +972,12 @@ private void EmitBindingLogicForComplexMember( Action? writeOnSuccess = !canSet ? null : bindedValueIdentifier => + { + if (memberAccessExpr != bindedValueIdentifier) { - if (memberAccessExpr != bindedValueIdentifier) - { - _writer.WriteLine($"{memberAccessExpr} = {bindedValueIdentifier};"); - } - }; + _writer.WriteLine($"{memberAccessExpr} = {bindedValueIdentifier};"); + } + }; EmitBindingLogic( effectiveMemberType, @@ -927,11 +997,11 @@ private void EmitBindingLogic( ValueDefaulting valueDefaulting, Action? writeOnSuccess = null) { - if (!type.HasBindableMembers) + if (!_typeIndex.HasBindableMembers(type)) { if (initKind is not InitializationKind.None) { - if (type.CanInstantiate) + if (_typeIndex.CanInstantiate(type)) { EmitObjectInit(type, memberAccessExpr, initKind, configArgExpr); } @@ -965,7 +1035,7 @@ void EmitBindingLogic(string instanceToBindExpr, InitializationKind initKind) { string bindCoreCall = $@"{nameof(MethodsToGen_CoreBindingHelper.BindCore)}({configArgExpr}, ref {instanceToBindExpr}, defaultValueIfNotFound: {FormatDefaultValueIfNotFound()}, {Identifier.binderOptions});"; - if (type.CanInstantiate) + if (_typeIndex.CanInstantiate(type)) { if (initKind is not InitializationKind.None) { @@ -977,15 +1047,13 @@ void EmitBindingLogic(string instanceToBindExpr, InitializationKind initKind) else { Debug.Assert(!type.IsValueType); - + EmitStartBlock($"if ({instanceToBindExpr} is not null)"); + EmitBindCoreCall(); + EmitEndBlock(); if (type is ObjectSpec { InitExceptionMessage: string exMsg }) { + EmitStartBlock("else"); _writer.WriteLine($@"throw new {Identifier.InvalidOperationException}(""{exMsg}"");"); - } - else - { - EmitStartBlock($"if ({instanceToBindExpr} is not null)"); - EmitBindCoreCall(); EmitEndBlock(); } } @@ -1018,7 +1086,7 @@ private void EmitBindingLogic( { StringParsableTypeKind.AssignFromSectionValue => stringValueToParse_Expr, StringParsableTypeKind.Enum => $"ParseEnum<{type.DisplayString}>({stringValueToParse_Expr}, () => {sectionPathExpr})", - _ => $"{type.ParseMethodName}({stringValueToParse_Expr}, () => {sectionPathExpr})", + _ => $"{TypeIndex.GetParseMethodName(type)}({stringValueToParse_Expr}, () => {sectionPathExpr})", }; if (!checkForNullSectionValue) @@ -1046,56 +1114,72 @@ private void EmitBindingLogic( private bool EmitObjectInit(ComplexTypeSpec type, string memberAccessExpr, InitializationKind initKind, string configArgExpr) { CollectionSpec? collectionType = type as CollectionSpec; + ObjectSpec? objectType = type as ObjectSpec; + + string? castExpr = null; string initExpr; string effectiveDisplayString = type.DisplayString; if (collectionType is not null) { - if (collectionType is EnumerableSpec { InstantiationStrategy: InstantiationStrategy.Array }) + if (collectionType is ArraySpec) { initExpr = $"new {s_arrayBracketsRegex.Replace(effectiveDisplayString, "[0]", 1)}"; } else { - effectiveDisplayString = (collectionType.TypeToInstantiate ?? collectionType).DisplayString; - initExpr = $"new {effectiveDisplayString}()"; + CollectionWithCtorInitSpec collectionWithCtorInitType = (CollectionWithCtorInitSpec)collectionType; + + if (collectionWithCtorInitType.InstantiationConcreteType is not CollectionInstantiationConcreteType.Self) + { + castExpr = $"({collectionWithCtorInitType.DisplayString})"; + } + + effectiveDisplayString = _typeIndex.GetInstantiationTypeDisplayString(collectionWithCtorInitType); + initExpr = $"{castExpr}new {effectiveDisplayString}()"; } } - else if (type.InstantiationStrategy is InstantiationStrategy.ParameterlessConstructor) - { - initExpr = $"new {effectiveDisplayString}()"; - } else { - Debug.Assert(type.InstantiationStrategy is InstantiationStrategy.ParameterizedConstructor); - string initMethodIdentifier = GetInitalizeMethodDisplayString(((ObjectSpec)type)); - initExpr = $"{initMethodIdentifier}({configArgExpr}, {Identifier.binderOptions})"; + Debug.Assert(objectType is not null); + ObjectInstantiationStrategy strategy = objectType.InstantiationStrategy; + + if (strategy is ObjectInstantiationStrategy.ParameterlessConstructor) + { + initExpr = $"new {effectiveDisplayString}()"; + } + else + { + Debug.Assert(strategy is ObjectInstantiationStrategy.ParameterizedConstructor); + string initMethodIdentifier = GetInitalizeMethodDisplayString(((ObjectSpec)type)); + initExpr = $"{initMethodIdentifier}({configArgExpr}, {Identifier.binderOptions})"; + } } switch (initKind) { case InitializationKind.Declaration: { - Debug.Assert(!memberAccessExpr.Contains(".")); + Debug.Assert(!memberAccessExpr.Contains('.')); _writer.WriteLine($"var {memberAccessExpr} = {initExpr};"); } break; case InitializationKind.AssignmentWithNullCheck: { - if (collectionType is CollectionSpec + + if (collectionType is CollectionWithCtorInitSpec { - InstantiationStrategy: InstantiationStrategy.ParameterizedConstructor or InstantiationStrategy.ToEnumerableMethod - }) + InstantiationStrategy: CollectionInstantiationStrategy.CopyConstructor or CollectionInstantiationStrategy.LinqToDictionary + } collectionWithCtorInitType) { - if (collectionType.InstantiationStrategy is InstantiationStrategy.ParameterizedConstructor) - { - _writer.WriteLine($"{memberAccessExpr} = {memberAccessExpr} is null ? {initExpr} : new {effectiveDisplayString}({memberAccessExpr});"); - } - else - { - Debug.Assert(collectionType is DictionarySpec); - _writer.WriteLine($"{memberAccessExpr} = {memberAccessExpr} is null ? {initExpr} : {memberAccessExpr}.ToDictionary(pair => pair.Key, pair => pair.Value);"); - } + string assignmentValueIfMemberNull = collectionWithCtorInitType.InstantiationStrategy is CollectionInstantiationStrategy.CopyConstructor + ? $"new {effectiveDisplayString}({memberAccessExpr})" + : $"{memberAccessExpr}.ToDictionary(pair => pair.Key, pair => pair.Value)"; + + Debug.Assert(castExpr is not null || collectionWithCtorInitType.InstantiationConcreteType is CollectionInstantiationConcreteType.Self); + assignmentValueIfMemberNull = $"{castExpr}{assignmentValueIfMemberNull}"; + + _writer.WriteLine($"{memberAccessExpr} = {memberAccessExpr} is null ? {initExpr} : {assignmentValueIfMemberNull};"); } else { @@ -1130,20 +1214,25 @@ private void EmitIConfigurationHasValueOrChildrenCheck(bool voidReturn) _writer.WriteLine(); } - private void EmitCollectionCastIfRequired(CollectionSpec type, out string instanceIdentifier) + private void EmitCollectionCastIfRequired(CollectionWithCtorInitSpec type, out string instanceIdentifier) { - instanceIdentifier = Identifier.instance; - if (type.PopulationStrategy is CollectionPopulationStrategy.Cast_Then_Add) + if (type.PopulationCastType is CollectionPopulationCastType.NotApplicable) { - instanceIdentifier = Identifier.temp; - _writer.WriteLine($$""" - if ({{Identifier.instance}} is not {{type.PopulationCastType!.DisplayString}} {{instanceIdentifier}}) + instanceIdentifier = Identifier.instance; + return; + } + + string castTypeDisplayString = _typeIndex.GetPopulationCastTypeDisplayString(type); + instanceIdentifier = Identifier.temp; + + _writer.WriteLine($$""" + if ({{Identifier.instance}} is not {{castTypeDisplayString}} {{instanceIdentifier}}) { return; } """); - _writer.WriteLine(); - } + _writer.WriteLine(); + } private void Emit_Foreach_Section_In_ConfigChildren_StartBlock() => @@ -1171,9 +1260,6 @@ private static string GetConditionKindExpr(ref bool isFirstType) return "else if"; } - - private static string GetConfigKeyCacheFieldName(ObjectSpec type) => - $"s_configKeys_{type.IdentifierCompatibleSubstring}"; } } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/Helpers.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/Helpers.cs index a7db2fb5163979..34a97d3c64c76c 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/Helpers.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/Helpers.cs @@ -1,7 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; using System.Collections.Generic; using System.Diagnostics; using System.Reflection; @@ -135,30 +134,29 @@ private static class Identifier public const string Value = nameof(Value); } - private bool ShouldEmitBindingExtensions() => - ShouldEmitMethods(MethodsToGen_ConfigurationBinder.Any) || - ShouldEmitMethods(MethodsToGen_Extensions_OptionsBuilder.Any) || - ShouldEmitMethods(MethodsToGen_Extensions_ServiceCollection.Any); + private bool ShouldEmitMethods(MethodsToGen methods) => (_interceptorInfo.MethodsToGen & methods) != 0; - private void EmitInterceptsLocationAnnotations(Enum generatedBindingOverload) + private void EmitInterceptsLocationAnnotations(MethodsToGen overload) { + IEnumerable? infoList = _interceptorInfo.GetInfo(overload); + bool interceptsCalls = infoList is not null; + // The only time a generated binding method won't have any locations to // intercept is when either of these methods are used as helpers for // other generated OptionsBuilder or ServiceCollection binding extensions. - bool interceptsCalls = _sourceGenSpec.InterceptionInfo.TryGetValue(generatedBindingOverload, out List? infoList); Debug.Assert(interceptsCalls || - generatedBindingOverload is MethodsToGen_Extensions_ServiceCollection.Configure_T_name_BinderOptions || - generatedBindingOverload is MethodsToGen_Extensions_OptionsBuilder.Bind_T_BinderOptions); + overload is MethodsToGen.ServiceCollectionExt_Configure_T_name_BinderOptions || + overload is MethodsToGen.OptionsBuilderExt_Bind_T_BinderOptions); if (interceptsCalls) { - EmitInterceptsLocationAnnotations(infoList); + EmitInterceptsLocationAnnotations(infoList!); } } - private void EmitInterceptsLocationAnnotations(List infoList) + private void EmitInterceptsLocationAnnotations(IEnumerable infoList) { - foreach (InterceptorLocationInfo info in infoList) + foreach (InvocationLocationInfo info in infoList) { _writer.WriteLine($@"[{Identifier.InterceptsLocation}(@""{info.FilePath}"", {info.LineNumber}, {info.CharacterNumber})]"); } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/OptionsBuilderConfigurationExtensions.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/OptionsBuilderConfigurationExtensions.cs index 7fd5d695eaf45a..fdc4286e34c559 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/OptionsBuilderConfigurationExtensions.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/OptionsBuilderConfigurationExtensions.cs @@ -7,11 +7,9 @@ public sealed partial class ConfigurationBindingGenerator { private sealed partial class Emitter { - private bool ShouldEmitMethods(MethodsToGen_Extensions_OptionsBuilder methods) => (_sourceGenSpec.MethodsToGen_OptionsBuilderExt & methods) != 0; - private void EmitBindingExtensions_OptionsBuilder() { - if (!ShouldEmitMethods(MethodsToGen_Extensions_OptionsBuilder.Any)) + if (!ShouldEmitMethods(MethodsToGen.OptionsBuilderExt_Any)) { return; } @@ -24,7 +22,7 @@ private void EmitBindingExtensions_OptionsBuilder() private void EmitBindMethods_Extensions_OptionsBuilder() { - if (!ShouldEmitMethods(MethodsToGen_Extensions_OptionsBuilder.Bind)) + if (!ShouldEmitMethods(MethodsToGen.OptionsBuilderExt_Bind)) { return; } @@ -32,15 +30,15 @@ private void EmitBindMethods_Extensions_OptionsBuilder() const string documentation = @"/// Registers a configuration instance which will bind against."; const string paramList = $"{Identifier.IConfiguration} {Identifier.config}"; - if (ShouldEmitMethods(MethodsToGen_Extensions_OptionsBuilder.Bind_T)) + if (ShouldEmitMethods(MethodsToGen.OptionsBuilderExt_Bind_T)) { - EmitMethodStartBlock(MethodsToGen_Extensions_OptionsBuilder.Bind_T, "Bind", paramList, documentation); + EmitMethodStartBlock(MethodsToGen.OptionsBuilderExt_Bind_T, "Bind", paramList, documentation); _writer.WriteLine($"return Bind({Identifier.optionsBuilder}, {Identifier.config}, {Identifier.configureBinder}: null);"); EmitEndBlock(); } EmitMethodStartBlock( - MethodsToGen_Extensions_OptionsBuilder.Bind_T_BinderOptions, + MethodsToGen.OptionsBuilderExt_Bind_T_BinderOptions, "Bind", paramList + $", {TypeDisplayString.NullableActionOfBinderOptions} {Identifier.configureBinder}", documentation); @@ -57,7 +55,7 @@ private void EmitBindMethods_Extensions_OptionsBuilder() private void EmitBindConfigurationMethod() { - if (!ShouldEmitMethods(MethodsToGen_Extensions_OptionsBuilder.BindConfiguration_T_path_BinderOptions)) + if (!ShouldEmitMethods(MethodsToGen.OptionsBuilderExt_BindConfiguration_T_path_BinderOptions)) { return; } @@ -65,7 +63,7 @@ private void EmitBindConfigurationMethod() const string documentation = $@"/// Registers the dependency injection container to bind against the obtained from the DI service provider."; string paramList = $"string {Identifier.configSectionPath}, {TypeDisplayString.NullableActionOfBinderOptions} {Identifier.configureBinder} = null"; - EmitMethodStartBlock(MethodsToGen_Extensions_OptionsBuilder.BindConfiguration, "BindConfiguration", paramList, documentation); + EmitMethodStartBlock(MethodsToGen.OptionsBuilderExt_BindConfiguration, "BindConfiguration", paramList, documentation); EmitCheckForNullArgument_WithBlankLine(Identifier.optionsBuilder); EmitCheckForNullArgument_WithBlankLine(Identifier.configSectionPath); @@ -89,7 +87,7 @@ private void EmitBindConfigurationMethod() EmitEndBlock(); } - private void EmitMethodStartBlock(MethodsToGen_Extensions_OptionsBuilder method, string methodName, string paramList, string documentation) + private void EmitMethodStartBlock(MethodsToGen method, string methodName, string paramList, string documentation) { paramList = $"this {TypeDisplayString.OptionsBuilderOfTOptions} {Identifier.optionsBuilder}, {paramList}"; EmitBlankLineIfRequired(); diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/OptionsConfigurationServiceCollectionExtensions.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/OptionsConfigurationServiceCollectionExtensions.cs index 7577e0c49de4d0..daa3b79db8abc4 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/OptionsConfigurationServiceCollectionExtensions.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Emitter/OptionsConfigurationServiceCollectionExtensions.cs @@ -7,11 +7,9 @@ public sealed partial class ConfigurationBindingGenerator { private sealed partial class Emitter { - private bool ShouldEmitMethods(MethodsToGen_Extensions_ServiceCollection methods) => (_sourceGenSpec.MethodsToGen_ServiceCollectionExt & methods) != 0; - private void EmitBindingExtensions_IServiceCollection() { - if (!ShouldEmitMethods(MethodsToGen_Extensions_ServiceCollection.Any)) + if (!ShouldEmitMethods(MethodsToGen.ServiceCollectionExt_Any)) { return; } @@ -26,26 +24,26 @@ private void EmitConfigureMethods() const string defaultNameExpr = "string.Empty"; string configParam = $"{Identifier.IConfiguration} {Identifier.config}"; - if (ShouldEmitMethods(MethodsToGen_Extensions_ServiceCollection.Configure_T)) + if (ShouldEmitMethods(MethodsToGen.ServiceCollectionExt_Configure_T)) { - EmitStartMethod(MethodsToGen_Extensions_ServiceCollection.Configure_T, configParam); + EmitStartMethod(MethodsToGen.ServiceCollectionExt_Configure_T, configParam); _writer.WriteLine($"return {Identifier.Configure}<{Identifier.TOptions}>({Identifier.services}, {defaultNameExpr}, {Identifier.config}, {Identifier.configureOptions}: null);"); EmitEndBlock(); } - if (ShouldEmitMethods(MethodsToGen_Extensions_ServiceCollection.Configure_T_name)) + if (ShouldEmitMethods(MethodsToGen.ServiceCollectionExt_Configure_T_name)) { EmitStartMethod( - MethodsToGen_Extensions_ServiceCollection.Configure_T_name, + MethodsToGen.ServiceCollectionExt_Configure_T_name, paramList: $"string? {Identifier.name}, " + configParam); _writer.WriteLine($"return {Identifier.Configure}<{Identifier.TOptions}>({Identifier.services}, {Identifier.name}, {Identifier.config}, {Identifier.configureOptions}: null);"); EmitEndBlock(); } - if (ShouldEmitMethods(MethodsToGen_Extensions_ServiceCollection.Configure_T_BinderOptions)) + if (ShouldEmitMethods(MethodsToGen.ServiceCollectionExt_Configure_T_BinderOptions)) { EmitStartMethod( - MethodsToGen_Extensions_ServiceCollection.Configure_T_BinderOptions, + MethodsToGen.ServiceCollectionExt_Configure_T_BinderOptions, paramList: configParam + $", {TypeDisplayString.NullableActionOfBinderOptions} {Identifier.configureOptions}"); _writer.WriteLine($"return {Identifier.Configure}<{Identifier.TOptions}>({Identifier.services}, {defaultNameExpr}, {Identifier.config}, {Identifier.configureOptions});"); EmitEndBlock(); @@ -54,7 +52,7 @@ private void EmitConfigureMethods() // Core Configure method that the other overloads call. // Like the others, it is public API that could be called directly by users. // So, it is always generated whenever a Configure overload is called. - EmitStartMethod(MethodsToGen_Extensions_ServiceCollection.Configure_T_name_BinderOptions, paramList: $"string? {Identifier.name}, " + configParam + $", {TypeDisplayString.NullableActionOfBinderOptions} {Identifier.configureOptions}"); + EmitStartMethod(MethodsToGen.ServiceCollectionExt_Configure_T_name_BinderOptions, paramList: $"string? {Identifier.name}, " + configParam + $", {TypeDisplayString.NullableActionOfBinderOptions} {Identifier.configureOptions}"); EmitCheckForNullArgument_WithBlankLine(Identifier.services); EmitCheckForNullArgument_WithBlankLine(Identifier.config); _writer.WriteLine($$""" @@ -65,7 +63,7 @@ private void EmitConfigureMethods() EmitEndBlock(); } - private void EmitStartMethod(MethodsToGen_Extensions_ServiceCollection overload, string paramList) + private void EmitStartMethod(MethodsToGen overload, string paramList) { paramList = $"this {Identifier.IServiceCollection} {Identifier.services}, {paramList}"; diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Microsoft.Extensions.Configuration.Binder.SourceGeneration.csproj b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Microsoft.Extensions.Configuration.Binder.SourceGeneration.csproj index 92c4c04cfa67c9..764682b43daa86 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Microsoft.Extensions.Configuration.Binder.SourceGeneration.csproj +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Microsoft.Extensions.Configuration.Binder.SourceGeneration.csproj @@ -9,6 +9,11 @@ $(DefineConstants);LAUNCH_DEBUGGER + + + $(NetCoreAppToolCurrent);netstandard2.0 + + @@ -17,15 +22,19 @@ - - - + + + + + + + @@ -38,20 +47,20 @@ - + + - - + + + - - diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/BinderInvocation.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/BinderInvocation.cs index ad7c4c09204d4b..b1cf51acb3b4a6 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/BinderInvocation.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/BinderInvocation.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics; @@ -9,8 +9,17 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { - internal sealed record BinderInvocation(IInvocationOperation Operation, Location Location) + internal sealed class BinderInvocation { + private BinderInvocation(IInvocationOperation operation, Location location) + { + Operation = operation; + Location = location; + } + + public IInvocationOperation Operation { get; } + public Location Location { get; } + public static BinderInvocation? Create(GeneratorSyntaxContext context, CancellationToken cancellationToken) { Debug.Assert(IsCandidateSyntaxNode(context.Node)); @@ -35,8 +44,8 @@ public static bool IsCandidateSyntaxNode(SyntaxNode node) } && IsCandidateBindingMethodName(memberName); static bool IsCandidateBindingMethodName(string name) => - IsCandidateMethodName_ConfigurationBinder(name) || - IsCandidateMethodName_OptionsBuilderConfigurationExtensions(name) || + IsValidMethodName_ConfigurationBinder(name) || + IsValidMethodName_OptionsBuilderConfigurationExtensions(name) || IsValidMethodName_OptionsConfigurationServiceCollectionExtensions(name); } @@ -62,10 +71,10 @@ public static bool IsBindingOperation(IInvocationOperation operation) { "ConfigurationBinder" => containingNamespaceName is "Microsoft.Extensions.Configuration" && - IsCandidateMethodName_ConfigurationBinder(methodName), + IsValidMethodName_ConfigurationBinder(methodName), "OptionsBuilderConfigurationExtensions" => containingNamespaceName is "Microsoft.Extensions.DependencyInjection" && - IsCandidateMethodName_OptionsBuilderConfigurationExtensions(methodName), + IsValidMethodName_OptionsBuilderConfigurationExtensions(methodName), "OptionsConfigurationServiceCollectionExtensions" => containingNamespaceName is "Microsoft.Extensions.DependencyInjection" && IsValidMethodName_OptionsConfigurationServiceCollectionExtensions(methodName), @@ -73,16 +82,10 @@ containingNamespaceName is "Microsoft.Extensions.DependencyInjection" && }; } - private static bool IsCandidateMethodName_ConfigurationBinder(string name) => name is - nameof(MethodsToGen_ConfigurationBinder.Bind) or - nameof(MethodsToGen_ConfigurationBinder.Get) or - nameof(MethodsToGen_ConfigurationBinder.GetValue); + private static bool IsValidMethodName_ConfigurationBinder(string name) => name is "Bind" or "Get" or "GetValue"; - private static bool IsCandidateMethodName_OptionsBuilderConfigurationExtensions(string name) => name is - nameof(MethodsToGen_Extensions_OptionsBuilder.Bind) or - nameof(MethodsToGen_Extensions_OptionsBuilder.BindConfiguration); + private static bool IsValidMethodName_OptionsBuilderConfigurationExtensions(string name) => name is "Bind" or "BindConfiguration"; - private static bool IsValidMethodName_OptionsConfigurationServiceCollectionExtensions(string name) => name is - nameof(MethodsToGen_Extensions_ServiceCollection.Configure); + private static bool IsValidMethodName_OptionsConfigurationServiceCollectionExtensions(string name) => name is "Configure"; } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/ConfigurationBinder.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/ConfigurationBinder.cs index 3996142adf9089..645786e35c1c55 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/ConfigurationBinder.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/ConfigurationBinder.cs @@ -6,28 +6,29 @@ using System.Linq; using Microsoft.CodeAnalysis.Operations; using Microsoft.CodeAnalysis; +using System.Diagnostics; namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { public sealed partial class ConfigurationBindingGenerator { - private sealed partial class Parser + internal sealed partial class Parser { private void ParseInvocation_ConfigurationBinder(BinderInvocation invocation) { switch (invocation.Operation.TargetMethod.Name) { - case nameof(MethodsToGen_ConfigurationBinder.Bind): + case "Bind": { ParseBindInvocation_ConfigurationBinder(invocation); } break; - case nameof(MethodsToGen_ConfigurationBinder.Get): + case "Get": { ParseGetInvocation(invocation); } break; - case nameof(MethodsToGen_ConfigurationBinder.GetValue): + case "GetValue": { ParseGetValueInvocation(invocation); } @@ -46,39 +47,39 @@ private void ParseBindInvocation_ConfigurationBinder(BinderInvocation invocation return; } - MethodsToGen_ConfigurationBinder overload = MethodsToGen_ConfigurationBinder.None; + MethodsToGen overload = MethodsToGen.None; if (paramCount is 2) { - overload = MethodsToGen_ConfigurationBinder.Bind_instance; + overload = MethodsToGen.ConfigBinder_Bind_instance; } else if (paramCount is 3) { if (@params[1].Type.SpecialType is SpecialType.System_String) { - overload = MethodsToGen_ConfigurationBinder.Bind_key_instance; + overload = MethodsToGen.ConfigBinder_Bind_key_instance; } else if (SymbolEqualityComparer.Default.Equals(@params[2].Type, _typeSymbols.ActionOfBinderOptions)) { - overload = MethodsToGen_ConfigurationBinder.Bind_instance_BinderOptions; + overload = MethodsToGen.ConfigBinder_Bind_instance_BinderOptions; } } - if (overload is MethodsToGen_ConfigurationBinder.None) + if (overload is MethodsToGen.None) { return; } int instanceIndex = overload switch { - MethodsToGen_ConfigurationBinder.Bind_instance => 1, - MethodsToGen_ConfigurationBinder.Bind_instance_BinderOptions => 1, - MethodsToGen_ConfigurationBinder.Bind_key_instance => 2, + MethodsToGen.ConfigBinder_Bind_instance => 1, + MethodsToGen.ConfigBinder_Bind_instance_BinderOptions => 1, + MethodsToGen.ConfigBinder_Bind_key_instance => 2, _ => throw new InvalidOperationException() }; IArgumentOperation instanceArg = GetArgumentForParameterAtIndex(operation.Arguments, instanceIndex); - if (instanceArg.Parameter.Type.SpecialType != SpecialType.System_Object) + if (instanceArg.Parameter?.Type.SpecialType is not SpecialType.System_Object) { return; } @@ -87,20 +88,17 @@ private void ParseBindInvocation_ConfigurationBinder(BinderInvocation invocation if (!IsValidRootConfigType(type)) { - _context.ReportDiagnostic(Diagnostic.Create(Diagnostics.CouldNotDetermineTypeInfo, invocation.Location)); + RecordDiagnostic(DiagnosticDescriptors.CouldNotDetermineTypeInfo, invocation.Location); return; } - if (type!.IsValueType) + if (type.IsValueType) { - _context.ReportDiagnostic(Diagnostic.Create(Diagnostics.ValueTypesInvalidForBind, invocation.Location, type)); + RecordDiagnostic(DiagnosticDescriptors.ValueTypesInvalidForBind, invocation.Location, messageArgs: new object[] { type }); return; } - if (GetTargetTypeForRootInvocationCore(type, invocation.Location) is TypeSpec typeSpec) - { - RegisterInterceptor(overload, typeSpec, invocation.Operation); - } + EnqueueTargetTypeForRootInvocation(type, overload, invocation); static ITypeSymbol? ResolveType(IOperation conversionOperation) => conversionOperation switch @@ -144,7 +142,7 @@ private void ParseGetInvocation(BinderInvocation invocation) return; } - MethodsToGen_ConfigurationBinder overload = MethodsToGen_ConfigurationBinder.None; + MethodsToGen overload = MethodsToGen.None; ITypeSymbol? type; if (targetMethod.IsGenericMethod) @@ -158,11 +156,11 @@ private void ParseGetInvocation(BinderInvocation invocation) if (paramCount is 1) { - overload = MethodsToGen_ConfigurationBinder.Get_T; + overload = MethodsToGen.ConfigBinder_Get_T; } else if (paramCount is 2 && SymbolEqualityComparer.Default.Equals(@params[1].Type, _typeSymbols.ActionOfBinderOptions)) { - overload = MethodsToGen_ConfigurationBinder.Get_T_BinderOptions; + overload = MethodsToGen.ConfigBinder_Get_T_BinderOptions; } } else if (paramCount > 3) @@ -176,20 +174,15 @@ private void ParseGetInvocation(BinderInvocation invocation) if (paramCount is 2) { - overload = MethodsToGen_ConfigurationBinder.Get_TypeOf; + overload = MethodsToGen.ConfigBinder_Get_TypeOf; } else if (paramCount is 3 && SymbolEqualityComparer.Default.Equals(@params[2].Type, _typeSymbols.ActionOfBinderOptions)) { - overload = MethodsToGen_ConfigurationBinder.Get_TypeOf_BinderOptions; + overload = MethodsToGen.ConfigBinder_Get_TypeOf_BinderOptions; } } - if (GetTargetTypeForRootInvocation(type, invocation.Location) is TypeSpec typeSpec) - { - RegisterInvocation(overload, invocation.Operation); - RegisterTypeForGetCoreGen(typeSpec); - } - + EnqueueTargetTypeForRootInvocation(type, overload, invocation); } private void ParseGetValueInvocation(BinderInvocation invocation) @@ -199,7 +192,7 @@ private void ParseGetValueInvocation(BinderInvocation invocation) ImmutableArray @params = targetMethod.Parameters; int paramCount = @params.Length; - MethodsToGen_ConfigurationBinder overload = MethodsToGen_ConfigurationBinder.None; + MethodsToGen overload = MethodsToGen.None; ITypeSymbol? type; if (targetMethod.IsGenericMethod) @@ -213,11 +206,11 @@ private void ParseGetValueInvocation(BinderInvocation invocation) if (paramCount is 2) { - overload = MethodsToGen_ConfigurationBinder.GetValue_T_key; + overload = MethodsToGen.ConfigBinder_GetValue_T_key; } else if (paramCount is 3 && SymbolEqualityComparer.Default.Equals(@params[2].Type, type)) { - overload = MethodsToGen_ConfigurationBinder.GetValue_T_key_defaultValue; + overload = MethodsToGen.ConfigBinder_GetValue_T_key_defaultValue; } } else if (paramCount > 4) @@ -236,45 +229,56 @@ private void ParseGetValueInvocation(BinderInvocation invocation) if (paramCount is 3) { - overload = MethodsToGen_ConfigurationBinder.GetValue_TypeOf_key; + overload = MethodsToGen.ConfigBinder_GetValue_TypeOf_key; } else if (paramCount is 4 && @params[3].Type.SpecialType is SpecialType.System_Object) { - overload = MethodsToGen_ConfigurationBinder.GetValue_TypeOf_key_defaultValue; + overload = MethodsToGen.ConfigBinder_GetValue_TypeOf_key_defaultValue; } } - ITypeSymbol effectiveType = (IsNullable(type, out ITypeSymbol? underlyingType) ? underlyingType : type)!; - if (!IsValidRootConfigType(type)) { - _context.ReportDiagnostic(Diagnostic.Create(Diagnostics.CouldNotDetermineTypeInfo, invocation.Location)); + RecordDiagnostic(DiagnosticDescriptors.CouldNotDetermineTypeInfo, invocation.Location); return; } - if (IsParsableFromString(effectiveType, out _) && - GetTargetTypeForRootInvocationCore(type, invocation.Location) is TypeSpec typeSpec) + ITypeSymbol effectiveType = IsNullable(type, out ITypeSymbol? underlyingType) ? underlyingType : type; + + if (IsParsableFromString(effectiveType, out _)) { - RegisterInvocation(overload, invocation.Operation); - RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper.GetValueCore, typeSpec); + EnqueueTargetTypeForRootInvocation(type, overload, invocation); } } - private void RegisterInvocation(MethodsToGen_ConfigurationBinder overload, IInvocationOperation operation) + private void RegisterInterceptor_ConfigurationBinder(TypeParseInfo typeParseInfo, TypeSpec typeSpec) { - _sourceGenSpec.MethodsToGen_ConfigurationBinder |= overload; - RegisterInterceptor(overload, operation); - } + MethodsToGen overload = typeParseInfo.BindingOverload; + IInvocationOperation invocationOperation = typeParseInfo.BinderInvocation!.Operation; + Debug.Assert((MethodsToGen.ConfigBinder_Any & overload) is not 0); - /// - /// Registers generated Bind methods as interceptors. This is done differently from other root - /// methods because we need to - /// explicitly account for the type to bind, to avoid type-check issues for polymorphic objects. - /// - private void RegisterInterceptor(MethodsToGen_ConfigurationBinder overload, TypeSpec typeSpec, IInvocationOperation operation) - { - _sourceGenSpec.MethodsToGen_ConfigurationBinder |= overload; - _sourceGenSpec.InterceptionInfo_ConfigBinder.RegisterOverloadInfo(overload, typeSpec, operation); + if ((MethodsToGen.ConfigBinder_Bind & overload) is not 0) + { + if (typeSpec is ComplexTypeSpec complexTypeSpec && + _helperInfoBuilder!.TryRegisterTransitiveTypesForMethodGen(complexTypeSpec.TypeRef)) + { + _interceptorInfoBuilder.RegisterInterceptor_ConfigBinder_Bind(overload, complexTypeSpec, invocationOperation); + } + } + else + { + Debug.Assert((MethodsToGen.ConfigBinder_Get & overload) is not 0 || + (MethodsToGen.ConfigBinder_GetValue & overload) is not 0); + + bool registered = (MethodsToGen.ConfigBinder_Get & overload) is not 0 + ? _helperInfoBuilder!.TryRegisterTypeForGetGen(typeSpec) + : _helperInfoBuilder!.TryRegisterTypeForGetValueGen(typeSpec); + + if (registered) + { + _interceptorInfoBuilder.RegisterInterceptor(overload, invocationOperation); + } + } } } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/Diagnostics.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/DiagnosticDescriptors.cs similarity index 82% rename from src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/Diagnostics.cs rename to src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/DiagnosticDescriptors.cs index d6d816545bcd0a..3f694c78be8309 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/Diagnostics.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/DiagnosticDescriptors.cs @@ -9,9 +9,9 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { public sealed partial class ConfigurationBindingGenerator { - private sealed partial class Parser + internal sealed partial class Parser { - internal static class Diagnostics + private static class DiagnosticDescriptors { public static DiagnosticDescriptor TypeNotSupported { get; } = CreateTypeNotSupportedDescriptor(nameof(SR.TypeNotSupported)); public static DiagnosticDescriptor MissingPublicInstanceConstructor { get; } = CreateTypeNotSupportedDescriptor(nameof(SR.MissingPublicInstanceConstructor)); @@ -62,6 +62,20 @@ private static DiagnosticDescriptor CreateTypeNotSupportedDescriptor(string name category: ProjectName, defaultSeverity: DiagnosticSeverity.Warning, isEnabledByDefault: true); + + public static DiagnosticDescriptor GetNotSupportedDescriptor(NotSupportedReason reason) => + reason switch + { + NotSupportedReason.UnknownType => TypeNotSupported, + NotSupportedReason.MissingPublicInstanceConstructor => MissingPublicInstanceConstructor, + NotSupportedReason.CollectionNotSupported => CollectionNotSupported, + NotSupportedReason.DictionaryKeyNotSupported => DictionaryKeyNotSupported, + NotSupportedReason.ElementTypeNotSupported => ElementTypeNotSupported, + NotSupportedReason.MultipleParameterizedConstructors => MultipleParameterizedConstructors, + NotSupportedReason.MultiDimArraysNotSupported => MultiDimArraysNotSupported, + NotSupportedReason.NullableUnderlyingTypeNotSupported => NullableUnderlyingTypeNotSupported, + _ => throw new InvalidOperationException() + }; } } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/Extensions.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/Extensions.cs index fa0b3691ec4047..f685842639966a 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/Extensions.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/Extensions.cs @@ -8,6 +8,54 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { + public sealed partial class ConfigurationBindingGenerator + { + internal sealed partial class Parser + { + private readonly struct TypeParseInfo + { + public ITypeSymbol TypeSymbol { get; private init; } + public string TypeName { get; private init; } + public MethodsToGen BindingOverload { get; private init; } + public BinderInvocation BinderInvocation { get; private init; } + public ContainingTypeDiagnosticInfo? ContainingTypeDiagnosticInfo { get; private init; } + + public static TypeParseInfo Create(ITypeSymbol typeSymbol, MethodsToGen overload, BinderInvocation invocation, ContainingTypeDiagnosticInfo? containingTypeDiagInfo = null) => + new TypeParseInfo + { + TypeSymbol = typeSymbol, + TypeName = typeSymbol.GetName(), + BindingOverload = overload, + BinderInvocation = invocation, + ContainingTypeDiagnosticInfo = containingTypeDiagInfo, + }; + + public TypeParseInfo ToTransitiveTypeParseInfo(ITypeSymbol memberType, DiagnosticDescriptor? diagDescriptor = null, string? memberName = null) + { + ContainingTypeDiagnosticInfo? diagnosticInfo = diagDescriptor is null + ? null + : new() + { + TypeName = TypeName, + Descriptor = diagDescriptor, + MemberName = memberName, + ContainingTypeInfo = ContainingTypeDiagnosticInfo, + }; + + return Create(memberType, BindingOverload, BinderInvocation, diagnosticInfo); + } + } + + private sealed class ContainingTypeDiagnosticInfo + { + public required string TypeName { get; init; } + public required string? MemberName { get; init; } + public required DiagnosticDescriptor Descriptor { get; init; } + public required ContainingTypeDiagnosticInfo? ContainingTypeInfo { get; init; } + } + } + } + internal static class ParserExtensions { private static readonly SymbolDisplayFormat s_identifierCompatibleFormat = new SymbolDisplayFormat( @@ -16,6 +64,12 @@ internal static class ParserExtensions genericsOptions: SymbolDisplayGenericsOptions.None, miscellaneousOptions: SymbolDisplayMiscellaneousOptions.UseSpecialTypes); + private static readonly SymbolDisplayFormat s_minimalDisplayFormat = new SymbolDisplayFormat( + globalNamespaceStyle: SymbolDisplayGlobalNamespaceStyle.Omitted, + typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypes, + genericsOptions: SymbolDisplayGenericsOptions.IncludeTypeParameters, + miscellaneousOptions: SymbolDisplayMiscellaneousOptions.UseSpecialTypes); + public static void RegisterCacheEntry(this Dictionary cache, TKey key, TEntry entry) where TKey : notnull where TValue : ICollection, new() @@ -28,12 +82,6 @@ public static void RegisterCacheEntry(this Dictionary> source, out ComplexTypeSpec Key, out List Value) - { - Key = (ComplexTypeSpec)source.Key; - Value = source.Value; - } - public static string ToIdentifierCompatibleSubstring(this ITypeSymbol type) { if (type is IArrayTypeSymbol arrayType) @@ -64,5 +112,15 @@ public static string ToIdentifierCompatibleSubstring(this ITypeSymbol type) return sb.ToString(); } + + public static (string? Namespace, string DisplayString, string Name) GetTypeName(this ITypeSymbol type) + { + string? @namespace = type.ContainingNamespace is { IsGlobalNamespace: false } containingNamespace ? containingNamespace.ToDisplayString() : null; + string displayString = type.ToDisplayString(s_minimalDisplayFormat); + string name = (@namespace is null ? string.Empty : @namespace + ".") + displayString.Replace(".", "+"); + return (@namespace, displayString, name); + } + + public static string GetName(this ITypeSymbol type) => GetTypeName(type).Name; } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/KnownTypeSymbols.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/KnownTypeSymbols.cs similarity index 96% rename from src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/KnownTypeSymbols.cs rename to src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/KnownTypeSymbols.cs index e381dc9c7c43ee..07dae8689782e4 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/KnownTypeSymbols.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/KnownTypeSymbols.cs @@ -11,7 +11,7 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { - internal sealed record KnownTypeSymbols + internal sealed class KnownTypeSymbols { public CSharpCompilation Compilation { get; } @@ -37,7 +37,7 @@ internal sealed record KnownTypeSymbols public INamedTypeSymbol? OptionsConfigurationServiceCollectionExtensions { get; } public INamedTypeSymbol GenericIList_Unbound { get; } - public INamedTypeSymbol GenericICollection_Unbound { get; } + public INamedTypeSymbol? GenericICollection_Unbound { get; } public INamedTypeSymbol GenericICollection { get; } public INamedTypeSymbol GenericIEnumerable_Unbound { get; } public INamedTypeSymbol IEnumerable { get; } @@ -61,7 +61,8 @@ public KnownTypeSymbols(CSharpCompilation compilation) { Compilation = compilation; - // Primitives (needed because they are Microsoft.CodeAnalysis.SpecialType.None) + // Primitives + String = compilation.GetSpecialType(SpecialType.System_String); CultureInfo = compilation.GetBestTypeByMetadataName(typeof(CultureInfo)); DateOnly = compilation.GetBestTypeByMetadataName("System.DateOnly"); DateTimeOffset = compilation.GetBestTypeByMetadataName(typeof(DateTimeOffset)); @@ -103,7 +104,7 @@ public KnownTypeSymbols(CSharpCompilation compilation) // Used for type equivalency checks for unbound generics. The parameters of the types // retured by the Roslyn Get*Type* APIs are not unbound, so we construct unbound // generics to equal those corresponding to generic types in the input type graphs. - GenericICollection_Unbound = GenericICollection?.ConstructUnboundGenericType(); + GenericICollection_Unbound = GenericICollection.ConstructUnboundGenericType(); GenericIDictionary_Unbound = GenericIDictionary?.ConstructUnboundGenericType(); GenericIEnumerable_Unbound = compilation.GetSpecialType(SpecialType.System_Collections_Generic_IEnumerable_T).ConstructUnboundGenericType(); GenericIList_Unbound = compilation.GetSpecialType(SpecialType.System_Collections_Generic_IList_T).ConstructUnboundGenericType(); diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/OptionsBuilderConfigurationExtensions.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/OptionsBuilderConfigurationExtensions.cs index 9cf59a120e1fdc..eb0ab086bcd588 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/OptionsBuilderConfigurationExtensions.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/OptionsBuilderConfigurationExtensions.cs @@ -10,7 +10,7 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { public sealed partial class ConfigurationBindingGenerator { - private sealed partial class Parser + internal sealed partial class Parser { private void ParseInvocation_OptionsBuilderExt(BinderInvocation invocation) { @@ -29,22 +29,17 @@ private void ParseInvocation_OptionsBuilderExt(BinderInvocation invocation) // This would violate generic type constraint; any such invocation could not have been included in the initial parser. Debug.Assert(typeSymbol?.IsValueType is not true); - if (GetTargetTypeForRootInvocation(typeSymbol, invocation.Location) is not ComplexTypeSpec typeSpec) - { - return; - } - if (targetMethod.Name is "Bind") { - ParseBindInvocation_OptionsBuilderExt(invocation, typeSpec); + ParseBindInvocation_OptionsBuilderExt(invocation, typeSymbol); } else if (targetMethod.Name is "BindConfiguration") { - ParseBindConfigurationInvocation(invocation, typeSpec); + ParseBindConfigurationInvocation(invocation, typeSymbol); } } - private void ParseBindInvocation_OptionsBuilderExt(BinderInvocation invocation, ComplexTypeSpec typeSpec) + private void ParseBindInvocation_OptionsBuilderExt(BinderInvocation invocation, ITypeSymbol? type) { IInvocationOperation operation = invocation.Operation!; IMethodSymbol targetMethod = operation.TargetMethod; @@ -58,22 +53,21 @@ private void ParseBindInvocation_OptionsBuilderExt(BinderInvocation invocation, return; } - MethodsToGen_Extensions_OptionsBuilder overload = paramCount switch + MethodsToGen overload = paramCount switch { - 2 => MethodsToGen_Extensions_OptionsBuilder.Bind_T, + 2 => MethodsToGen.OptionsBuilderExt_Bind_T, 3 when SymbolEqualityComparer.Default.Equals(_typeSymbols.ActionOfBinderOptions, @params[2].Type) => - MethodsToGen_Extensions_OptionsBuilder.Bind_T_BinderOptions, - _ => MethodsToGen_Extensions_OptionsBuilder.None + MethodsToGen.OptionsBuilderExt_Bind_T_BinderOptions, + _ => MethodsToGen.None }; - if (overload is not MethodsToGen_Extensions_OptionsBuilder.None && - TryRegisterTypeForMethodGen(MethodsToGen_Extensions_ServiceCollection.Configure_T_name_BinderOptions, typeSpec)) + if (overload is not MethodsToGen.None) { - RegisterInvocation(overload, operation); + EnqueueTargetTypeForRootInvocation(type, overload, invocation); } } - private void ParseBindConfigurationInvocation(BinderInvocation invocation, ComplexTypeSpec typeSpec) + private void ParseBindConfigurationInvocation(BinderInvocation invocation, ITypeSymbol? type) { IMethodSymbol targetMethod = invocation.Operation.TargetMethod; ImmutableArray @params = targetMethod.Parameters; @@ -83,23 +77,41 @@ private void ParseBindConfigurationInvocation(BinderInvocation invocation, Compl if (paramCount is 3 && @params[1].Type.SpecialType is SpecialType.System_String && - SymbolEqualityComparer.Default.Equals(_typeSymbols.ActionOfBinderOptions, @params[2].Type) && - TryRegisterTypeForBindCoreMainGen(typeSpec)) + SymbolEqualityComparer.Default.Equals(_typeSymbols.ActionOfBinderOptions, @params[2].Type)) { - RegisterInvocation(MethodsToGen_Extensions_OptionsBuilder.BindConfiguration_T_path_BinderOptions, invocation.Operation); + EnqueueTargetTypeForRootInvocation(type, MethodsToGen.OptionsBuilderExt_BindConfiguration_T_path_BinderOptions, invocation); } } - private void RegisterInvocation(MethodsToGen_Extensions_OptionsBuilder overload, IInvocationOperation operation) + private void RegisterInterceptor_OptionsBuilderExt(TypeParseInfo typeParseInfo, TypeSpec typeSpec) { - _sourceGenSpec.MethodsToGen_OptionsBuilderExt |= overload; - RegisterInterceptor(overload, operation); + MethodsToGen overload = typeParseInfo.BindingOverload; + Debug.Assert((MethodsToGen.OptionsBuilderExt_Any & overload) is not 0); + + if (typeSpec is not ComplexTypeSpec complexTypeSpec) + { + return; + } + + if ((MethodsToGen.OptionsBuilderExt_Bind & overload) is not 0) + { + if (!TryRegisterTypeForOverloadGen_ServiceCollectionExt(MethodsToGen.ServiceCollectionExt_Configure_T_name_BinderOptions, complexTypeSpec)) + { + return; + } + } + else if (!_helperInfoBuilder!.TryRegisterTypeForBindCoreMainGen(complexTypeSpec)) + { + return; + } + + _interceptorInfoBuilder.RegisterInterceptor(typeParseInfo.BindingOverload, typeParseInfo.BinderInvocation.Operation); // Emitting refs to IOptionsChangeTokenSource, ConfigurationChangeTokenSource. - _sourceGenSpec.Namespaces.Add("Microsoft.Extensions.Options"); + _helperInfoBuilder!.RegisterNamespace("Microsoft.Extensions.Options"); // Emitting refs to OptionsBuilder. - _sourceGenSpec.Namespaces.Add("Microsoft.Extensions.DependencyInjection"); + _helperInfoBuilder!.RegisterNamespace("Microsoft.Extensions.DependencyInjection"); } } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/OptionsConfigurationServiceCollectionExtensions.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/OptionsConfigurationServiceCollectionExtensions.cs index e86231f32e42ab..1ccef24bc6b71f 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/OptionsConfigurationServiceCollectionExtensions.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Parser/OptionsConfigurationServiceCollectionExtensions.cs @@ -10,7 +10,7 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { public sealed partial class ConfigurationBindingGenerator { - private sealed partial class Parser + internal sealed partial class Parser { private void ParseInvocation_ServiceCollectionExt(BinderInvocation invocation) { @@ -30,11 +30,11 @@ private void ParseInvocation_ServiceCollectionExt(BinderInvocation invocation) return; } - MethodsToGen_Extensions_ServiceCollection overload; + MethodsToGen overload; if (paramCount is 2 && SymbolEqualityComparer.Default.Equals(_typeSymbols.IConfiguration, @params[1].Type)) { - overload = MethodsToGen_Extensions_ServiceCollection.Configure_T; + overload = MethodsToGen.ServiceCollectionExt_Configure_T; } else if (paramCount is 3) { @@ -44,12 +44,12 @@ private void ParseInvocation_ServiceCollectionExt(BinderInvocation invocation) if (secondParamType.SpecialType is SpecialType.System_String && SymbolEqualityComparer.Default.Equals(_typeSymbols.IConfiguration, thirdParamType)) { - overload = MethodsToGen_Extensions_ServiceCollection.Configure_T_name; + overload = MethodsToGen.ServiceCollectionExt_Configure_T_name; } else if (SymbolEqualityComparer.Default.Equals(_typeSymbols.IConfiguration, secondParamType) && SymbolEqualityComparer.Default.Equals(_typeSymbols.ActionOfBinderOptions, thirdParamType)) { - overload = MethodsToGen_Extensions_ServiceCollection.Configure_T_BinderOptions; + overload = MethodsToGen.ServiceCollectionExt_Configure_T_BinderOptions; } else { @@ -61,7 +61,7 @@ @params[1].Type.SpecialType is SpecialType.System_String && SymbolEqualityComparer.Default.Equals(_typeSymbols.IConfiguration, @params[2].Type) && SymbolEqualityComparer.Default.Equals(_typeSymbols.ActionOfBinderOptions, @params[3].Type)) { - overload = MethodsToGen_Extensions_ServiceCollection.Configure_T_name_BinderOptions; + overload = MethodsToGen.ServiceCollectionExt_Configure_T_name_BinderOptions; } else { @@ -73,25 +73,34 @@ @params[1].Type.SpecialType is SpecialType.System_String && // This would violate generic type constraint; any such invocation could not have been included in the initial parser. Debug.Assert(typeSymbol?.IsValueType is not true); - if (GetTargetTypeForRootInvocation(typeSymbol, invocation.Location) is ComplexTypeSpec typeSpec && - TryRegisterTypeForMethodGen(overload, typeSpec)) + EnqueueTargetTypeForRootInvocation(typeSymbol, overload, invocation); + } + + private void RegisterInterceptor_ServiceCollectionExt(TypeParseInfo typeParseInfo, TypeSpec typeSpec) + { + MethodsToGen overload = typeParseInfo.BindingOverload; + + if (typeSpec is ComplexTypeSpec complexTypeSpec && + TryRegisterTypeForOverloadGen_ServiceCollectionExt(overload, complexTypeSpec)) { - RegisterInterceptor(overload, operation); + _interceptorInfoBuilder.RegisterInterceptor(overload, typeParseInfo.BinderInvocation.Operation); } } - private bool TryRegisterTypeForMethodGen(MethodsToGen_Extensions_ServiceCollection overload, ComplexTypeSpec typeSpec) + private bool TryRegisterTypeForOverloadGen_ServiceCollectionExt(MethodsToGen overload, ComplexTypeSpec typeSpec) { - if (TryRegisterTypeForBindCoreMainGen(typeSpec)) + Debug.Assert((MethodsToGen.ServiceCollectionExt_Any & overload) is not 0); + + if (!_helperInfoBuilder!.TryRegisterTypeForBindCoreMainGen(typeSpec)) { - _sourceGenSpec.MethodsToGen_ServiceCollectionExt |= overload; - _sourceGenSpec.Namespaces.Add("Microsoft.Extensions.DependencyInjection"); - // Emitting refs to IOptionsChangeTokenSource, ConfigurationChangeTokenSource, IConfigureOptions<>, ConfigureNamedOptions<>. - _sourceGenSpec.Namespaces.Add("Microsoft.Extensions.Options"); - return true; + return false; } - return false; + _interceptorInfoBuilder.MethodsToGen |= overload; + _helperInfoBuilder!.RegisterNamespace("Microsoft.Extensions.DependencyInjection"); + // Emitting refs to IOptionsChangeTokenSource, ConfigurationChangeTokenSource, IConfigureOptions<>, ConfigureNamedOptions<>. + _helperInfoBuilder!.RegisterNamespace("Microsoft.Extensions.Options"); + return true; } } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/BindingHelperInfo.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/BindingHelperInfo.cs new file mode 100644 index 00000000000000..096c8410717ae7 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/BindingHelperInfo.cs @@ -0,0 +1,237 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using SourceGenerators; + +namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration +{ + public sealed record BindingHelperInfo + { + public required ImmutableEquatableArray Namespaces { get; init; } + public required bool EmitConfigurationKeyCaches { get; init; } + + public required MethodsToGen_CoreBindingHelper MethodsToGen { get; init; } + public required ImmutableEquatableArray? TypesForGen_BindCoreMain { get; init; } + public required ImmutableEquatableArray? TypesForGen_GetCore { get; init; } + public required ImmutableEquatableArray? TypesForGen_GetValueCore { get; init; } + public required ImmutableEquatableArray? TypesForGen_BindCore { get; init; } + public required ImmutableEquatableArray? TypesForGen_Initialize { get; init; } + public required ImmutableEquatableArray? TypesForGen_ParsePrimitive { get; init; } + + internal sealed class Builder(TypeIndex _typeIndex) + { + private readonly Dictionary _seenTransitiveTypes = new(); + + private MethodsToGen_CoreBindingHelper _methodsToGen; + private bool _emitConfigurationKeyCaches; + + private readonly Dictionary> _typesForGen = new(); + + private readonly SortedSet _namespaces = new() + { + "System", + "System.CodeDom.Compiler", + "System.Globalization", + "System.Runtime.CompilerServices", + "Microsoft.Extensions.Configuration", + }; + + public BindingHelperInfo ToIncrementalValue() + { + return new BindingHelperInfo + { + Namespaces = _namespaces.ToImmutableEquatableArray(), + EmitConfigurationKeyCaches = _emitConfigurationKeyCaches, + + MethodsToGen = _methodsToGen, + TypesForGen_GetCore = GetTypesForGen_CoreBindingHelper(MethodsToGen_CoreBindingHelper.GetCore), + TypesForGen_BindCoreMain = GetTypesForGen_CoreBindingHelper(MethodsToGen_CoreBindingHelper.BindCoreMain), + TypesForGen_GetValueCore = GetTypesForGen_CoreBindingHelper(MethodsToGen_CoreBindingHelper.GetValueCore), + TypesForGen_BindCore = GetTypesForGen_CoreBindingHelper(MethodsToGen_CoreBindingHelper.BindCore), + TypesForGen_Initialize = GetTypesForGen_CoreBindingHelper(MethodsToGen_CoreBindingHelper.Initialize), + TypesForGen_ParsePrimitive = GetTypesForGen_CoreBindingHelper(MethodsToGen_CoreBindingHelper.ParsePrimitive) + }; + + ImmutableEquatableArray? GetTypesForGen_CoreBindingHelper(MethodsToGen_CoreBindingHelper overload) + where TSpec : TypeSpec, IEquatable + { + _typesForGen.TryGetValue(overload, out HashSet? typesAsBase); + + if (typesAsBase is null) + { + return null; + } + + IEnumerable types = typeof(TSpec) == typeof(TypeSpec) + ? (HashSet)(object)typesAsBase + : typesAsBase.Select(t => (TSpec)t); + + return GetTypesForGen(types); + } + + static ImmutableEquatableArray GetTypesForGen(IEnumerable types) + where TSpec : TypeSpec, IEquatable => + types.ToImmutableEquatableArray(); + } + + public bool TryRegisterTypeForGetGen(TypeSpec type) + { + if (TryRegisterTransitiveTypesForMethodGen(type.TypeRef)) + { + RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper.GetCore, type); + RegisterForGen_AsConfigWithChildrenHelper(); + return true; + } + + return false; + } + + public bool TryRegisterTypeForGetValueGen(TypeSpec typeSpec) + { + ParsableFromStringSpec effectiveType = (ParsableFromStringSpec)_typeIndex.GetEffectiveTypeSpec(typeSpec); + RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper.GetValueCore, typeSpec); + RegisterStringParsableTypeIfApplicable(effectiveType); + return true; + } + + public bool TryRegisterTypeForBindCoreMainGen(ComplexTypeSpec type) + { + if (TryRegisterTransitiveTypesForMethodGen(type.TypeRef)) + { + RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper.BindCoreMain, type); + RegisterForGen_AsConfigWithChildrenHelper(); + return true; + } + + return false; + } + + public bool TryRegisterTransitiveTypesForMethodGen(TypeRef typeRef) + { + return _seenTransitiveTypes.TryGetValue(typeRef, out bool isValid) + ? isValid + : (_seenTransitiveTypes[typeRef] = TryRegisterCore()); + + bool TryRegisterCore() + { + switch (_typeIndex.GetTypeSpec(typeRef)) + { + case NullableSpec nullableSpec: + { + return TryRegisterTransitiveTypesForMethodGen(nullableSpec.EffectiveTypeRef); + } + case ParsableFromStringSpec stringParsableSpec: + { + RegisterStringParsableTypeIfApplicable(stringParsableSpec); + return true; + } + case DictionarySpec dictionarySpec: + { + bool shouldRegister = _typeIndex.CanBindTo(typeRef) && + TryRegisterTransitiveTypesForMethodGen(dictionarySpec.KeyTypeRef) && + TryRegisterTransitiveTypesForMethodGen(dictionarySpec.ElementTypeRef) && + TryRegisterTypeForBindCoreGen(dictionarySpec); + + if (shouldRegister && dictionarySpec.InstantiationStrategy is CollectionInstantiationStrategy.LinqToDictionary) + { + _namespaces.Add("System.Linq"); + } + + return shouldRegister; + } + case CollectionSpec collectionSpec: + { + return TryRegisterTransitiveTypesForMethodGen(collectionSpec.ElementTypeRef) && + TryRegisterTypeForBindCoreGen(collectionSpec); + } + case ObjectSpec objectSpec: + { + // Base case to avoid stack overflow for recursive object graphs. + // Register all object types for gen; we need to throw runtime exceptions in some cases. + bool shouldRegister = true; + _seenTransitiveTypes.Add(typeRef, shouldRegister); + + // List is used in generated code as a temp holder for formatting + // an error for config properties that don't map to object properties. + _namespaces.Add("System.Collections.Generic"); + + if (_typeIndex.HasBindableMembers(objectSpec)) + { + foreach (PropertySpec property in objectSpec.Properties!) + { + TryRegisterTransitiveTypesForMethodGen(property.TypeRef); + + if (_typeIndex.GetTypeSpec(property.TypeRef) is ComplexTypeSpec) + { + RegisterForGen_AsConfigWithChildrenHelper(); + } + } + + bool registeredForBindCore = TryRegisterTypeForBindCoreGen(objectSpec); + Debug.Assert(registeredForBindCore); + + if (objectSpec is { InstantiationStrategy: ObjectInstantiationStrategy.ParameterizedConstructor, InitExceptionMessage: null }) + { + RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper.Initialize, objectSpec); + } + } + + return true; + } + default: + { + return true; + } + } + } + } + + public void RegisterNamespace(string @namespace) => _namespaces.Add(@namespace); + + private bool TryRegisterTypeForBindCoreGen(ComplexTypeSpec type) + { + if (_typeIndex.HasBindableMembers(type)) + { + RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper.BindCore, type); + _emitConfigurationKeyCaches = true; + return true; + } + + return false; + } + + private void RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper method, TypeSpec type) + { + if (!_typesForGen.TryGetValue(method, out HashSet? types)) + { + _typesForGen[method] = types = new HashSet(); + } + + if (types.Add(type)) + { + _methodsToGen |= method; + + if (type is { Namespace: string @namespace }) + { + _namespaces.Add(@namespace); + } + } + } + + private void RegisterStringParsableTypeIfApplicable(ParsableFromStringSpec type) + { + if (type.StringParsableTypeKind is not StringParsableTypeKind.AssignFromSectionValue) + { + _methodsToGen |= MethodsToGen_CoreBindingHelper.ParsePrimitive; + RegisterTypeForMethodGen(MethodsToGen_CoreBindingHelper.ParsePrimitive, type); + } + } + + private void RegisterForGen_AsConfigWithChildrenHelper() => _methodsToGen |= MethodsToGen_CoreBindingHelper.AsConfigWithChildren; + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/InterceptorInfo.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/InterceptorInfo.cs new file mode 100644 index 00000000000000..999ed6514f99d7 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/InterceptorInfo.cs @@ -0,0 +1,202 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Operations; +using Microsoft.CodeAnalysis.Text; +using SourceGenerators; + +namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration +{ + public sealed record InterceptorInfo + { + public required MethodsToGen MethodsToGen { get; init; } + + public required ImmutableEquatableArray? ConfigBinder_Bind_instance { get; init; } + public required ImmutableEquatableArray? ConfigBinder_Bind_instance_BinderOptions { get; init; } + public required ImmutableEquatableArray? ConfigBinder_Bind_key_instance { get; init; } + + + public required ImmutableEquatableArray? ConfigBinder { get; init; } + public required ImmutableEquatableArray? OptionsBuilderExt { get; init; } + public required ImmutableEquatableArray? ServiceCollectionExt { get; init; } + + public IEnumerable? GetInfo(MethodsToGen interceptor) + { + Debug.Assert((MethodsToGen.ConfigBinder_Bind & interceptor) is 0); + + ImmutableEquatableArray? infoList; + if ((MethodsToGen.ConfigBinder_Any ^ MethodsToGen.ConfigBinder_Bind & interceptor) is not 0) + { + infoList = ConfigBinder; + } + else if ((MethodsToGen.OptionsBuilderExt_Any & interceptor) is not 0) + { + infoList = OptionsBuilderExt; + } + else + { + Debug.Assert((MethodsToGen.ServiceCollectionExt_Any & interceptor) is not 0); + infoList = ServiceCollectionExt; + } + + return infoList?.Where(i => i.Interceptor == interceptor); + } + + internal sealed class Builder + { + private TypedInterceptorInfoBuildler? _configBinder_InfoBuilder_Bind_instance; + private TypedInterceptorInfoBuildler? _configBinder_InfoBuilder_Bind_instance_BinderOptions; + private TypedInterceptorInfoBuildler? _configBinder_InfoBuilder_Bind_key_instance; + + private List? _interceptors_configBinder; + private List? _interceptors_OptionsBuilderExt; + private List? _interceptors_serviceCollectionExt; + + public MethodsToGen MethodsToGen { get; set; } + + public void RegisterInterceptor_ConfigBinder_Bind(MethodsToGen overload, ComplexTypeSpec type, IInvocationOperation invocation) + { + Debug.Assert((MethodsToGen.ConfigBinder_Bind & overload) is not 0); + + switch (overload) + { + case MethodsToGen.ConfigBinder_Bind_instance: + RegisterInterceptor(ref _configBinder_InfoBuilder_Bind_instance); + break; + case MethodsToGen.ConfigBinder_Bind_instance_BinderOptions: + RegisterInterceptor(ref _configBinder_InfoBuilder_Bind_instance_BinderOptions); + break; + case MethodsToGen.ConfigBinder_Bind_key_instance: + RegisterInterceptor(ref _configBinder_InfoBuilder_Bind_key_instance); + break; + } + + MethodsToGen |= overload; + + void RegisterInterceptor(ref TypedInterceptorInfoBuildler? infoBuilder) + { + infoBuilder ??= new TypedInterceptorInfoBuildler(); + infoBuilder.RegisterInterceptor(overload, type, invocation); + } + } + + public void RegisterInterceptor(MethodsToGen overload, IInvocationOperation operation) + { + Debug.Assert((MethodsToGen.ConfigBinder_Bind & overload) is 0); + + if ((MethodsToGen.ConfigBinder_Any ^ MethodsToGen.ConfigBinder_Bind & overload) is not 0) + { + RegisterInterceptor(ref _interceptors_configBinder); + } + else if ((MethodsToGen.OptionsBuilderExt_Any & overload) is not 0) + { + RegisterInterceptor(ref _interceptors_OptionsBuilderExt); + } + else + { + Debug.Assert((MethodsToGen.ServiceCollectionExt_Any & overload) is not 0); + RegisterInterceptor(ref _interceptors_serviceCollectionExt); + } + + MethodsToGen |= overload; + + void RegisterInterceptor(ref List? infoList) + { + infoList ??= new List(); + infoList.Add(new InvocationLocationInfo(overload, operation)); + } + } + + public InterceptorInfo ToIncrementalValue() => + new InterceptorInfo + { + MethodsToGen = MethodsToGen, + + ConfigBinder = _interceptors_configBinder?.ToImmutableEquatableArray(), + OptionsBuilderExt = _interceptors_OptionsBuilderExt?.ToImmutableEquatableArray(), + ServiceCollectionExt = _interceptors_serviceCollectionExt?.ToImmutableEquatableArray(), + + ConfigBinder_Bind_instance = _configBinder_InfoBuilder_Bind_instance?.ToIncrementalValue(), + ConfigBinder_Bind_instance_BinderOptions = _configBinder_InfoBuilder_Bind_instance_BinderOptions?.ToIncrementalValue(), + ConfigBinder_Bind_key_instance = _configBinder_InfoBuilder_Bind_key_instance?.ToIncrementalValue(), + }; + } + } + + internal sealed class TypedInterceptorInfoBuildler + { + private readonly Dictionary _invocationInfoBuilderCache = new(); + + public void RegisterInterceptor(MethodsToGen overload, ComplexTypeSpec type, IInvocationOperation invocation) + { + if (!_invocationInfoBuilderCache.TryGetValue(type, out TypedInterceptorInvocationInfo.Builder? invocationInfoBuilder)) + { + _invocationInfoBuilderCache[type] = invocationInfoBuilder = new TypedInterceptorInvocationInfo.Builder(overload, type); + } + + invocationInfoBuilder.RegisterInvocation(invocation); + } + + public ImmutableEquatableArray? ToIncrementalValue() => + _invocationInfoBuilderCache.Values + .Select(b => b.ToIncrementalValue()) + .ToImmutableEquatableArray(); + } + + public sealed record TypedInterceptorInvocationInfo(ComplexTypeSpec TargetType, ImmutableEquatableArray Locations) + { + public sealed class Builder(MethodsToGen Overload, ComplexTypeSpec TargetType) + { + private readonly List _infoList = new(); + + public void RegisterInvocation(IInvocationOperation invocation) => + _infoList.Add(new InvocationLocationInfo(Overload, invocation)); + + public TypedInterceptorInvocationInfo ToIncrementalValue() => new( + TargetType, + Locations: _infoList.ToImmutableEquatableArray()); + } + } + + public sealed record InvocationLocationInfo + { + public InvocationLocationInfo(MethodsToGen interceptor, IInvocationOperation invocation) + { + Debug.Assert(BinderInvocation.IsBindingOperation(invocation)); + + if (invocation.Syntax is not InvocationExpressionSyntax { Expression: MemberAccessExpressionSyntax memberAccessExprSyntax }) + { + const string InvalidInvocationErrMsg = "The invocation should have been validated upstream when selecting invocations to emit interceptors for."; + throw new ArgumentException(InvalidInvocationErrMsg, nameof(invocation)); + } + + SyntaxTree operationSyntaxTree = invocation.Syntax.SyntaxTree; + TextSpan memberNameSpan = memberAccessExprSyntax.Name.Span; + FileLinePositionSpan linePosSpan = operationSyntaxTree.GetLineSpan(memberNameSpan); + + Interceptor = interceptor; + LineNumber = linePosSpan.StartLinePosition.Line + 1; + CharacterNumber = linePosSpan.StartLinePosition.Character + 1; + FilePath = GetInterceptorFilePath(); + + // Use the same logic used by the interceptors API for resolving the source mapped value of a path. + // https://github.com/dotnet/roslyn/blob/f290437fcc75dad50a38c09e0977cce13a64f5ba/src/Compilers/CSharp/Portable/Compilation/CSharpCompilation.cs#L1063-L1064 + string GetInterceptorFilePath() + { + SourceReferenceResolver? sourceReferenceResolver = invocation.SemanticModel?.Compilation.Options.SourceReferenceResolver; + return sourceReferenceResolver?.NormalizePath(operationSyntaxTree.FilePath, baseFilePath: null) ?? operationSyntaxTree.FilePath; + } + } + + public MethodsToGen Interceptor { get; } + public string FilePath { get; } + public int LineNumber { get; } + public int CharacterNumber { get; } + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/InterceptorLocationInfo.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/InterceptorLocationInfo.cs deleted file mode 100644 index 441acbe6a7444f..00000000000000 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/InterceptorLocationInfo.cs +++ /dev/null @@ -1,89 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System; -using System.Collections; -using System.Collections.Generic; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis.Operations; -using Microsoft.CodeAnalysis.Text; - -namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration -{ - internal sealed record InterceptorLocationInfo - { - public InterceptorLocationInfo(IInvocationOperation operation) - { - MemberAccessExpressionSyntax memberAccessExprSyntax = ((MemberAccessExpressionSyntax)((InvocationExpressionSyntax)operation.Syntax).Expression); - SyntaxTree operationSyntaxTree = operation.Syntax.SyntaxTree; - TextSpan memberNameSpan = memberAccessExprSyntax.Name.Span; - FileLinePositionSpan linePosSpan = operationSyntaxTree.GetLineSpan(memberNameSpan); - - LineNumber = linePosSpan.StartLinePosition.Line + 1; - CharacterNumber = linePosSpan.StartLinePosition.Character + 1; - FilePath = GetInterceptorFilePath(); - - // Use the same logic used by the interceptors API for resolving the source mapped value of a path. - // https://github.com/dotnet/roslyn/blob/f290437fcc75dad50a38c09e0977cce13a64f5ba/src/Compilers/CSharp/Portable/Compilation/CSharpCompilation.cs#L1063-L1064 - string GetInterceptorFilePath() - { - SourceReferenceResolver? sourceReferenceResolver = operation.SemanticModel?.Compilation.Options.SourceReferenceResolver; - return sourceReferenceResolver?.NormalizePath(operationSyntaxTree.FilePath, baseFilePath: null) ?? operationSyntaxTree.FilePath; - } - } - - public string FilePath { get; } - public int LineNumber { get; } - public int CharacterNumber { get; } - } - - internal sealed record ConfigurationBinderInterceptorInfo - { - private OverloadInterceptorInfo? _bind_Instance; - private OverloadInterceptorInfo? _bind_instance_BinderOptions; - private OverloadInterceptorInfo? _bind_key_instance; - - public void RegisterOverloadInfo(MethodsToGen_ConfigurationBinder overload, TypeSpec type, IInvocationOperation operation) - { - OverloadInterceptorInfo overloadInfo = DetermineOverload(overload, initIfNull: true); - overloadInfo.RegisterLocationInfo(type, operation); - } - - public OverloadInterceptorInfo GetOverloadInfo(MethodsToGen_ConfigurationBinder overload) => - DetermineOverload(overload, initIfNull: false) ?? throw new ArgumentOutOfRangeException(nameof(overload)); - - private OverloadInterceptorInfo? DetermineOverload(MethodsToGen_ConfigurationBinder overload, bool initIfNull) - { - return overload switch - { - MethodsToGen_ConfigurationBinder.Bind_instance => InitIfNull(ref _bind_Instance), - MethodsToGen_ConfigurationBinder.Bind_instance_BinderOptions => InitIfNull(ref _bind_instance_BinderOptions), - MethodsToGen_ConfigurationBinder.Bind_key_instance => InitIfNull(ref _bind_key_instance), - _ => throw new InvalidOperationException(nameof(overload)), - }; - - OverloadInterceptorInfo InitIfNull(ref OverloadInterceptorInfo? info) - { - if (initIfNull) - { - info ??= new OverloadInterceptorInfo(); - } - - return info; - } - } - } - - internal sealed record OverloadInterceptorInfo : IEnumerable>> - { - private readonly Dictionary> _typeInterceptionInfo = new(); - - public void RegisterLocationInfo(TypeSpec type, IInvocationOperation operation) => - _typeInterceptionInfo.RegisterCacheEntry(type, new InterceptorLocationInfo(operation)); - - public IEnumerator>> GetEnumerator() => _typeInterceptionInfo.GetEnumerator(); - - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - } -} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Members/MemberSpec.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Members/MemberSpec.cs index effd550482595d..dc5b03087ac87a 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Members/MemberSpec.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Members/MemberSpec.cs @@ -3,10 +3,11 @@ using System.Diagnostics; using Microsoft.CodeAnalysis; +using SourceGenerators; namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { - internal abstract record MemberSpec + public abstract record MemberSpec { public MemberSpec(ISymbol member) { @@ -18,7 +19,7 @@ public MemberSpec(ISymbol member) public string Name { get; } public string DefaultValueExpr { get; protected set; } - public required TypeSpec Type { get; init; } + public required TypeRef TypeRef { get; init; } public required string ConfigurationKeyName { get; init; } public abstract bool CanGet { get; } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Members/ParameterSpec.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Members/ParameterSpec.cs index 0f17a6247f74d2..62c781e1f1631f 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Members/ParameterSpec.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Members/ParameterSpec.cs @@ -6,7 +6,7 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { - internal sealed record ParameterSpec : MemberSpec + public sealed record ParameterSpec : MemberSpec { public ParameterSpec(IParameterSymbol parameter) : base(parameter) { @@ -14,7 +14,7 @@ public ParameterSpec(IParameterSymbol parameter) : base(parameter) if (parameter.HasExplicitDefaultValue) { - string formatted = SymbolDisplay.FormatPrimitive(parameter.ExplicitDefaultValue, quoteStrings: true, useHexadecimalNumbers: false); + string formatted = SymbolDisplay.FormatPrimitive(parameter.ExplicitDefaultValue!, quoteStrings: true, useHexadecimalNumbers: false); if (formatted is not "null") { DefaultValueExpr = formatted; diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Members/PropertySpec.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Members/PropertySpec.cs index 4e9c468c4e3352..443e39d32e4933 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Members/PropertySpec.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Members/PropertySpec.cs @@ -5,7 +5,7 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { - internal sealed record PropertySpec : MemberSpec + public sealed record PropertySpec : MemberSpec { public PropertySpec(IPropertySymbol property) : base(property) { @@ -28,7 +28,5 @@ public PropertySpec(IPropertySymbol property) : base(property) public override bool CanGet { get; } public override bool CanSet { get; } - - public bool ShouldBindTo => CanGet || CanSet; } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/MethodsToGen.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/MethodsToGen.cs index 6165a3e6d46dcb..af2a33fa6c2f80 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/MethodsToGen.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/MethodsToGen.cs @@ -16,137 +16,130 @@ public enum MethodsToGen_CoreBindingHelper Initialize = 0x10, HasValueOrChildren = 0x20, AsConfigWithChildren = 0x40, + ParsePrimitive = 0x80, } /// /// Methods on Microsoft.Extensions.Configuration.ConfigurationBinder /// [Flags] - internal enum MethodsToGen_ConfigurationBinder + public enum MethodsToGen { None = 0x0, + Any = ConfigBinder_Any | OptionsBuilderExt_Any | ServiceCollectionExt_Any, + #region IConfiguration ext. method overloads: 0x1 - 0x400 /// /// Bind(IConfiguration, object?). /// - Bind_instance = 0x1, + ConfigBinder_Bind_instance = 0x1, /// /// Bind(IConfiguration, object?, Action?). /// - Bind_instance_BinderOptions = 0x2, + ConfigBinder_Bind_instance_BinderOptions = 0x2, /// /// Bind(IConfiguration, string, object?). /// - Bind_key_instance = 0x4, + ConfigBinder_Bind_key_instance = 0x4, /// /// Get(IConfiguration). /// - Get_T = 0x8, + ConfigBinder_Get_T = 0x8, /// /// Get(IConfiguration, Action?). /// - Get_T_BinderOptions = 0x10, + ConfigBinder_Get_T_BinderOptions = 0x10, /// /// Get(IConfiguration, Type). /// - Get_TypeOf = 0x20, + ConfigBinder_Get_TypeOf = 0x20, /// /// Get(IConfiguration, Type, Action?). /// - Get_TypeOf_BinderOptions = 0x40, + ConfigBinder_Get_TypeOf_BinderOptions = 0x40, /// /// GetValue(IConfiguration, Type, string). /// - GetValue_TypeOf_key = 0x80, + ConfigBinder_GetValue_TypeOf_key = 0x80, /// /// GetValue(IConfiguration, Type, object?). /// - GetValue_TypeOf_key_defaultValue = 0x100, + ConfigBinder_GetValue_TypeOf_key_defaultValue = 0x100, /// /// GetValue(IConfiguration, string). /// - GetValue_T_key = 0x200, + ConfigBinder_GetValue_T_key = 0x200, /// /// GetValue(IConfiguration, string, T). /// - GetValue_T_key_defaultValue = 0x400, + ConfigBinder_GetValue_T_key_defaultValue = 0x400, // Method groups - Bind = Bind_instance | Bind_instance_BinderOptions | Bind_key_instance, - Get = Get_T | Get_T_BinderOptions | Get_TypeOf | Get_TypeOf_BinderOptions, - GetValue = GetValue_T_key | GetValue_T_key_defaultValue | GetValue_TypeOf_key | GetValue_TypeOf_key_defaultValue, + ConfigBinder_Bind = ConfigBinder_Bind_instance | ConfigBinder_Bind_instance_BinderOptions | ConfigBinder_Bind_key_instance, + ConfigBinder_Get = ConfigBinder_Get_T | ConfigBinder_Get_T_BinderOptions | ConfigBinder_Get_TypeOf | ConfigBinder_Get_TypeOf_BinderOptions, + ConfigBinder_GetValue = ConfigBinder_GetValue_T_key | ConfigBinder_GetValue_T_key_defaultValue | ConfigBinder_GetValue_TypeOf_key | ConfigBinder_GetValue_TypeOf_key_defaultValue, - Any = Bind | Get | GetValue, - } - - [Flags] - internal enum MethodsToGen_Extensions_OptionsBuilder - { - None = 0x0, + ConfigBinder_Any = ConfigBinder_Bind | ConfigBinder_Get | ConfigBinder_GetValue, + #endregion ConfigurationBinder ext. method overloads. + #region OptionsBuilder ext. method overloads: 0x800 - 0x2000 /// /// Bind(OptionsBuilder, IConfiguration). /// - Bind_T = 0x1, + OptionsBuilderExt_Bind_T = 0x800, /// /// Bind(OptionsBuilder, IConfiguration, Action?). /// - Bind_T_BinderOptions = 0x2, + OptionsBuilderExt_Bind_T_BinderOptions = 0x1000, /// /// BindConfiguration(OptionsBuilder, string, Action?). /// - BindConfiguration_T_path_BinderOptions = 0x4, + OptionsBuilderExt_BindConfiguration_T_path_BinderOptions = 0x2000, // Method group. BindConfiguration_T is its own method group. - Bind = Bind_T | Bind_T_BinderOptions, - - BindConfiguration = BindConfiguration_T_path_BinderOptions, + OptionsBuilderExt_Bind = OptionsBuilderExt_Bind_T | OptionsBuilderExt_Bind_T_BinderOptions, - Any = Bind | BindConfiguration, - } + OptionsBuilderExt_BindConfiguration = OptionsBuilderExt_BindConfiguration_T_path_BinderOptions, - /// - /// Methods on Microsoft.Extensions.DependencyInjection.OptionsConfigurationServiceCollectionExtensions - /// - [Flags] - public enum MethodsToGen_Extensions_ServiceCollection - { - None = 0x0, + OptionsBuilderExt_Any = OptionsBuilderExt_Bind | OptionsBuilderExt_BindConfiguration, + #endregion OptionsBuilder ext. method overloads. + #region IServiceCollection ext. method overloads: 0x4000 - 0x20000 /// /// Configure(IServiceCollection, IConfiguration). /// - Configure_T = 0x1, + ServiceCollectionExt_Configure_T = 0x4000, /// /// Configure(IServiceCollection, string, IConfiguration). /// - Configure_T_name = 0x2, + ServiceCollectionExt_Configure_T_name = 0x8000, /// /// Configure(IServiceCollection, IConfiguration, Action?). /// - Configure_T_BinderOptions = 0x4, + ServiceCollectionExt_Configure_T_BinderOptions = 0x10000, /// /// Configure(IServiceCollection, string, IConfiguration, Action?). /// - Configure_T_name_BinderOptions = 0x8, + ServiceCollectionExt_Configure_T_name_BinderOptions = 0x20000, - Configure = Configure_T | Configure_T_name | Configure_T_BinderOptions | Configure_T_name_BinderOptions, + ServiceCollectionExt_Configure = ServiceCollectionExt_Configure_T | ServiceCollectionExt_Configure_T_name | ServiceCollectionExt_Configure_T_BinderOptions | ServiceCollectionExt_Configure_T_name_BinderOptions, - Any = Configure, + ServiceCollectionExt_Any = ServiceCollectionExt_Configure, + #endregion IServiceCollection ext. method overloads: 0x4000 - 0x20000 } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/SourceGenerationSpec.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/SourceGenerationSpec.cs index 760d57b1dcc888..4f57316429e2b1 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/SourceGenerationSpec.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/SourceGenerationSpec.cs @@ -1,31 +1,14 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; -using System.Collections.Generic; +using SourceGenerators; namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { - internal sealed record SourceGenerationSpec + public sealed record SourceGenerationSpec { - public Dictionary> InterceptionInfo { get; } = new(); - public ConfigurationBinderInterceptorInfo InterceptionInfo_ConfigBinder { get; } = new(); - - public Dictionary> TypesForGen_CoreBindingHelper_Methods { get; } = new(); - - public HashSet PrimitivesForHelperGen { get; } = new(); - public HashSet Namespaces { get; } = new() - { - "System", - "System.CodeDom.Compiler", - "System.Globalization", - "System.Runtime.CompilerServices", - "Microsoft.Extensions.Configuration", - }; - - public MethodsToGen_CoreBindingHelper MethodsToGen_CoreBindingHelper { get; set; } - public MethodsToGen_ConfigurationBinder MethodsToGen_ConfigurationBinder { get; set; } - public MethodsToGen_Extensions_OptionsBuilder MethodsToGen_OptionsBuilderExt { get; set; } - public MethodsToGen_Extensions_ServiceCollection MethodsToGen_ServiceCollectionExt { get; set; } + public required InterceptorInfo InterceptorInfo { get; init; } + public required BindingHelperInfo BindingHelperInfo { get; init; } + public required ImmutableEquatableArray ConfigTypes { get; init; } } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/TypeIndex.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/TypeIndex.cs new file mode 100644 index 00000000000000..5b59577b392921 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/TypeIndex.cs @@ -0,0 +1,122 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using SourceGenerators; + +namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration +{ + internal sealed class TypeIndex(IEnumerable typeSpecs) + { + private readonly Dictionary _index = typeSpecs.ToDictionary(spec => spec.TypeRef); + + public bool CanBindTo(TypeRef typeRef) => GetEffectiveTypeSpec(typeRef) switch + { + SimpleTypeSpec => true, + ComplexTypeSpec complexTypeSpec => CanInstantiate(complexTypeSpec) || HasBindableMembers(complexTypeSpec), + _ => throw new InvalidOperationException(), + }; + + public bool CanInstantiate(ComplexTypeSpec typeSpec) => typeSpec switch + { + ObjectSpec objectSpec => objectSpec is { InstantiationStrategy: not ObjectInstantiationStrategy.None, InitExceptionMessage: null }, + DictionarySpec dictionarySpec => KeyIsSupported(dictionarySpec), + CollectionSpec collectionSpec => CanBindTo(collectionSpec.ElementTypeRef), + _ => throw new InvalidOperationException(), + }; + + public bool HasBindableMembers(ComplexTypeSpec typeSpec) => + typeSpec switch + { + ObjectSpec objectSpec => objectSpec.Properties?.Any(ShouldBindTo) is true, + DictionarySpec dictSpec => KeyIsSupported(dictSpec) && CanBindTo(dictSpec.ElementTypeRef), + CollectionSpec collectionSpec => CanBindTo(collectionSpec.ElementTypeRef), + _ => throw new InvalidOperationException(), + }; + + public bool ShouldBindTo(PropertySpec property) + { + TypeSpec propTypeSpec = GetEffectiveTypeSpec(property.TypeRef); + return IsAccessible() && !IsCollectionAndCannotOverride() && !IsDictWithUnsupportedKey(); + + bool IsAccessible() => property.CanGet || property.CanSet; + + bool IsDictWithUnsupportedKey() => propTypeSpec is DictionarySpec dictionarySpec && !KeyIsSupported(dictionarySpec); + + bool IsCollectionAndCannotOverride() => !property.CanSet && + propTypeSpec is CollectionWithCtorInitSpec + { + InstantiationStrategy: CollectionInstantiationStrategy.CopyConstructor or CollectionInstantiationStrategy.LinqToDictionary + }; + } + + public TypeSpec GetEffectiveTypeSpec(TypeRef typeRef) + { + TypeSpec typeSpec = GetTypeSpec(typeRef); + return GetEffectiveTypeSpec(typeSpec); + } + + public TypeSpec GetEffectiveTypeSpec(TypeSpec typeSpec) + { + TypeRef effectiveRef = typeSpec.EffectiveTypeRef; + TypeSpec effectiveSpec = effectiveRef == typeSpec.TypeRef ? typeSpec : _index[effectiveRef]; + return effectiveSpec; + } + + public TypeSpec GetTypeSpec(TypeRef typeRef) => _index[typeRef]; + + public string GetInstantiationTypeDisplayString(CollectionWithCtorInitSpec type) + { + CollectionInstantiationConcreteType concreteType = type.InstantiationConcreteType; + return concreteType is CollectionInstantiationConcreteType.Self + ? type.DisplayString + : GetGenericTypeDisplayString(type, concreteType); + } + + public string GetPopulationCastTypeDisplayString(CollectionWithCtorInitSpec type) + { + CollectionPopulationCastType castType = type.PopulationCastType; + Debug.Assert(castType is not CollectionPopulationCastType.NotApplicable); + return GetGenericTypeDisplayString(type, castType); + } + + public string GetGenericTypeDisplayString(CollectionWithCtorInitSpec type, Enum genericProxyTypeName) + { + string proxyTypeNameStr = genericProxyTypeName.ToString(); + string elementTypeDisplayString = GetTypeSpec(type.ElementTypeRef).DisplayString; + + if (type is EnumerableSpec) + { + return $"{proxyTypeNameStr}<{elementTypeDisplayString}>"; + } + + string keyTypeDisplayString = GetTypeSpec(((DictionarySpec)type).KeyTypeRef).DisplayString; + return $"{proxyTypeNameStr}<{keyTypeDisplayString}, {elementTypeDisplayString}>"; + } + + public bool KeyIsSupported(DictionarySpec typeSpec) => + // Only types that are parsable from string are supported. + // Nullable keys not allowed; that would cause us to emit + // code that violates dictionary key notnull constraint. + GetTypeSpec(typeSpec.KeyTypeRef) is ParsableFromStringSpec; + + public static string GetConfigKeyCacheFieldName(ObjectSpec type) => $"s_configKeys_{type.IdentifierCompatibleSubstring}"; + + public static string GetParseMethodName(ParsableFromStringSpec type) + { + Debug.Assert(type.StringParsableTypeKind is not StringParsableTypeKind.AssignFromSectionValue); + + string displayString = type.DisplayString; + + string parseMethod = type.StringParsableTypeKind is StringParsableTypeKind.ByteArray + ? "ParseByteArray" + // MinimalDisplayString.Length is certainly > 2. + : $"Parse{(char.ToUpper(displayString[0]) + displayString.Substring(1)).Replace(".", "")}"; + + return parseMethod; + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/CollectionSpec.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/CollectionSpec.cs index f565d245cc5502..f891328f77af7c 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/CollectionSpec.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/CollectionSpec.cs @@ -2,48 +2,67 @@ // The .NET Foundation licenses this file to you under the MIT license. using Microsoft.CodeAnalysis; +using SourceGenerators; namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { internal abstract record CollectionSpec : ComplexTypeSpec { - public CollectionSpec(ITypeSymbol type) : base(type) { } + protected CollectionSpec(ITypeSymbol type) : base(type) { } - public sealed override bool CanInstantiate => TypeToInstantiate?.CanInstantiate ?? InstantiationStrategy is not InstantiationStrategy.None; + public required TypeRef ElementTypeRef { get; init; } - public required TypeSpec ElementType { get; init; } + } + + internal abstract record CollectionWithCtorInitSpec : CollectionSpec + { + protected CollectionWithCtorInitSpec(ITypeSymbol type) : base(type) { } - public required CollectionPopulationStrategy PopulationStrategy { get; init; } + public required CollectionInstantiationStrategy InstantiationStrategy { get; init; } - public required CollectionSpec? TypeToInstantiate { get; init; } + public required CollectionInstantiationConcreteType InstantiationConcreteType { get; init; } - public required CollectionSpec? PopulationCastType { get; init; } + public required CollectionPopulationCastType PopulationCastType { get; init; } } - internal sealed record EnumerableSpec : CollectionSpec + internal sealed record ArraySpec : CollectionSpec { - public EnumerableSpec(ITypeSymbol type) : base(type) { } - - public override TypeSpecKind SpecKind => TypeSpecKind.Enumerable; + public ArraySpec(ITypeSymbol type) : base(type) { } + } - public override bool HasBindableMembers => PopulationStrategy is not CollectionPopulationStrategy.Unknown && ElementType.CanBindTo; + internal sealed record EnumerableSpec : CollectionWithCtorInitSpec + { + public EnumerableSpec(ITypeSymbol type) : base(type) { } } - internal sealed record DictionarySpec : CollectionSpec + internal sealed record DictionarySpec : CollectionWithCtorInitSpec { public DictionarySpec(INamedTypeSymbol type) : base(type) { } - public override TypeSpecKind SpecKind => TypeSpecKind.Dictionary; + public required TypeRef KeyTypeRef { get; init; } + } - public override bool HasBindableMembers => PopulationStrategy is not CollectionPopulationStrategy.Unknown; + internal enum CollectionInstantiationStrategy + { + NotApplicable = 0, + ParameterlessConstructor = 1, + CopyConstructor = 2, + LinqToDictionary = 3, + } - public required ParsableFromStringSpec KeyType { get; init; } + internal enum CollectionInstantiationConcreteType + { + Self = 0, + Dictionary = 1, + List = 2, + HashSet = 3, } - internal enum CollectionPopulationStrategy + internal enum CollectionPopulationCastType { - Unknown = 0, - Add = 1, - Cast_Then_Add = 2, + NotApplicable = 0, + IDictionary = 1, + ICollection = 2, + ISet = 3, } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/ComplexTypeSpec.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/ComplexTypeSpec.cs deleted file mode 100644 index da5a5130141a53..00000000000000 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/ComplexTypeSpec.cs +++ /dev/null @@ -1,29 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using Microsoft.CodeAnalysis; - -namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration -{ - internal abstract record ComplexTypeSpec : TypeSpec - { - public ComplexTypeSpec(ITypeSymbol type) : base(type) { } - - public InstantiationStrategy InstantiationStrategy { get; set; } - - public sealed override bool CanBindTo => CanInstantiate || HasBindableMembers; - - public sealed override TypeSpec EffectiveType => this; - - public abstract bool HasBindableMembers { get; } - } - - internal enum InstantiationStrategy - { - None = 0, - ParameterlessConstructor = 1, - ParameterizedConstructor = 2, - ToEnumerableMethod = 3, - Array = 4, - } -} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/NullableSpec.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/NullableSpec.cs deleted file mode 100644 index 3de6d7d465ad98..00000000000000 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/NullableSpec.cs +++ /dev/null @@ -1,22 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using Microsoft.CodeAnalysis; - -namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration -{ - internal sealed record NullableSpec : TypeSpec - { - private readonly TypeSpec _underlyingType; - - public NullableSpec(ITypeSymbol type, TypeSpec underlyingType) : base(type) => _underlyingType = underlyingType; - - public override bool CanBindTo => _underlyingType.CanBindTo; - - public override bool CanInstantiate => _underlyingType.CanInstantiate; - - public override TypeSpecKind SpecKind => TypeSpecKind.Nullable; - - public override TypeSpec EffectiveType => _underlyingType; - } -} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/ObjectSpec.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/ObjectSpec.cs index f6978fa9cf470a..abc01258d4190c 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/ObjectSpec.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/ObjectSpec.cs @@ -1,27 +1,39 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; -using System.Collections.Generic; -using System.Linq; using Microsoft.CodeAnalysis; +using SourceGenerators; namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { - internal sealed record ObjectSpec : ComplexTypeSpec + public sealed record ObjectSpec : ComplexTypeSpec { - public ObjectSpec(INamedTypeSymbol type) : base(type) { } - - public override TypeSpecKind SpecKind => TypeSpecKind.Object; - - public override bool HasBindableMembers => Properties.Values.Any(p => p.ShouldBindTo); - - public override bool CanInstantiate => InstantiationStrategy is not InstantiationStrategy.None && InitExceptionMessage is null; - - public Dictionary Properties { get; } = new(StringComparer.OrdinalIgnoreCase); - - public List ConstructorParameters { get; } = new(); + public ObjectSpec( + INamedTypeSymbol type, + ObjectInstantiationStrategy instantiationStrategy, + ImmutableEquatableArray? properties, + ImmutableEquatableArray? constructorParameters, + string? initExceptionMessage) : base(type) + { + InstantiationStrategy = instantiationStrategy; + Properties = properties; + ConstructorParameters = constructorParameters; + InitExceptionMessage = initExceptionMessage; + } + + public ObjectInstantiationStrategy InstantiationStrategy { get; } + + public ImmutableEquatableArray? Properties { get; } + + public ImmutableEquatableArray? ConstructorParameters { get; } + + public string? InitExceptionMessage { get; } + } - public string? InitExceptionMessage { get; set; } + public enum ObjectInstantiationStrategy + { + None = 0, + ParameterlessConstructor = 1, + ParameterizedConstructor = 2, } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/SimpleTypeSpec.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/SimpleTypeSpec.cs index 2dfe08dc5f547a..70c7a8042e0359 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/SimpleTypeSpec.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/SimpleTypeSpec.cs @@ -1,55 +1,28 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Diagnostics; using Microsoft.CodeAnalysis; namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { - internal abstract record SimpleTypeSpec : TypeSpec + public abstract record SimpleTypeSpec : TypeSpec { public SimpleTypeSpec(ITypeSymbol type) : base(type) { } - - public sealed override bool CanBindTo => true; - - public sealed override TypeSpec EffectiveType => this; - - public sealed override bool CanInstantiate => true; } internal sealed record ConfigurationSectionSpec : SimpleTypeSpec { public ConfigurationSectionSpec(ITypeSymbol type) : base(type) { } - - public override TypeSpecKind SpecKind => TypeSpecKind.IConfigurationSection; } - internal sealed record ParsableFromStringSpec : SimpleTypeSpec + public sealed record ParsableFromStringSpec : SimpleTypeSpec { public ParsableFromStringSpec(ITypeSymbol type) : base(type) { } - public override TypeSpecKind SpecKind => TypeSpecKind.ParsableFromString; - public required StringParsableTypeKind StringParsableTypeKind { get; init; } - - private string? _parseMethodName; - public string ParseMethodName - { - get - { - Debug.Assert(StringParsableTypeKind is not StringParsableTypeKind.AssignFromSectionValue); - - _parseMethodName ??= StringParsableTypeKind is StringParsableTypeKind.ByteArray - ? "ParseByteArray" - // MinimalDisplayString.Length is certainly > 2. - : $"Parse{(char.ToUpper(DisplayString[0]) + DisplayString.Substring(1)).Replace(".", "")}"; - - return _parseMethodName; - } - } } - internal enum StringParsableTypeKind + public enum StringParsableTypeKind { None = 0, diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/TypeSpec.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/TypeSpec.cs index 651a40639f0ced..1c243ae1cdc7c1 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/TypeSpec.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/gen/Specs/Types/TypeSpec.cs @@ -3,27 +3,26 @@ using System.Diagnostics; using Microsoft.CodeAnalysis; +using SourceGenerators; namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { [DebuggerDisplay("Name={DisplayString}, Kind={SpecKind}")] - internal abstract record TypeSpec + public abstract record TypeSpec { - private static readonly SymbolDisplayFormat s_minimalDisplayFormat = new SymbolDisplayFormat( - globalNamespaceStyle: SymbolDisplayGlobalNamespaceStyle.Omitted, - typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypes, - genericsOptions: SymbolDisplayGenericsOptions.IncludeTypeParameters, - miscellaneousOptions: SymbolDisplayMiscellaneousOptions.UseSpecialTypes); - public TypeSpec(ITypeSymbol type) { - Namespace = type.ContainingNamespace?.ToDisplayString(); - DisplayString = type.ToDisplayString(s_minimalDisplayFormat); - Name = (Namespace is null ? string.Empty : Namespace + ".") + DisplayString.Replace(".", "+"); + TypeRef = new TypeRef(type); + EffectiveTypeRef = TypeRef; // Overriden by NullableSpec. + (Namespace, DisplayString, Name) = type.GetTypeName(); IdentifierCompatibleSubstring = type.ToIdentifierCompatibleSubstring(); IsValueType = type.IsValueType; } + public TypeRef TypeRef { get; } + + public TypeRef EffectiveTypeRef { get; protected init; } + public string Name { get; } public string DisplayString { get; } @@ -33,24 +32,35 @@ public TypeSpec(ITypeSymbol type) public string? Namespace { get; } public bool IsValueType { get; } + } - public abstract TypeSpecKind SpecKind { get; } + public abstract record ComplexTypeSpec : TypeSpec + { + protected ComplexTypeSpec(ITypeSymbol type) : base(type) { } + } - public abstract bool CanBindTo { get; } + internal sealed record NullableSpec : TypeSpec + { + public NullableSpec(ITypeSymbol type, TypeRef underlyingTypeRef) : base(type) => + EffectiveTypeRef = underlyingTypeRef; + } - public abstract bool CanInstantiate { get; } + internal sealed record UnsupportedTypeSpec : TypeSpec + { + public UnsupportedTypeSpec(ITypeSymbol type) : base(type) { } - public abstract TypeSpec EffectiveType { get; } + public required NotSupportedReason NotSupportedReason { get; init; } } - internal enum TypeSpecKind + public enum NotSupportedReason { - Unknown = 0, - ParsableFromString = 1, - Object = 2, - Enumerable = 3, - Dictionary = 4, - IConfigurationSection = 5, - Nullable = 6, + UnknownType = 1, + MissingPublicInstanceConstructor = 2, + CollectionNotSupported = 3, + DictionaryKeyNotSupported = 4, + ElementTypeNotSupported = 5, + MultipleParameterizedConstructors = 6, + MultiDimArraysNotSupported = 7, + NullableUnderlyingTypeNotSupported = 8, } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/src/ConfigurationBinder.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/src/ConfigurationBinder.cs index 74365d8084aa3d..dfc35d80208553 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/src/ConfigurationBinder.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/src/ConfigurationBinder.cs @@ -298,6 +298,11 @@ private static void BindInstance( return; } + if (config is null) + { + return; + } + var section = config as IConfigurationSection; string? configValue = section?.Value; if (configValue != null && TryConvertValue(type, configValue, section?.Path, out object? convertedValue, out Exception? error)) diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/Common/ConfigurationBinderTests.TestClasses.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/Common/ConfigurationBinderTests.TestClasses.cs index f47cdbe6dbbb54..7d10f66c822fc0 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/Common/ConfigurationBinderTests.TestClasses.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/Common/ConfigurationBinderTests.TestClasses.cs @@ -842,6 +842,11 @@ internal sealed class DerivedWithAnotherProp : AbstractBase { public int Value2 { get; set; } } + + internal class ClassWithAbstractProp + { + public AbstractBase AbstractProp { get; set; } + } internal class ClassWithAbstractCtorParam { @@ -888,5 +893,11 @@ public int MyIntProperty } } + public class SimplePoco + { + public string A { get; set; } + public string B { get; set; } + } + } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/Common/ConfigurationBinderTests.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/Common/ConfigurationBinderTests.cs index 7c955e789184c8..296cb790c22ba5 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/Common/ConfigurationBinderTests.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/Common/ConfigurationBinderTests.cs @@ -11,6 +11,7 @@ #if BUILDING_SOURCE_GENERATOR_TESTS using Microsoft.Extensions.Configuration; #endif +using Microsoft.Extensions.Configuration.Memory; using Microsoft.Extensions.Configuration.Test; using Xunit; @@ -1767,7 +1768,7 @@ public void EnsureCallingThePropertySetter() Assert.Equal(0, options.OtherCodeNullable); Assert.Equal("default", options.OtherCodeString); Assert.Null(options.OtherCodeNull); - Assert.Null(options.OtherCodeUri); + Assert.Null(options.OtherCodeUri); } [Fact] @@ -2238,7 +2239,7 @@ void TestUntypedOverloads(IConfiguration? configuration, string? key) Assert.Throws(() => configuration.GetValue(typeof(GeolocationClass), key, new GeolocationClass())); Assert.Throws(() => configuration.GetValue(typeof(Geolocation), key)); Assert.Throws(() => configuration.GetValue(typeof(Geolocation), key, defaultValue: null)); - Assert.Throws(() => configuration.GetValue(typeof(Geolocation), key, default(Geolocation))); + Assert.Throws(() => configuration.GetValue(typeof(Geolocation), key, default(Geolocation))); } } @@ -2323,6 +2324,25 @@ public static void TestBindingAbstractMember_AsCtorParam() Assert.Throws(configuration.Get); } + [Fact] + public static void TestBindingInitializedAbstractMember() + { + IConfiguration configuration = TestHelpers.GetConfigurationFromJsonString(@"{ ""AbstractProp"": {""Value"":1} }"); + ClassWithAbstractProp c = new(); + c.AbstractProp = new Derived(); + configuration.Bind(c); + Assert.Equal(1, c.AbstractProp.Value); + } + + [Fact] + public static void TestBindingUninitializedAbstractMember() + { + IConfiguration configuration = TestHelpers.GetConfigurationFromJsonString(@"{ ""AbstractProp"": {""Value"":1} }"); + ClassWithAbstractProp c = new(); + c.AbstractProp = null; + Assert.Throws(() => configuration.Bind(c)); + } + [Fact] public void GetIConfigurationSection() { @@ -2404,5 +2424,38 @@ public void SharedChildInstance() config.GetSection("A").Bind(instance); Assert.Equal("localhost", instance.ConnectionString); } + + [Fact] + public void CanBindToMockConfigurationSection() + { + const string expectedA = "hello"; + + var configSource = new MemoryConfigurationSource() + { + InitialData = new Dictionary() + { + [$":{nameof(SimplePoco.A)}"] = expectedA, + } + }; + var configRoot = new MockConfigurationRoot(new[] { configSource.Build(null) }); + var configSection = new ConfigurationSection(configRoot, string.Empty); + + SimplePoco result = new(); + configSection.Bind(result); + + Assert.Equal(expectedA, result.A); + Assert.Equal(default(string), result.B); + } + + // a mock configuration root that will return null for undefined Sections, + // as is common when Configuration interfaces are mocked + class MockConfigurationRoot : ConfigurationRoot, IConfigurationRoot + { + public MockConfigurationRoot(IList providers) : base(providers) + { } + + IConfigurationSection IConfiguration.GetSection(string key) => + this[key] is null ? null : new ConfigurationSection(this, key); + } } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/Collections.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/Collections.generated.txt index ddd52c68b99892..ea4fba79cbc465 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/Collections.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/Collections.generated.txt @@ -37,7 +37,7 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration #endregion IConfiguration extensions. #region Core binding extensions. - private readonly static Lazy> s_configKeys_ProgramMyClassWithCustomCollections = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "CustomDictionary", "CustomList", "IReadOnlyList", "IReadOnlyDictionary" }); + private readonly static Lazy> s_configKeys_ProgramMyClassWithCustomCollections = new(() => new HashSet(StringComparer.OrdinalIgnoreCase) { "CustomDictionary", "CustomList", "ICustomDictionary", "ICustomCollection", "IReadOnlyList", "UnsupportedIReadOnlyDictionaryUnsupported", "IReadOnlyDictionary" }); public static object? GetCore(this IConfiguration configuration, Type type, Action? configureOptions) { @@ -85,28 +85,6 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration } } - public static void BindCore(IConfiguration configuration, ref List instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) - { - foreach (IConfigurationSection section in configuration.GetChildren()) - { - if (section.Value is string value) - { - instance.Add(ParseInt(value, () => section.Path)); - } - } - } - - public static void BindCore(IConfiguration configuration, ref ICollection instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) - { - foreach (IConfigurationSection section in configuration.GetChildren()) - { - if (section.Value is string value) - { - instance.Add(ParseInt(value, () => section.Path)); - } - } - } - public static void BindCore(IConfiguration configuration, ref IReadOnlyList instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { if (instance is not ICollection temp) @@ -123,28 +101,6 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration } } - public static void BindCore(IConfiguration configuration, ref Dictionary instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) - { - foreach (IConfigurationSection section in configuration.GetChildren()) - { - if (section.Value is string value) - { - instance[section.Key] = ParseInt(value, () => section.Path); - } - } - } - - public static void BindCore(IConfiguration configuration, ref IDictionary instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) - { - foreach (IConfigurationSection section in configuration.GetChildren()) - { - if (section.Value is string value) - { - instance[section.Key] = ParseInt(value, () => section.Path); - } - } - } - public static void BindCore(IConfiguration configuration, ref IReadOnlyDictionary instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { if (instance is not IDictionary temp) @@ -184,7 +140,7 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration if (AsConfigWithChildren(configuration.GetSection("IReadOnlyList")) is IConfigurationSection section7) { IReadOnlyList? temp9 = instance.IReadOnlyList; - temp9 = temp9 is null ? new List() : new List(temp9); + temp9 = temp9 is null ? (IReadOnlyList)new List() : (IReadOnlyList)new List(temp9); BindCore(section7, ref temp9, defaultValueIfNotFound: false, binderOptions); instance.IReadOnlyList = temp9; } @@ -192,7 +148,7 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration if (AsConfigWithChildren(configuration.GetSection("IReadOnlyDictionary")) is IConfigurationSection section10) { IReadOnlyDictionary? temp12 = instance.IReadOnlyDictionary; - temp12 = temp12 is null ? new Dictionary() : temp12.ToDictionary(pair => pair.Key, pair => pair.Value); + temp12 = temp12 is null ? (IReadOnlyDictionary)new Dictionary() : (IReadOnlyDictionary)temp12.ToDictionary(pair => pair.Key, pair => pair.Value); BindCore(section10, ref temp12, defaultValueIfNotFound: false, binderOptions); instance.IReadOnlyDictionary = temp12; } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get.generated.txt index 5e7eeae29254a4..b6fb659d544d42 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get.generated.txt @@ -95,7 +95,15 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration public static void BindCore(IConfiguration configuration, ref int[] instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { var temp2 = new List(); - BindCore(configuration, ref temp2, defaultValueIfNotFound: false, binderOptions); + + foreach (IConfigurationSection section in configuration.GetChildren()) + { + if (section.Value is string value) + { + temp2.Add(ParseInt(value, () => section.Path)); + } + } + int originalCount = instance.Length; Array.Resize(ref instance, originalCount + temp2.Count); temp2.CopyTo(instance, originalCount); @@ -116,42 +124,42 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { ValidateConfigurationKeys(typeof(Program.MyClass), s_configKeys_ProgramMyClass, configuration, binderOptions); - if (configuration["MyString"] is string value4) + if (configuration["MyString"] is string value3) { - instance.MyString = value4; + instance.MyString = value3; } - if (configuration["MyInt"] is string value5) + if (configuration["MyInt"] is string value4) { - instance.MyInt = ParseInt(value5, () => configuration.GetSection("MyInt").Path); + instance.MyInt = ParseInt(value4, () => configuration.GetSection("MyInt").Path); } else if (defaultValueIfNotFound) { instance.MyInt = default; } - if (AsConfigWithChildren(configuration.GetSection("MyList")) is IConfigurationSection section6) + if (AsConfigWithChildren(configuration.GetSection("MyList")) is IConfigurationSection section5) { - List? temp8 = instance.MyList; - temp8 ??= new List(); - BindCore(section6, ref temp8, defaultValueIfNotFound: false, binderOptions); - instance.MyList = temp8; + List? temp7 = instance.MyList; + temp7 ??= new List(); + BindCore(section5, ref temp7, defaultValueIfNotFound: false, binderOptions); + instance.MyList = temp7; } - if (AsConfigWithChildren(configuration.GetSection("MyArray")) is IConfigurationSection section9) + if (AsConfigWithChildren(configuration.GetSection("MyArray")) is IConfigurationSection section8) { - int[]? temp11 = instance.MyArray; - temp11 ??= new int[0]; - BindCore(section9, ref temp11, defaultValueIfNotFound: false, binderOptions); - instance.MyArray = temp11; + int[]? temp10 = instance.MyArray; + temp10 ??= new int[0]; + BindCore(section8, ref temp10, defaultValueIfNotFound: false, binderOptions); + instance.MyArray = temp10; } - if (AsConfigWithChildren(configuration.GetSection("MyDictionary")) is IConfigurationSection section12) + if (AsConfigWithChildren(configuration.GetSection("MyDictionary")) is IConfigurationSection section11) { - Dictionary? temp14 = instance.MyDictionary; - temp14 ??= new Dictionary(); - BindCore(section12, ref temp14, defaultValueIfNotFound: false, binderOptions); - instance.MyDictionary = temp14; + Dictionary? temp13 = instance.MyDictionary; + temp13 ??= new Dictionary(); + BindCore(section11, ref temp13, defaultValueIfNotFound: false, binderOptions); + instance.MyDictionary = temp13; } } @@ -159,9 +167,9 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { ValidateConfigurationKeys(typeof(Program.MyClass2), s_configKeys_ProgramMyClass2, configuration, binderOptions); - if (configuration["MyInt"] is string value15) + if (configuration["MyInt"] is string value14) { - instance.MyInt = ParseInt(value15, () => configuration.GetSection("MyInt").Path); + instance.MyInt = ParseInt(value14, () => configuration.GetSection("MyInt").Path); } else if (defaultValueIfNotFound) { diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_PrimitivesOnly.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_PrimitivesOnly.generated.txt new file mode 100644 index 00000000000000..b703fb5f1c864b --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_PrimitivesOnly.generated.txt @@ -0,0 +1,182 @@ +// +#nullable enable +#pragma warning disable CS0612, CS0618 // Suppress warnings about [Obsolete] member usage in generated code. + +namespace System.Runtime.CompilerServices +{ + using System; + using System.CodeDom.Compiler; + + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : Attribute + { + public InterceptsLocationAttribute(string filePath, int line, int column) + { + } + } +} + +namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration +{ + using Microsoft.Extensions.Configuration; + using System; + using System.CodeDom.Compiler; + using System.Globalization; + using System.Runtime.CompilerServices; + + [GeneratedCode("Microsoft.Extensions.Configuration.Binder.SourceGeneration", "42.42.42.42")] + file static class BindingExtensions + { + #region IConfiguration extensions. + /// Attempts to bind the configuration instance to a new instance of type T. + [InterceptsLocation(@"src-0.cs", 10, 16)] + public static T? Get(this IConfiguration configuration) => (T?)(GetCore(configuration, typeof(T), configureOptions: null) ?? default(T)); + + /// Attempts to bind the configuration instance to a new instance of type T. + [InterceptsLocation(@"src-0.cs", 12, 16)] + public static T? Get(this IConfiguration configuration, Action? configureOptions) => (T?)(GetCore(configuration, typeof(T), configureOptions) ?? default(T)); + + /// Attempts to bind the configuration instance to a new instance of type T. + [InterceptsLocation(@"src-0.cs", 11, 16)] + public static object? Get(this IConfiguration configuration, Type type) => GetCore(configuration, type, configureOptions: null); + + /// Attempts to bind the configuration instance to a new instance of type T. + [InterceptsLocation(@"src-0.cs", 13, 16)] + public static object? Get(this IConfiguration configuration, Type type, Action? configureOptions) => GetCore(configuration, type, configureOptions); + #endregion IConfiguration extensions. + + #region Core binding extensions. + public static object? GetCore(this IConfiguration configuration, Type type, Action? configureOptions) + { + if (configuration is null) + { + throw new ArgumentNullException(nameof(configuration)); + } + + BinderOptions? binderOptions = GetBinderOptions(configureOptions); + + if (!HasValueOrChildren(configuration)) + { + return null; + } + + if (type == typeof(int)) + { + if (configuration is not IConfigurationSection section) + { + throw new InvalidOperationException(); + } + if (section.Value is string value) + { + return ParseInt(value, () => section.Path); + } + } + else if (type == typeof(string)) + { + if (configuration is not IConfigurationSection section) + { + throw new InvalidOperationException(); + } + return section.Value; + } + else if (type == typeof(float)) + { + if (configuration is not IConfigurationSection section) + { + throw new InvalidOperationException(); + } + if (section.Value is string value) + { + return ParseFloat(value, () => section.Path); + } + } + else if (type == typeof(double)) + { + if (configuration is not IConfigurationSection section) + { + throw new InvalidOperationException(); + } + if (section.Value is string value) + { + return ParseDouble(value, () => section.Path); + } + } + + throw new NotSupportedException($"Unable to bind to type '{type}': generator did not detect the type as input."); + } + + public static bool HasValueOrChildren(IConfiguration configuration) + { + if ((configuration as IConfigurationSection)?.Value is not null) + { + return true; + } + return AsConfigWithChildren(configuration) is not null; + } + + public static IConfiguration? AsConfigWithChildren(IConfiguration configuration) + { + foreach (IConfigurationSection _ in configuration.GetChildren()) + { + return configuration; + } + return null; + } + + public static BinderOptions? GetBinderOptions(Action? configureOptions) + { + if (configureOptions is null) + { + return null; + } + + BinderOptions binderOptions = new(); + configureOptions(binderOptions); + + if (binderOptions.BindNonPublicProperties) + { + throw new NotSupportedException($"The configuration binding source generator does not support 'BinderOptions.BindNonPublicProperties'."); + } + + return binderOptions; + } + + public static int ParseInt(string value, Func getPath) + { + try + { + return int.Parse(value, NumberStyles.Integer, CultureInfo.InvariantCulture); + } + catch (Exception exception) + { + throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(int)}'.", exception); + } + } + + public static float ParseFloat(string value, Func getPath) + { + try + { + return float.Parse(value, NumberStyles.Float, CultureInfo.InvariantCulture); + } + catch (Exception exception) + { + throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(float)}'.", exception); + } + } + + public static double ParseDouble(string value, Func getPath) + { + try + { + return double.Parse(value, NumberStyles.Float, CultureInfo.InvariantCulture); + } + catch (Exception exception) + { + throw new InvalidOperationException($"Failed to convert configuration value at '{getPath()}' to type '{typeof(double)}'.", exception); + } + } + #endregion Core binding extensions. + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_T.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_T.generated.txt index 3fc5176bf50f09..c2e8f167bb4750 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_T.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_T.generated.txt @@ -76,7 +76,15 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration public static void BindCore(IConfiguration configuration, ref int[] instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { var temp1 = new List(); - BindCore(configuration, ref temp1, defaultValueIfNotFound: false, binderOptions); + + foreach (IConfigurationSection section in configuration.GetChildren()) + { + if (section.Value is string value) + { + temp1.Add(ParseInt(value, () => section.Path)); + } + } + int originalCount = instance.Length; Array.Resize(ref instance, originalCount + temp1.Count); temp1.CopyTo(instance, originalCount); @@ -97,42 +105,42 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { ValidateConfigurationKeys(typeof(Program.MyClass), s_configKeys_ProgramMyClass, configuration, binderOptions); - if (configuration["MyString"] is string value3) + if (configuration["MyString"] is string value2) { - instance.MyString = value3; + instance.MyString = value2; } - if (configuration["MyInt"] is string value4) + if (configuration["MyInt"] is string value3) { - instance.MyInt = ParseInt(value4, () => configuration.GetSection("MyInt").Path); + instance.MyInt = ParseInt(value3, () => configuration.GetSection("MyInt").Path); } else if (defaultValueIfNotFound) { instance.MyInt = default; } - if (AsConfigWithChildren(configuration.GetSection("MyList")) is IConfigurationSection section5) + if (AsConfigWithChildren(configuration.GetSection("MyList")) is IConfigurationSection section4) { - List? temp7 = instance.MyList; - temp7 ??= new List(); - BindCore(section5, ref temp7, defaultValueIfNotFound: false, binderOptions); - instance.MyList = temp7; + List? temp6 = instance.MyList; + temp6 ??= new List(); + BindCore(section4, ref temp6, defaultValueIfNotFound: false, binderOptions); + instance.MyList = temp6; } - if (AsConfigWithChildren(configuration.GetSection("MyArray")) is IConfigurationSection section8) + if (AsConfigWithChildren(configuration.GetSection("MyArray")) is IConfigurationSection section7) { - int[]? temp10 = instance.MyArray; - temp10 ??= new int[0]; - BindCore(section8, ref temp10, defaultValueIfNotFound: false, binderOptions); - instance.MyArray = temp10; + int[]? temp9 = instance.MyArray; + temp9 ??= new int[0]; + BindCore(section7, ref temp9, defaultValueIfNotFound: false, binderOptions); + instance.MyArray = temp9; } - if (AsConfigWithChildren(configuration.GetSection("MyDictionary")) is IConfigurationSection section11) + if (AsConfigWithChildren(configuration.GetSection("MyDictionary")) is IConfigurationSection section10) { - Dictionary? temp13 = instance.MyDictionary; - temp13 ??= new Dictionary(); - BindCore(section11, ref temp13, defaultValueIfNotFound: false, binderOptions); - instance.MyDictionary = temp13; + Dictionary? temp12 = instance.MyDictionary; + temp12 ??= new Dictionary(); + BindCore(section10, ref temp12, defaultValueIfNotFound: false, binderOptions); + instance.MyDictionary = temp12; } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_T_BinderOptions.generated.txt b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_T_BinderOptions.generated.txt index 81c23d7ceea65a..cd3f237917d4e3 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_T_BinderOptions.generated.txt +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Baselines/ConfigurationBinder/Get_T_BinderOptions.generated.txt @@ -76,7 +76,15 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration public static void BindCore(IConfiguration configuration, ref int[] instance, bool defaultValueIfNotFound, BinderOptions? binderOptions) { var temp1 = new List(); - BindCore(configuration, ref temp1, defaultValueIfNotFound: false, binderOptions); + + foreach (IConfigurationSection section in configuration.GetChildren()) + { + if (section.Value is string value) + { + temp1.Add(ParseInt(value, () => section.Path)); + } + } + int originalCount = instance.Length; Array.Resize(ref instance, originalCount + temp1.Count); temp1.CopyTo(instance, originalCount); @@ -97,42 +105,42 @@ namespace Microsoft.Extensions.Configuration.Binder.SourceGeneration { ValidateConfigurationKeys(typeof(Program.MyClass), s_configKeys_ProgramMyClass, configuration, binderOptions); - if (configuration["MyString"] is string value3) + if (configuration["MyString"] is string value2) { - instance.MyString = value3; + instance.MyString = value2; } - if (configuration["MyInt"] is string value4) + if (configuration["MyInt"] is string value3) { - instance.MyInt = ParseInt(value4, () => configuration.GetSection("MyInt").Path); + instance.MyInt = ParseInt(value3, () => configuration.GetSection("MyInt").Path); } else if (defaultValueIfNotFound) { instance.MyInt = default; } - if (AsConfigWithChildren(configuration.GetSection("MyList")) is IConfigurationSection section5) + if (AsConfigWithChildren(configuration.GetSection("MyList")) is IConfigurationSection section4) { - List? temp7 = instance.MyList; - temp7 ??= new List(); - BindCore(section5, ref temp7, defaultValueIfNotFound: false, binderOptions); - instance.MyList = temp7; + List? temp6 = instance.MyList; + temp6 ??= new List(); + BindCore(section4, ref temp6, defaultValueIfNotFound: false, binderOptions); + instance.MyList = temp6; } - if (AsConfigWithChildren(configuration.GetSection("MyArray")) is IConfigurationSection section8) + if (AsConfigWithChildren(configuration.GetSection("MyArray")) is IConfigurationSection section7) { - int[]? temp10 = instance.MyArray; - temp10 ??= new int[0]; - BindCore(section8, ref temp10, defaultValueIfNotFound: false, binderOptions); - instance.MyArray = temp10; + int[]? temp9 = instance.MyArray; + temp9 ??= new int[0]; + BindCore(section7, ref temp9, defaultValueIfNotFound: false, binderOptions); + instance.MyArray = temp9; } - if (AsConfigWithChildren(configuration.GetSection("MyDictionary")) is IConfigurationSection section11) + if (AsConfigWithChildren(configuration.GetSection("MyDictionary")) is IConfigurationSection section10) { - Dictionary? temp13 = instance.MyDictionary; - temp13 ??= new Dictionary(); - BindCore(section11, ref temp13, defaultValueIfNotFound: false, binderOptions); - instance.MyDictionary = temp13; + Dictionary? temp12 = instance.MyDictionary; + temp12 ??= new Dictionary(); + BindCore(section10, ref temp12, defaultValueIfNotFound: false, binderOptions); + instance.MyDictionary = temp12; } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/ConfigBindingGenTestDriver.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/ConfigBindingGenTestDriver.cs new file mode 100644 index 00000000000000..4373b404fc67f0 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/ConfigBindingGenTestDriver.cs @@ -0,0 +1,156 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics; +using System.Globalization; +using System.Linq; +using System.Reflection; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.Extensions.Configuration.Binder.SourceGeneration; +using SourceGenerators.Tests; +using Xunit; + +namespace Microsoft.Extensions.SourceGeneration.Configuration.Binder.Tests +{ + [ActiveIssue("https://github.com/dotnet/runtime/issues/52062", TestPlatforms.Browser)] + public partial class ConfigurationBindingGeneratorTests : ConfigurationBinderTestsBase + { + internal sealed class ConfigBindingGenTestDriver + { + private readonly CSharpParseOptions _parseOptions; + private GeneratorDriver _generatorDriver; + private SourceGenerationSpec? _genSpec; + + private readonly LanguageVersion _langVersion; + private readonly IEnumerable? _assemblyReferences; + private Compilation _compilation = null; + + public ConfigBindingGenTestDriver( + LanguageVersion langVersion = LanguageVersion.LatestMajor, + IEnumerable? assemblyReferences = null) + { + _langVersion = langVersion; + + _assemblyReferences = assemblyReferences ?? s_compilationAssemblyRefs; + + _parseOptions = new CSharpParseOptions(langVersion).WithFeatures(new[] { + new KeyValuePair("InterceptorsPreview", "") , + new KeyValuePair("InterceptorsPreviewNamespaces", "Microsoft.Extensions.Configuration.Binder.SourceGeneration") + }); + + ConfigurationBindingGenerator generator = new() { OnSourceEmitting = spec => _genSpec = spec }; + _generatorDriver = CSharpGeneratorDriver.Create( + new ISourceGenerator[] { generator.AsSourceGenerator() }, + parseOptions: _parseOptions, + driverOptions: new GeneratorDriverOptions( + disabledOutputs: IncrementalGeneratorOutputKind.None, + trackIncrementalGeneratorSteps: true)); + } + + public async Task RunGeneratorAndUpdateCompilation(string? source = null) + { + await UpdateCompilationWithSource(source); + Assert.NotNull(_compilation); + + _generatorDriver = _generatorDriver.RunGeneratorsAndUpdateCompilation(_compilation, out Compilation outputCompilation, out _, CancellationToken.None); + GeneratorDriverRunResult runResult = _generatorDriver.GetRunResult(); + + return new ConfigBindingGenRunResult + { + OutputCompilation = outputCompilation, + Diagnostics = runResult.Diagnostics, + GeneratedSource = runResult.Results[0].GeneratedSources is { Length: not 0 } sources ? sources[0] : null, + TrackedSteps = runResult.Results[0].TrackedSteps[ConfigurationBindingGenerator.GenSpecTrackingName], + GenerationSpec = _genSpec + }; + } + + private async Task UpdateCompilationWithSource(string? source = null) + { + if (_compilation is not null && source is not null) + { + SyntaxTree newTree = CSharpSyntaxTree.ParseText(source, _parseOptions); + _compilation = _compilation.ReplaceSyntaxTree(_compilation.SyntaxTrees.First(), newTree); + } + else if (_compilation is null) + { + Assert.True(source is not null, "Generator test requires input source."); + using AdhocWorkspace workspace = RoslynTestUtils.CreateTestWorkspace(); + + Project project = RoslynTestUtils.CreateTestProject(workspace, _assemblyReferences, langVersion: _langVersion) + .WithCompilationOptions(new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary).WithNullableContextOptions(NullableContextOptions.Annotations)) + .WithParseOptions(_parseOptions) + .WithDocuments(new string[] { source }); + Assert.True(project.Solution.Workspace.TryApplyChanges(project.Solution)); + + _compilation = (await project.GetCompilationAsync(CancellationToken.None).ConfigureAwait(false))!; + } + } + } + } + + internal struct ConfigBindingGenRunResult + { + public required Compilation OutputCompilation { get; init; } + + public required GeneratedSourceResult? GeneratedSource { get; init; } + + /// + /// Diagnostics produced by the generator alone. Doesn't include any from other build participants. + /// + public required ImmutableArray Diagnostics { get; init; } + + public required ImmutableArray TrackedSteps { get; init; } + + public required SourceGenerationSpec? GenerationSpec { get; init; } + } + + internal enum ExpectedDiagnostics + { + None, + FromGeneratorOnly, + } + + internal static class ConfigBindingGenTestDriverExtensions + { + public static void ValidateIncrementalResult(this ConfigBindingGenRunResult result, + IncrementalStepRunReason inputReason, + IncrementalStepRunReason outputReason) + { + Assert.Collection(result.TrackedSteps, step => + { + Assert.Collection(step.Inputs, source => Assert.Equal(inputReason, source.Source.Outputs[source.OutputIndex].Reason)); + Assert.Collection(step.Outputs, output => Assert.Equal(outputReason, output.Reason)); + }); + } + + public static void ValidateDiagnostics(this ConfigBindingGenRunResult result, ExpectedDiagnostics expectedDiags) + { + ImmutableArray outputDiagnostics = result.OutputCompilation.GetDiagnostics(); + + if (expectedDiags is ExpectedDiagnostics.None) + { + foreach (Diagnostic diagnostic in outputDiagnostics) + { + Assert.True( + IsPermitted(diagnostic), + $"Generator caused dagnostic in output compilation: {diagnostic.GetMessage(CultureInfo.InvariantCulture)}."); + } + } + else + { + Debug.Assert(expectedDiags is ExpectedDiagnostics.FromGeneratorOnly); + + Assert.NotEmpty(result.Diagnostics); + Assert.False(outputDiagnostics.Any(diag => !IsPermitted(diag))); + } + + static bool IsPermitted(Diagnostic diagnostic) => diagnostic.Severity <= DiagnosticSeverity.Info; + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.Baselines.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.Baselines.cs index 3c46f5f99818b1..e05a7737137128 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.Baselines.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.Baselines.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Collections.Immutable; using System.Linq; using System.Threading.Tasks; using Microsoft.CodeAnalysis; @@ -141,7 +142,7 @@ public static void Main() public class MyClass { - public string MyString { get; set; } + public string? MyString { get; set; } public int MyInt { get; set; } public List MyList { get; set; } public Dictionary MyDictionary { get; set; } @@ -314,6 +315,30 @@ public class MyClass4 await VerifyAgainstBaselineUsingFile("Get.generated.txt", source, extType: ExtensionClassType.ConfigurationBinder); } + [Fact] + public async Task Get_PrimitivesOnly() + { + string source = """ + using Microsoft.Extensions.Configuration; + + public class Program + { + public static void Main() + { + ConfigurationBuilder configurationBuilder = new(); + IConfigurationRoot config = configurationBuilder.Build(); + + config.Get(); + config.Get(typeof(string)); + config.Get(binderOptions => { }); + config.Get(typeof(double), binderOptions => { }); + } + } + """; + + await VerifyAgainstBaselineUsingFile("Get_PrimitivesOnly.generated.txt", source, extType: ExtensionClassType.ConfigurationBinder); + } + [Fact] public async Task Get_T() { @@ -654,9 +679,9 @@ public class MyClass2 }" ; - var (d, r) = await RunGenerator(source); - Assert.Empty(r); - Assert.Empty(d); + ConfigBindingGenRunResult result = await RunGeneratorAndUpdateCompilation(source); + Assert.False(result.GeneratedSource.HasValue); + Assert.Empty(result.Diagnostics); } [Fact] @@ -736,6 +761,7 @@ public static void Main() section.Get(); } + // Diagnostic warning because we don't know how to instantiate two properties on this type. public class MyClassWithCustomCollections { public CustomDictionary CustomDictionary { get; set; } @@ -743,6 +769,7 @@ public class MyClassWithCustomCollections public ICustomDictionary ICustomDictionary { get; set; } public ICustomSet ICustomCollection { get; set; } public IReadOnlyList IReadOnlyList { get; set; } + // Diagnostic warning because we don't know how to instantiate the property type. public IReadOnlyDictionary UnsupportedIReadOnlyDictionaryUnsupported { get; set; } public IReadOnlyDictionary IReadOnlyDictionary { get; set; } } @@ -755,21 +782,26 @@ public class CustomList : List { } + // Diagnostic warning because we don't know how to instantiate this type. public interface ICustomDictionary : IDictionary { } + // Diagnostic warning because we don't know how to instantiate this type. public interface ICustomSet : ISet { } } """; - await VerifyAgainstBaselineUsingFile("Collections.generated.txt", source, validateOutputCompDiags: false, assessDiagnostics: (d) => - { - Assert.Equal(3, d.Where(diag => diag.Id == Diagnostics.TypeNotSupported.Id).Count()); - Assert.Equal(6, d.Where(diag => diag.Id == Diagnostics.PropertyNotSupported.Id).Count()); - }); + ConfigBindingGenRunResult result = await VerifyAgainstBaselineUsingFile( + "Collections.generated.txt", + source, + expectedDiags: ExpectedDiagnostics.FromGeneratorOnly); + + ImmutableArray diagnostics = result.Diagnostics; + Assert.Equal(3, diagnostics.Where(diag => diag.Id == Diagnostics.TypeNotSupported.Id).Count()); + Assert.Equal(3, diagnostics.Where(diag => diag.Id == Diagnostics.PropertyNotSupported.Id).Count()); } [Fact] @@ -811,14 +843,12 @@ public abstract class AbstractType_CannotInit } """; - await VerifyAgainstBaselineUsingFile( + ConfigBindingGenRunResult result = await VerifyAgainstBaselineUsingFile( "EmptyConfigType.generated.txt", source, - assessDiagnostics: (d) => - { - Assert.Equal(2, d.Where(diag => diag.Id == Diagnostics.TypeNotSupported.Id).Count()); - }, - validateOutputCompDiags: false); + expectedDiags: ExpectedDiagnostics.FromGeneratorOnly); + + Assert.Equal(2, result.Diagnostics.Where(diag => diag.Id == Diagnostics.TypeNotSupported.Id).Count()); } } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.Helpers.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.Helpers.cs index 7a47ad1cb27251..cbbd34e7fc41da 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.Helpers.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.Helpers.cs @@ -9,10 +9,10 @@ using System.IO; using System.Linq; using System.Reflection; -using System.Threading; using System.Threading.Tasks; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Text; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Configuration.Binder.SourceGeneration; using Microsoft.Extensions.DependencyInjection; @@ -24,6 +24,9 @@ namespace Microsoft.Extensions.SourceGeneration.Configuration.Binder.Tests { public partial class ConfigurationBindingGeneratorTests { + /// + /// Keep in sync with variants, e.g. . + /// private const string BindCallSampleCode = """ using System.Collections.Generic; using Microsoft.Extensions.Configuration; @@ -63,6 +66,7 @@ private static class Diagnostics } private static readonly Assembly[] s_compilationAssemblyRefs = new[] { + typeof(BitArray).Assembly, typeof(ConfigurationBinder).Assembly, typeof(ConfigurationBuilder).Assembly, typeof(CultureInfo).Assembly, @@ -87,18 +91,19 @@ private enum ExtensionClassType private static async Task VerifyThatSourceIsGenerated(string testSourceCode) { - var (d, r) = await RunGenerator(testSourceCode); - Assert.Equal(1, r.Length); - Assert.Empty(d); - Assert.True(r[0].SourceText.Lines.Count > 10); + ConfigBindingGenRunResult result = await RunGeneratorAndUpdateCompilation(testSourceCode); + GeneratedSourceResult? source = result.GeneratedSource; + + Assert.NotNull(source); + Assert.Empty(result.Diagnostics); + Assert.True(source.Value.SourceText.Lines.Count > 10); } - private static async Task VerifyAgainstBaselineUsingFile( + private static async Task VerifyAgainstBaselineUsingFile( string filename, string testSourceCode, - Action>? assessDiagnostics = null, ExtensionClassType extType = ExtensionClassType.None, - bool validateOutputCompDiags = true) + ExpectedDiagnostics expectedDiags = ExpectedDiagnostics.None) { string path = extType is ExtensionClassType.None ? Path.Combine("Baselines", filename) @@ -107,70 +112,52 @@ private static async Task VerifyAgainstBaselineUsingFile( string[] expectedLines = baseline.Replace("%VERSION%", typeof(ConfigurationBindingGenerator).Assembly.GetName().Version?.ToString()) .Split(Environment.NewLine); - var (d, r) = await RunGenerator(testSourceCode, validateOutputCompDiags); - bool success = RoslynTestUtils.CompareLines(expectedLines, r[0].SourceText, out string errorMessage); + ConfigBindingGenRunResult result = await RunGeneratorAndUpdateCompilation(testSourceCode); + result.ValidateDiagnostics(expectedDiags); + + SourceText resultSourceText = result.GeneratedSource.Value.SourceText; + bool resultEqualsBaseline = RoslynTestUtils.CompareLines(expectedLines, resultSourceText, out string errorMessage); #if UPDATE_BASELINES - if (!success) + if (!resultEqualsBaseline) { - string? repoRootDir = Environment.GetEnvironmentVariable("RepoRootDir"); - Assert.True(repoRootDir is not null, "To update baselines, specifiy the root runtime repo dir"); + const string envVarName = "RepoRootDir"; + string errMessage = $"To update baselines, specify a '{envVarName}' environment variable. See this assembly's README.md doc for more details."; + + string? repoRootDir = Environment.GetEnvironmentVariable(envVarName); + Assert.True(repoRootDir is not null, errMessage); - IEnumerable lines = r[0].SourceText.Lines.Select(l => l.ToString()); + IEnumerable lines = resultSourceText.Lines.Select(l => l.ToString()); string source = string.Join(Environment.NewLine, lines).TrimEnd(Environment.NewLine.ToCharArray()) + Environment.NewLine; path = Path.Combine($"{repoRootDir}\\src\\libraries\\Microsoft.Extensions.Configuration.Binder\\tests\\SourceGenerationTests\\", path); await File.WriteAllTextAsync(path, source).ConfigureAwait(false); - success = true; + resultEqualsBaseline = true; } #endif - Assert.Single(r); - (assessDiagnostics ?? ((d) => Assert.Empty(d))).Invoke(d); - Assert.True(success, errorMessage); + Assert.True(resultEqualsBaseline, errorMessage); + + return result; } - private static async Task<(ImmutableArray, ImmutableArray)> RunGenerator( - string testSourceCode, - bool validateOutputCompDiags = false, + private static async Task RunGeneratorAndUpdateCompilation( + string source, LanguageVersion langVersion = LanguageVersion.CSharp12, - IEnumerable? references = null) + IEnumerable? assemblyReferences = null) { - using var workspace = RoslynTestUtils.CreateTestWorkspace(); - CSharpParseOptions parseOptions = new CSharpParseOptions(langVersion).WithFeatures(new[] { - new KeyValuePair("InterceptorsPreview", ""), - new KeyValuePair("InterceptorsPreviewNamespaces", "Microsoft.Extensions.Configuration.Binder.SourceGeneration") - }); - - Project proj = RoslynTestUtils.CreateTestProject(workspace, references ?? s_compilationAssemblyRefs, langVersion: langVersion) - .WithCompilationOptions(new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary).WithNullableContextOptions(NullableContextOptions.Annotations)) - .WithDocuments(new string[] { testSourceCode }) - .WithParseOptions(parseOptions); - - Assert.True(proj.Solution.Workspace.TryApplyChanges(proj.Solution)); - - Compilation comp = await proj.GetCompilationAsync(CancellationToken.None).ConfigureAwait(false); - CSharpGeneratorDriver cgd = CSharpGeneratorDriver.Create(new[] { new ConfigurationBindingGenerator().AsSourceGenerator() }, parseOptions: parseOptions); - GeneratorDriver gd = cgd.RunGeneratorsAndUpdateCompilation(comp, out Compilation outputCompilation, out _, CancellationToken.None); - GeneratorDriverRunResult runResult = gd.GetRunResult(); - - if (validateOutputCompDiags) - { - ImmutableArray diagnostics = outputCompilation.GetDiagnostics(); - Assert.False(diagnostics.Any(d => d.Severity > DiagnosticSeverity.Info)); - } - - return (runResult.Results[0].Diagnostics, runResult.Results[0].GeneratedSources); + ConfigBindingGenTestDriver driver = new ConfigBindingGenTestDriver(langVersion, assemblyReferences); + return await driver.RunGeneratorAndUpdateCompilation(source); } - public static List GetAssemblyRefsWithAdditional(params Type[] additional) + private static List GetAssemblyRefsWithAdditional(params Type[] additional) { List assemblies = new(s_compilationAssemblyRefs); assemblies.AddRange(additional.Select(t => t.Assembly)); return assemblies; } - public static HashSet GetFilteredAssemblyRefs(IEnumerable exclusions) + private static HashSet GetFilteredAssemblyRefs(IEnumerable exclusions) { HashSet assemblies = new(s_compilationAssemblyRefs); foreach (Type exclusion in exclusions) diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.Incremental.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.Incremental.cs new file mode 100644 index 00000000000000..aff9a0c20364ca --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.Incremental.cs @@ -0,0 +1,362 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading.Tasks; +using Microsoft.CodeAnalysis; +using Microsoft.Extensions.Configuration.Binder.SourceGeneration; +using SourceGenerators.Tests; +using Xunit; + +namespace Microsoft.Extensions.SourceGeneration.Configuration.Binder.Tests +{ + public partial class ConfigurationBindingGeneratorTests : ConfigurationBinderTestsBase + { + [ActiveIssue("https://github.com/dotnet/runtime/issues/52062", TestPlatforms.Browser)] + public sealed class IncrementalTests + { + [Fact] + public async Task CompilingTheSameSourceResultsInEqualModels() + { + SourceGenerationSpec spec1 = (await new ConfigBindingGenTestDriver().RunGeneratorAndUpdateCompilation(BindCallSampleCode)).GenerationSpec; + SourceGenerationSpec spec2 = (await new ConfigBindingGenTestDriver().RunGeneratorAndUpdateCompilation(BindCallSampleCode)).GenerationSpec; + + Assert.NotSame(spec1, spec2); + GeneratorTestHelpers.AssertStructurallyEqual(spec1, spec2); + + Assert.Equal(spec1, spec2); + Assert.Equal(spec1.GetHashCode(), spec2.GetHashCode()); + } + + [Fact] + public async Task RunWithNoDiags_Then_NoEdit() + { + ConfigBindingGenTestDriver driver = new ConfigBindingGenTestDriver(); + + ConfigBindingGenRunResult result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCode); + result.ValidateIncrementalResult(IncrementalStepRunReason.New, IncrementalStepRunReason.New); + + result = await driver.RunGeneratorAndUpdateCompilation(); + result.ValidateIncrementalResult(IncrementalStepRunReason.Modified, IncrementalStepRunReason.Unchanged); + } + + [Fact] + public async Task RunWithNoDiags_Then_ChangeInputOrder() + { + ConfigBindingGenTestDriver driver = new ConfigBindingGenTestDriver(); + + ConfigBindingGenRunResult result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCode); + result.ValidateIncrementalResult(IncrementalStepRunReason.New, IncrementalStepRunReason.New); + + // We expect different spec because diag locations are different. + result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCodeVariant_ReorderedInvocations); + result.ValidateIncrementalResult(IncrementalStepRunReason.Modified, IncrementalStepRunReason.Modified); + + // We expect different spec because members are reordered. + result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCodeVariant_ReorderedConfigTypeMembers); + result.ValidateIncrementalResult(IncrementalStepRunReason.Modified, IncrementalStepRunReason.Modified); + } + + [Fact] + public async Task RunWithNoDiags_Then_EditWithNoDiags() + { + ConfigBindingGenTestDriver driver = new ConfigBindingGenTestDriver(); + + ConfigBindingGenRunResult result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCode); + result.ValidateIncrementalResult(IncrementalStepRunReason.New, IncrementalStepRunReason.New); + + result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCodeVariant_WithDifferentConfigTypeName); + result.ValidateIncrementalResult(IncrementalStepRunReason.Modified, IncrementalStepRunReason.Modified); + } + + [Fact] + public async Task RunWithNoDiags_Then_EditWithDiags() + { + ConfigBindingGenTestDriver driver = new ConfigBindingGenTestDriver(); + + ConfigBindingGenRunResult result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCode); + result.ValidateIncrementalResult(IncrementalStepRunReason.New, IncrementalStepRunReason.New); + + result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCodeVariant_WithUnsupportedMember); + result.ValidateIncrementalResult(IncrementalStepRunReason.Modified, IncrementalStepRunReason.Modified); + } + + [Fact] + public async Task RunWithDiags_Then_NoEdit() + { + ConfigBindingGenTestDriver driver = new ConfigBindingGenTestDriver(); + + ConfigBindingGenRunResult result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCodeVariant_WithUnsupportedMember); + result.ValidateIncrementalResult(IncrementalStepRunReason.New, IncrementalStepRunReason.New); + + result = await driver.RunGeneratorAndUpdateCompilation(); + result.ValidateIncrementalResult(IncrementalStepRunReason.Modified, IncrementalStepRunReason.Unchanged); + } + + [Fact] + public async Task RunWithDiags_Then_ChangeInputOrder() + { + ConfigBindingGenTestDriver driver = new ConfigBindingGenTestDriver(); + + ConfigBindingGenRunResult result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCodeVariant_WithUnsupportedMember); + result.ValidateIncrementalResult(IncrementalStepRunReason.New, IncrementalStepRunReason.New); + + // We expect different spec because diag locations are different. + result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCodeVariant_WithUnsupportedMember_ReorderedInvocations); + result.ValidateIncrementalResult(IncrementalStepRunReason.Modified, IncrementalStepRunReason.Modified); + + // We expect different spec because members are reordered. + result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCodeVariant_WithUnsupportedMember_ReorderedConfigTypeMembers); + result.ValidateIncrementalResult(IncrementalStepRunReason.Modified, IncrementalStepRunReason.Modified); + } + + [Fact] + public async Task RunWithDiags_Then_EditWithNoDiags() + { + ConfigBindingGenTestDriver driver = new ConfigBindingGenTestDriver(); + + ConfigBindingGenRunResult result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCodeVariant_WithUnsupportedMember); + result.ValidateIncrementalResult(IncrementalStepRunReason.New, IncrementalStepRunReason.New); + + result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCode); + result.ValidateIncrementalResult(IncrementalStepRunReason.Modified, IncrementalStepRunReason.Modified); + } + + [Fact] + public async Task RunWithDiags_Then_EditWithDiags() + { + ConfigBindingGenTestDriver driver = new ConfigBindingGenTestDriver(); + + ConfigBindingGenRunResult result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCodeVariant_WithUnsupportedMember); + result.ValidateIncrementalResult(IncrementalStepRunReason.New, IncrementalStepRunReason.New); + + result = await driver.RunGeneratorAndUpdateCompilation(BindCallSampleCodeVariant_WithUnsupportedMember_WithDiffMemberName); + result.ValidateIncrementalResult(IncrementalStepRunReason.Modified, IncrementalStepRunReason.Modified); + } + } + + #region Incremental test sources. + /// + /// Keep in sync with . + /// + private const string BindCallSampleCodeVariant_ReorderedInvocations = """ + using System.Collections.Generic; + using Microsoft.Extensions.Configuration; + + public class Program + { + public static void Main() + { + ConfigurationBuilder configurationBuilder = new(); + IConfigurationRoot config = configurationBuilder.Build(); + + MyClass configObj = new(); + config.Bind(configObj, options => { }); + config.Bind("key", configObj); + config.Bind(configObj); + } + + public class MyClass + { + public string MyString { get; set; } + public int MyInt { get; set; } + public List MyList { get; set; } + public Dictionary MyDictionary { get; set; } + public Dictionary MyComplexDictionary { get; set; } + } + + public class MyClass2 { } + } + """; + + /// + /// Keep in sync with . + /// + private const string BindCallSampleCodeVariant_ReorderedConfigTypeMembers = """ + using System.Collections.Generic; + using Microsoft.Extensions.Configuration; + + public class Program + { + public static void Main() + { + ConfigurationBuilder configurationBuilder = new(); + IConfigurationRoot config = configurationBuilder.Build(); + + MyClass configObj = new(); + config.Bind(configObj, options => { }); + config.Bind("key", configObj); + config.Bind(configObj); + } + + public class MyClass + { + public List MyList { get; set; } + public Dictionary MyDictionary { get; set; } + public string MyString { get; set; } + public int MyInt { get; set; } + public Dictionary MyComplexDictionary { get; set; } + } + + public class MyClass2 { } + } + """; + + /// + /// Keep in sync with . + /// + private const string BindCallSampleCodeVariant_WithDifferentConfigTypeName = """ + using System.Collections.Generic; + using Microsoft.Extensions.Configuration; + + public class Program + { + public static void Main() + { + ConfigurationBuilder configurationBuilder = new(); + IConfigurationRoot config = configurationBuilder.Build(); + + MyClass0 configObj = new(); + config.Bind(configObj, options => { }); + config.Bind("key", configObj); + config.Bind(configObj); + } + + public class MyClass0 + { + public List MyList { get; set; } + public Dictionary MyDictionary { get; set; } + public string MyString { get; set; } + public int MyInt { get; set; } + public Dictionary MyComplexDictionary { get; set; } + } + + public class MyClass2 { } + } + """; + + private const string BindCallSampleCodeVariant_WithUnsupportedMember = """ + using System.Collections.Generic; + using Microsoft.Extensions.Configuration; + + public class Program + { + public static void Main() + { + ConfigurationBuilder configurationBuilder = new(); + IConfigurationRoot config = configurationBuilder.Build(); + + MyClass configObj = new(); + config.Bind(configObj); + config.Bind(configObj, options => { }); + config.Bind("key", configObj); + } + + public class MyClass + { + public string MyString { get; set; } + public int MyInt { get; set; } + public List MyList { get; set; } + public Dictionary MyDictionary { get; set; } + public Dictionary MyComplexDictionary { get; set; } + public int[,] UnsupportedMember { get; set; } + } + + public class MyClass2 { } + } + """; + + private const string BindCallSampleCodeVariant_WithUnsupportedMember_ReorderedInvocations = """ + using System.Collections.Generic; + using Microsoft.Extensions.Configuration; + + public class Program + { + public static void Main() + { + ConfigurationBuilder configurationBuilder = new(); + IConfigurationRoot config = configurationBuilder.Build(); + + MyClass configObj = new(); + config.Bind("key", configObj); + config.Bind(configObj); + config.Bind(configObj, options => { }); + } + + public class MyClass + { + public string MyString { get; set; } + public int MyInt { get; set; } + public List MyList { get; set; } + public Dictionary MyDictionary { get; set; } + public Dictionary MyComplexDictionary { get; set; } + public int[,] UnsupportedMember { get; set; } + } + + public class MyClass2 { } + } + """; + + private const string BindCallSampleCodeVariant_WithUnsupportedMember_ReorderedConfigTypeMembers = """ + using System.Collections.Generic; + using Microsoft.Extensions.Configuration; + + public class Program + { + public static void Main() + { + ConfigurationBuilder configurationBuilder = new(); + IConfigurationRoot config = configurationBuilder.Build(); + + MyClass configObj = new(); + config.Bind("key", configObj); + config.Bind(configObj); + config.Bind(configObj, options => { }); + } + + public class MyClass + { + public string MyString { get; set; } + public int MyInt { get; set; } + public int[,] UnsupportedMember { get; set; } + public Dictionary MyDictionary { get; set; } + public Dictionary MyComplexDictionary { get; set; } + public List MyList { get; set; } + } + + public class MyClass2 { } + } + """; + + private const string BindCallSampleCodeVariant_WithUnsupportedMember_WithDiffMemberName = """ + using System.Collections.Generic; + using Microsoft.Extensions.Configuration; + + public class Program + { + public static void Main() + { + ConfigurationBuilder configurationBuilder = new(); + IConfigurationRoot config = configurationBuilder.Build(); + + MyClass configObj = new(); + config.Bind(configObj); + config.Bind(configObj, options => { }); + config.Bind("key", configObj); + } + + public class MyClass + { + public string MyString { get; set; } + public int MyInt { get; set; } + public List MyList { get; set; } + public Dictionary MyDictionary { get; set; } + public Dictionary MyComplexDictionary { get; set; } + public int[,] UnsupportedMember_DiffMemberName { get; set; } + } + + public class MyClass2 { } + } + """; + #endregion Incremental test sources. + } +} diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.cs b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.cs index 846e64d904d531..d93607d3763996 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.cs +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/GeneratorTests.cs @@ -27,10 +27,10 @@ public partial class ConfigurationBindingGeneratorTests : ConfigurationBinderTes [InlineData(LanguageVersion.CSharp10)] public async Task LangVersionMustBeCharp12OrHigher(LanguageVersion langVersion) { - var (d, r) = await RunGenerator(BindCallSampleCode, langVersion: langVersion); - Assert.Empty(r); + ConfigBindingGenRunResult result = await RunGeneratorAndUpdateCompilation(BindCallSampleCode, langVersion: langVersion); + Assert.False(result.GeneratedSource.HasValue); - Diagnostic diagnostic = Assert.Single(d); + Diagnostic diagnostic = Assert.Single(result.Diagnostics); Assert.True(diagnostic.Id == "SYSLIB1102"); Assert.Contains("C# 12", diagnostic.Descriptor.Title.ToString(CultureInfo.InvariantCulture)); Assert.Equal(DiagnosticSeverity.Error, diagnostic.Severity); @@ -75,11 +75,11 @@ public record struct MyRecordStruct { } } """; - var (d, r) = await RunGenerator(source); - Assert.Empty(r); - Assert.Equal(7, d.Count()); + ConfigBindingGenRunResult result = await RunGeneratorAndUpdateCompilation(source); + Assert.False(result.GeneratedSource.HasValue); + Assert.Equal(7, result.Diagnostics.Count()); - foreach (Diagnostic diagnostic in d) + foreach (Diagnostic diagnostic in result.Diagnostics) { Assert.True(diagnostic.Id == Diagnostics.ValueTypesInvalidForBind.Id); Assert.Contains(Diagnostics.ValueTypesInvalidForBind.Title, diagnostic.Descriptor.Title.ToString(CultureInfo.InvariantCulture)); @@ -111,11 +111,11 @@ public record struct MyRecordStruct { } } """; - var (d, r) = await RunGenerator(source); - Assert.Empty(r); - Assert.Equal(2, d.Count()); + ConfigBindingGenRunResult result = await RunGeneratorAndUpdateCompilation(source); + Assert.False(result.GeneratedSource.HasValue); + Assert.Equal(2, result.Diagnostics.Count()); - foreach (Diagnostic diagnostic in d) + foreach (Diagnostic diagnostic in result.Diagnostics) { Assert.True(diagnostic.Id == Diagnostics.CouldNotDetermineTypeInfo.Id); Assert.Contains(Diagnostics.CouldNotDetermineTypeInfo.Title, diagnostic.Descriptor.Title.ToString(CultureInfo.InvariantCulture)); @@ -163,11 +163,11 @@ public class MyClass { } } """; - var (d, r) = await RunGenerator(source); - Assert.Empty(r); - Assert.Equal(6, d.Count()); + ConfigBindingGenRunResult result = await RunGeneratorAndUpdateCompilation(source); + Assert.False(result.GeneratedSource.HasValue); + Assert.Equal(6, result.Diagnostics.Count()); - foreach (Diagnostic diagnostic in d) + foreach (Diagnostic diagnostic in result.Diagnostics) { Assert.True(diagnostic.Id == Diagnostics.CouldNotDetermineTypeInfo.Id); Assert.Contains(Diagnostics.CouldNotDetermineTypeInfo.Title, diagnostic.Descriptor.Title.ToString(CultureInfo.InvariantCulture)); @@ -218,22 +218,15 @@ public class MyClass0 { } async Task Test(bool expectOutput) { - var (d, r) = await RunGenerator(source, references: GetFilteredAssemblyRefs(exclusions)); - - Assert.Empty(d); - - if (expectOutput) - { - Assert.Single(r); - } - else - { - Assert.Empty(r); - } + ConfigBindingGenRunResult result = await RunGeneratorAndUpdateCompilation(source, assemblyReferences: GetFilteredAssemblyRefs(exclusions)); + Assert.Empty(result.Diagnostics); + Action ValidateSourceResult = expectOutput ? () => Assert.NotNull(result.GeneratedSource) : () => Assert.False(result.GeneratedSource.HasValue); + ValidateSourceResult(); } } [Fact] + [ActiveIssue("Work out why we aren't getting all the expected diagnostics.")] public async Task IssueDiagnosticsForAllOffendingCallsites() { string source = """ @@ -282,10 +275,10 @@ public class AnotherGraphWithUnsupportedMembers } """; - var (d, r) = await RunGenerator(source, references: GetAssemblyRefsWithAdditional(typeof(ImmutableArray<>), typeof(Encoding), typeof(JsonSerializer))); - Assert.Single(r); - Assert.Equal(47, d.Where(diag => diag.Id == Diagnostics.TypeNotSupported.Id).Count()); - Assert.Equal(44, d.Where(diag => diag.Id == Diagnostics.PropertyNotSupported.Id).Count()); + ConfigBindingGenRunResult result = await RunGeneratorAndUpdateCompilation(source, assemblyReferences: GetAssemblyRefsWithAdditional(typeof(ImmutableArray<>), typeof(Encoding), typeof(JsonSerializer))); + Assert.NotNull(result.GeneratedSource); + Assert.True(result.Diagnostics.Any(diag => diag.Id == Diagnostics.TypeNotSupported.Id)); + Assert.True(result.Diagnostics.Any(diag => diag.Id == Diagnostics.PropertyNotSupported.Id)); } } } diff --git a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Microsoft.Extensions.Configuration.Binder.SourceGeneration.Tests.csproj b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Microsoft.Extensions.Configuration.Binder.SourceGeneration.Tests.csproj index fc8db157eddeea..848d93b32a475a 100644 --- a/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Microsoft.Extensions.Configuration.Binder.SourceGeneration.Tests.csproj +++ b/src/libraries/Microsoft.Extensions.Configuration.Binder/tests/SourceGenerationTests/Microsoft.Extensions.Configuration.Binder.SourceGeneration.Tests.csproj @@ -2,8 +2,10 @@ $(NetCoreAppCurrent);$(NetFrameworkMinimum) true - - SYSLIB1100,SYSLIB1101 + + $(NoWarn);SYSLIB1100,SYSLIB1101 + + $(NoWarn);SYSLIB1103,SYSLIB1104 $(Features);InterceptorsPreview $(InterceptorsPreviewNamespaces);Microsoft.Extensions.Configuration.Binder.SourceGeneration @@ -22,6 +24,7 @@ + @@ -46,17 +49,16 @@ - + PreserveNewest + + diff --git a/src/libraries/Microsoft.Extensions.FileProviders.Abstractions/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.FileProviders.Abstractions/src/PACKAGE.md new file mode 100644 index 00000000000000..bbb9b68a06e4a6 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.FileProviders.Abstractions/src/PACKAGE.md @@ -0,0 +1,51 @@ +## About + + + +Serves as the foundation for creating file providers in .NET, offering core abstractions to develop custom file providers capable of fetching files from various sources. + +## Key Features + + + +* Core abstractions for creating and managing file providers. +* Flexibility to develop custom file providers for fetching files from distinct sources. + +## How to Use + + + +This package is typically used with an implementation of the file provider abstractions, such as `Microsoft.Extensions.FileProviders.Composite` or `Microsoft.Extensions.FileProviders.Physical`. + +## Main Types + + + +The main types provided by this library are: + +* `Microsoft.Extensions.FileProviders.IFileProvider` +* `Microsoft.Extensions.FileProviders.IDirectoryContents` +* `Microsoft.Extensions.FileProviders.IFileInfo` +* `Microsoft.Extensions.FileProviders.NullFileProvider` + +## Additional Documentation + + + +* [Conceptual documentation](https://learn.microsoft.com/aspnet/core/fundamentals/file-providers) +* [Detect changes with change tokens](https://learn.microsoft.com/aspnet/core/fundamentals/change-tokens) +* [API documentation](https://learn.microsoft.com/dotnet/api/microsoft.extensions.fileproviders) + +## Related Packages + + + +* File provider for physical files: [Microsoft.Extensions.FileProviders.Physical](https://www.nuget.org/packages/Microsoft.Extensions.FileProviders.Physical/) +* File provider for files in embedded resources: [Microsoft.Extensions.FileProviders.Embedded](https://www.nuget.org/packages/Microsoft.Extensions.FileProviders.Embedded/) +* Composite file and directory providers: [Microsoft.Extensions.FileProviders.Composite](https://www.nuget.org/packages/Microsoft.Extensions.FileProviders.Composite/) + +## Feedback & Contributing + + + +Microsoft.Extensions.FileProviders.Abstractions is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/Microsoft.Extensions.FileProviders.Physical/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.FileProviders.Physical/src/PACKAGE.md new file mode 100644 index 00000000000000..6ffcd733120209 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.FileProviders.Physical/src/PACKAGE.md @@ -0,0 +1,70 @@ +## About + + + +Provides an implementation of a physical file provider, facilitating file access and monitoring on the disk. The primary type, [`PhysicalFileProvider`](https://learn.microsoft.com/dotnet/api/microsoft.extensions.fileproviders.physicalfileprovider), enables the lookup of files on disk and can watch for changes either via `FileSystemWatcher` or polling mechanisms. + + +## Key Features + + + +* Easy access and monitoring of files on the disk. +* Ability to watch for file changes either by using `FileSystemWatcher` or through polling. + +## How to Use + + + +This library can be used to look up files on disk and monitor file changes effectively. +Below is an example of how to use the `PhysicalFileProvider` to access files on disk and monitor changes: + +```c# +using Microsoft.Extensions.FileProviders; +using Microsoft.Extensions.FileProviders.Physical; + +using var provider = new PhysicalFileProvider(AppContext.BaseDirectory); + +Environment.SetEnvironmentVariable("DOTNET_USE_POLLING_FILE_WATCHER", "1"); + +var contents = provider.GetDirectoryContents(string.Empty); +foreach (PhysicalFileInfo fileInfo in contents) +{ + Console.WriteLine(fileInfo.PhysicalPath); +} + +var changeToken = provider.Watch("*.txt"); +changeToken.RegisterChangeCallback(_ => Console.WriteLine("Text file changed"), null); + +Console.ReadLine(); +``` + +## Main Types + + + +The main types provided by this library are: + +* `Microsoft.Extensions.FileProviders.PhysicalFileProvider` +* `Microsoft.Extensions.FileProviders.PhysicalDirectoryInfo` +* `Microsoft.Extensions.FileProviders.PhysicalFileInfo` + +## Additional Documentation + + + +* [Conceptual documentation](https://learn.microsoft.com/aspnet/core/fundamentals/file-providers#physical-file-provider) +* [API documentation](https://learn.microsoft.com/dotnet/api/microsoft.extensions.fileproviders.physical) + +## Related Packages + + + +* Abstractions of files and directories: [Microsoft.Extensions.FileProviders.Abstractions](https://www.nuget.org/packages/Microsoft.Extensions.FileProviders.Abstractions/) +* File system globbing to find files matching a specified pattern: [Microsoft.Extensions.FileSystemGlobbing](https://www.nuget.org/packages/Microsoft.Extensions.FileSystemGlobbing/) + +## Feedback & Contributing + + + +Microsoft.Extensions.FileProviders.Physical is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/Microsoft.Extensions.FileSystemGlobbing/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.FileSystemGlobbing/src/PACKAGE.md new file mode 100644 index 00000000000000..25bd9129c3968b --- /dev/null +++ b/src/libraries/Microsoft.Extensions.FileSystemGlobbing/src/PACKAGE.md @@ -0,0 +1,52 @@ +## About + + + +Provides support for matching file system names/paths using [glob patterns](https://en.wikipedia.org/wiki/Glob_(programming)). + +## Key Features + + + +* Contains the `Matcher` type, which can be used to match files in the file system based on user-defined patterns. + +## How to Use + + + +Get all matching files: + +```c# +using Microsoft.Extensions.FileSystemGlobbing; + +Matcher matcher = new(); +matcher.AddIncludePatterns(new[] { "*.txt", "*.asciidoc", "*.md" }); + +string searchDirectory = "../starting-folder/"; + +IEnumerable matchingFiles = matcher.GetResultsInFullPath(searchDirectory); + +// Use matchingFiles if there are any found. +// The files in this collection are fully qualified file system paths. +``` + +## Main Types + + + +The main types provided by this library are: + +* `Microsoft.Extensions.FileSystemGlobbing.Matcher` + +## Additional Documentation + + + +* [Conceptual documentation](https://learn.microsoft.com/dotnet/core/extensions/file-globbing) +* [API documentation](https://learn.microsoft.com/dotnet/api/microsoft.extensions.filesystemglobbing) + +## Feedback & Contributing + + + +Microsoft.Extensions.FileSystemGlobbing is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/Microsoft.Extensions.Logging.Console/src/Microsoft.Extensions.Logging.Console.csproj b/src/libraries/Microsoft.Extensions.Logging.Console/src/Microsoft.Extensions.Logging.Console.csproj index 1f12dab5b9ac44..e83340eb0eae55 100644 --- a/src/libraries/Microsoft.Extensions.Logging.Console/src/Microsoft.Extensions.Logging.Console.csproj +++ b/src/libraries/Microsoft.Extensions.Logging.Console/src/Microsoft.Extensions.Logging.Console.csproj @@ -11,7 +11,8 @@ $(InterceptorsPreviewNamespaces);Microsoft.Extensions.Configuration.Binder.SourceGeneration true - true + + $(NoWarn);SYSLIB1100;SYSLIB1101 Console logger provider implementation for Microsoft.Extensions.Logging. diff --git a/src/libraries/Microsoft.Extensions.Logging.TraceSource/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Logging.TraceSource/src/PACKAGE.md new file mode 100644 index 00000000000000..a58e190ec552b9 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Logging.TraceSource/src/PACKAGE.md @@ -0,0 +1,75 @@ +## About + + + +Implements a trace logger provider for the .NET logging infrastructre facilitating enhanced logging capabilities and trace-level diagnostics in application by writing messages to a trace listener using System.Diagnostic.TraceSource. + +## Key Features + + + +* Seamless integration with .NET logging infrastructure. +* Fine-grained control over trace messages using SourceSwitch. +* A set of builder methods to configure logging infrastructure. + +## How to Use + + + +The Microsoft.Extensions.Logging.TraceSource library provides extension methods to the logger factory and the logger builder to add a trace source with trace listeners. + +```csharp +using System.Diagnostics; +using Microsoft.Extensions.Logging; + +using var consoleTraceListener = new ConsoleTraceListener(); +using var textWriterTraceListener = new TextWriterTraceListener("/traces.txt"); +using var loggerFactory = LoggerFactory.Create(builder => +{ + builder + .AddTraceSource(new SourceSwitch("Something") { Level = SourceLevels.All }, consoleTraceListener) + .AddTraceSource(new SourceSwitch("HouseKeeping") { Level = SourceLevels.All }, textWriterTraceListener); +}); + +var logger = loggerFactory.CreateLogger(); + +logger.LogInformation("Information message."); +// Program Information: 0 : Information message. +logger.LogWarning("Warning message."); +// Program Warning: 0 : Warning message. + +var traceSource = new TraceSource("HouseKeeping", SourceLevels.All); +traceSource.Listeners.Add(consoleTraceListener); +traceSource.Listeners.Add(textWriterTraceListener); + +traceSource.TraceEvent(TraceEventType.Error, 0, "Error message."); +//HouseKeeping Error: 0 : Error message. +``` + +## Main Types + + + +The main types provided by this library are: + +* `Microsoft.Extensions.Logging.TraceSource.TraceSourceLoggerProvider` + +## Additional Documentation + + + +* [API documentation](https://learn.microsoft.com/dotnet/api/microsoft.extensions.logging.tracesource) + +## Related Packages + + + +* Abstractions for dependency injection: [Microsoft.Extensions.DependencyInjection.Abstractions](https://www.nuget.org/packages/Microsoft.Extensions.DependencyInjection.Abstractions/) +* Default implementation of logging infrastructure: [Microsoft.Extensions.Logging](https://www.nuget.org/packages/Microsoft.Extensions.Logging/) +* Abstractions for logging: [Microsoft.Extensions.Logging.Abstractions](https://www.nuget.org/packages/Microsoft.Extensions.Logging.Abstractions/) + +## Feedback & Contributing + + + +Microsoft.Extensions.Logging.TraceSource is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/Microsoft.Extensions.Options.DataAnnotations/src/DataAnnotationValidateOptions.cs b/src/libraries/Microsoft.Extensions.Options.DataAnnotations/src/DataAnnotationValidateOptions.cs index e8f4b606882cf9..5b734edf32738c 100644 --- a/src/libraries/Microsoft.Extensions.Options.DataAnnotations/src/DataAnnotationValidateOptions.cs +++ b/src/libraries/Microsoft.Extensions.Options.DataAnnotations/src/DataAnnotationValidateOptions.cs @@ -95,7 +95,8 @@ private static bool TryValidateOptions(object options, string qualifiedName, Lis foreach (PropertyInfo propertyInfo in options.GetType().GetProperties(BindingFlags.Instance | BindingFlags.Public)) { - if (propertyInfo.GetMethod is null) + // Indexers are properties which take parameters. Ignore them. + if (propertyInfo.GetMethod is null || propertyInfo.GetMethod.GetParameters().Length > 0) { continue; } diff --git a/src/libraries/Microsoft.Extensions.Options.DataAnnotations/src/PACKAGE.md b/src/libraries/Microsoft.Extensions.Options.DataAnnotations/src/PACKAGE.md new file mode 100644 index 00000000000000..368518bbedb313 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Options.DataAnnotations/src/PACKAGE.md @@ -0,0 +1,75 @@ +## About + + + +Microsoft.Extensions.Options.DataAnnotations is a library that adds extra validation functionality to configuration options using data annotations. + +It allows to apply validation rules to configuration classes to ensure they are correctly configured before the application starts running. + +This way, misconfiguration issues are catched early during the application startup rather than facing them later in production. + +## Key Features + + + +* Enables validation of configuration options using data annotations. +* Early detection of misconfiguration issues during application startup. + +## How to Use + + + +While configuring services, chain the `ValidateDataAnnotations()` and `ValidateOnStart()` methods to the `AddOptions` method for your configuration class. + +Here is a simple example demonstrating how to validate options on application startup: + +```csharp +services + .AddOptions() + .ValidateDataAnnotations() + .ValidateOnStart(); +``` + +In the configuration class, use data annotations to specify the validation rules. + +For instance, in the following `MyOptions` class, the `Name` property is marked as required: + +```csharp +using System.ComponentModel.DataAnnotations; + +public class MyOptions +{ + [Required(AllowEmptyStrings = false)] + public string Name { get; set; } +} +``` + +With this setup, an error indicating that the `Name` field is required will be thrown upon startup if it hasn't been configured. + +## Main Types + + + +The main types provided by this library are: + +* `Microsoft.Extensions.Options.DataAnnotationsValidateOptions` +* `Microsoft.Extensions.DependencyInjection.OptionsBuilderDataAnnotationsExtensions` + +## Additional Documentation + + + +* [Conceptual documentation](https://learn.microsoft.com/dotnet/core/extensions/options) +* [API documentation](https://learn.microsoft.com/dotnet/api/microsoft.extensions.options.dataannotationvalidateoptions-1) + +## Related Packages + + + +Core options: [Microsoft.Extensions.Options](https://www.nuget.org/packages/Microsoft.Extensions.Options/) + +## Feedback & Contributing + + + +Microsoft.Extensions.Options.DataAnnotations is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/Microsoft.Extensions.Options/gen/DiagDescriptors.cs b/src/libraries/Microsoft.Extensions.Options/gen/DiagDescriptors.cs index 141fc6b9c7f9ae..49562a0a128c2b 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/DiagDescriptors.cs +++ b/src/libraries/Microsoft.Extensions.Options/gen/DiagDescriptors.cs @@ -112,5 +112,12 @@ internal sealed class DiagDescriptors : DiagDescriptorsBase messageFormat: SR.OptionsUnsupportedLanguageVersionMessage, category: Category, defaultSeverity: DiagnosticSeverity.Error); + + public static DiagnosticDescriptor IncompatibleWithTypeForValidationAttribute { get; } = Make( + id: "SYSLIB1217", + title: SR.TypeCannotBeUsedWithTheValidationAttributeTitle, + messageFormat: SR.TypeCannotBeUsedWithTheValidationAttributeMessage, + category: Category, + defaultSeverity: DiagnosticSeverity.Warning); } } diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Emitter.cs b/src/libraries/Microsoft.Extensions.Options/gen/Emitter.cs index 9e0cb659a6c94a..41609ad4b2010a 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Emitter.cs +++ b/src/libraries/Microsoft.Extensions.Options/gen/Emitter.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; +using System.Text; using System.Threading; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; @@ -18,36 +19,39 @@ namespace Microsoft.Extensions.Options.Generators internal sealed class Emitter : EmitterBase { private const string StaticFieldHolderClassesNamespace = "__OptionValidationStaticInstances"; + internal const string StaticGeneratedValidationAttributesClassesNamespace = "__OptionValidationGeneratedAttributes"; + internal const string StaticAttributeClassNamePrefix = "__SourceGen_"; + internal const string StaticGeneratedMaxLengthAttributeClassesName = "__SourceGen_MaxLengthAttribute"; private const string StaticListType = "global::System.Collections.Generic.List"; private const string StaticValidationResultType = "global::System.ComponentModel.DataAnnotations.ValidationResult"; private const string StaticValidationAttributeType = "global::System.ComponentModel.DataAnnotations.ValidationAttribute"; - + private const string StaticValidationContextType = "global::System.ComponentModel.DataAnnotations.ValidationContext"; private string _staticValidationAttributeHolderClassName = "__Attributes"; private string _staticValidatorHolderClassName = "__Validators"; private string _staticValidationAttributeHolderClassFQN; private string _staticValidatorHolderClassFQN; - private string _modifier; private string _TryGetValueNullableAnnotation; + private readonly SymbolHolder _symbolHolder; + private readonly OptionsSourceGenContext _optionsSourceGenContext; + private sealed record StaticFieldInfo(string FieldTypeFQN, int FieldOrder, string FieldName, IList InstantiationLines); - public Emitter(Compilation compilation, bool emitPreamble = true) : base(emitPreamble) + public Emitter(Compilation compilation, SymbolHolder symbolHolder, OptionsSourceGenContext optionsSourceGenContext, bool emitPreamble = true) : base(emitPreamble) { - if (((CSharpCompilation)compilation).LanguageVersion >= Microsoft.CodeAnalysis.CSharp.LanguageVersion.CSharp11) - { - _modifier = "file"; - } - else + _optionsSourceGenContext = optionsSourceGenContext; + + if (!_optionsSourceGenContext.IsLangVersion11AndAbove) { - _modifier = "internal"; - string suffix = $"_{GetNonRandomizedHashCode(compilation.SourceModule.Name):X8}"; - _staticValidationAttributeHolderClassName += suffix; - _staticValidatorHolderClassName += suffix; + _staticValidationAttributeHolderClassName += _optionsSourceGenContext.Suffix; + _staticValidatorHolderClassName += _optionsSourceGenContext.Suffix; } _staticValidationAttributeHolderClassFQN = $"global::{StaticFieldHolderClassesNamespace}.{_staticValidationAttributeHolderClassName}"; _staticValidatorHolderClassFQN = $"global::{StaticFieldHolderClassesNamespace}.{_staticValidatorHolderClassName}"; _TryGetValueNullableAnnotation = GetNullableAnnotationStringForTryValidateValueToUseInGeneratedCode(compilation); + + _symbolHolder = symbolHolder; } public string Emit( @@ -65,6 +69,7 @@ public string Emit( GenStaticClassWithStaticReadonlyFields(staticValidationAttributesDict.Values, StaticFieldHolderClassesNamespace, _staticValidationAttributeHolderClassName); GenStaticClassWithStaticReadonlyFields(staticValidatorsDict.Values, StaticFieldHolderClassesNamespace, _staticValidatorHolderClassName); + GenValidationAttributesClasses(); return Capture(); } @@ -146,7 +151,7 @@ private void GenStaticClassWithStaticReadonlyFields(IEnumerable OutOpenBrace(); OutGeneratedCodeAttribute(); - OutLn($"{_modifier} static class {className}"); + OutLn($"{_optionsSourceGenContext.ClassModifier} static class {className}"); OutOpenBrace(); var staticValidationAttributes = staticFields @@ -186,6 +191,396 @@ private void GenStaticClassWithStaticReadonlyFields(IEnumerable OutCloseBrace(); } + public void EmitMaxLengthAttribute(string modifier, string prefix, string className, string linesToInsert, string suffix) + { + OutGeneratedCodeAttribute(); + + string qualifiedClassName = $"{prefix}{suffix}_{className}"; + + OutLn($$""" +[global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + {{modifier}} class {{qualifiedClassName}} : {{StaticValidationAttributeType}} + { + private const int MaxAllowableLength = -1; + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a maximum length of '{1}'."; + public {{qualifiedClassName}}(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public {{qualifiedClassName}}(): base(() => DefaultErrorMessageString) { Length = MaxAllowableLength; } + public int Length { get; } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + public override bool IsValid(object? value) + { + if (Length == 0 || Length < -1) + { + throw new global::System.InvalidOperationException("MaxLengthAttribute must have a Length value that is greater than zero. Use MaxLength() without parameters to indicate that the string or array can have the maximum allowable length."); + } + if (value == null || MaxAllowableLength == Length) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + {{linesToInsert}}else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length <= Length; + } + } +"""); + } + + public void EmitMinLengthAttribute(string modifier, string prefix, string className, string linesToInsert, string suffix) + { + OutGeneratedCodeAttribute(); + + string qualifiedClassName = $"{prefix}{suffix}_{className}"; + + OutLn($$""" +[global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + {{modifier}} class {{qualifiedClassName}} : {{StaticValidationAttributeType}} + { + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a minimum length of '{1}'."; + + public {{qualifiedClassName}}(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public int Length { get; } + public override bool IsValid(object? value) + { + if (Length < -1) + { + throw new global::System.InvalidOperationException("MinLengthAttribute must have a Length value that is zero or greater."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + {{linesToInsert}}else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length >= Length; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + } +"""); + } + + public void EmitLengthAttribute(string modifier, string prefix, string className, string linesToInsert, string suffix) + { + OutGeneratedCodeAttribute(); + + string qualifiedClassName = $"{prefix}{suffix}_{className}"; + + OutLn($$""" +[global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + {{modifier}} class {{qualifiedClassName}} : {{StaticValidationAttributeType}} + { + private static string DefaultErrorMessageString => "The field {0} must be a string or collection type with a minimum length of '{1}' and maximum length of '{2}'."; + public {{qualifiedClassName}}(int minimumLength, int maximumLength) : base(() => DefaultErrorMessageString) { MinimumLength = minimumLength; MaximumLength = maximumLength; } + public int MinimumLength { get; } + public int MaximumLength { get; } + public override bool IsValid(object? value) + { + if (MinimumLength < 0) + { + throw new global::System.InvalidOperationException("LengthAttribute must have a MinimumLength value that is zero or greater."); + } + if (MaximumLength < MinimumLength) + { + throw new global::System.InvalidOperationException("LengthAttribute must have a MaximumLength value that is greater than or equal to MinimumLength."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + {{linesToInsert}}else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return (uint)(length - MinimumLength) <= (uint)(MaximumLength - MinimumLength); + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, MinimumLength, MaximumLength); + } +"""); + } + + public void EmitCompareAttribute(string modifier, string prefix, string className, string linesToInsert, string suffix) + { + OutGeneratedCodeAttribute(); + + string qualifiedClassName = $"{prefix}{suffix}_{className}"; + + OutLn($$""" +[global::System.AttributeUsage(global::System.AttributeTargets.Property, AllowMultiple = false)] + {{modifier}} class {{qualifiedClassName}} : {{StaticValidationAttributeType}} + { + private static string DefaultErrorMessageString => "'{0}' and '{1}' do not match."; + public {{qualifiedClassName}}(string otherProperty) : base(() => DefaultErrorMessageString) + { + if (otherProperty == null) + { + throw new global::System.ArgumentNullException(nameof(otherProperty)); + } + OtherProperty = otherProperty; + } + public string OtherProperty { get; } + public override bool RequiresValidationContext => true; + + protected override {{StaticValidationResultType}}? IsValid(object? value, {{StaticValidationContextType}} validationContext) + { + bool result = true; + + {{linesToInsert}} + if (!result) + { + string[]? memberNames = validationContext.MemberName is null ? null : new string[] { validationContext.MemberName }; + return new {{StaticValidationResultType}}(FormatErrorMessage(validationContext.DisplayName), memberNames); + } + + return null; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, OtherProperty); + } +"""); + } + + public void EmitRangeAttribute(string modifier, string prefix, string className, string suffix) + { + OutGeneratedCodeAttribute(); + + string qualifiedClassName = $"{prefix}{suffix}_{className}"; + + OutLn($$""" +[global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + {{modifier}} class {{qualifiedClassName}} : {{StaticValidationAttributeType}} + { + public {{qualifiedClassName}}(int minimum, int maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(int); + } + public {{qualifiedClassName}}(double minimum, double maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(double); + } + public {{qualifiedClassName}}(global::System.Type type, string minimum, string maximum) : base() + { + OperandType = type; + NeedToConvertMinMax = true; + Minimum = minimum; + Maximum = maximum; + } + public object Minimum { get; private set; } + public object Maximum { get; private set; } + public bool MinimumIsExclusive { get; set; } + public bool MaximumIsExclusive { get; set; } + public global::System.Type OperandType { get; } + public bool ParseLimitsInInvariantCulture { get; set; } + public bool ConvertValueInInvariantCulture { get; set; } + public override string FormatErrorMessage(string name) => + string.Format(global::System.Globalization.CultureInfo.CurrentCulture, GetValidationErrorMessage(), name, Minimum, Maximum); + private bool NeedToConvertMinMax { get; } + private bool Initialized { get; set; } + public override bool IsValid(object? value) + { + if (!Initialized) + { + if (Minimum is null || Maximum is null) + { + throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + if (NeedToConvertMinMax) + { + System.Globalization.CultureInfo culture = ParseLimitsInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + Minimum = ConvertValue(Minimum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + Maximum = ConvertValue(Maximum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + int cmp = ((global::System.IComparable)Minimum).CompareTo((global::System.IComparable)Maximum); + if (cmp > 0) + { + throw new global::System.InvalidOperationException("The maximum value '{Maximum}' must be greater than or equal to the minimum value '{Minimum}'."); + } + else if (cmp == 0 && (MinimumIsExclusive || MaximumIsExclusive)) + { + throw new global::System.InvalidOperationException("Cannot use exclusive bounds when the maximum value is equal to the minimum value."); + } + Initialized = true; + } + + if (value is null or string { Length: 0 }) + { + return true; + } + + System.Globalization.CultureInfo formatProvider = ConvertValueInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + object? convertedValue; + + try + { + convertedValue = ConvertValue(value, formatProvider); + } + catch (global::System.Exception e) when (e is global::System.FormatException or global::System.InvalidCastException or global::System.NotSupportedException) + { + return false; + } + + var min = (global::System.IComparable)Minimum; + var max = (global::System.IComparable)Maximum; + + return + (MinimumIsExclusive ? min.CompareTo(convertedValue) < 0 : min.CompareTo(convertedValue) <= 0) && + (MaximumIsExclusive ? max.CompareTo(convertedValue) > 0 : max.CompareTo(convertedValue) >= 0); + } + private string GetValidationErrorMessage() + { + return (MinimumIsExclusive, MaximumIsExclusive) switch + { + (false, false) => "The field {0} must be between {1} and {2}.", + (true, false) => "The field {0} must be between {1} exclusive and {2}.", + (false, true) => "The field {0} must be between {1} and {2} exclusive.", + (true, true) => "The field {0} must be between {1} exclusive and {2} exclusive.", + }; + } + private object? ConvertValue(object? value, System.Globalization.CultureInfo formatProvider) + { + if (value is string stringValue) + { + value = global::System.Convert.ChangeType(stringValue, OperandType, formatProvider); + } + else + { + value = global::System.Convert.ChangeType(value, OperandType, formatProvider); + } + return value; + } + } +"""); + } + + private string GenerateStronglyTypedCodeForLengthAttributes(HashSet data) + { + if (data.Count == 0) + { + return string.Empty; + } + + StringBuilder sb = new(); + string padding = GetPaddingString(3); + + foreach (var type in data) + { + string typeName = (string)type; + sb.AppendLine($"else if (value is {typeName})"); + sb.AppendLine($"{padding}{{"); + sb.AppendLine($"{padding} length = (({typeName})value).Count;"); + sb.AppendLine($"{padding}}}"); + sb.Append($"{padding}"); + } + + return sb.ToString(); + } + + private string GenerateStronglyTypedCodeForCompareAttribute(HashSet? data) + { + if (data is null || data.Count == 0) + { + return string.Empty; + } + + StringBuilder sb = new(); + string padding = GetPaddingString(3); + bool first = true; + + foreach (var obj in data) + { + (string type, string property) = ((string, string))obj; + sb.Append(first ? $"if " : $"{padding}else if "); + sb.AppendLine($"(validationContext.ObjectInstance is {type} && OtherProperty == \"{property}\")"); + sb.AppendLine($"{padding}{{"); + sb.AppendLine($"{padding} result = Equals(value, (({type})validationContext.ObjectInstance).{property});"); + sb.AppendLine($"{padding}}}"); + first = false; + } + + return sb.ToString(); + } + + private void GenValidationAttributesClasses() + { + if (_optionsSourceGenContext.AttributesToGenerate.Count == 0) + { + return; + } + + var attributesData = _optionsSourceGenContext.AttributesToGenerate.OrderBy(static kvp => kvp.Key, StringComparer.Ordinal).ToArray(); + + OutLn($"namespace {StaticGeneratedValidationAttributesClassesNamespace}"); + OutOpenBrace(); + + foreach (var attributeData in attributesData) + { + if (attributeData.Key == _symbolHolder.MaxLengthAttributeSymbol.Name) + { + string linesToInsert = attributeData.Value is not null ? GenerateStronglyTypedCodeForLengthAttributes((HashSet)attributeData.Value) : string.Empty; + EmitMaxLengthAttribute(_optionsSourceGenContext.ClassModifier, Emitter.StaticAttributeClassNamePrefix, attributeData.Key, linesToInsert, _optionsSourceGenContext.Suffix); + } + else if (attributeData.Key == _symbolHolder.MinLengthAttributeSymbol.Name) + { + string linesToInsert = attributeData.Value is not null ? GenerateStronglyTypedCodeForLengthAttributes((HashSet)attributeData.Value) : string.Empty; + EmitMinLengthAttribute(_optionsSourceGenContext.ClassModifier, Emitter.StaticAttributeClassNamePrefix, attributeData.Key, linesToInsert, _optionsSourceGenContext.Suffix); + } + else if (_symbolHolder.LengthAttributeSymbol is not null && attributeData.Key == _symbolHolder.LengthAttributeSymbol.Name) + { + string linesToInsert = attributeData.Value is not null ? GenerateStronglyTypedCodeForLengthAttributes((HashSet)attributeData.Value) : string.Empty; + EmitLengthAttribute(_optionsSourceGenContext.ClassModifier, Emitter.StaticAttributeClassNamePrefix, attributeData.Key, linesToInsert, _optionsSourceGenContext.Suffix); + } + else if (attributeData.Key == _symbolHolder.CompareAttributeSymbol.Name && attributeData.Value is not null) + { + string linesToInsert = GenerateStronglyTypedCodeForCompareAttribute((HashSet)attributeData.Value); + EmitCompareAttribute(_optionsSourceGenContext.ClassModifier, Emitter.StaticAttributeClassNamePrefix, attributeData.Key, linesToInsert: linesToInsert, _optionsSourceGenContext.Suffix); + } + else if (attributeData.Key == _symbolHolder.RangeAttributeSymbol.Name) + { + EmitRangeAttribute(_optionsSourceGenContext.ClassModifier, Emitter.StaticAttributeClassNamePrefix, attributeData.Key, _optionsSourceGenContext.Suffix); + } + } + + OutCloseBrace(); + } + private void GenModelSelfValidationIfNecessary(ValidatedModel modelToValidate) { if (modelToValidate.SelfValidates) @@ -209,10 +604,18 @@ private void GenModelValidationMethod( OutLn($"/// Validation result."); OutGeneratedCodeAttribute(); + if (_symbolHolder.UnconditionalSuppressMessageAttributeSymbol is not null) + { + // We disable the warning on `new ValidationContext(object)` usage as we use it in a safe way that not require executing the reflection code. + // This is done by initializing the DisplayName in the context which is the part trigger reflection if it is not initialized. + OutLn($"[System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage(\"Trimming\", \"IL2026:RequiresUnreferencedCode\","); + OutLn($" Justification = \"The created ValidationContext object is used in a way that never call reflection\")]"); + } + OutLn($"public {(makeStatic ? "static " : string.Empty)}global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, {modelToValidate.Name} options)"); OutOpenBrace(); OutLn($"global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null;"); - OutLn($"var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options);"); + OutLn($"var context = new {StaticValidationContextType}(options);"); int capacity = modelToValidate.MembersToValidate.Max(static vm => vm.ValidationAttributes.Count); if (capacity > 0) @@ -438,19 +841,5 @@ private StaticFieldInfo GetOrAddStaticValidator(ref Dictionary - /// Returns a non-randomized hash code for the given string. - /// We always return a positive value. - /// - internal static int GetNonRandomizedHashCode(string s) - { - uint result = 2166136261u; - foreach (char c in s) - { - result = (c ^ result) * 16777619; - } - return Math.Abs((int)result); - } } } diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Generator.cs b/src/libraries/Microsoft.Extensions.Options/gen/Generator.cs index 34533fc0a96b05..02fb554ca0b487 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Generator.cs +++ b/src/libraries/Microsoft.Extensions.Options/gen/Generator.cs @@ -32,18 +32,25 @@ public void Initialize(IncrementalGeneratorInitializationContext context) private static void HandleAnnotatedTypes(Compilation compilation, ImmutableArray<(TypeDeclarationSyntax? TypeSyntax, SemanticModel SemanticModel)> types, SourceProductionContext context) { + if (types.Length == 0) + { + return; + } + if (!SymbolLoader.TryLoad(compilation, out var symbolHolder)) { // Not eligible compilation return; } - var parser = new Parser(compilation, context.ReportDiagnostic, symbolHolder!, context.CancellationToken); + OptionsSourceGenContext optionsSourceGenContext = new(compilation); + + var parser = new Parser(compilation, context.ReportDiagnostic, symbolHolder!, optionsSourceGenContext, context.CancellationToken); var validatorTypes = parser.GetValidatorTypes(types); if (validatorTypes.Count > 0) { - var emitter = new Emitter(compilation); + var emitter = new Emitter(compilation, symbolHolder!, optionsSourceGenContext); var result = emitter.Emit(validatorTypes, context.CancellationToken); context.AddSource("Validators.g.cs", SourceText.From(result, Encoding.UTF8)); diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Microsoft.Extensions.Options.SourceGeneration.csproj b/src/libraries/Microsoft.Extensions.Options/gen/Microsoft.Extensions.Options.SourceGeneration.csproj index 5571341b06060e..f5bad279371755 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Microsoft.Extensions.Options.SourceGeneration.csproj +++ b/src/libraries/Microsoft.Extensions.Options/gen/Microsoft.Extensions.Options.SourceGeneration.csproj @@ -30,6 +30,7 @@ + diff --git a/src/libraries/Microsoft.Extensions.Options/gen/OptionsSourceGenContext.cs b/src/libraries/Microsoft.Extensions.Options/gen/OptionsSourceGenContext.cs new file mode 100644 index 00000000000000..8da3e317769627 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Options/gen/OptionsSourceGenContext.cs @@ -0,0 +1,83 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Runtime.Versioning; + +namespace Microsoft.Extensions.Options.Generators +{ + internal sealed class OptionsSourceGenContext + { + public OptionsSourceGenContext(Compilation compilation) + { + IsLangVersion11AndAbove = ((CSharpCompilation)compilation).LanguageVersion >= Microsoft.CodeAnalysis.CSharp.LanguageVersion.CSharp11; + ClassModifier = IsLangVersion11AndAbove ? "file" : "internal"; + Suffix = IsLangVersion11AndAbove ? "" : $"_{GetNonRandomizedHashCode(compilation.SourceModule.Name):X8}"; + } + + internal string Suffix { get; } + internal string ClassModifier { get; } + internal bool IsLangVersion11AndAbove { get; } + internal Dictionary?> AttributesToGenerate { get; set; } = new Dictionary?>(); + + internal void EnsureTrackingAttribute(string attributeName, bool createValue, out HashSet? value) + { + bool exist = AttributesToGenerate.TryGetValue(attributeName, out value); + if (value is null) + { + if (createValue) + { + value = new HashSet(); + } + + if (!exist || createValue) + { + AttributesToGenerate[attributeName] = value; + } + } + } + + internal static bool IsConvertibleBasicType(ITypeSymbol typeSymbol) + { + return typeSymbol.SpecialType switch + { + SpecialType.System_Boolean => true, + SpecialType.System_Byte => true, + SpecialType.System_Char => true, + SpecialType.System_DateTime => true, + SpecialType.System_Decimal => true, + SpecialType.System_Double => true, + SpecialType.System_Int16 => true, + SpecialType.System_Int32 => true, + SpecialType.System_Int64 => true, + SpecialType.System_SByte => true, + SpecialType.System_Single => true, + SpecialType.System_UInt16 => true, + SpecialType.System_UInt32 => true, + SpecialType.System_UInt64 => true, + SpecialType.System_String => true, + _ => false, + }; + } + + /// + /// Returns a non-randomized hash code for the given string. + /// We always return a positive value. + /// + internal static int GetNonRandomizedHashCode(string s) + { + uint result = 2166136261u; + foreach (char c in s) + { + result = (c ^ result) * 16777619; + } + + return Math.Abs((int)result); + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Parser.cs b/src/libraries/Microsoft.Extensions.Options/gen/Parser.cs index 010b89562a9179..47cb71c3411cde 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Parser.cs +++ b/src/libraries/Microsoft.Extensions.Options/gen/Parser.cs @@ -25,6 +25,7 @@ internal sealed class Parser private readonly Compilation _compilation; private readonly Action _reportDiagnostic; private readonly SymbolHolder _symbolHolder; + private readonly OptionsSourceGenContext _optionsSourceGenContext; private readonly Dictionary _synthesizedValidators = new(SymbolEqualityComparer.Default); private readonly HashSet _visitedModelTypes = new(SymbolEqualityComparer.Default); @@ -32,12 +33,14 @@ public Parser( Compilation compilation, Action reportDiagnostic, SymbolHolder symbolHolder, + OptionsSourceGenContext optionsSourceGenContext, CancellationToken cancellationToken) { _compilation = compilation; _cancellationToken = cancellationToken; _reportDiagnostic = reportDiagnostic; _symbolHolder = symbolHolder; + _optionsSourceGenContext = optionsSourceGenContext; } public IReadOnlyList GetValidatorTypes(IEnumerable<(TypeDeclarationSyntax TypeSyntax, SemanticModel SemanticModel)> classes) @@ -288,7 +291,7 @@ private List GetMembersToValidate(ITypeSymbol modelType, bool s ? memberLocation : lowerLocationInCompilation; - var memberInfo = GetMemberInfo(member, speculate, location, validatorType); + var memberInfo = GetMemberInfo(member, speculate, location, modelType, validatorType); if (memberInfo is not null) { if (member.DeclaredAccessibility != Accessibility.Public) @@ -304,7 +307,7 @@ private List GetMembersToValidate(ITypeSymbol modelType, bool s return membersToValidate; } - private ValidatedMember? GetMemberInfo(ISymbol member, bool speculate, Location location, ITypeSymbol validatorType) + private ValidatedMember? GetMemberInfo(ISymbol member, bool speculate, Location location, ITypeSymbol modelType, ITypeSymbol validatorType) { ITypeSymbol memberType; switch (member) @@ -325,7 +328,7 @@ private List GetMembersToValidate(ITypeSymbol modelType, bool s break; */ default: - // we only care about properties and fields + // we only care about properties return null; } @@ -467,7 +470,26 @@ private List GetMembersToValidate(ITypeSymbol modelType, bool s continue; } - var validationAttr = new ValidationAttributeInfo(attributeType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + string attributeFullQualifiedName = attributeType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + if (SymbolEqualityComparer.Default.Equals(attributeType, _symbolHolder.MaxLengthAttributeSymbol) || + SymbolEqualityComparer.Default.Equals(attributeType, _symbolHolder.MinLengthAttributeSymbol) || + (_symbolHolder.LengthAttributeSymbol is not null && SymbolEqualityComparer.Default.Equals(attributeType, _symbolHolder.LengthAttributeSymbol))) + { + if (!LengthBasedAttributeIsTrackedForSubstitution(memberType, location, attributeType, ref attributeFullQualifiedName)) + { + continue; + } + } + else if (SymbolEqualityComparer.Default.Equals(attributeType, _symbolHolder.CompareAttributeSymbol)) + { + TrackCompareAttributeForSubstitution(attribute, modelType, ref attributeFullQualifiedName); + } + else if (SymbolEqualityComparer.Default.Equals(attributeType, _symbolHolder.RangeAttributeSymbol)) + { + TrackRangeAttributeForSubstitution(attribute, memberType, ref attributeFullQualifiedName); + } + + var validationAttr = new ValidationAttributeInfo(attributeFullQualifiedName); validationAttrs.Add(validationAttr); ImmutableArray parameters = attribute.AttributeConstructor?.Parameters ?? ImmutableArray.Empty; @@ -567,6 +589,79 @@ private List GetMembersToValidate(ITypeSymbol modelType, bool s return null; } + private bool LengthBasedAttributeIsTrackedForSubstitution(ITypeSymbol memberType, Location location, ITypeSymbol attributeType, ref string attributeFullQualifiedName) + { + if (memberType.SpecialType == SpecialType.System_String || ConvertTo(memberType, _symbolHolder.ICollectionSymbol)) + { + _optionsSourceGenContext.EnsureTrackingAttribute(attributeType.Name, createValue: false, out _); + } + else if (ParserUtilities.TypeHasProperty(memberType, "Count", SpecialType.System_Int32)) + { + _optionsSourceGenContext.EnsureTrackingAttribute(attributeType.Name, createValue: true, out HashSet? trackedTypeList); + trackedTypeList!.Add(memberType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + } + else + { + Diag(DiagDescriptors.IncompatibleWithTypeForValidationAttribute, location, attributeType.Name, memberType.Name); + return false; + } + + attributeFullQualifiedName = $"{Emitter.StaticGeneratedValidationAttributesClassesNamespace}.{Emitter.StaticAttributeClassNamePrefix}{_optionsSourceGenContext.Suffix}_{attributeType.Name}"; + return true; + } + + private void TrackCompareAttributeForSubstitution(AttributeData attribute, ITypeSymbol modelType, ref string attributeFullQualifiedName) + { + ImmutableArray constructorParameters = attribute.AttributeConstructor?.Parameters ?? ImmutableArray.Empty; + if (constructorParameters.Length == 1 && constructorParameters[0].Name == "otherProperty" && constructorParameters[0].Type.SpecialType == SpecialType.System_String) + { + _optionsSourceGenContext.EnsureTrackingAttribute(attribute.AttributeClass!.Name, createValue: true, out HashSet? trackedTypeList); + trackedTypeList!.Add((modelType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), (string)attribute.ConstructorArguments[0].Value!)); + attributeFullQualifiedName = $"{Emitter.StaticGeneratedValidationAttributesClassesNamespace}.{Emitter.StaticAttributeClassNamePrefix}{_optionsSourceGenContext.Suffix}_{attribute.AttributeClass!.Name}"; + } + } + + private void TrackRangeAttributeForSubstitution(AttributeData attribute, ITypeSymbol memberType, ref string attributeFullQualifiedName) + { + ImmutableArray constructorParameters = attribute.AttributeConstructor?.Parameters ?? ImmutableArray.Empty; + SpecialType argumentSpecialType = SpecialType.None; + if (constructorParameters.Length == 2) + { + argumentSpecialType = constructorParameters[0].Type.SpecialType; + } + else if (constructorParameters.Length == 3) + { + object? argumentValue = null; + for (int i = 0; i < constructorParameters.Length; i++) + { + if (constructorParameters[i].Name == "type") + { + argumentValue = attribute.ConstructorArguments[i].Value; + break; + } + } + + if (argumentValue is INamedTypeSymbol namedTypeSymbol && OptionsSourceGenContext.IsConvertibleBasicType(namedTypeSymbol)) + { + argumentSpecialType = namedTypeSymbol.SpecialType; + } + } + + ITypeSymbol typeSymbol = memberType; + if (typeSymbol.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T) + { + typeSymbol = ((INamedTypeSymbol)typeSymbol).TypeArguments[0]; + } + + if (argumentSpecialType != SpecialType.None && + OptionsSourceGenContext.IsConvertibleBasicType(typeSymbol) && + (constructorParameters.Length != 3 || typeSymbol.SpecialType == argumentSpecialType)) // When type is provided as a parameter, it has to match the property type. + { + _optionsSourceGenContext.EnsureTrackingAttribute(attribute.AttributeClass!.Name, createValue: false, out _); + attributeFullQualifiedName = $"{Emitter.StaticGeneratedValidationAttributesClassesNamespace}.{Emitter.StaticAttributeClassNamePrefix}{_optionsSourceGenContext.Suffix}_{attribute.AttributeClass!.Name}"; + } + } + private string? AddSynthesizedValidator(ITypeSymbol modelType, ISymbol member, Location location, ITypeSymbol validatorType) { var mt = modelType.WithNullableAnnotation(NullableAnnotation.None); diff --git a/src/libraries/Microsoft.Extensions.Options/gen/ParserUtilities.cs b/src/libraries/Microsoft.Extensions.Options/gen/ParserUtilities.cs index d79ad4cccb653d..0b63cc90c800ed 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/ParserUtilities.cs +++ b/src/libraries/Microsoft.Extensions.Options/gen/ParserUtilities.cs @@ -68,6 +68,41 @@ internal static bool ImplementsInterface(this ITypeSymbol type, ITypeSymbol inte return false; } + internal static bool TypeHasProperty(ITypeSymbol typeSymbol, string propertyName, SpecialType returnType) + { + ITypeSymbol? type = typeSymbol; + do + { + if (type.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T) + { + type = ((INamedTypeSymbol)type).TypeArguments[0]; // extract the T from a Nullable + } + + if (type.GetMembers(propertyName).OfType().Any(property => + property.Type.SpecialType == returnType && property.DeclaredAccessibility == Accessibility.Public && + property.Kind == SymbolKind.Property && !property.IsStatic && property.GetMethod != null && property.Parameters.IsEmpty)) + { + return true; + } + + type = type.BaseType; + } while (type is not null && type.SpecialType != SpecialType.System_Object); + + // When we have an interface type, we need to check all the interfaces that it extends. + // Like IList extends ICollection where the property we're looking for is defined. + foreach (var interfaceType in typeSymbol.AllInterfaces) + { + if (interfaceType.GetMembers(propertyName).OfType().Any(property => + property.Type.SpecialType == returnType && property.Kind == SymbolKind.Property && + !property.IsStatic && property.GetMethod != null && property.Parameters.IsEmpty)) + { + return true; + } + } + + return false; + } + // Check if parameter has either simplified (i.e. "int?") or explicit (Nullable) nullable type declaration: internal static bool IsNullableOfT(this ITypeSymbol type) => type.SpecialType == SpecialType.System_Nullable_T || type.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T; diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/Strings.resx b/src/libraries/Microsoft.Extensions.Options/gen/Resources/Strings.resx index 6293431eb7f90e..7100030eecf132 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/Strings.resx +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/Strings.resx @@ -213,4 +213,10 @@ The options validation source generator is not available in C# {0}. Please use language version {1} or greater. + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.cs.xlf b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.cs.xlf index 953376c434d310..c8490d951761cb 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.cs.xlf +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.cs.xlf @@ -152,6 +152,16 @@ U člena potenciálně chybí přenositelné ověření. + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + Ověřovací atribut {0} by měl být použit pouze na vlastnosti typu string, array nebo ICollection. Použití s typem {1} může vést k selháním modulu runtime. + + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + Ověřovací atribut se vztahuje pouze na vlastnosti typu string, array nebo ICollection; nelze použít s jinými typy. + + Validator type {0} doesn't have a parameterless constructor. Typ validátoru {0} nemá konstruktor bez parametrů. diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.de.xlf b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.de.xlf index bbd4d29cfc247e..eb7a7f423a3e70 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.de.xlf +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.de.xlf @@ -152,6 +152,16 @@ Dem Member fehlt möglicherweise die transitive Validierung. + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + Das Validierungsattribut {0} sollte nur auf Eigenschaften vom Typ "string", "array" oder "ICollection" angewendet werden. Die Verwendung mit dem Typ {1} kann zu Laufzeitfehlern führen. + + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + Das Validierungsattribut gilt nur für Eigenschaften vom Typ "string", "array" oder "ICollection"; es kann nicht mit anderen Typen verwendet werden. + + Validator type {0} doesn't have a parameterless constructor. Der Validierungssteuerelementtyp "{0}" hat keinen parameterlosen Konstruktor. diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.es.xlf b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.es.xlf index 0c4661573bbd40..19c69bb8c864f9 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.es.xlf +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.es.xlf @@ -152,6 +152,16 @@ Posiblemente falta la validación transitiva en el miembro. + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + El atributo de validación {0} solo se debe aplicar a propiedades de tipo cadena, matriz o ICollection. Si la usa con el tipo {1}, podrían producirse errores en tiempo de ejecución. + + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + El atributo de validación solo es aplicable a propiedades de tipo cadena, matriz o ICollection; no se puede usar con otros tipos. + + Validator type {0} doesn't have a parameterless constructor. El tipo de validador {0} no tiene un constructor sin parámetros. diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.fr.xlf b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.fr.xlf index e1996eeca163cd..651ab8b303719a 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.fr.xlf +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.fr.xlf @@ -152,6 +152,16 @@ Le membre n’a peut-être pas de validation transitive. + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + L’attribut de validation {0} doit uniquement être appliqué aux propriétés de type chaîne, tableau ou ICollection. Son utilisation avec le type {1} peut entraîner des échecs d’exécution. + + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + L’attribut de validation s’applique uniquement aux propriétés de type chaîne, tableau ou ICollection ; il ne peut pas être utilisé avec d’autres types. + + Validator type {0} doesn't have a parameterless constructor. Le type de validateur {0} n’a pas de constructeur sans paramètre. diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.it.xlf b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.it.xlf index edcf16b967dd32..575f60469e48dc 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.it.xlf +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.it.xlf @@ -152,6 +152,16 @@ Il membro potrebbe non avere una convalida transitiva. + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + L'attributo {0} di convalida deve essere applicato solo alle proprietà di tipo stringa, matrice o ICollection. L'uso con il tipo {1} potrebbe causare errori di runtime. + + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + L'attributo di convalida è applicabile solo alle proprietà di tipo stringa, matrice o ICollection; non può essere usato con altri tipi. + + Validator type {0} doesn't have a parameterless constructor. Il tipo di convalida {0} non dispone di un costruttore senza parametri. diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.ja.xlf b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.ja.xlf index 89b2f23777fb2c..8251ac27c6da15 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.ja.xlf +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.ja.xlf @@ -152,6 +152,16 @@ メンバーに推移性の検証がない可能性があります。 + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + 検証属性 {0} は、型文字列、配列、または ICollection のプロパティにしか適用できません。型 {1} と共に使用すると、ランタイム エラーが発生する可能性があります。 + + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + 検証属性は、型文字列、配列、または ICollection のプロパティにのみ適用でき、他の型では使用できません。 + + Validator type {0} doesn't have a parameterless constructor. バリデーター型 {0} にパラメーターなしのコンストラクターがありません。 diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.ko.xlf b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.ko.xlf index 817bc64eba8ac4..4c8ce2865d1070 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.ko.xlf +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.ko.xlf @@ -152,6 +152,16 @@ 멤버에 전이적 유효성 검사가 누락되었을 수 있습니다. + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + 유효성 검사 특성 {0}은(는) 문자열, 배열 또는 ICollection 형식의 속성에만 적용해야 합니다. {1} 형식과 함께 사용하면 런타임 오류가 발생할 수 있습니다. + + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + 유효성 검사 특성은 문자열, 배열 또는 ICollection 형식의 속성에만 적용할 수 있습니다. 다른 형식과 함께 사용할 수 없습니다. + + Validator type {0} doesn't have a parameterless constructor. 유효성 검사기 형식 {0}은(는) 매개 변수가 없는 생성자가 없습니다. diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.pl.xlf b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.pl.xlf index 190ca17b561792..e07e568d131874 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.pl.xlf +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.pl.xlf @@ -152,6 +152,16 @@ W przypadku elementu członkowskiego może potencjalnie brakować weryfikacji przechodniej. + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + Atrybut sprawdzania poprawności {0} powinien być stosowany tylko do właściwości typu ciąg, tablica lub ICollection. Użycie go z typem {1} może prowadzić do błędów środowiska uruchomieniowego. + + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + Atrybut sprawdzania poprawności ma zastosowanie tylko do właściwości typu ciąg, tablica lub ICollection; nie można go używać z innymi typami. + + Validator type {0} doesn't have a parameterless constructor. Typ modułu sprawdzania poprawności {0} nie ma konstruktora bez parametrów. diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.pt-BR.xlf b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.pt-BR.xlf index 24a4203391c016..d422309a40995b 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.pt-BR.xlf +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.pt-BR.xlf @@ -152,6 +152,16 @@ Membro potencialmente ausente na validação transitiva. + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + O atributo de validação {0} deve ser aplicado somente às propriedades do tipo cadeia de caracteres, matriz ou ICollection. Usá-lo com o tipo {1} pode levar a falhas de runtime. + + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + O atributo de validação só é aplicável às propriedades do tipo cadeia de caracteres, matriz ou ICollection; não pode ser usado com outros tipos. + + Validator type {0} doesn't have a parameterless constructor. O tipo de validador {0} não tem um construtor sem parâmetros. diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.ru.xlf b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.ru.xlf index f14d71833a2355..bd3a0be55eb270 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.ru.xlf +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.ru.xlf @@ -152,6 +152,16 @@ Возможно, в элементе отсутствует транзитивная проверка. + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + Атрибут проверки {0} следует применять только к свойствам строки типа, массива или ICollection. Использование его с типом {1} может привести к сбоям во время выполнения. + + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + Атрибут проверки применим только к свойствам строки типа, массива или ICollection; его нельзя использовать с другими типами. + + Validator type {0} doesn't have a parameterless constructor. Тип проверяющего элемента управления {0} не имеет конструктора без параметров. diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.tr.xlf b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.tr.xlf index 79e094b36b3122..e478a77cd3c238 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.tr.xlf +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.tr.xlf @@ -152,6 +152,16 @@ Üyede geçişli doğrulama eksik olabilir. + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + Doğrulama özniteliği {0} yalnızca string, dizi veya ICollection türündeki özelliklere uygulanmalıdır. {1} türüyle kullanılması çalışma zamanı hatalarına neden olabilir. + + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + Doğrulama özniteliği yalnızca string, dizi veya ICollection türündeki özelliklere uygulanabilir; diğer türlerle kullanılamaz. + + Validator type {0} doesn't have a parameterless constructor. {0} doğrulayıcı türü parametresiz bir oluşturucuya sahip değil. diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.zh-Hans.xlf b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.zh-Hans.xlf index 7688e0e628811f..76a26db32a673c 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.zh-Hans.xlf +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.zh-Hans.xlf @@ -152,6 +152,16 @@ 成员可能缺少可传递验证。 + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + 验证特性 {0} 只能应用于字符串、数组或 ICollection 类型的属性。将它与 {1} 类型一起使用可能会导致运行时故障。 + + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + 验证特性仅适用于字符串、数组或 ICollection 类型的属性;它不能与其他类型一起使用。 + + Validator type {0} doesn't have a parameterless constructor. 验证程序类型 {0} 没有无参数构造函数。 diff --git a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.zh-Hant.xlf b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.zh-Hant.xlf index 663bff9c183861..9997092e3be6f3 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.zh-Hant.xlf +++ b/src/libraries/Microsoft.Extensions.Options/gen/Resources/xlf/Strings.zh-Hant.xlf @@ -152,6 +152,16 @@ 成員可能遺漏轉移的驗證。 + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + 驗證屬性 {0} 只能套用至類型字串、陣列或 ICollection 的屬性。搭配 {1} 類型使用可能會導致執行階段失敗。 + + + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + 驗證屬性只適用於類型字串、陣列或 ICollection 的屬性;無法與其他類型搭配使用。 + + Validator type {0} doesn't have a parameterless constructor. 驗證程式類型 {0} 沒有無參數建構函式。 diff --git a/src/libraries/Microsoft.Extensions.Options/gen/SymbolHolder.cs b/src/libraries/Microsoft.Extensions.Options/gen/SymbolHolder.cs index 55d382e4036219..3447a07d398305 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/SymbolHolder.cs +++ b/src/libraries/Microsoft.Extensions.Options/gen/SymbolHolder.cs @@ -11,6 +11,13 @@ namespace Microsoft.Extensions.Options.Generators internal sealed record class SymbolHolder( INamedTypeSymbol OptionsValidatorSymbol, INamedTypeSymbol ValidationAttributeSymbol, + INamedTypeSymbol MaxLengthAttributeSymbol, + INamedTypeSymbol MinLengthAttributeSymbol, + INamedTypeSymbol CompareAttributeSymbol, + INamedTypeSymbol? LengthAttributeSymbol, + INamedTypeSymbol? UnconditionalSuppressMessageAttributeSymbol, + INamedTypeSymbol RangeAttributeSymbol, + INamedTypeSymbol ICollectionSymbol, INamedTypeSymbol DataTypeAttributeSymbol, INamedTypeSymbol ValidateOptionsSymbol, INamedTypeSymbol IValidatableObjectSymbol, diff --git a/src/libraries/Microsoft.Extensions.Options/gen/SymbolLoader.cs b/src/libraries/Microsoft.Extensions.Options/gen/SymbolLoader.cs index 94035cedacbf98..ea556228929756 100644 --- a/src/libraries/Microsoft.Extensions.Options/gen/SymbolLoader.cs +++ b/src/libraries/Microsoft.Extensions.Options/gen/SymbolLoader.cs @@ -9,6 +9,12 @@ internal static class SymbolLoader { public const string OptionsValidatorAttribute = "Microsoft.Extensions.Options.OptionsValidatorAttribute"; internal const string ValidationAttribute = "System.ComponentModel.DataAnnotations.ValidationAttribute"; + internal const string MaxLengthAttribute = "System.ComponentModel.DataAnnotations.MaxLengthAttribute"; + internal const string MinLengthAttribute = "System.ComponentModel.DataAnnotations.MinLengthAttribute"; + internal const string CompareAttribute = "System.ComponentModel.DataAnnotations.CompareAttribute"; + internal const string LengthAttribute = "System.ComponentModel.DataAnnotations.LengthAttribute"; + internal const string RangeAttribute = "System.ComponentModel.DataAnnotations.RangeAttribute"; + internal const string ICollectionType = "System.Collections.ICollection"; internal const string DataTypeAttribute = "System.ComponentModel.DataAnnotations.DataTypeAttribute"; internal const string IValidatableObjectType = "System.ComponentModel.DataAnnotations.IValidatableObject"; internal const string IValidateOptionsType = "Microsoft.Extensions.Options.IValidateOptions`1"; @@ -16,6 +22,7 @@ internal static class SymbolLoader internal const string ValidateObjectMembersAttribute = "Microsoft.Extensions.Options.ValidateObjectMembersAttribute"; internal const string ValidateEnumeratedItemsAttribute = "Microsoft.Extensions.Options.ValidateEnumeratedItemsAttribute"; internal const string GenericIEnumerableType = "System.Collections.Generic.IEnumerable`1"; + internal const string UnconditionalSuppressMessageAttributeType = "System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessageAttribute"; public static bool TryLoad(Compilation compilation, out SymbolHolder? symbolHolder) { @@ -24,6 +31,12 @@ public static bool TryLoad(Compilation compilation, out SymbolHolder? symbolHold // required var optionsValidatorSymbol = GetSymbol(OptionsValidatorAttribute); var validationAttributeSymbol = GetSymbol(ValidationAttribute); + var maxLengthAttributeSymbol = GetSymbol(MaxLengthAttribute); + var minLengthAttributeSymbol = GetSymbol(MinLengthAttribute); + var compareAttributeSymbol = GetSymbol(CompareAttribute); + var lengthAttributeSymbol = GetSymbol(LengthAttribute); + var rangeAttributeSymbol = GetSymbol(RangeAttribute); + var iCollectionSymbol = GetSymbol(ICollectionType); var dataTypeAttributeSymbol = GetSymbol(DataTypeAttribute); var ivalidatableObjectSymbol = GetSymbol(IValidatableObjectType); var validateOptionsSymbol = GetSymbol(IValidateOptionsType); @@ -31,10 +44,27 @@ public static bool TryLoad(Compilation compilation, out SymbolHolder? symbolHold var typeSymbol = GetSymbol(TypeOfType); var validateObjectMembersAttribute = GetSymbol(ValidateObjectMembersAttribute); var validateEnumeratedItemsAttribute = GetSymbol(ValidateEnumeratedItemsAttribute); + var unconditionalSuppressMessageAttributeSymbol = GetSymbol(UnconditionalSuppressMessageAttributeType); + if (unconditionalSuppressMessageAttributeSymbol is not null) + { + var containingAssemblyName = unconditionalSuppressMessageAttributeSymbol.ContainingAssembly.Identity.Name; + if (!containingAssemblyName.Equals("System.Private.CoreLib", System.StringComparison.OrdinalIgnoreCase) && + !containingAssemblyName.Equals("System.Runtime", System.StringComparison.OrdinalIgnoreCase)) + { + // The compilation returns UnconditionalSuppressMessageAttribute symbol even if the attribute is not available like the case when running on .NET Framework. + // We need to make sure that the attribute is really available by checking the containing assembly which in .NET Core will be either System.Private.CoreLib or System.Runtime. + unconditionalSuppressMessageAttributeSymbol = null; + } + } #pragma warning disable S1067 // Expressions should not be too complex if (optionsValidatorSymbol == null || validationAttributeSymbol == null || + maxLengthAttributeSymbol == null || + minLengthAttributeSymbol == null || + compareAttributeSymbol == null || + rangeAttributeSymbol == null || + iCollectionSymbol == null || dataTypeAttributeSymbol == null || ivalidatableObjectSymbol == null || validateOptionsSymbol == null || @@ -51,6 +81,13 @@ public static bool TryLoad(Compilation compilation, out SymbolHolder? symbolHold symbolHolder = new( optionsValidatorSymbol, validationAttributeSymbol, + maxLengthAttributeSymbol, + minLengthAttributeSymbol, + compareAttributeSymbol, + lengthAttributeSymbol, + unconditionalSuppressMessageAttributeSymbol, + rangeAttributeSymbol, + iCollectionSymbol, dataTypeAttributeSymbol, validateOptionsSymbol, ivalidatableObjectSymbol, diff --git a/src/libraries/Microsoft.Extensions.Options/src/Microsoft.Extensions.Options.csproj b/src/libraries/Microsoft.Extensions.Options/src/Microsoft.Extensions.Options.csproj index 7225edf84ba537..c7ea3e00049e3e 100644 --- a/src/libraries/Microsoft.Extensions.Options/src/Microsoft.Extensions.Options.csproj +++ b/src/libraries/Microsoft.Extensions.Options/src/Microsoft.Extensions.Options.csproj @@ -28,12 +28,13 @@ - + diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/DataAnnotationAttributesWithParams.g.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/DataAnnotationAttributesWithParams.g.cs new file mode 100644 index 00000000000000..c51c551222f42f --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/DataAnnotationAttributesWithParams.g.cs @@ -0,0 +1,135 @@ + + // + #nullable enable + #pragma warning disable CS1591 // Compensate for https://github.com/dotnet/roslyn/issues/54103 + namespace Test +{ + partial class MyOptionsValidator + { + /// + /// Validates a specific named options instance (or all when is ). + /// + /// The name of the options instance being validated. + /// The options instance. + /// Validation result. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] + public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Test.MyOptions options) + { + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; + var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); + var validationResults = new global::System.Collections.Generic.List(); + var validationAttributes = new global::System.Collections.Generic.List(1); + + context.MemberName = "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P1" : $"{name}.P1"; + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P2"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P2" : $"{name}.P2"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P2, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P3"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P3" : $"{name}.P3"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P3, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P4"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P4" : $"{name}.P4"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A4); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); + } + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Attributes + { + internal static readonly global::System.ComponentModel.DataAnnotations.RequiredAttribute A1 = new global::System.ComponentModel.DataAnnotations.RequiredAttribute(); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__LengthAttribute A2 = new __OptionValidationGeneratedAttributes.__SourceGen__LengthAttribute( + (int)10, + (int)20); + + internal static readonly global::System.ComponentModel.DataAnnotations.AllowedValuesAttribute A3 = new global::System.ComponentModel.DataAnnotations.AllowedValuesAttribute( + (int)10, (int)20, (int)30); + + internal static readonly global::System.ComponentModel.DataAnnotations.DeniedValuesAttribute A4 = new global::System.ComponentModel.DataAnnotations.DeniedValuesAttribute( + "One", "Ten", "Hundred"); + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Validators + { + } +} +namespace __OptionValidationGeneratedAttributes +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__LengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "The field {0} must be a string or collection type with a minimum length of '{1}' and maximum length of '{2}'."; + public __SourceGen__LengthAttribute(int minimumLength, int maximumLength) : base(() => DefaultErrorMessageString) { MinimumLength = minimumLength; MaximumLength = maximumLength; } + public int MinimumLength { get; } + public int MaximumLength { get; } + public override bool IsValid(object? value) + { + if (MinimumLength < 0) + { + throw new global::System.InvalidOperationException("LengthAttribute must have a MinimumLength value that is zero or greater."); + } + if (MaximumLength < MinimumLength) + { + throw new global::System.InvalidOperationException("LengthAttribute must have a MaximumLength value that is greater than or equal to MinimumLength."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return (uint)(length - MinimumLength) <= (uint)(MaximumLength - MinimumLength); + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, MinimumLength, MaximumLength); + } +} diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/EmitterWithCustomValidator.netcore.g.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/EmitterWithCustomValidator.netcore.g.cs new file mode 100644 index 00000000000000..2c5af12c5b5f24 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/EmitterWithCustomValidator.netcore.g.cs @@ -0,0 +1,175 @@ + + // + #nullable enable + #pragma warning disable CS1591 // Compensate for https://github.com/dotnet/roslyn/issues/54103 + namespace HelloWorld +{ + partial struct MyOptionsValidator + { + /// + /// Validates a specific named options instance (or all when is ). + /// + /// The name of the options instance being validated. + /// The options instance. + /// Validation result. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] + public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::HelloWorld.MyOptions options) + { + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; + var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); + var validationResults = new global::System.Collections.Generic.List(); + var validationAttributes = new global::System.Collections.Generic.List(1); + + context.MemberName = "Val1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.Val1" : $"{name}.Val1"; + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val1, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "Val2"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.Val2" : $"{name}.Val2"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val2, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); + } + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Attributes + { + internal static readonly global::System.ComponentModel.DataAnnotations.RequiredAttribute A1 = new global::System.ComponentModel.DataAnnotations.RequiredAttribute(); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A2 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( + (int)1, + (int)3); + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Validators + { + } +} +namespace __OptionValidationGeneratedAttributes +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__RangeAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + public __SourceGen__RangeAttribute(int minimum, int maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(int); + } + public __SourceGen__RangeAttribute(double minimum, double maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(double); + } + public __SourceGen__RangeAttribute(global::System.Type type, string minimum, string maximum) : base() + { + OperandType = type; + NeedToConvertMinMax = true; + Minimum = minimum; + Maximum = maximum; + } + public object Minimum { get; private set; } + public object Maximum { get; private set; } + public bool MinimumIsExclusive { get; set; } + public bool MaximumIsExclusive { get; set; } + public global::System.Type OperandType { get; } + public bool ParseLimitsInInvariantCulture { get; set; } + public bool ConvertValueInInvariantCulture { get; set; } + public override string FormatErrorMessage(string name) => + string.Format(global::System.Globalization.CultureInfo.CurrentCulture, GetValidationErrorMessage(), name, Minimum, Maximum); + private bool NeedToConvertMinMax { get; } + private bool Initialized { get; set; } + public override bool IsValid(object? value) + { + if (!Initialized) + { + if (Minimum is null || Maximum is null) + { + throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + if (NeedToConvertMinMax) + { + System.Globalization.CultureInfo culture = ParseLimitsInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + Minimum = ConvertValue(Minimum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + Maximum = ConvertValue(Maximum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + int cmp = ((global::System.IComparable)Minimum).CompareTo((global::System.IComparable)Maximum); + if (cmp > 0) + { + throw new global::System.InvalidOperationException("The maximum value '{Maximum}' must be greater than or equal to the minimum value '{Minimum}'."); + } + else if (cmp == 0 && (MinimumIsExclusive || MaximumIsExclusive)) + { + throw new global::System.InvalidOperationException("Cannot use exclusive bounds when the maximum value is equal to the minimum value."); + } + Initialized = true; + } + + if (value is null or string { Length: 0 }) + { + return true; + } + + System.Globalization.CultureInfo formatProvider = ConvertValueInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + object? convertedValue; + + try + { + convertedValue = ConvertValue(value, formatProvider); + } + catch (global::System.Exception e) when (e is global::System.FormatException or global::System.InvalidCastException or global::System.NotSupportedException) + { + return false; + } + + var min = (global::System.IComparable)Minimum; + var max = (global::System.IComparable)Maximum; + + return + (MinimumIsExclusive ? min.CompareTo(convertedValue) < 0 : min.CompareTo(convertedValue) <= 0) && + (MaximumIsExclusive ? max.CompareTo(convertedValue) > 0 : max.CompareTo(convertedValue) >= 0); + } + private string GetValidationErrorMessage() + { + return (MinimumIsExclusive, MaximumIsExclusive) switch + { + (false, false) => "The field {0} must be between {1} and {2}.", + (true, false) => "The field {0} must be between {1} exclusive and {2}.", + (false, true) => "The field {0} must be between {1} and {2} exclusive.", + (true, true) => "The field {0} must be between {1} exclusive and {2} exclusive.", + }; + } + private object? ConvertValue(object? value, System.Globalization.CultureInfo formatProvider) + { + if (value is string stringValue) + { + value = global::System.Convert.ChangeType(stringValue, OperandType, formatProvider); + } + else + { + value = global::System.Convert.ChangeType(value, OperandType, formatProvider); + } + return value; + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/EmitterWithCustomValidator.netfx.g.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/EmitterWithCustomValidator.netfx.g.cs new file mode 100644 index 00000000000000..9dc3ded5bd4624 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/EmitterWithCustomValidator.netfx.g.cs @@ -0,0 +1,173 @@ + + // + #nullable enable + #pragma warning disable CS1591 // Compensate for https://github.com/dotnet/roslyn/issues/54103 + namespace HelloWorld +{ + partial struct MyOptionsValidator + { + /// + /// Validates a specific named options instance (or all when is ). + /// + /// The name of the options instance being validated. + /// The options instance. + /// Validation result. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::HelloWorld.MyOptions options) + { + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; + var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); + var validationResults = new global::System.Collections.Generic.List(); + var validationAttributes = new global::System.Collections.Generic.List(1); + + context.MemberName = "Val1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.Val1" : $"{name}.Val1"; + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val1, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "Val2"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.Val2" : $"{name}.Val2"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val2, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); + } + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Attributes + { + internal static readonly global::System.ComponentModel.DataAnnotations.RequiredAttribute A1 = new global::System.ComponentModel.DataAnnotations.RequiredAttribute(); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A2 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( + (int)1, + (int)3); + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Validators + { + } +} +namespace __OptionValidationGeneratedAttributes +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__RangeAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + public __SourceGen__RangeAttribute(int minimum, int maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(int); + } + public __SourceGen__RangeAttribute(double minimum, double maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(double); + } + public __SourceGen__RangeAttribute(global::System.Type type, string minimum, string maximum) : base() + { + OperandType = type; + NeedToConvertMinMax = true; + Minimum = minimum; + Maximum = maximum; + } + public object Minimum { get; private set; } + public object Maximum { get; private set; } + public bool MinimumIsExclusive { get; set; } + public bool MaximumIsExclusive { get; set; } + public global::System.Type OperandType { get; } + public bool ParseLimitsInInvariantCulture { get; set; } + public bool ConvertValueInInvariantCulture { get; set; } + public override string FormatErrorMessage(string name) => + string.Format(global::System.Globalization.CultureInfo.CurrentCulture, GetValidationErrorMessage(), name, Minimum, Maximum); + private bool NeedToConvertMinMax { get; } + private bool Initialized { get; set; } + public override bool IsValid(object? value) + { + if (!Initialized) + { + if (Minimum is null || Maximum is null) + { + throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + if (NeedToConvertMinMax) + { + System.Globalization.CultureInfo culture = ParseLimitsInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + Minimum = ConvertValue(Minimum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + Maximum = ConvertValue(Maximum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + int cmp = ((global::System.IComparable)Minimum).CompareTo((global::System.IComparable)Maximum); + if (cmp > 0) + { + throw new global::System.InvalidOperationException("The maximum value '{Maximum}' must be greater than or equal to the minimum value '{Minimum}'."); + } + else if (cmp == 0 && (MinimumIsExclusive || MaximumIsExclusive)) + { + throw new global::System.InvalidOperationException("Cannot use exclusive bounds when the maximum value is equal to the minimum value."); + } + Initialized = true; + } + + if (value is null or string { Length: 0 }) + { + return true; + } + + System.Globalization.CultureInfo formatProvider = ConvertValueInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + object? convertedValue; + + try + { + convertedValue = ConvertValue(value, formatProvider); + } + catch (global::System.Exception e) when (e is global::System.FormatException or global::System.InvalidCastException or global::System.NotSupportedException) + { + return false; + } + + var min = (global::System.IComparable)Minimum; + var max = (global::System.IComparable)Maximum; + + return + (MinimumIsExclusive ? min.CompareTo(convertedValue) < 0 : min.CompareTo(convertedValue) <= 0) && + (MaximumIsExclusive ? max.CompareTo(convertedValue) > 0 : max.CompareTo(convertedValue) >= 0); + } + private string GetValidationErrorMessage() + { + return (MinimumIsExclusive, MaximumIsExclusive) switch + { + (false, false) => "The field {0} must be between {1} and {2}.", + (true, false) => "The field {0} must be between {1} exclusive and {2}.", + (false, true) => "The field {0} must be between {1} and {2} exclusive.", + (true, true) => "The field {0} must be between {1} exclusive and {2} exclusive.", + }; + } + private object? ConvertValue(object? value, System.Globalization.CultureInfo formatProvider) + { + if (value is string stringValue) + { + value = global::System.Convert.ChangeType(stringValue, OperandType, formatProvider); + } + else + { + value = global::System.Convert.ChangeType(value, OperandType, formatProvider); + } + return value; + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/GeneratedAttributesTest.netcore.lang10.g.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/GeneratedAttributesTest.netcore.lang10.g.cs new file mode 100644 index 00000000000000..cc9864a2619c45 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/GeneratedAttributesTest.netcore.lang10.g.cs @@ -0,0 +1,471 @@ + + // + #nullable enable + #pragma warning disable CS1591 // Compensate for https://github.com/dotnet/roslyn/issues/54103 + namespace ValidationTest +{ + partial class OptionsUsingGeneratedAttributesValidator + { + /// + /// Validates a specific named options instance (or all when is ). + /// + /// The name of the options instance being validated. + /// The options instance. + /// Validation result. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] + public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::ValidationTest.OptionsUsingGeneratedAttributes options) + { + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; + var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); + var validationResults = new global::System.Collections.Generic.List(); + var validationAttributes = new global::System.Collections.Generic.List(1); + + context.MemberName = "P0"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P0" : $"{name}.P0"; + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P0, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P1" : $"{name}.P1"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P2"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P2" : $"{name}.P2"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P2, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P3"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P3" : $"{name}.P3"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P3, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P4"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P4" : $"{name}.P4"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P5" : $"{name}.P5"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A4); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P6"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P6" : $"{name}.P6"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A5); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P7"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P7" : $"{name}.P7"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P7, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P8"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P8" : $"{name}.P8"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P8, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P9"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P9" : $"{name}.P9"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A4); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P9, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P10"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P10" : $"{name}.P10"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A4); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P10, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P11"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P11" : $"{name}.P11"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P11, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P12"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P12" : $"{name}.P12"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A4); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P12, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); + } + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + internal static class __Attributes_2C497155 + { + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__2C497155_LengthAttribute A1 = new __OptionValidationGeneratedAttributes.__SourceGen__2C497155_LengthAttribute( + (int)1, + (int)3); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__2C497155_RangeAttribute A2 = new __OptionValidationGeneratedAttributes.__SourceGen__2C497155_RangeAttribute( + (int)1, + (int)3); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__2C497155_MinLengthAttribute A3 = new __OptionValidationGeneratedAttributes.__SourceGen__2C497155_MinLengthAttribute( + (int)5); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__2C497155_MaxLengthAttribute A4 = new __OptionValidationGeneratedAttributes.__SourceGen__2C497155_MaxLengthAttribute( + (int)5); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__2C497155_CompareAttribute A5 = new __OptionValidationGeneratedAttributes.__SourceGen__2C497155_CompareAttribute( + "P5"); + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + internal static class __Validators_2C497155 + { + } +} +namespace __OptionValidationGeneratedAttributes +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property, AllowMultiple = false)] + internal class __SourceGen__2C497155_CompareAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "'{0}' and '{1}' do not match."; + public __SourceGen__2C497155_CompareAttribute(string otherProperty) : base(() => DefaultErrorMessageString) + { + if (otherProperty == null) + { + throw new global::System.ArgumentNullException(nameof(otherProperty)); + } + OtherProperty = otherProperty; + } + public string OtherProperty { get; } + public override bool RequiresValidationContext => true; + + protected override global::System.ComponentModel.DataAnnotations.ValidationResult? IsValid(object? value, global::System.ComponentModel.DataAnnotations.ValidationContext validationContext) + { + bool result = true; + + if (validationContext.ObjectInstance is global::ValidationTest.OptionsUsingGeneratedAttributes && OtherProperty == "P5") + { + result = Equals(value, ((global::ValidationTest.OptionsUsingGeneratedAttributes)validationContext.ObjectInstance).P5); + } + + if (!result) + { + string[]? memberNames = validationContext.MemberName is null ? null : new string[] { validationContext.MemberName }; + return new global::System.ComponentModel.DataAnnotations.ValidationResult(FormatErrorMessage(validationContext.DisplayName), memberNames); + } + + return null; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, OtherProperty); + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + internal class __SourceGen__2C497155_LengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "The field {0} must be a string or collection type with a minimum length of '{1}' and maximum length of '{2}'."; + public __SourceGen__2C497155_LengthAttribute(int minimumLength, int maximumLength) : base(() => DefaultErrorMessageString) { MinimumLength = minimumLength; MaximumLength = maximumLength; } + public int MinimumLength { get; } + public int MaximumLength { get; } + public override bool IsValid(object? value) + { + if (MinimumLength < 0) + { + throw new global::System.InvalidOperationException("LengthAttribute must have a MinimumLength value that is zero or greater."); + } + if (MaximumLength < MinimumLength) + { + throw new global::System.InvalidOperationException("LengthAttribute must have a MaximumLength value that is greater than or equal to MinimumLength."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::ValidationTest.FakeCount) + { + length = ((global::ValidationTest.FakeCount)value).Count; + } + else if (value is global::ValidationTest.FakeCountChild) + { + length = ((global::ValidationTest.FakeCountChild)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return (uint)(length - MinimumLength) <= (uint)(MaximumLength - MinimumLength); + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, MinimumLength, MaximumLength); + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + internal class __SourceGen__2C497155_MaxLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private const int MaxAllowableLength = -1; + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a maximum length of '{1}'."; + public __SourceGen__2C497155_MaxLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public __SourceGen__2C497155_MaxLengthAttribute(): base(() => DefaultErrorMessageString) { Length = MaxAllowableLength; } + public int Length { get; } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + public override bool IsValid(object? value) + { + if (Length == 0 || Length < -1) + { + throw new global::System.InvalidOperationException("MaxLengthAttribute must have a Length value that is greater than zero. Use MaxLength() without parameters to indicate that the string or array can have the maximum allowable length."); + } + if (value == null || MaxAllowableLength == Length) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::ValidationTest.FakeCount) + { + length = ((global::ValidationTest.FakeCount)value).Count; + } + else if (value is global::ValidationTest.FakeCountChild) + { + length = ((global::ValidationTest.FakeCountChild)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length <= Length; + } + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + internal class __SourceGen__2C497155_MinLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a minimum length of '{1}'."; + + public __SourceGen__2C497155_MinLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public int Length { get; } + public override bool IsValid(object? value) + { + if (Length < -1) + { + throw new global::System.InvalidOperationException("MinLengthAttribute must have a Length value that is zero or greater."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::ValidationTest.FakeCount) + { + length = ((global::ValidationTest.FakeCount)value).Count; + } + else if (value is global::ValidationTest.FakeCountChild) + { + length = ((global::ValidationTest.FakeCountChild)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length >= Length; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + internal class __SourceGen__2C497155_RangeAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + public __SourceGen__2C497155_RangeAttribute(int minimum, int maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(int); + } + public __SourceGen__2C497155_RangeAttribute(double minimum, double maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(double); + } + public __SourceGen__2C497155_RangeAttribute(global::System.Type type, string minimum, string maximum) : base() + { + OperandType = type; + NeedToConvertMinMax = true; + Minimum = minimum; + Maximum = maximum; + } + public object Minimum { get; private set; } + public object Maximum { get; private set; } + public bool MinimumIsExclusive { get; set; } + public bool MaximumIsExclusive { get; set; } + public global::System.Type OperandType { get; } + public bool ParseLimitsInInvariantCulture { get; set; } + public bool ConvertValueInInvariantCulture { get; set; } + public override string FormatErrorMessage(string name) => + string.Format(global::System.Globalization.CultureInfo.CurrentCulture, GetValidationErrorMessage(), name, Minimum, Maximum); + private bool NeedToConvertMinMax { get; } + private bool Initialized { get; set; } + public override bool IsValid(object? value) + { + if (!Initialized) + { + if (Minimum is null || Maximum is null) + { + throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + if (NeedToConvertMinMax) + { + System.Globalization.CultureInfo culture = ParseLimitsInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + Minimum = ConvertValue(Minimum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + Maximum = ConvertValue(Maximum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + int cmp = ((global::System.IComparable)Minimum).CompareTo((global::System.IComparable)Maximum); + if (cmp > 0) + { + throw new global::System.InvalidOperationException("The maximum value '{Maximum}' must be greater than or equal to the minimum value '{Minimum}'."); + } + else if (cmp == 0 && (MinimumIsExclusive || MaximumIsExclusive)) + { + throw new global::System.InvalidOperationException("Cannot use exclusive bounds when the maximum value is equal to the minimum value."); + } + Initialized = true; + } + + if (value is null or string { Length: 0 }) + { + return true; + } + + System.Globalization.CultureInfo formatProvider = ConvertValueInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + object? convertedValue; + + try + { + convertedValue = ConvertValue(value, formatProvider); + } + catch (global::System.Exception e) when (e is global::System.FormatException or global::System.InvalidCastException or global::System.NotSupportedException) + { + return false; + } + + var min = (global::System.IComparable)Minimum; + var max = (global::System.IComparable)Maximum; + + return + (MinimumIsExclusive ? min.CompareTo(convertedValue) < 0 : min.CompareTo(convertedValue) <= 0) && + (MaximumIsExclusive ? max.CompareTo(convertedValue) > 0 : max.CompareTo(convertedValue) >= 0); + } + private string GetValidationErrorMessage() + { + return (MinimumIsExclusive, MaximumIsExclusive) switch + { + (false, false) => "The field {0} must be between {1} and {2}.", + (true, false) => "The field {0} must be between {1} exclusive and {2}.", + (false, true) => "The field {0} must be between {1} and {2} exclusive.", + (true, true) => "The field {0} must be between {1} exclusive and {2} exclusive.", + }; + } + private object? ConvertValue(object? value, System.Globalization.CultureInfo formatProvider) + { + if (value is string stringValue) + { + value = global::System.Convert.ChangeType(stringValue, OperandType, formatProvider); + } + else + { + value = global::System.Convert.ChangeType(value, OperandType, formatProvider); + } + return value; + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/GeneratedAttributesTest.netcore.lang11.g.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/GeneratedAttributesTest.netcore.lang11.g.cs new file mode 100644 index 00000000000000..2a33e51b0b6175 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/GeneratedAttributesTest.netcore.lang11.g.cs @@ -0,0 +1,471 @@ + + // + #nullable enable + #pragma warning disable CS1591 // Compensate for https://github.com/dotnet/roslyn/issues/54103 + namespace ValidationTest +{ + partial class OptionsUsingGeneratedAttributesValidator + { + /// + /// Validates a specific named options instance (or all when is ). + /// + /// The name of the options instance being validated. + /// The options instance. + /// Validation result. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] + public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::ValidationTest.OptionsUsingGeneratedAttributes options) + { + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; + var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); + var validationResults = new global::System.Collections.Generic.List(); + var validationAttributes = new global::System.Collections.Generic.List(1); + + context.MemberName = "P0"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P0" : $"{name}.P0"; + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P0, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P1" : $"{name}.P1"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P2"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P2" : $"{name}.P2"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P2, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P3"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P3" : $"{name}.P3"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P3, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P4"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P4" : $"{name}.P4"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P5" : $"{name}.P5"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A4); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P6"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P6" : $"{name}.P6"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A5); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P7"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P7" : $"{name}.P7"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P7, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P8"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P8" : $"{name}.P8"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P8, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P9"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P9" : $"{name}.P9"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A4); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P9, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P10"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P10" : $"{name}.P10"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A4); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P10, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P11"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P11" : $"{name}.P11"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P11, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P12"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P12" : $"{name}.P12"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A4); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P12, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); + } + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Attributes + { + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__LengthAttribute A1 = new __OptionValidationGeneratedAttributes.__SourceGen__LengthAttribute( + (int)1, + (int)3); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A2 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( + (int)1, + (int)3); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__MinLengthAttribute A3 = new __OptionValidationGeneratedAttributes.__SourceGen__MinLengthAttribute( + (int)5); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__MaxLengthAttribute A4 = new __OptionValidationGeneratedAttributes.__SourceGen__MaxLengthAttribute( + (int)5); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__CompareAttribute A5 = new __OptionValidationGeneratedAttributes.__SourceGen__CompareAttribute( + "P5"); + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Validators + { + } +} +namespace __OptionValidationGeneratedAttributes +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property, AllowMultiple = false)] + file class __SourceGen__CompareAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "'{0}' and '{1}' do not match."; + public __SourceGen__CompareAttribute(string otherProperty) : base(() => DefaultErrorMessageString) + { + if (otherProperty == null) + { + throw new global::System.ArgumentNullException(nameof(otherProperty)); + } + OtherProperty = otherProperty; + } + public string OtherProperty { get; } + public override bool RequiresValidationContext => true; + + protected override global::System.ComponentModel.DataAnnotations.ValidationResult? IsValid(object? value, global::System.ComponentModel.DataAnnotations.ValidationContext validationContext) + { + bool result = true; + + if (validationContext.ObjectInstance is global::ValidationTest.OptionsUsingGeneratedAttributes && OtherProperty == "P5") + { + result = Equals(value, ((global::ValidationTest.OptionsUsingGeneratedAttributes)validationContext.ObjectInstance).P5); + } + + if (!result) + { + string[]? memberNames = validationContext.MemberName is null ? null : new string[] { validationContext.MemberName }; + return new global::System.ComponentModel.DataAnnotations.ValidationResult(FormatErrorMessage(validationContext.DisplayName), memberNames); + } + + return null; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, OtherProperty); + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__LengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "The field {0} must be a string or collection type with a minimum length of '{1}' and maximum length of '{2}'."; + public __SourceGen__LengthAttribute(int minimumLength, int maximumLength) : base(() => DefaultErrorMessageString) { MinimumLength = minimumLength; MaximumLength = maximumLength; } + public int MinimumLength { get; } + public int MaximumLength { get; } + public override bool IsValid(object? value) + { + if (MinimumLength < 0) + { + throw new global::System.InvalidOperationException("LengthAttribute must have a MinimumLength value that is zero or greater."); + } + if (MaximumLength < MinimumLength) + { + throw new global::System.InvalidOperationException("LengthAttribute must have a MaximumLength value that is greater than or equal to MinimumLength."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::ValidationTest.FakeCount) + { + length = ((global::ValidationTest.FakeCount)value).Count; + } + else if (value is global::ValidationTest.FakeCountChild) + { + length = ((global::ValidationTest.FakeCountChild)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return (uint)(length - MinimumLength) <= (uint)(MaximumLength - MinimumLength); + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, MinimumLength, MaximumLength); + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__MaxLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private const int MaxAllowableLength = -1; + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a maximum length of '{1}'."; + public __SourceGen__MaxLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public __SourceGen__MaxLengthAttribute(): base(() => DefaultErrorMessageString) { Length = MaxAllowableLength; } + public int Length { get; } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + public override bool IsValid(object? value) + { + if (Length == 0 || Length < -1) + { + throw new global::System.InvalidOperationException("MaxLengthAttribute must have a Length value that is greater than zero. Use MaxLength() without parameters to indicate that the string or array can have the maximum allowable length."); + } + if (value == null || MaxAllowableLength == Length) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::ValidationTest.FakeCount) + { + length = ((global::ValidationTest.FakeCount)value).Count; + } + else if (value is global::ValidationTest.FakeCountChild) + { + length = ((global::ValidationTest.FakeCountChild)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length <= Length; + } + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__MinLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a minimum length of '{1}'."; + + public __SourceGen__MinLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public int Length { get; } + public override bool IsValid(object? value) + { + if (Length < -1) + { + throw new global::System.InvalidOperationException("MinLengthAttribute must have a Length value that is zero or greater."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::ValidationTest.FakeCount) + { + length = ((global::ValidationTest.FakeCount)value).Count; + } + else if (value is global::ValidationTest.FakeCountChild) + { + length = ((global::ValidationTest.FakeCountChild)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length >= Length; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__RangeAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + public __SourceGen__RangeAttribute(int minimum, int maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(int); + } + public __SourceGen__RangeAttribute(double minimum, double maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(double); + } + public __SourceGen__RangeAttribute(global::System.Type type, string minimum, string maximum) : base() + { + OperandType = type; + NeedToConvertMinMax = true; + Minimum = minimum; + Maximum = maximum; + } + public object Minimum { get; private set; } + public object Maximum { get; private set; } + public bool MinimumIsExclusive { get; set; } + public bool MaximumIsExclusive { get; set; } + public global::System.Type OperandType { get; } + public bool ParseLimitsInInvariantCulture { get; set; } + public bool ConvertValueInInvariantCulture { get; set; } + public override string FormatErrorMessage(string name) => + string.Format(global::System.Globalization.CultureInfo.CurrentCulture, GetValidationErrorMessage(), name, Minimum, Maximum); + private bool NeedToConvertMinMax { get; } + private bool Initialized { get; set; } + public override bool IsValid(object? value) + { + if (!Initialized) + { + if (Minimum is null || Maximum is null) + { + throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + if (NeedToConvertMinMax) + { + System.Globalization.CultureInfo culture = ParseLimitsInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + Minimum = ConvertValue(Minimum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + Maximum = ConvertValue(Maximum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + int cmp = ((global::System.IComparable)Minimum).CompareTo((global::System.IComparable)Maximum); + if (cmp > 0) + { + throw new global::System.InvalidOperationException("The maximum value '{Maximum}' must be greater than or equal to the minimum value '{Minimum}'."); + } + else if (cmp == 0 && (MinimumIsExclusive || MaximumIsExclusive)) + { + throw new global::System.InvalidOperationException("Cannot use exclusive bounds when the maximum value is equal to the minimum value."); + } + Initialized = true; + } + + if (value is null or string { Length: 0 }) + { + return true; + } + + System.Globalization.CultureInfo formatProvider = ConvertValueInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + object? convertedValue; + + try + { + convertedValue = ConvertValue(value, formatProvider); + } + catch (global::System.Exception e) when (e is global::System.FormatException or global::System.InvalidCastException or global::System.NotSupportedException) + { + return false; + } + + var min = (global::System.IComparable)Minimum; + var max = (global::System.IComparable)Maximum; + + return + (MinimumIsExclusive ? min.CompareTo(convertedValue) < 0 : min.CompareTo(convertedValue) <= 0) && + (MaximumIsExclusive ? max.CompareTo(convertedValue) > 0 : max.CompareTo(convertedValue) >= 0); + } + private string GetValidationErrorMessage() + { + return (MinimumIsExclusive, MaximumIsExclusive) switch + { + (false, false) => "The field {0} must be between {1} and {2}.", + (true, false) => "The field {0} must be between {1} exclusive and {2}.", + (false, true) => "The field {0} must be between {1} and {2} exclusive.", + (true, true) => "The field {0} must be between {1} exclusive and {2} exclusive.", + }; + } + private object? ConvertValue(object? value, System.Globalization.CultureInfo formatProvider) + { + if (value is string stringValue) + { + value = global::System.Convert.ChangeType(stringValue, OperandType, formatProvider); + } + else + { + value = global::System.Convert.ChangeType(value, OperandType, formatProvider); + } + return value; + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/GeneratedAttributesTest.netfx.lang10.g.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/GeneratedAttributesTest.netfx.lang10.g.cs new file mode 100644 index 00000000000000..7f5eb90a202815 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/GeneratedAttributesTest.netfx.lang10.g.cs @@ -0,0 +1,386 @@ + + // + #nullable enable + #pragma warning disable CS1591 // Compensate for https://github.com/dotnet/roslyn/issues/54103 + namespace ValidationTest +{ + partial class OptionsUsingGeneratedAttributesValidator + { + /// + /// Validates a specific named options instance (or all when is ). + /// + /// The name of the options instance being validated. + /// The options instance. + /// Validation result. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::ValidationTest.OptionsUsingGeneratedAttributes options) + { + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; + var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); + var validationResults = new global::System.Collections.Generic.List(); + var validationAttributes = new global::System.Collections.Generic.List(1); + + context.MemberName = "P3"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P3" : $"{name}.P3"; + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P3, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P4"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P4" : $"{name}.P4"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P5" : $"{name}.P5"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P6"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P6" : $"{name}.P6"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A4); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P7"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P7" : $"{name}.P7"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P7, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P8"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P8" : $"{name}.P8"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P8, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P9"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P9" : $"{name}.P9"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P9, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P10"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P10" : $"{name}.P10"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P10, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P11"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P11" : $"{name}.P11"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P11, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P12"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P12" : $"{name}.P12"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes_2C497155.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P12, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); + } + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + internal static class __Attributes_2C497155 + { + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__2C497155_RangeAttribute A1 = new __OptionValidationGeneratedAttributes.__SourceGen__2C497155_RangeAttribute( + (int)1, + (int)3); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__2C497155_MinLengthAttribute A2 = new __OptionValidationGeneratedAttributes.__SourceGen__2C497155_MinLengthAttribute( + (int)5); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__2C497155_MaxLengthAttribute A3 = new __OptionValidationGeneratedAttributes.__SourceGen__2C497155_MaxLengthAttribute( + (int)5); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__2C497155_CompareAttribute A4 = new __OptionValidationGeneratedAttributes.__SourceGen__2C497155_CompareAttribute( + "P5"); + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + internal static class __Validators_2C497155 + { + } +} +namespace __OptionValidationGeneratedAttributes +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property, AllowMultiple = false)] + internal class __SourceGen__2C497155_CompareAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "'{0}' and '{1}' do not match."; + public __SourceGen__2C497155_CompareAttribute(string otherProperty) : base(() => DefaultErrorMessageString) + { + if (otherProperty == null) + { + throw new global::System.ArgumentNullException(nameof(otherProperty)); + } + OtherProperty = otherProperty; + } + public string OtherProperty { get; } + public override bool RequiresValidationContext => true; + + protected override global::System.ComponentModel.DataAnnotations.ValidationResult? IsValid(object? value, global::System.ComponentModel.DataAnnotations.ValidationContext validationContext) + { + bool result = true; + + if (validationContext.ObjectInstance is global::ValidationTest.OptionsUsingGeneratedAttributes && OtherProperty == "P5") + { + result = Equals(value, ((global::ValidationTest.OptionsUsingGeneratedAttributes)validationContext.ObjectInstance).P5); + } + + if (!result) + { + string[]? memberNames = validationContext.MemberName is null ? null : new string[] { validationContext.MemberName }; + return new global::System.ComponentModel.DataAnnotations.ValidationResult(FormatErrorMessage(validationContext.DisplayName), memberNames); + } + + return null; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, OtherProperty); + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + internal class __SourceGen__2C497155_MaxLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private const int MaxAllowableLength = -1; + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a maximum length of '{1}'."; + public __SourceGen__2C497155_MaxLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public __SourceGen__2C497155_MaxLengthAttribute(): base(() => DefaultErrorMessageString) { Length = MaxAllowableLength; } + public int Length { get; } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + public override bool IsValid(object? value) + { + if (Length == 0 || Length < -1) + { + throw new global::System.InvalidOperationException("MaxLengthAttribute must have a Length value that is greater than zero. Use MaxLength() without parameters to indicate that the string or array can have the maximum allowable length."); + } + if (value == null || MaxAllowableLength == Length) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::ValidationTest.FakeCount) + { + length = ((global::ValidationTest.FakeCount)value).Count; + } + else if (value is global::ValidationTest.FakeCountChild) + { + length = ((global::ValidationTest.FakeCountChild)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length <= Length; + } + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + internal class __SourceGen__2C497155_MinLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a minimum length of '{1}'."; + + public __SourceGen__2C497155_MinLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public int Length { get; } + public override bool IsValid(object? value) + { + if (Length < -1) + { + throw new global::System.InvalidOperationException("MinLengthAttribute must have a Length value that is zero or greater."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::ValidationTest.FakeCount) + { + length = ((global::ValidationTest.FakeCount)value).Count; + } + else if (value is global::ValidationTest.FakeCountChild) + { + length = ((global::ValidationTest.FakeCountChild)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length >= Length; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + internal class __SourceGen__2C497155_RangeAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + public __SourceGen__2C497155_RangeAttribute(int minimum, int maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(int); + } + public __SourceGen__2C497155_RangeAttribute(double minimum, double maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(double); + } + public __SourceGen__2C497155_RangeAttribute(global::System.Type type, string minimum, string maximum) : base() + { + OperandType = type; + NeedToConvertMinMax = true; + Minimum = minimum; + Maximum = maximum; + } + public object Minimum { get; private set; } + public object Maximum { get; private set; } + public bool MinimumIsExclusive { get; set; } + public bool MaximumIsExclusive { get; set; } + public global::System.Type OperandType { get; } + public bool ParseLimitsInInvariantCulture { get; set; } + public bool ConvertValueInInvariantCulture { get; set; } + public override string FormatErrorMessage(string name) => + string.Format(global::System.Globalization.CultureInfo.CurrentCulture, GetValidationErrorMessage(), name, Minimum, Maximum); + private bool NeedToConvertMinMax { get; } + private bool Initialized { get; set; } + public override bool IsValid(object? value) + { + if (!Initialized) + { + if (Minimum is null || Maximum is null) + { + throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + if (NeedToConvertMinMax) + { + System.Globalization.CultureInfo culture = ParseLimitsInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + Minimum = ConvertValue(Minimum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + Maximum = ConvertValue(Maximum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + int cmp = ((global::System.IComparable)Minimum).CompareTo((global::System.IComparable)Maximum); + if (cmp > 0) + { + throw new global::System.InvalidOperationException("The maximum value '{Maximum}' must be greater than or equal to the minimum value '{Minimum}'."); + } + else if (cmp == 0 && (MinimumIsExclusive || MaximumIsExclusive)) + { + throw new global::System.InvalidOperationException("Cannot use exclusive bounds when the maximum value is equal to the minimum value."); + } + Initialized = true; + } + + if (value is null or string { Length: 0 }) + { + return true; + } + + System.Globalization.CultureInfo formatProvider = ConvertValueInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + object? convertedValue; + + try + { + convertedValue = ConvertValue(value, formatProvider); + } + catch (global::System.Exception e) when (e is global::System.FormatException or global::System.InvalidCastException or global::System.NotSupportedException) + { + return false; + } + + var min = (global::System.IComparable)Minimum; + var max = (global::System.IComparable)Maximum; + + return + (MinimumIsExclusive ? min.CompareTo(convertedValue) < 0 : min.CompareTo(convertedValue) <= 0) && + (MaximumIsExclusive ? max.CompareTo(convertedValue) > 0 : max.CompareTo(convertedValue) >= 0); + } + private string GetValidationErrorMessage() + { + return (MinimumIsExclusive, MaximumIsExclusive) switch + { + (false, false) => "The field {0} must be between {1} and {2}.", + (true, false) => "The field {0} must be between {1} exclusive and {2}.", + (false, true) => "The field {0} must be between {1} and {2} exclusive.", + (true, true) => "The field {0} must be between {1} exclusive and {2} exclusive.", + }; + } + private object? ConvertValue(object? value, System.Globalization.CultureInfo formatProvider) + { + if (value is string stringValue) + { + value = global::System.Convert.ChangeType(stringValue, OperandType, formatProvider); + } + else + { + value = global::System.Convert.ChangeType(value, OperandType, formatProvider); + } + return value; + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/GeneratedAttributesTest.netfx.lang11.g.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/GeneratedAttributesTest.netfx.lang11.g.cs new file mode 100644 index 00000000000000..3ab56e21320a07 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/GeneratedAttributesTest.netfx.lang11.g.cs @@ -0,0 +1,386 @@ + + // + #nullable enable + #pragma warning disable CS1591 // Compensate for https://github.com/dotnet/roslyn/issues/54103 + namespace ValidationTest +{ + partial class OptionsUsingGeneratedAttributesValidator + { + /// + /// Validates a specific named options instance (or all when is ). + /// + /// The name of the options instance being validated. + /// The options instance. + /// Validation result. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::ValidationTest.OptionsUsingGeneratedAttributes options) + { + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; + var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); + var validationResults = new global::System.Collections.Generic.List(); + var validationAttributes = new global::System.Collections.Generic.List(1); + + context.MemberName = "P3"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P3" : $"{name}.P3"; + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P3, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P4"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P4" : $"{name}.P4"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P5" : $"{name}.P5"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P6"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P6" : $"{name}.P6"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A4); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P7"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P7" : $"{name}.P7"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P7, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P8"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P8" : $"{name}.P8"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P8, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P9"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P9" : $"{name}.P9"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P9, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P10"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P10" : $"{name}.P10"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P10, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P11"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P11" : $"{name}.P11"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P11, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P12"; + context.DisplayName = string.IsNullOrEmpty(name) ? "OptionsUsingGeneratedAttributes.P12" : $"{name}.P12"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P12, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); + } + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Attributes + { + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A1 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( + (int)1, + (int)3); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__MinLengthAttribute A2 = new __OptionValidationGeneratedAttributes.__SourceGen__MinLengthAttribute( + (int)5); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__MaxLengthAttribute A3 = new __OptionValidationGeneratedAttributes.__SourceGen__MaxLengthAttribute( + (int)5); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__CompareAttribute A4 = new __OptionValidationGeneratedAttributes.__SourceGen__CompareAttribute( + "P5"); + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Validators + { + } +} +namespace __OptionValidationGeneratedAttributes +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property, AllowMultiple = false)] + file class __SourceGen__CompareAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "'{0}' and '{1}' do not match."; + public __SourceGen__CompareAttribute(string otherProperty) : base(() => DefaultErrorMessageString) + { + if (otherProperty == null) + { + throw new global::System.ArgumentNullException(nameof(otherProperty)); + } + OtherProperty = otherProperty; + } + public string OtherProperty { get; } + public override bool RequiresValidationContext => true; + + protected override global::System.ComponentModel.DataAnnotations.ValidationResult? IsValid(object? value, global::System.ComponentModel.DataAnnotations.ValidationContext validationContext) + { + bool result = true; + + if (validationContext.ObjectInstance is global::ValidationTest.OptionsUsingGeneratedAttributes && OtherProperty == "P5") + { + result = Equals(value, ((global::ValidationTest.OptionsUsingGeneratedAttributes)validationContext.ObjectInstance).P5); + } + + if (!result) + { + string[]? memberNames = validationContext.MemberName is null ? null : new string[] { validationContext.MemberName }; + return new global::System.ComponentModel.DataAnnotations.ValidationResult(FormatErrorMessage(validationContext.DisplayName), memberNames); + } + + return null; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, OtherProperty); + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__MaxLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private const int MaxAllowableLength = -1; + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a maximum length of '{1}'."; + public __SourceGen__MaxLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public __SourceGen__MaxLengthAttribute(): base(() => DefaultErrorMessageString) { Length = MaxAllowableLength; } + public int Length { get; } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + public override bool IsValid(object? value) + { + if (Length == 0 || Length < -1) + { + throw new global::System.InvalidOperationException("MaxLengthAttribute must have a Length value that is greater than zero. Use MaxLength() without parameters to indicate that the string or array can have the maximum allowable length."); + } + if (value == null || MaxAllowableLength == Length) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::ValidationTest.FakeCount) + { + length = ((global::ValidationTest.FakeCount)value).Count; + } + else if (value is global::ValidationTest.FakeCountChild) + { + length = ((global::ValidationTest.FakeCountChild)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length <= Length; + } + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__MinLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a minimum length of '{1}'."; + + public __SourceGen__MinLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public int Length { get; } + public override bool IsValid(object? value) + { + if (Length < -1) + { + throw new global::System.InvalidOperationException("MinLengthAttribute must have a Length value that is zero or greater."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::ValidationTest.FakeCount) + { + length = ((global::ValidationTest.FakeCount)value).Count; + } + else if (value is global::ValidationTest.FakeCountChild) + { + length = ((global::ValidationTest.FakeCountChild)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length >= Length; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__RangeAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + public __SourceGen__RangeAttribute(int minimum, int maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(int); + } + public __SourceGen__RangeAttribute(double minimum, double maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(double); + } + public __SourceGen__RangeAttribute(global::System.Type type, string minimum, string maximum) : base() + { + OperandType = type; + NeedToConvertMinMax = true; + Minimum = minimum; + Maximum = maximum; + } + public object Minimum { get; private set; } + public object Maximum { get; private set; } + public bool MinimumIsExclusive { get; set; } + public bool MaximumIsExclusive { get; set; } + public global::System.Type OperandType { get; } + public bool ParseLimitsInInvariantCulture { get; set; } + public bool ConvertValueInInvariantCulture { get; set; } + public override string FormatErrorMessage(string name) => + string.Format(global::System.Globalization.CultureInfo.CurrentCulture, GetValidationErrorMessage(), name, Minimum, Maximum); + private bool NeedToConvertMinMax { get; } + private bool Initialized { get; set; } + public override bool IsValid(object? value) + { + if (!Initialized) + { + if (Minimum is null || Maximum is null) + { + throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + if (NeedToConvertMinMax) + { + System.Globalization.CultureInfo culture = ParseLimitsInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + Minimum = ConvertValue(Minimum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + Maximum = ConvertValue(Maximum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + int cmp = ((global::System.IComparable)Minimum).CompareTo((global::System.IComparable)Maximum); + if (cmp > 0) + { + throw new global::System.InvalidOperationException("The maximum value '{Maximum}' must be greater than or equal to the minimum value '{Minimum}'."); + } + else if (cmp == 0 && (MinimumIsExclusive || MaximumIsExclusive)) + { + throw new global::System.InvalidOperationException("Cannot use exclusive bounds when the maximum value is equal to the minimum value."); + } + Initialized = true; + } + + if (value is null or string { Length: 0 }) + { + return true; + } + + System.Globalization.CultureInfo formatProvider = ConvertValueInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + object? convertedValue; + + try + { + convertedValue = ConvertValue(value, formatProvider); + } + catch (global::System.Exception e) when (e is global::System.FormatException or global::System.InvalidCastException or global::System.NotSupportedException) + { + return false; + } + + var min = (global::System.IComparable)Minimum; + var max = (global::System.IComparable)Maximum; + + return + (MinimumIsExclusive ? min.CompareTo(convertedValue) < 0 : min.CompareTo(convertedValue) <= 0) && + (MaximumIsExclusive ? max.CompareTo(convertedValue) > 0 : max.CompareTo(convertedValue) >= 0); + } + private string GetValidationErrorMessage() + { + return (MinimumIsExclusive, MaximumIsExclusive) switch + { + (false, false) => "The field {0} must be between {1} and {2}.", + (true, false) => "The field {0} must be between {1} exclusive and {2}.", + (false, true) => "The field {0} must be between {1} and {2} exclusive.", + (true, true) => "The field {0} must be between {1} exclusive and {2} exclusive.", + }; + } + private object? ConvertValue(object? value, System.Globalization.CultureInfo formatProvider) + { + if (value is string stringValue) + { + value = global::System.Convert.ChangeType(stringValue, OperandType, formatProvider); + } + else + { + value = global::System.Convert.ChangeType(value, OperandType, formatProvider); + } + return value; + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/UsingInterfaceAsPropertyTypeForLengthAttributesTests.netcore.g.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/UsingInterfaceAsPropertyTypeForLengthAttributesTests.netcore.g.cs new file mode 100644 index 00000000000000..1cd942fab0f1bb --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/UsingInterfaceAsPropertyTypeForLengthAttributesTests.netcore.g.cs @@ -0,0 +1,252 @@ + + // + #nullable enable + #pragma warning disable CS1591 // Compensate for https://github.com/dotnet/roslyn/issues/54103 + namespace Test +{ + partial class MyOptionsValidator + { + /// + /// Validates a specific named options instance (or all when is ). + /// + /// The name of the options instance being validated. + /// The options instance. + /// Validation result. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] + public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Test.MyOptions options) + { + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; + var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); + var validationResults = new global::System.Collections.Generic.List(); + var validationAttributes = new global::System.Collections.Generic.List(1); + + context.MemberName = "P1"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P1" : $"{name}.P1"; + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P2"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P2" : $"{name}.P2"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P2, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P3"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P3" : $"{name}.P3"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P3, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P4"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P4" : $"{name}.P4"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P5" : $"{name}.P5"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P6"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P6" : $"{name}.P6"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A3); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); + } + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Attributes + { + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__LengthAttribute A1 = new __OptionValidationGeneratedAttributes.__SourceGen__LengthAttribute( + (int)10, + (int)20); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__MinLengthAttribute A2 = new __OptionValidationGeneratedAttributes.__SourceGen__MinLengthAttribute( + (int)4); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__MaxLengthAttribute A3 = new __OptionValidationGeneratedAttributes.__SourceGen__MaxLengthAttribute( + (int)5); + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Validators + { + } +} +namespace __OptionValidationGeneratedAttributes +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__LengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "The field {0} must be a string or collection type with a minimum length of '{1}' and maximum length of '{2}'."; + public __SourceGen__LengthAttribute(int minimumLength, int maximumLength) : base(() => DefaultErrorMessageString) { MinimumLength = minimumLength; MaximumLength = maximumLength; } + public int MinimumLength { get; } + public int MaximumLength { get; } + public override bool IsValid(object? value) + { + if (MinimumLength < 0) + { + throw new global::System.InvalidOperationException("LengthAttribute must have a MinimumLength value that is zero or greater."); + } + if (MaximumLength < MinimumLength) + { + throw new global::System.InvalidOperationException("LengthAttribute must have a MaximumLength value that is greater than or equal to MinimumLength."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::System.Collections.Generic.IList) + { + length = ((global::System.Collections.Generic.IList)value).Count; + } + else if (value is global::System.Collections.Generic.ICollection) + { + length = ((global::System.Collections.Generic.ICollection)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return (uint)(length - MinimumLength) <= (uint)(MaximumLength - MinimumLength); + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, MinimumLength, MaximumLength); + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__MaxLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private const int MaxAllowableLength = -1; + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a maximum length of '{1}'."; + public __SourceGen__MaxLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public __SourceGen__MaxLengthAttribute(): base(() => DefaultErrorMessageString) { Length = MaxAllowableLength; } + public int Length { get; } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + public override bool IsValid(object? value) + { + if (Length == 0 || Length < -1) + { + throw new global::System.InvalidOperationException("MaxLengthAttribute must have a Length value that is greater than zero. Use MaxLength() without parameters to indicate that the string or array can have the maximum allowable length."); + } + if (value == null || MaxAllowableLength == Length) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::System.Collections.Generic.IList) + { + length = ((global::System.Collections.Generic.IList)value).Count; + } + else if (value is global::System.Collections.Generic.ICollection) + { + length = ((global::System.Collections.Generic.ICollection)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length <= Length; + } + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__MinLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a minimum length of '{1}'."; + + public __SourceGen__MinLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public int Length { get; } + public override bool IsValid(object? value) + { + if (Length < -1) + { + throw new global::System.InvalidOperationException("MinLengthAttribute must have a Length value that is zero or greater."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::System.Collections.Generic.IList) + { + length = ((global::System.Collections.Generic.IList)value).Count; + } + else if (value is global::System.Collections.Generic.ICollection) + { + length = ((global::System.Collections.Generic.ICollection)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length >= Length; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + } +} diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/UsingInterfaceAsPropertyTypeForLengthAttributesTests.netfx.g.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/UsingInterfaceAsPropertyTypeForLengthAttributesTests.netfx.g.cs new file mode 100644 index 00000000000000..603680a9ec732c --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Baselines/UsingInterfaceAsPropertyTypeForLengthAttributesTests.netfx.g.cs @@ -0,0 +1,177 @@ + + // + #nullable enable + #pragma warning disable CS1591 // Compensate for https://github.com/dotnet/roslyn/issues/54103 + namespace Test +{ + partial class MyOptionsValidator + { + /// + /// Validates a specific named options instance (or all when is ). + /// + /// The name of the options instance being validated. + /// The options instance. + /// Validation result. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Test.MyOptions options) + { + global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; + var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); + var validationResults = new global::System.Collections.Generic.List(); + var validationAttributes = new global::System.Collections.Generic.List(1); + + context.MemberName = "P2"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P2" : $"{name}.P2"; + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P2, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P3"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P3" : $"{name}.P3"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P3, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P5"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P5" : $"{name}.P5"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P5, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + context.MemberName = "P6"; + context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P6" : $"{name}.P6"; + validationResults.Clear(); + validationAttributes.Clear(); + validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); + if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P6, context, validationResults, validationAttributes)) + { + (builder ??= new()).AddResults(validationResults); + } + + return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); + } + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Attributes + { + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__MinLengthAttribute A1 = new __OptionValidationGeneratedAttributes.__SourceGen__MinLengthAttribute( + (int)4); + + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__MaxLengthAttribute A2 = new __OptionValidationGeneratedAttributes.__SourceGen__MaxLengthAttribute( + (int)5); + } +} +namespace __OptionValidationStaticInstances +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + file static class __Validators + { + } +} +namespace __OptionValidationGeneratedAttributes +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__MaxLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private const int MaxAllowableLength = -1; + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a maximum length of '{1}'."; + public __SourceGen__MaxLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public __SourceGen__MaxLengthAttribute(): base(() => DefaultErrorMessageString) { Length = MaxAllowableLength; } + public int Length { get; } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + public override bool IsValid(object? value) + { + if (Length == 0 || Length < -1) + { + throw new global::System.InvalidOperationException("MaxLengthAttribute must have a Length value that is greater than zero. Use MaxLength() without parameters to indicate that the string or array can have the maximum allowable length."); + } + if (value == null || MaxAllowableLength == Length) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::System.Collections.Generic.IList) + { + length = ((global::System.Collections.Generic.IList)value).Count; + } + else if (value is global::System.Collections.Generic.ICollection) + { + length = ((global::System.Collections.Generic.ICollection)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length <= Length; + } + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__MinLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a minimum length of '{1}'."; + + public __SourceGen__MinLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public int Length { get; } + public override bool IsValid(object? value) + { + if (Length < -1) + { + throw new global::System.InvalidOperationException("MinLengthAttribute must have a Length value that is zero or greater."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else if (value is global::System.Collections.Generic.IList) + { + length = ((global::System.Collections.Generic.IList)value).Count; + } + else if (value is global::System.Collections.Generic.ICollection) + { + length = ((global::System.Collections.Generic.ICollection)value).Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length >= Length; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + } +} diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Main.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Main.cs index aa51c9dfbae73a..623251707f87ba 100644 --- a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Main.cs +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Main.cs @@ -52,77 +52,15 @@ public partial struct MyOptionsValidator : IValidateOptions } """; - string generatedSource = """ - - // - #nullable enable - #pragma warning disable CS1591 // Compensate for https://github.com/dotnet/roslyn/issues/54103 - namespace HelloWorld -{ - partial struct MyOptionsValidator - { - /// - /// Validates a specific named options instance (or all when is ). - /// - /// The name of the options instance being validated. - /// The options instance. - /// Validation result. - [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] - public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::HelloWorld.MyOptions options) - { - global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; - var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); - var validationResults = new global::System.Collections.Generic.List(); - var validationAttributes = new global::System.Collections.Generic.List(1); - - context.MemberName = "Val1"; - context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.Val1" : $"{name}.Val1"; - validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val1, context, validationResults, validationAttributes)) - { - (builder ??= new()).AddResults(validationResults); - } - - context.MemberName = "Val2"; - context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.Val2" : $"{name}.Val2"; - validationResults.Clear(); - validationAttributes.Clear(); - validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.Val2, context, validationResults, validationAttributes)) - { - (builder ??= new()).AddResults(validationResults); - } - - return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); - } - } -} -namespace __OptionValidationStaticInstances -{ - [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] - file static class __Attributes - { - internal static readonly global::System.ComponentModel.DataAnnotations.RequiredAttribute A1 = new global::System.ComponentModel.DataAnnotations.RequiredAttribute(); - - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A2 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( - (int)1, - (int)3); - } -} -namespace __OptionValidationStaticInstances -{ - [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] - file static class __Validators - { - } -} - -"""; - var (diagnostics, generatedSources) = await RunGeneratorOnOptionsSource(source); Assert.Empty(diagnostics); _ = Assert.Single(generatedSources); +#if NETCOREAPP + string generatedSource = File.ReadAllText(@"Baselines/EmitterWithCustomValidator.netcore.g.cs"); +#else + string generatedSource = File.ReadAllText(@"Baselines/EmitterWithCustomValidator.netfx.g.cs"); +#endif // NETCOREAPP Assert.Equal(generatedSource.Replace("\r\n", "\n"), generatedSources[0].SourceText.ToString().Replace("\r\n", "\n")); } @@ -1443,7 +1381,7 @@ internal sealed partial class ExtOptionsValidator : IValidateOptions Assert.Single(diagnostics); Assert.Equal(DiagDescriptors.InaccessibleValidationAttribute.Id, diagnostics[0].Id); string generatedSource = generatedSources[0].SourceText.ToString(); - Assert.Contains("global::System.ComponentModel.DataAnnotations.RangeAttribute", generatedSource); + Assert.Contains("__OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute", generatedSource); Assert.Contains("global::System.ComponentModel.DataAnnotations.RequiredAttribute", generatedSource); Assert.DoesNotContain("Timeout", generatedSource); @@ -1666,113 +1604,22 @@ public partial class MyOptionsValidator : IValidateOptions Assert.Empty(diagnostics); Assert.Single(generatedSources); - var generatedSource = """ - - // - #nullable enable - #pragma warning disable CS1591 // Compensate for https://github.com/dotnet/roslyn/issues/54103 - namespace Test -{ - partial class MyOptionsValidator - { - /// - /// Validates a specific named options instance (or all when is ). - /// - /// The name of the options instance being validated. - /// The options instance. - /// Validation result. - [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] - public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Test.MyOptions options) - { - global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; - var context = new global::System.ComponentModel.DataAnnotations.ValidationContext(options); - var validationResults = new global::System.Collections.Generic.List(); - var validationAttributes = new global::System.Collections.Generic.List(1); - - context.MemberName = "P1"; - context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P1" : $"{name}.P1"; - validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A1); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P1, context, validationResults, validationAttributes)) - { - (builder ??= new()).AddResults(validationResults); - } - - context.MemberName = "P2"; - context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P2" : $"{name}.P2"; - validationResults.Clear(); - validationAttributes.Clear(); - validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A2); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P2, context, validationResults, validationAttributes)) - { - (builder ??= new()).AddResults(validationResults); - } - - context.MemberName = "P3"; - context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P3" : $"{name}.P3"; - validationResults.Clear(); - validationAttributes.Clear(); - validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A3); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P3, context, validationResults, validationAttributes)) - { - (builder ??= new()).AddResults(validationResults); - } - - context.MemberName = "P4"; - context.DisplayName = string.IsNullOrEmpty(name) ? "MyOptions.P4" : $"{name}.P4"; - validationResults.Clear(); - validationAttributes.Clear(); - validationAttributes.Add(global::__OptionValidationStaticInstances.__Attributes.A4); - if (!global::System.ComponentModel.DataAnnotations.Validator.TryValidateValue(options.P4, context, validationResults, validationAttributes)) - { - (builder ??= new()).AddResults(validationResults); - } - - return builder is null ? global::Microsoft.Extensions.Options.ValidateOptionsResult.Success : builder.Build(); - } - } -} -namespace __OptionValidationStaticInstances -{ - [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] - file static class __Attributes - { - internal static readonly global::System.ComponentModel.DataAnnotations.RequiredAttribute A1 = new global::System.ComponentModel.DataAnnotations.RequiredAttribute(); - - internal static readonly global::System.ComponentModel.DataAnnotations.LengthAttribute A2 = new global::System.ComponentModel.DataAnnotations.LengthAttribute( - (int)10, - (int)20); - - internal static readonly global::System.ComponentModel.DataAnnotations.AllowedValuesAttribute A3 = new global::System.ComponentModel.DataAnnotations.AllowedValuesAttribute( - (int)10, (int)20, (int)30); - - internal static readonly global::System.ComponentModel.DataAnnotations.DeniedValuesAttribute A4 = new global::System.ComponentModel.DataAnnotations.DeniedValuesAttribute( - "One", "Ten", "Hundred"); - } -} -namespace __OptionValidationStaticInstances -{ - [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] - file static class __Validators - { - } -} - -"""; + string generatedSource = File.ReadAllText(@"Baselines/DataAnnotationAttributesWithParams.g.cs"); Assert.Equal(generatedSource.Replace("\r\n", "\n"), generatedSources[0].SourceText.ToString().Replace("\r\n", "\n")); } - private static CSharpCompilation CreateCompilationForOptionsSource(string assemblyName, string source, string? refAssemblyPath = null) + private static CSharpCompilation CreateCompilationForOptionsSource(string assemblyName, string source, string? refAssemblyPath = null, LanguageVersion languageVersion = LanguageVersion.Default) { // Ensure the generated source compiles var compilation = CSharpCompilation - .Create(Path.GetRandomFileName()+".dll", options: new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)) + .Create($"{assemblyName}.dll", options: new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)) .AddReferences(MetadataReference.CreateFromFile(AppDomain.CurrentDomain.GetAssemblies().FirstOrDefault(a => a.GetName().Name == "System.Runtime").Location)) .AddReferences(MetadataReference.CreateFromFile(typeof(string).Assembly.Location)) .AddReferences(MetadataReference.CreateFromFile(typeof(RequiredAttribute).Assembly.Location)) .AddReferences(MetadataReference.CreateFromFile(typeof(OptionsValidatorAttribute).Assembly.Location)) .AddReferences(MetadataReference.CreateFromFile(typeof(IValidateOptions).Assembly.Location)) .AddReferences(MetadataReference.CreateFromFile(typeof(System.CodeDom.Compiler.GeneratedCodeAttribute).Assembly.Location)) - .AddSyntaxTrees(CSharpSyntaxTree.ParseText(source)); + .AddSyntaxTrees(CSharpSyntaxTree.ParseText(source, new CSharpParseOptions(languageVersion))); if (refAssemblyPath is not null) { @@ -1861,4 +1708,153 @@ private static CSharpCompilation CreateCompilationForOptionsSource(string assemb return result; } + + [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsNotBrowser))] + [InlineData(LanguageVersion.CSharp10)] + [InlineData(LanguageVersion.CSharp11)] + public async Task GeneratedAttributesTest(LanguageVersion languageVersion) + { + +#if NETCOREAPP + string lengthAttribute = $$""" + [LengthAttribute(1, 3)] + public string? P0 { get; set; } + + [LengthAttribute(1, 3)] + public FakeCount? P1 { get; set; } + + [LengthAttribute(1, 3)] + public FakeCountChild? P2 { get; set; } + """; +#else +string lengthAttribute = ""; +#endif //NETCOREAPP + + string source = $$""" + using System.Collections.Generic; + using Microsoft.Extensions.Options; + using System.ComponentModel.DataAnnotations; + + #nullable enable + + namespace ValidationTest + { + public class FakeCount + { + public FakeCount(int count) { Count = count; } + public int Count { get; } + } + public class FakeCountChild : FakeCount + { + public FakeCountChild(int count) : base(count) { } + } + + public class OptionsUsingGeneratedAttributes + { + {{lengthAttribute}} + + [RangeAttribute(1, 3)] + public int P3 { get; set; } + + [MinLengthAttribute(5)] + public string? P4 { get; set; } + + [MaxLengthAttribute(5)] + public string? P5 { get; set; } + + [CompareAttribute("P5")] + public string? P6 { get; set; } + + [MinLengthAttribute(5)] + public FakeCount? P7 { get; set; } + + [MinLengthAttribute(5)] + public FakeCountChild? P8 { get; set; } + + [MaxLengthAttribute(5)] + public FakeCount? P9 { get; set; } + + [MaxLengthAttribute(5)] + public FakeCountChild? P10 { get; set; } + + [MinLengthAttribute(5)] + public List? P11 { get; set; } + + [MaxLengthAttribute(5)] + public List? P12 { get; set; } + } + + [OptionsValidator] + public sealed partial class OptionsUsingGeneratedAttributesValidator : IValidateOptions + { + } + } + """; + + var (diagnostics, generatedSources) = await RunGeneratorOnOptionsSource(source, null, languageVersion); + Assert.Empty(diagnostics); + Assert.Single(generatedSources); + + string emittedSource = generatedSources[0].SourceText.ToString(); + SyntaxTree syntaxTree = SyntaxFactory.ParseSyntaxTree(emittedSource, new CSharpParseOptions(languageVersion)); + var diags = syntaxTree.GetDiagnostics().ToArray(); + Assert.Empty(diags); + +#if NETCOREAPP + string generatedSource = File.ReadAllText(languageVersion == LanguageVersion.CSharp10 ? @"Baselines/GeneratedAttributesTest.netcore.lang10.g.cs" : @"Baselines/GeneratedAttributesTest.netcore.lang11.g.cs"); +#else + string generatedSource = File.ReadAllText(languageVersion == LanguageVersion.CSharp10 ? @"Baselines/GeneratedAttributesTest.netfx.lang10.g.cs" : @"Baselines/GeneratedAttributesTest.netfx.lang11.g.cs"); +#endif // NET8_0_OR_GREATER + Assert.Equal(generatedSource.Replace("\r\n", "\n"), emittedSource.Replace("\r\n", "\n")); + + CSharpCompilation compilation = CreateCompilationForOptionsSource(Path.GetRandomFileName(), source + emittedSource, refAssemblyPath: null, languageVersion); + var emitResult = compilation.Emit(new MemoryStream()); + + Assert.True(emitResult.Success); + // Console.WriteLine(emittedSource); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsNotBrowser))] + public async Task UsingInterfaceAsPropertyTypeForLengthAttributesTests() + { + var (diagnostics, generatedSources) = await RunGenerator(@""" + using System.Collections.Generic; + + public class MyOptions + { + [Length(10, 20)] + public IList P1 { get; set; } + + [MinLength(4)] + public IList P2 { get; set; } + + [MaxLength(5)] + public IList P3 { get; set; } + + [Length(10, 20)] + public ICollection P4 { get; set; } + + [MinLength(4)] + public ICollection P5 { get; set; } + + [MaxLength(5)] + public ICollection P6 { get; set; } + } + + [OptionsValidator] + public partial class MyOptionsValidator : IValidateOptions + { + } + """); + + Assert.Empty(diagnostics); + Assert.Single(generatedSources); + +#if NETCOREAPP + string generatedSource = File.ReadAllText(@"Baselines/UsingInterfaceAsPropertyTypeForLengthAttributesTests.netcore.g.cs"); +#else + string generatedSource = File.ReadAllText(@"Baselines/UsingInterfaceAsPropertyTypeForLengthAttributesTests.netfx.g.cs"); +#endif // NETCOREAPP + Assert.Equal(generatedSource.Replace("\r\n", "\n"), generatedSources[0].SourceText.ToString().Replace("\r\n", "\n")); + } } diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Microsoft.Extensions.Options.SourceGeneration.Unit.Tests.csproj b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Microsoft.Extensions.Options.SourceGeneration.Unit.Tests.csproj index d76cbd45302f91..f3a5f33b4a2b45 100644 --- a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Microsoft.Extensions.Options.SourceGeneration.Unit.Tests.csproj +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/Microsoft.Extensions.Options.SourceGeneration.Unit.Tests.csproj @@ -14,6 +14,7 @@ + @@ -35,6 +36,12 @@ OutputItemType="Analyzer" ReferenceOutputAssembly="true" SetTargetFramework="TargetFramework=netstandard2.0"/> + + + PreserveNewest + + + diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/OptionsRuntimeTests.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/OptionsRuntimeTests.cs index b644eea74120f7..4c701e4b9f498f 100644 --- a/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/OptionsRuntimeTests.cs +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGeneration.Unit.Tests/OptionsRuntimeTests.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.ComponentModel.DataAnnotations; +using System.Globalization; using System.Linq; using System.Threading.Tasks; using Xunit; @@ -177,6 +178,30 @@ public void TestValidationWithEnumeration() result2.Failures); } + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsNotBrowser))] + public void TestObjectsWithIndexerProperties() + { + DataAnnotationValidateOptions dataAnnotationValidateOptions1 = new("MyDictionaryOptions"); + MyDictionaryOptionsOptionsValidator sourceGenOptionsValidator1 = new(); + + var options1 = new MyDictionaryOptions(); + ValidateOptionsResult result1 = sourceGenOptionsValidator1.Validate("MyDictionaryOptions", options1); + ValidateOptionsResult result2 = dataAnnotationValidateOptions1.Validate("MyDictionaryOptions", options1); + + Assert.True(result1.Succeeded); + Assert.True(result2.Succeeded); + + DataAnnotationValidateOptions> dataAnnotationValidateOptions2 = new("MyListOptions"); + MyListOptionsOptionsValidator sourceGenOptionsValidator2 = new(); + + var options2 = new MyListOptions() { Prop = "test" }; + result1 = sourceGenOptionsValidator2.Validate("MyListOptions", options2); + result2 = dataAnnotationValidateOptions2.Validate("MyListOptions", options2); + + Assert.True(result1.Succeeded); + Assert.True(result2.Succeeded); + } + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsNotBrowser))] public void TestValidationWithCyclicReferences() { @@ -242,6 +267,253 @@ public void TestNewDataAnnotationFailures() }, result.Failures); } #endif // NET8_0_OR_GREATER + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsNotBrowser))] + public void TestCustomGeneratedAttributes() + { + OptionsUsingGeneratedAttributes noFailures = new OptionsUsingGeneratedAttributes() + { +#if NET8_0_OR_GREATER + P0 = "123", + P11 = new DateTime(2023, 2, 1), + P12 = 6, + P13 = 9, + P14 = new List() { "1", "2" }, + P15 = new FakeCount(5), + P16 = new FakeCountChild(5), + P17 = new int[] { 1, 2 }, + P18 = new List() { "1", "2", "3" }, + P19 = new FakeCount(3), + P20 = new FakeCountChild(3), + P23 = new List() { "1", "2", "3", "4" }, + P24 = new FakeCount(4), + P25 = new FakeCountChild(4), + P27 = new List { "1", "2" }, + P28 = new HashSet { "1", "2" }, + P29 = new List { "1", "2", "3" }, + P30 = new HashSet { "1", "2", "3" }, + P31 = new List { 1, 2, 3, 4 }, + P32 = new HashSet { 1, 2, 3, 4 }, +#endif // NET8_0_OR_GREATER + P1 = 2, + P2 = "12345", + P3 = "12345", + P4 = "12345", + P5 = 4, + P6 = 4, + P7 = 15, + P8 = 15, + P9 = 2.5m, + P10 = 14.0, + P21 = new int[] { 1, 2, 3 }, + P22 = new int[] { 1, 2, 3, 4 }, + P26 = 14.0, + }; + List results = new(); + Assert.True(Validator.TryValidateObject(noFailures, new ValidationContext(noFailures), results, true)); + + OptionsUsingGeneratedAttributesValidator validator = new(); + Assert.True(validator.Validate("OptionsUsingGeneratedAttributes", noFailures).Succeeded); + + OptionsUsingGeneratedAttributes failing = new OptionsUsingGeneratedAttributes() + { +#if NET8_0_OR_GREATER + P0 = "", + P11 = new DateTime(2023, 1, 1), + P12 = 5, + P13 = 10, + P14 = new List() { "1" }, + P15 = new FakeCount(1), + P16 = new FakeCountChild(11), + P17 = new int[] { 1 }, + P18 = new List() { "1", "2" }, + P19 = new FakeCount(2), + P20 = new FakeCountChild(1), + P23 = new List() { "1", "2", "3", "4", "5" }, + P24 = new FakeCount(5), + P25 = new FakeCountChild(5), + P27 = new List { "1" }, + P28 = new HashSet { "1" }, + P29 = new List { "1", "2" }, + P30 = new HashSet { "1", "2" }, + P31 = new List { 1, 2, 3, 4, 5 }, + P32 = new HashSet { 1, 2, 3, 4, 5 }, +#endif // NET8_0_OR_GREATER + P1 = 4, + P2 = "1234", + P3 = "123456", + P4 = "12345", + P5 = 10, + P6 = 10, + P7 = 5, + P8 = 5, + P9 = 4.0m, + P10 = 20.0, + P21 = new int[] { 1, 2 }, + P22 = new int[] { 1, 2, 3, 4, 5 }, + P26 = 20.0, + }; + + Assert.False(Validator.TryValidateObject(failing, new ValidationContext(failing), results, true)); + + ValidateOptionsResult generatorResult = validator.Validate("OptionsUsingGeneratedAttributes", failing); + Assert.True(generatorResult.Failed); + + Assert.Equal(new [] { +#if NET8_0_OR_GREATER + "P0: The field OptionsUsingGeneratedAttributes.P0 must be a string or collection type with a minimum length of '1' and maximum length of '3'.", + string.Format(CultureInfo.CurrentCulture, "P11: The field OptionsUsingGeneratedAttributes.P11 must be between {0} and {1}.", new DateTime(2023, 1, 30), new DateTime(2023, 12, 30)), + "P12: The field OptionsUsingGeneratedAttributes.P12 must be between 5 exclusive and 10.", + "P13: The field OptionsUsingGeneratedAttributes.P13 must be between 5 and 10 exclusive.", + "P14: The field OptionsUsingGeneratedAttributes.P14 must be a string or collection type with a minimum length of '2' and maximum length of '10'.", + "P15: The field OptionsUsingGeneratedAttributes.P15 must be a string or collection type with a minimum length of '2' and maximum length of '10'.", + "P16: The field OptionsUsingGeneratedAttributes.P16 must be a string or collection type with a minimum length of '2' and maximum length of '10'.", + "P17: The field OptionsUsingGeneratedAttributes.P17 must be a string or collection type with a minimum length of '2' and maximum length of '10'.", + "P18: The field OptionsUsingGeneratedAttributes.P18 must be a string or array type with a minimum length of '3'.", + "P19: The field OptionsUsingGeneratedAttributes.P19 must be a string or array type with a minimum length of '3'.", + "P20: The field OptionsUsingGeneratedAttributes.P20 must be a string or array type with a minimum length of '3'.", + "P23: The field OptionsUsingGeneratedAttributes.P23 must be a string or array type with a maximum length of '4'.", + "P24: The field OptionsUsingGeneratedAttributes.P24 must be a string or array type with a maximum length of '4'.", + "P25: The field OptionsUsingGeneratedAttributes.P25 must be a string or array type with a maximum length of '4'.", + "P27: The field OptionsUsingGeneratedAttributes.P27 must be a string or collection type with a minimum length of '2' and maximum length of '10'.", + "P28: The field OptionsUsingGeneratedAttributes.P28 must be a string or collection type with a minimum length of '2' and maximum length of '10'.", + "P29: The field OptionsUsingGeneratedAttributes.P29 must be a string or array type with a minimum length of '3'.", + "P30: The field OptionsUsingGeneratedAttributes.P30 must be a string or array type with a minimum length of '3'.", + "P31: The field OptionsUsingGeneratedAttributes.P31 must be a string or array type with a maximum length of '4'.", + "P32: The field OptionsUsingGeneratedAttributes.P32 must be a string or array type with a maximum length of '4'.", +#endif // NET8_0_OR_GREATER + "P1: The field OptionsUsingGeneratedAttributes.P1 must be between 1 and 3.", + "P2: The field OptionsUsingGeneratedAttributes.P2 must be a string or array type with a minimum length of '5'.", + "P3: The field OptionsUsingGeneratedAttributes.P3 must be a string or array type with a maximum length of '5'.", + "P4: 'OptionsUsingGeneratedAttributes.P4' and 'P2' do not match.", + "P5: The field OptionsUsingGeneratedAttributes.P5 must be between 2 and 8.", + "P6: The field OptionsUsingGeneratedAttributes.P6 must be between 2 and 8.", + "P7: The field OptionsUsingGeneratedAttributes.P7 must be between 10 and 20.", + "P8: The field OptionsUsingGeneratedAttributes.P8 must be between 10 and 20.", + "P9: The field OptionsUsingGeneratedAttributes.P9 must be between 1.5 and 3.14.", + "P10: The field OptionsUsingGeneratedAttributes.P10 must be between 12.4 and 16.5.", + "P21: The field OptionsUsingGeneratedAttributes.P21 must be a string or array type with a minimum length of '3'.", + "P22: The field OptionsUsingGeneratedAttributes.P22 must be a string or array type with a maximum length of '4'.", + "P26: The field OptionsUsingGeneratedAttributes.P26 must be between 12.4 and 16.5.", + }, generatorResult.Failures); + + Assert.Equal(results.Count(), generatorResult.Failures.Count()); + } + } + + public class FakeCount(int count) { public int Count { get { return count; } } } + public class FakeCountChild(int count) : FakeCount(count) { } + + public class OptionsUsingGeneratedAttributes + { +#if NET8_0_OR_GREATER + [LengthAttribute(1, 3)] + public string? P0 { get; set; } + + [RangeAttribute(typeof(DateTime), "01/30/2023", "12/30/2023", ParseLimitsInInvariantCulture = true, ConvertValueInInvariantCulture = true)] + public DateTime P11 { get; set; } + + [RangeAttribute(5, 10, MinimumIsExclusive = true)] + public int P12 { get; set; } + + [RangeAttribute(5, 10, MaximumIsExclusive = true)] + public int P13 { get; set; } + + [LengthAttribute(2, 10)] + public List P14 { get; set; } + + [LengthAttribute(2, 10)] + public FakeCount P15 { get; set; } + + [LengthAttribute(2, 10)] + public FakeCountChild P16 { get; set; } + + [LengthAttribute(2, 10)] + public int[] P17 { get; set; } + + // Although MinLength and MaxLength attributes defined in NETFX but the implementation there has a bug which can produce exception like the following when using types like List: + // System.InvalidCastException : Unable to cast object of type 'System.Collections.Generic.List`1[System.String]' to type 'System.Array'. + + [MinLengthAttribute(3)] + public List P18 { get; set; } + + [MinLengthAttribute(3)] + public FakeCount P19 { get; set; } + + [MinLengthAttribute(3)] + public FakeCountChild P20 { get; set; } + + [MaxLengthAttribute(4)] + public List P23 { get; set; } + + [MaxLengthAttribute(4)] + public FakeCount P24 { get; set; } + + [MaxLengthAttribute(4)] + public FakeCountChild P25 { get; set; } + + [LengthAttribute(2, 10)] + public IList P27 { get; set; } + + [LengthAttribute(2, 10)] + public ICollection P28 { get; set; } + + [MinLengthAttribute(3)] + public IList P29 { get; set; } + + [MinLengthAttribute(3)] + public ICollection P30 { get; set; } + + [MaxLengthAttribute(4)] + public IList P31 { get; set; } + + [MaxLengthAttribute(4)] + public ICollection P32 { get; set; } +#endif // NET8_0_OR_GREATER + + [RangeAttribute(1, 3)] + public int P1 { get; set; } + + [MinLengthAttribute(5)] + public string? P2 { get; set; } + + [MaxLengthAttribute(5)] + public string? P3 { get; set; } + + [CompareAttribute("P2")] + public string? P4 { get; set; } + + [RangeAttribute(typeof(byte), "2", "8")] + public byte P5 { get; set; } + + [RangeAttribute(typeof(sbyte), "2", "8")] + public sbyte P6 { get; set; } + + [RangeAttribute(typeof(short), "10", "20")] + public short P7 { get; set; } + + [RangeAttribute(typeof(ulong), "10", "20")] + public ulong P8 { get; set; } + + [RangeAttribute(typeof(decimal), "1.5", "3.14")] + public decimal P9 { get; set; } + + [RangeAttribute(typeof(double), "12.40", "16.50")] + public double P10 { get; set; } + + [MinLengthAttribute(3)] + public int[] P21 { get; set; } + + [MaxLengthAttribute(4)] + public int[] P22 { get; set; } + + [RangeAttribute(typeof(double), "12.40", "16.50")] + public double? P26 { get; set; } + } + + [OptionsValidator] + public partial class OptionsUsingGeneratedAttributesValidator : IValidateOptions + { } public class MyOptions @@ -302,6 +574,12 @@ public partial class MySourceGenOptionsValidator : IValidateOptions { } + public class MyDictionaryOptions : Dictionary { [Required] public string Prop { get; set; } = "test"; } + [OptionsValidator] public partial class MyDictionaryOptionsOptionsValidator : IValidateOptions { } + + public class MyListOptions : List { [Required] public T Prop { get; set; } = default; } + [OptionsValidator] public partial class MyListOptionsOptionsValidator : IValidateOptions> { } + #if NET8_0_OR_GREATER public class OptionsUsingNewAttributes { @@ -326,4 +604,5 @@ public partial class NewAttributesValidator : IValidateOptions The options validation source generator is not available in C# {0}. Please use language version {1} or greater. - \ No newline at end of file + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + + diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/Baselines/NetCoreApp/Validators.g.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/Baselines/NetCoreApp/Validators.g.cs index cc1ebda6414e6b..c487888c9f16bb 100644 --- a/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/Baselines/NetCoreApp/Validators.g.cs +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/Baselines/NetCoreApp/Validators.g.cs @@ -12,6 +12,8 @@ internal sealed partial class __ThirdModelNoNamespaceValidator__ /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::ThirdModelNoNamespace options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -40,6 +42,8 @@ partial class FirstValidatorNoNamespace /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::FirstModelNoNamespace options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -78,6 +82,8 @@ partial class SecondValidatorNoNamespace /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::SecondModelNoNamespace options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -108,6 +114,8 @@ partial class FirstValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::CustomAttr.FirstModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -149,6 +157,8 @@ internal sealed partial class __SecondModelValidator__ /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Enumeration.SecondModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -181,6 +191,8 @@ internal sealed partial class __ThirdModelValidator__ /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Enumeration.ThirdModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -211,6 +223,8 @@ partial struct FirstValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Enumeration.FirstModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -363,6 +377,8 @@ partial struct SecondValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Enumeration.SecondModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -394,6 +410,8 @@ partial struct FirstValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::FileScopedNamespace.FirstModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -425,6 +443,8 @@ partial struct FirstValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::FunnyStrings.FirstModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -456,6 +476,8 @@ internal sealed partial class __SecondModelValidator__ /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Generics.SecondModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -487,6 +509,8 @@ partial class FirstValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Generics.FirstModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -523,6 +547,8 @@ partial struct MultiValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::MultiModelValidator.FirstModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -553,6 +579,8 @@ partial struct MultiValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::MultiModelValidator.SecondModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -585,6 +613,8 @@ internal sealed partial class __ThirdModelValidator__ /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Nested.Container1.ThirdModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -618,6 +648,8 @@ partial record struct FifthValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Nested.Container1.SecondModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -654,6 +686,8 @@ partial struct FirstValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Nested.Container1.FirstModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -701,6 +735,8 @@ partial struct FourthValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Nested.Container1.SecondModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -737,6 +773,8 @@ partial struct SecondValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Nested.Container1.SecondModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -774,6 +812,8 @@ partial struct ThirdValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::Nested.Container1.SecondModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -807,6 +847,8 @@ partial class FirstValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::RandomMembers.FirstModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -839,6 +881,8 @@ internal sealed partial class __ThirdModelValidator__ /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::RecordTypes.ThirdModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -870,6 +914,8 @@ partial record struct FirstValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::RecordTypes.FirstModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -913,6 +959,8 @@ partial record struct SecondValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::RecordTypes.SecondModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -944,6 +992,8 @@ partial record class ThirdValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::RecordTypes.SecondModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -976,6 +1026,8 @@ internal sealed partial class __SecondModelValidator__ /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::RepeatedTypes.SecondModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -1012,6 +1064,8 @@ internal sealed partial class __ThirdModelValidator__ /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::RepeatedTypes.ThirdModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -1043,6 +1097,8 @@ partial class FirstValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::RepeatedTypes.FirstModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -1108,6 +1164,8 @@ partial struct FirstValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::SelfValidation.FirstModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -1141,6 +1199,8 @@ internal sealed partial class __RangeAttributeModelDoubleValidator__ /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.RangeAttributeModelDouble options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -1172,6 +1232,8 @@ internal sealed partial class __RequiredAttributeModelValidator__ /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.RequiredAttributeModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -1203,6 +1265,8 @@ internal sealed partial class __TypeWithoutOptionsValidatorValidator__ /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.TypeWithoutOptionsValidator options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -1248,6 +1312,8 @@ partial class AttributePropertyModelValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.AttributePropertyModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -1288,6 +1354,8 @@ partial class ComplexModelValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.ComplexModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -1318,6 +1386,8 @@ partial class CustomTypeCustomValidationAttributeModelValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.CustomTypeCustomValidationAttributeModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -1348,6 +1418,8 @@ partial class CustomValidationAttributeModelValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.CustomValidationAttributeModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -1378,6 +1450,8 @@ partial class DataTypeAttributeModelValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.DataTypeAttributeModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -1408,6 +1482,8 @@ partial class DerivedModelValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.DerivedModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -1458,6 +1534,8 @@ partial class EmailAttributeModelValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.EmailAttributeModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -1488,6 +1566,8 @@ partial class LeafModelValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.LeafModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -1538,6 +1618,8 @@ partial class MultipleAttributeModelValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.MultipleAttributeModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -1599,6 +1681,8 @@ partial class RangeAttributeModelDateValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.RangeAttributeModelDate options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -1629,6 +1713,8 @@ partial class RangeAttributeModelDoubleValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.RangeAttributeModelDouble options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -1659,6 +1745,8 @@ partial class RangeAttributeModelIntValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.RangeAttributeModelInt options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -1689,6 +1777,8 @@ partial class RegularExpressionAttributeModelValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.RegularExpressionAttributeModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -1719,6 +1809,8 @@ partial class RequiredAttributeModelValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::TestClasses.OptionsValidation.RequiredAttributeModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -1750,6 +1842,8 @@ internal sealed partial class __SecondModelValidator__ /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public static global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::ValueTypes.SecondModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -1781,6 +1875,8 @@ partial struct FirstValidator /// The options instance. /// Validation result. [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("Trimming", "IL2026:RequiresUnreferencedCode", + Justification = "The created ValidationContext object is used in a way that never call reflection")] public global::Microsoft.Extensions.Options.ValidateOptionsResult Validate(string? name, global::ValueTypes.FirstModel options) { global::Microsoft.Extensions.Options.ValidateOptionsResultBuilder? builder = null; @@ -1820,7 +1916,7 @@ file static class __Attributes { internal static readonly global::System.ComponentModel.DataAnnotations.RequiredAttribute A1 = new global::System.ComponentModel.DataAnnotations.RequiredAttribute(); - internal static readonly global::System.ComponentModel.DataAnnotations.MinLengthAttribute A2 = new global::System.ComponentModel.DataAnnotations.MinLengthAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__MinLengthAttribute A2 = new __OptionValidationGeneratedAttributes.__SourceGen__MinLengthAttribute( (int)5); internal static readonly global::CustomAttr.CustomAttribute A3 = new global::CustomAttr.CustomAttribute( @@ -1833,30 +1929,30 @@ file static class __Attributes false, "X"); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A5 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A5 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (int)0, (int)10); internal static readonly global::System.ComponentModel.DataAnnotations.RegularExpressionAttribute A6 = new global::System.ComponentModel.DataAnnotations.RegularExpressionAttribute( "\"\r\n\\\\"); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A7 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A7 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (double)0.5, (double)0.9); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A8 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A8 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( typeof(global::System.DateTime), "1/2/2004", "3/4/2004"); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A9 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A9 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (int)1, (int)3) { ErrorMessage = "ErrorMessage" }; - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A10 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A10 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (int)1, (int)3) { @@ -1880,19 +1976,19 @@ file static class __Attributes internal static readonly global::System.ComponentModel.DataAnnotations.DataTypeAttribute A15 = new global::System.ComponentModel.DataAnnotations.DataTypeAttribute( (global::System.ComponentModel.DataAnnotations.DataType)11); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A16 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A16 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (int)1, (int)3); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A17 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A17 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (int)3, (int)5); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A18 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A18 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (int)5, (int)9); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A19 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A19 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( typeof(global::System.DateTime), "1/2/2004", "3/4/2004") @@ -1924,3 +2020,150 @@ file static class __Validators internal static readonly global::RecordTypes.ThirdValidator V7 = new global::RecordTypes.ThirdValidator(); } } +namespace __OptionValidationGeneratedAttributes +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__MinLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a minimum length of '{1}'."; + + public __SourceGen__MinLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public int Length { get; } + public override bool IsValid(object? value) + { + if (Length < -1) + { + throw new global::System.InvalidOperationException("MinLengthAttribute must have a Length value that is zero or greater."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length >= Length; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__RangeAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + public __SourceGen__RangeAttribute(int minimum, int maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(int); + } + public __SourceGen__RangeAttribute(double minimum, double maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(double); + } + public __SourceGen__RangeAttribute(global::System.Type type, string minimum, string maximum) : base() + { + OperandType = type; + NeedToConvertMinMax = true; + Minimum = minimum; + Maximum = maximum; + } + public object Minimum { get; private set; } + public object Maximum { get; private set; } + public bool MinimumIsExclusive { get; set; } + public bool MaximumIsExclusive { get; set; } + public global::System.Type OperandType { get; } + public bool ParseLimitsInInvariantCulture { get; set; } + public bool ConvertValueInInvariantCulture { get; set; } + public override string FormatErrorMessage(string name) => + string.Format(global::System.Globalization.CultureInfo.CurrentCulture, GetValidationErrorMessage(), name, Minimum, Maximum); + private bool NeedToConvertMinMax { get; } + private bool Initialized { get; set; } + public override bool IsValid(object? value) + { + if (!Initialized) + { + if (Minimum is null || Maximum is null) + { + throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + if (NeedToConvertMinMax) + { + System.Globalization.CultureInfo culture = ParseLimitsInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + Minimum = ConvertValue(Minimum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + Maximum = ConvertValue(Maximum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + int cmp = ((global::System.IComparable)Minimum).CompareTo((global::System.IComparable)Maximum); + if (cmp > 0) + { + throw new global::System.InvalidOperationException("The maximum value '{Maximum}' must be greater than or equal to the minimum value '{Minimum}'."); + } + else if (cmp == 0 && (MinimumIsExclusive || MaximumIsExclusive)) + { + throw new global::System.InvalidOperationException("Cannot use exclusive bounds when the maximum value is equal to the minimum value."); + } + Initialized = true; + } + + if (value is null or string { Length: 0 }) + { + return true; + } + + System.Globalization.CultureInfo formatProvider = ConvertValueInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + object? convertedValue; + + try + { + convertedValue = ConvertValue(value, formatProvider); + } + catch (global::System.Exception e) when (e is global::System.FormatException or global::System.InvalidCastException or global::System.NotSupportedException) + { + return false; + } + + var min = (global::System.IComparable)Minimum; + var max = (global::System.IComparable)Maximum; + + return + (MinimumIsExclusive ? min.CompareTo(convertedValue) < 0 : min.CompareTo(convertedValue) <= 0) && + (MaximumIsExclusive ? max.CompareTo(convertedValue) > 0 : max.CompareTo(convertedValue) >= 0); + } + private string GetValidationErrorMessage() + { + return (MinimumIsExclusive, MaximumIsExclusive) switch + { + (false, false) => "The field {0} must be between {1} and {2}.", + (true, false) => "The field {0} must be between {1} exclusive and {2}.", + (false, true) => "The field {0} must be between {1} and {2} exclusive.", + (true, true) => "The field {0} must be between {1} exclusive and {2} exclusive.", + }; + } + private object? ConvertValue(object? value, System.Globalization.CultureInfo formatProvider) + { + if (value is string stringValue) + { + value = global::System.Convert.ChangeType(stringValue, OperandType, formatProvider); + } + else + { + value = global::System.Convert.ChangeType(value, OperandType, formatProvider); + } + return value; + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/Baselines/NetFX/Validators.g.cs b/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/Baselines/NetFX/Validators.g.cs index ebdcb1ad6d6ba7..7e998cea22cddf 100644 --- a/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/Baselines/NetFX/Validators.g.cs +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/Baselines/NetFX/Validators.g.cs @@ -1820,7 +1820,7 @@ file static class __Attributes { internal static readonly global::System.ComponentModel.DataAnnotations.RequiredAttribute A1 = new global::System.ComponentModel.DataAnnotations.RequiredAttribute(); - internal static readonly global::System.ComponentModel.DataAnnotations.MinLengthAttribute A2 = new global::System.ComponentModel.DataAnnotations.MinLengthAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__MinLengthAttribute A2 = new __OptionValidationGeneratedAttributes.__SourceGen__MinLengthAttribute( (int)5); internal static readonly global::CustomAttr.CustomAttribute A3 = new global::CustomAttr.CustomAttribute( @@ -1833,30 +1833,30 @@ file static class __Attributes false, "X"); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A5 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A5 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (int)0, (int)10); internal static readonly global::System.ComponentModel.DataAnnotations.RegularExpressionAttribute A6 = new global::System.ComponentModel.DataAnnotations.RegularExpressionAttribute( "\"\r\n\\\\"); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A7 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A7 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (double)0.5, (double)0.9); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A8 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A8 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( typeof(global::System.DateTime), "1/2/2004", "3/4/2004"); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A9 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A9 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (int)1, (int)3) { ErrorMessage = "ErrorMessage" }; - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A10 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A10 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (int)1, (int)3) { @@ -1880,15 +1880,15 @@ file static class __Attributes internal static readonly global::System.ComponentModel.DataAnnotations.DataTypeAttribute A15 = new global::System.ComponentModel.DataAnnotations.DataTypeAttribute( (global::System.ComponentModel.DataAnnotations.DataType)11); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A16 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A16 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (int)1, (int)3); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A17 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A17 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (int)3, (int)5); - internal static readonly global::System.ComponentModel.DataAnnotations.RangeAttribute A18 = new global::System.ComponentModel.DataAnnotations.RangeAttribute( + internal static readonly __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute A18 = new __OptionValidationGeneratedAttributes.__SourceGen__RangeAttribute( (int)5, (int)9); @@ -1916,3 +1916,150 @@ file static class __Validators internal static readonly global::RecordTypes.ThirdValidator V7 = new global::RecordTypes.ThirdValidator(); } } +namespace __OptionValidationGeneratedAttributes +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__MinLengthAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + private static string DefaultErrorMessageString => "The field {0} must be a string or array type with a minimum length of '{1}'."; + + public __SourceGen__MinLengthAttribute(int length) : base(() => DefaultErrorMessageString) { Length = length; } + public int Length { get; } + public override bool IsValid(object? value) + { + if (Length < -1) + { + throw new global::System.InvalidOperationException("MinLengthAttribute must have a Length value that is zero or greater."); + } + if (value == null) + { + return true; + } + + int length; + if (value is string stringValue) + { + length = stringValue.Length; + } + else if (value is System.Collections.ICollection collectionValue) + { + length = collectionValue.Count; + } + else + { + throw new global::System.InvalidCastException($"The field of type {value.GetType()} must be a string, array, or ICollection type."); + } + + return length >= Length; + } + public override string FormatErrorMessage(string name) => string.Format(global::System.Globalization.CultureInfo.CurrentCulture, ErrorMessageString, name, Length); + } + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Extensions.Options.SourceGeneration", "42.42.42.42")] + [global::System.AttributeUsage(global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Parameter, AllowMultiple = false)] + file class __SourceGen__RangeAttribute : global::System.ComponentModel.DataAnnotations.ValidationAttribute + { + public __SourceGen__RangeAttribute(int minimum, int maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(int); + } + public __SourceGen__RangeAttribute(double minimum, double maximum) : base() + { + Minimum = minimum; + Maximum = maximum; + OperandType = typeof(double); + } + public __SourceGen__RangeAttribute(global::System.Type type, string minimum, string maximum) : base() + { + OperandType = type; + NeedToConvertMinMax = true; + Minimum = minimum; + Maximum = maximum; + } + public object Minimum { get; private set; } + public object Maximum { get; private set; } + public bool MinimumIsExclusive { get; set; } + public bool MaximumIsExclusive { get; set; } + public global::System.Type OperandType { get; } + public bool ParseLimitsInInvariantCulture { get; set; } + public bool ConvertValueInInvariantCulture { get; set; } + public override string FormatErrorMessage(string name) => + string.Format(global::System.Globalization.CultureInfo.CurrentCulture, GetValidationErrorMessage(), name, Minimum, Maximum); + private bool NeedToConvertMinMax { get; } + private bool Initialized { get; set; } + public override bool IsValid(object? value) + { + if (!Initialized) + { + if (Minimum is null || Maximum is null) + { + throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + if (NeedToConvertMinMax) + { + System.Globalization.CultureInfo culture = ParseLimitsInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + Minimum = ConvertValue(Minimum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + Maximum = ConvertValue(Maximum, culture) ?? throw new global::System.InvalidOperationException("The minimum and maximum values must be set to valid values."); + } + int cmp = ((global::System.IComparable)Minimum).CompareTo((global::System.IComparable)Maximum); + if (cmp > 0) + { + throw new global::System.InvalidOperationException("The maximum value '{Maximum}' must be greater than or equal to the minimum value '{Minimum}'."); + } + else if (cmp == 0 && (MinimumIsExclusive || MaximumIsExclusive)) + { + throw new global::System.InvalidOperationException("Cannot use exclusive bounds when the maximum value is equal to the minimum value."); + } + Initialized = true; + } + + if (value is null or string { Length: 0 }) + { + return true; + } + + System.Globalization.CultureInfo formatProvider = ConvertValueInInvariantCulture ? global::System.Globalization.CultureInfo.InvariantCulture : global::System.Globalization.CultureInfo.CurrentCulture; + object? convertedValue; + + try + { + convertedValue = ConvertValue(value, formatProvider); + } + catch (global::System.Exception e) when (e is global::System.FormatException or global::System.InvalidCastException or global::System.NotSupportedException) + { + return false; + } + + var min = (global::System.IComparable)Minimum; + var max = (global::System.IComparable)Maximum; + + return + (MinimumIsExclusive ? min.CompareTo(convertedValue) < 0 : min.CompareTo(convertedValue) <= 0) && + (MaximumIsExclusive ? max.CompareTo(convertedValue) > 0 : max.CompareTo(convertedValue) >= 0); + } + private string GetValidationErrorMessage() + { + return (MinimumIsExclusive, MaximumIsExclusive) switch + { + (false, false) => "The field {0} must be between {1} and {2}.", + (true, false) => "The field {0} must be between {1} exclusive and {2}.", + (false, true) => "The field {0} must be between {1} and {2} exclusive.", + (true, true) => "The field {0} must be between {1} exclusive and {2} exclusive.", + }; + } + private object? ConvertValue(object? value, System.Globalization.CultureInfo formatProvider) + { + if (value is string stringValue) + { + value = global::System.Convert.ChangeType(stringValue, OperandType, formatProvider); + } + else + { + value = global::System.Convert.ChangeType(value, OperandType, formatProvider); + } + return value; + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/Resources/Strings.resx b/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/Resources/Strings.resx index 6b8167e94b119a..90e5b01ed3b1b7 100644 --- a/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/Resources/Strings.resx +++ b/src/libraries/Microsoft.Extensions.Options/tests/SourceGenerationTests/Resources/Strings.resx @@ -216,4 +216,10 @@ The options validation source generator is not available in C# {0}. Please use language version {1} or greater. - \ No newline at end of file + + The validation attribute is only applicable to properties of type string, array, or ICollection; it cannot be used with other types. + + + The validation attribute {0} should only be applied to properties of type string, array, or ICollection. Using it with the type {1} could lead to runtime failures. + + diff --git a/src/libraries/Microsoft.Extensions.Options/tests/TrimmingTests/ConfigureTests.cs b/src/libraries/Microsoft.Extensions.Options/tests/TrimmingTests/ConfigureTests.cs index 870b27bbe9e100..7a39d8a8810fd3 100644 --- a/src/libraries/Microsoft.Extensions.Options/tests/TrimmingTests/ConfigureTests.cs +++ b/src/libraries/Microsoft.Extensions.Options/tests/TrimmingTests/ConfigureTests.cs @@ -4,6 +4,8 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Options; using System; +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations; class Program { @@ -37,6 +39,22 @@ optionsC is null || return -1; } + LocalOptionsValidator localOptionsValidator = new LocalOptionsValidator(); + OptionsUsingValidationAttributes optionsUsingValidationAttributes = new OptionsUsingValidationAttributes + { + P1 = "12345", + P2 = new List { "1234", "12345" }, + P3 = "123456", + P4 = "12345", + P5 = 7 + }; + + ValidateOptionsResult result = localOptionsValidator.Validate("", optionsUsingValidationAttributes); + if (result.Failed) + { + return -2; + } + return 100; } @@ -76,3 +94,29 @@ private class OptionsD public string OptionString { get; set; } } } + +public class OptionsUsingValidationAttributes +{ + [Required] + [MinLength(5)] + public string P1 { get; set; } + + [Required] + [MaxLength(5)] + public List P2 { get; set; } + + [Length(2, 8)] + public string P3 { get; set; } + + [Compare("P1")] + public string P4 { get; set; } + + [Range(1, 10, MinimumIsExclusive = true, MaximumIsExclusive = true)] + public int P5 { get; set; } +} + +[OptionsValidator] +public partial class LocalOptionsValidator : IValidateOptions +{ +} + diff --git a/src/libraries/Microsoft.Extensions.Options/tests/TrimmingTests/Microsoft.Extensions.Options.TrimmingTests.proj b/src/libraries/Microsoft.Extensions.Options/tests/TrimmingTests/Microsoft.Extensions.Options.TrimmingTests.proj index 669ac862ad7b16..15b6dc0a6ea0e2 100644 --- a/src/libraries/Microsoft.Extensions.Options/tests/TrimmingTests/Microsoft.Extensions.Options.TrimmingTests.proj +++ b/src/libraries/Microsoft.Extensions.Options/tests/TrimmingTests/Microsoft.Extensions.Options.TrimmingTests.proj @@ -7,10 +7,15 @@ Microsoft.Extensions.DependencyInjection - + + + + <_additionalProjectReference Include="<ProjectReference Include="$(LibrariesProjectRoot)Microsoft.Extensions.Options\gen\Microsoft.Extensions.Options.SourceGeneration.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="true" SetTargetFramework="TargetFramework=netstandard2.0" />" /> + + diff --git a/src/libraries/System.Collections.Immutable/src/PACKAGE.md b/src/libraries/System.Collections.Immutable/src/PACKAGE.md index 202fb6d0137833..0ca0b161aa448d 100644 --- a/src/libraries/System.Collections.Immutable/src/PACKAGE.md +++ b/src/libraries/System.Collections.Immutable/src/PACKAGE.md @@ -53,6 +53,10 @@ The main types provided by this library are: * `System.Collections.Immutable.ImmutableSortedSet` * `System.Collections.Immutable.ImmutableStack` * `System.Collections.Immutable.ImmutableStack` +* `System.Collections.Frozen.FrozenDictionary` +* `System.Collections.Frozen.FrozenDictionary` +* `System.Collections.Frozen.FrozenSet` +* `System.Collections.Frozen.FrozenSet` ## Additional Documentation @@ -65,4 +69,4 @@ The main types provided by this library are: -System.Collections.Immutable is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file +System.Collections.Immutable is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/System.Collections/tests/Generic/PriorityQueue/PriorityQueue.Tests.cs b/src/libraries/System.Collections/tests/Generic/PriorityQueue/PriorityQueue.Tests.cs index c7de89b28629b5..0bd6a70a8f2a57 100644 --- a/src/libraries/System.Collections/tests/Generic/PriorityQueue/PriorityQueue.Tests.cs +++ b/src/libraries/System.Collections/tests/Generic/PriorityQueue/PriorityQueue.Tests.cs @@ -287,7 +287,7 @@ void trimAndEnsureCapacity() private static int GetUnderlyingBufferCapacity(PriorityQueue queue) { - FieldInfo nodesField = queue.GetType().GetField("_nodes", BindingFlags.NonPublic | BindingFlags.Instance); + FieldInfo nodesField = typeof(PriorityQueue).GetField("_nodes", BindingFlags.NonPublic | BindingFlags.Instance); Assert.NotNull(nodesField); var nodes = ((TElement Element, TPriority Priority)[])nodesField.GetValue(queue); return nodes.Length; diff --git a/src/libraries/System.DirectoryServices.Protocols/src/PACKAGE.md b/src/libraries/System.DirectoryServices.Protocols/src/PACKAGE.md new file mode 100644 index 00000000000000..2938eee70fa8b0 --- /dev/null +++ b/src/libraries/System.DirectoryServices.Protocols/src/PACKAGE.md @@ -0,0 +1,69 @@ +## About + + + +System.DirectoryServices.Protocols provides a managed implementation of Lightweight Directory Access Protocol (LDAP) version 3 and Directory Services Markup Language (DSML) version 2.0 (V2) standards. + +It primarily uses the `LdapConnection` type for interacting with LDAP servers, using system native libraries to establish TCP/IP or UDP LDAP connections. +Supports both Windows and Unix, but certain features, such as setting client or server certificate options, are not available on Unix. + +## Key Features + + + +* Managed implementation of LDAP v3 and DSML V2 standards. + +## How to Use + + + +Using the `LdapConnection` type, you can establish connections to LDAP servers and issue requests. + +Here is a simple example: + +```csharp +using System.DirectoryServices.Protocols; + +// Create a new LdapConnection instance using the server URL. +using (LdapConnection connection = new LdapConnection("ldap.example.com")) { + + // Some credentials + connection.Credential = new NetworkCredential(dn, password); + + // Connect to the server + connection.Bind(); + + // Perform LDAP operations +} +``` + +## Main Types + + + +The main types provided by this library are: + +* `System.DirectoryServices.Protocols.LdapConnection` +* `System.DirectoryServices.Protocols.DirectoryAttribute` +* `System.DirectoryServices.Protocols.DirectoryOperation` +* `System.DirectoryServices.Protocols.DirectoryRequest` +* `System.DirectoryServices.Protocols.DirectoryResponse` + +## Additional Documentation + + + +* [API documentation](https://learn.microsoft.com/dotnet/api/system.directoryservices.protocols) +* [Active Directory Domain Services](https://learn.microsoft.com/windows/win32/ad/active-directory-domain-services) + +## Related Packages + + + +* [System.DirectoryServices](https://www.nuget.org/packages/System.DirectoryServices/) + +## Feedback & Contributing + + + +System.DirectoryServices.Protocols is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/System.Formats.Cbor/src/PACKAGE.md b/src/libraries/System.Formats.Cbor/src/PACKAGE.md new file mode 100644 index 00000000000000..b6549d0941e8c2 --- /dev/null +++ b/src/libraries/System.Formats.Cbor/src/PACKAGE.md @@ -0,0 +1,95 @@ +## About + + + +Provides support for reading and writing values in Concise Binary Object Representation (CBOR) format, as originally defined in [IETF RFC 7049](https://www.ietf.org/rfc/rfc7049.html). + + +## Key Features + + + +* Reader and writer types for the CBOR format. +* Built-in support for different CBOR conformance modes. + +## How to Use + + + +Write and read primitives: + +```csharp +using System.Formats.Cbor; + +var cborWriter = new CborWriter(CborConformanceMode.Lax); +cborWriter.WriteTextString("Hello World"); + +var cborReader = new CborReader(cborWriter.Encode(), CborConformanceMode.Lax); +Console.WriteLine(cborReader.ReadTextString()); +// Hello World +``` + +Write and read an array: + +```csharp +var cborWriter = new CborWriter(CborConformanceMode.Lax); +cborWriter.WriteStartArray(5); +for (var index = 0; index < 5; index++) +{ + cborWriter.WriteInt32(index); +} +cborWriter.WriteEndArray(); + +var cborReader = new CborReader(cborWriter.Encode(), CborConformanceMode.Lax); +var arrayLength = cborReader.ReadStartArray(); +for (var index = 0; index < arrayLength; index++) +{ + Console.Write(cborReader.ReadInt32()); +} +// 01234 +cborReader.ReadEndArray(); +``` + +Inspect writer and reader state: + +```csharp +var cborWriter = new CborWriter(CborConformanceMode.Lax); +cborWriter.WriteTextString("SomeArray"); +Console.WriteLine(cborWriter.BytesWritten); +// 10 +Console.WriteLine(cborWriter.IsWriteCompleted); +// True + +var cborReader = new CborReader(cborWriter.Encode(), CborConformanceMode.Lax); +Console.WriteLine(cborReader.BytesRemaining); +// 10 +Console.WriteLine(cborReader.ReadTextString()); +// SomeArray +Console.WriteLine(cborReader.BytesRemaining); +// 0 +``` + +## Main Types + + + +The main types provided by this library are: + +* `System.Formats.Cbor.CborReader` +* `System.Formats.Cbor.CborWriter` +* `System.Formats.Cbor.CborReaderState` +* `System.Formats.Cbor.CborConformanceMode` +* `System.Formats.Cbor.CborContentException` +* `System.Formats.Cbor.CborTag` + +## Additional Documentation + + + +* [API documentation](https://learn.microsoft.com/dotnet/api/system.formats.cbor) + +## Feedback & Contributing + + + +System.Formats.Cbor is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). \ No newline at end of file diff --git a/src/libraries/System.Formats.Cbor/src/System/Formats/Cbor/Reader/CborReader.Tag.cs b/src/libraries/System.Formats.Cbor/src/System/Formats/Cbor/Reader/CborReader.Tag.cs index 238217856c1905..2dd4c4d1d9abe1 100644 --- a/src/libraries/System.Formats.Cbor/src/System/Formats/Cbor/Reader/CborReader.Tag.cs +++ b/src/libraries/System.Formats.Cbor/src/System/Formats/Cbor/Reader/CborReader.Tag.cs @@ -72,7 +72,7 @@ public DateTimeOffset ReadDateTimeOffset() string dateString = ReadTextString(); // TODO determine if conformance modes should allow inexact date sting parsing - if (!DateTimeOffset.TryParseExact(dateString, CborWriter.Rfc3339FormatString, null, DateTimeStyles.RoundtripKind, out DateTimeOffset result)) + if (!DateTimeOffset.TryParseExact(dateString, CborWriter.Rfc3339FormatString, CultureInfo.InvariantCulture, DateTimeStyles.RoundtripKind, out DateTimeOffset result)) { throw new CborContentException(SR.Cbor_Reader_InvalidDateTimeEncoding); } diff --git a/src/libraries/System.Formats.Cbor/src/System/Formats/Cbor/Writer/CborWriter.Tag.cs b/src/libraries/System.Formats.Cbor/src/System/Formats/Cbor/Writer/CborWriter.Tag.cs index e5e772dcfd1be5..3ca04f085e04c0 100644 --- a/src/libraries/System.Formats.Cbor/src/System/Formats/Cbor/Writer/CborWriter.Tag.cs +++ b/src/libraries/System.Formats.Cbor/src/System/Formats/Cbor/Writer/CborWriter.Tag.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics; +using System.Globalization; using System.Numerics; namespace System.Formats.Cbor @@ -42,8 +43,8 @@ public void WriteDateTimeOffset(DateTimeOffset value) #else value.Offset == TimeSpan.Zero ? #endif // NET8_0_OR_GREATER - value.UtcDateTime.ToString(Rfc3339FormatString) : // prefer 'Z' over '+00:00' - value.ToString(Rfc3339FormatString); + value.UtcDateTime.ToString(Rfc3339FormatString, CultureInfo.InvariantCulture) : // prefer 'Z' over '+00:00' + value.ToString(Rfc3339FormatString, CultureInfo.InvariantCulture); WriteTag(CborTag.DateTimeString); WriteTextString(dateString); diff --git a/src/libraries/System.Formats.Cbor/tests/Reader/CborReaderTests.Tag.cs b/src/libraries/System.Formats.Cbor/tests/Reader/CborReaderTests.Tag.cs index af9adfbe67b500..c6fba6d2a3981c 100644 --- a/src/libraries/System.Formats.Cbor/tests/Reader/CborReaderTests.Tag.cs +++ b/src/libraries/System.Formats.Cbor/tests/Reader/CborReaderTests.Tag.cs @@ -2,8 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Globalization; using System.Linq; using System.Numerics; +using System.Threading; +using Microsoft.DotNet.RemoteExecutor; using Test.Cryptography; using Xunit; @@ -192,6 +195,31 @@ public static void ReadDateTimeOffset_SingleValue_HappyPath(string expectedValue Assert.Equal(expectedValue.Offset, result.Offset); } + [SkipOnTargetFramework(TargetFrameworkMonikers.NetFramework)] + [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + public static void ReadDateTimeOffset_IsCultureInvariant() + { + // Regression test for https://github.com/dotnet/runtime/pull/92539 + RemoteExecutor.Invoke(static () => + { + DateTimeOffset expectedValue = DateTimeOffset.Parse("2020-04-09T14:31:21.3535941+01:00", CultureInfo.InvariantCulture); + byte[] data = "c07821323032302d30342d30395431343a33313a32312e333533353934312b30313a3030".HexToByteArray(); + + // Install a non-Gregorian calendar + var culture = new CultureInfo("he-IL"); + culture.DateTimeFormat.Calendar = new HebrewCalendar(); + Thread.CurrentThread.CurrentCulture = culture; + + var reader = new CborReader(data); + + DateTimeOffset result = reader.ReadDateTimeOffset(); + + Assert.Equal(CborReaderState.Finished, reader.PeekState()); + Assert.Equal(expectedValue, result); + Assert.Equal(expectedValue.Offset, result.Offset); + }).Dispose(); + } + [Theory] [InlineData("c01a514b67b0")] // string datetime tag with unix time payload public static void ReadDateTimeOffset_InvalidTagPayload_ShouldThrowCborContentException(string hexEncoding) @@ -206,6 +234,7 @@ public static void ReadDateTimeOffset_InvalidTagPayload_ShouldThrowCborContentEx [Theory] [InlineData("c07330392f30342f323032302031393a35313a3530")] // 0("09/04/2020 19:51:50") [InlineData("c06e4c617374204368726973746d6173")] // 0("Last Christmas") + [InlineData("c07828d7aad7a922d7a42dd796272dd79822d7955431343a33313a32312e333533353934312b30313a3030")] // Non-Gregorian calendar date. public static void ReadDateTimeOffset_InvalidDateString_ShouldThrowCborContentException(string hexEncoding) { byte[] encoding = hexEncoding.HexToByteArray(); diff --git a/src/libraries/System.Formats.Cbor/tests/System.Formats.Cbor.Tests.csproj b/src/libraries/System.Formats.Cbor/tests/System.Formats.Cbor.Tests.csproj index 2ade4c628c7fbc..bf7b2f2b4aac54 100644 --- a/src/libraries/System.Formats.Cbor/tests/System.Formats.Cbor.Tests.csproj +++ b/src/libraries/System.Formats.Cbor/tests/System.Formats.Cbor.Tests.csproj @@ -1,6 +1,7 @@ - + $(NetCoreAppCurrent);$(NetFrameworkCurrent) + true enable $(NoWarn);CS8002 diff --git a/src/libraries/System.Formats.Cbor/tests/Writer/CborWriterTests.Tag.cs b/src/libraries/System.Formats.Cbor/tests/Writer/CborWriterTests.Tag.cs index 3413eadc84cc31..ff480bca39e119 100644 --- a/src/libraries/System.Formats.Cbor/tests/Writer/CborWriterTests.Tag.cs +++ b/src/libraries/System.Formats.Cbor/tests/Writer/CborWriterTests.Tag.cs @@ -2,8 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Globalization; using System.Linq; using System.Numerics; +using System.Threading; +using Microsoft.DotNet.RemoteExecutor; using Test.Cryptography; using Xunit; @@ -88,6 +91,30 @@ public static void WriteDateTimeOffset_SingleValue_HappyPath(string valueString, AssertHelper.HexEqual(expectedHexEncoding.HexToByteArray(), encoding); } + [SkipOnTargetFramework(TargetFrameworkMonikers.NetFramework)] + [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + public static void WriteDateTimeOffset_IsCultureInvariant() + { + // Regression test for https://github.com/dotnet/runtime/pull/92539 + RemoteExecutor.Invoke(static () => + { + DateTimeOffset value = DateTimeOffset.Parse("2020-04-09T14:31:21.3535941+01:00", CultureInfo.InvariantCulture); + string expectedHexEncoding = "c07821323032302d30342d30395431343a33313a32312e333533353934312b30313a3030"; + + // Install a non-Gregorian calendar + var culture = new CultureInfo("he-IL"); + culture.DateTimeFormat.Calendar = new HebrewCalendar(); + Thread.CurrentThread.CurrentCulture = culture; + + var writer = new CborWriter(); + + writer.WriteDateTimeOffset(value); + + byte[] encoding = writer.Encode(); + AssertHelper.HexEqual(expectedHexEncoding.HexToByteArray(), encoding); + }).Dispose(); + } + [Theory] [InlineData(1363896240, "c11a514b67b0")] [InlineData(1586439081, "c11a5e8f23a9")] diff --git a/src/libraries/System.Globalization/tests/CultureInfo/CultureInfoGetCultures.cs b/src/libraries/System.Globalization/tests/CultureInfo/CultureInfoGetCultures.cs new file mode 100644 index 00000000000000..5dcdfb54c1f835 --- /dev/null +++ b/src/libraries/System.Globalization/tests/CultureInfo/CultureInfoGetCultures.cs @@ -0,0 +1,25 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; + +namespace System.Globalization.Tests +{ + public class CultureInfoGetCultures + { + [Fact] + public void GetSpecificCultures() + { + var specificCultures = CultureInfo.GetCultures(CultureTypes.SpecificCultures); + Assert.True(specificCultures.Length > 0); + Assert.All(specificCultures, c => Assert.True(c.IsNeutralCulture == false)); + } + + [Fact] + public void GetAllCultures() + { + var allCultures = CultureInfo.GetCultures(CultureTypes.AllCultures); + Assert.True(allCultures.Length > 0); + } + } +} diff --git a/src/libraries/System.Globalization/tests/Hybrid/System.Globalization.IOS.Tests.csproj b/src/libraries/System.Globalization/tests/Hybrid/System.Globalization.IOS.Tests.csproj index eac50b0d2f3e0f..bb38f00b88f9d8 100644 --- a/src/libraries/System.Globalization/tests/Hybrid/System.Globalization.IOS.Tests.csproj +++ b/src/libraries/System.Globalization/tests/Hybrid/System.Globalization.IOS.Tests.csproj @@ -6,6 +6,7 @@ + diff --git a/src/libraries/System.IO.Hashing/src/PACKAGE.md b/src/libraries/System.IO.Hashing/src/PACKAGE.md new file mode 100644 index 00000000000000..41b90205eac84e --- /dev/null +++ b/src/libraries/System.IO.Hashing/src/PACKAGE.md @@ -0,0 +1,91 @@ +## About + + + +System.IO.Hashing offers a variety of hash code algorithms. + +Hash code algorithms are pivotal for generating unique values for objects based on their content, facilitating object comparisons, and detecting content alterations. +The namespace encompasses algorithms like CRC-32, CRC-64, xxHash3, xxHash32, xxHash64, and xxHash128, all engineered for swift and efficient hash code generation, with xxHash being an "Extremely fast hash algorithm". + +**Warning**: The hash functions provided by System.IO.Hashing are not suitable for security purposes such as handling passwords or verifying untrusted content. +For such security-critical applications, consider using cryptographic hash functions provided by the [System.Security.Cryptography](https://learn.microsoft.com/dotnet/api/system.security.cryptography) namespace. + +## Key Features + + + +* Variety of hash code algorithms including CRC-32, CRC-64, xxHash3, xxHash32, xxHash64, and xxHash128. +* Implementations of CRC-32 and CRC-64 algorithms, as used in IEEE 802.3, and described in ECMA-182, Annex B respectively. +* Implementations of XxHash32 for generating 32-bit hashes, XxHash3 and XxHash64 for generating 64-bit hashes, and xxHash128 for generating 128-bit hashes. + +## How to Use + + + +Creating hash codes is straightforward. +Call the `Hash` method with the content to be hashed. + +Here is a practical example: + +```csharp +using System; +using System.IO.Hashing; + +byte[] data = new byte[] { 1, 2, 3, 4 }; + +byte[] crc32Value = Crc32.Hash(data); +Console.WriteLine($"CRC-32 Hash: {BitConverter.ToString(crc32Value)}"); +// CRC-32 Hash: CD-FB-3C-B6 + +byte[] crc64Value = Crc64.Hash(data); +Console.WriteLine($"CRC-64 Hash: {BitConverter.ToString(crc64Value)}"); +// CRC-64 Hash: 58-8D-5A-D4-2A-70-1D-B2 + +byte[] xxHash3Value = XxHash3.Hash(data); +Console.WriteLine($"XxHash3 Hash: {BitConverter.ToString(xxHash3Value)}"); +// XxHash3 Hash: 98-8B-7B-90-33-AC-46-22 + +byte[] xxHash32Value = XxHash32.Hash(data); +Console.WriteLine($"XxHash32 Hash: {BitConverter.ToString(xxHash32Value)}"); +// XxHash32 Hash: FE-96-D1-9C + +byte[] xxHash64Value = XxHash64.Hash(data); +Console.WriteLine($"XxHash64 Hash: {BitConverter.ToString(xxHash64Value)}"); +// XxHash64 Hash: 54-26-20-E3-A2-A9-2E-D1 + +byte[] xxHash128Value = XxHash128.Hash(data); +Console.WriteLine($"XxHash128 Hash: {BitConverter.ToString(xxHash128Value)}"); +// XxHash128 Hash: 49-A0-48-99-59-7A-35-67-53-76-53-A0-D9-95-5B-86 +``` + +## Main Types + + + +The main types provided by this library are: + +* `System.IO.Hashing.Crc32` +* `System.IO.Hashing.Crc64` +* `System.IO.Hashing.XxHash3` +* `System.IO.Hashing.XxHash32` +* `System.IO.Hashing.XxHash64` +* `System.IO.Hashing.XxHash128` + +## Additional Documentation + + + +* [API documentation](https://learn.microsoft.com/dotnet/api/system.io.hashing) +* [xxHash - Extremely fast hash algorithm](https://github.com/Cyan4973/xxHash/blob/release/doc/xxhash_spec.md) + +## Related Packages + + + +Cryptographic services, including secure encryption and decryption of data: [System.Security.Cryptography](https://learn.microsoft.com/dotnet/api/system.security.cryptography) + +## Feedback & Contributing + + + +System.IO.Hashing is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/System.IO/tests/BinaryWriter/BinaryWriter.EncodingTests.cs b/src/libraries/System.IO/tests/BinaryWriter/BinaryWriter.EncodingTests.cs index 43e5441ea0676d..4b8a9577e0ee59 100644 --- a/src/libraries/System.IO/tests/BinaryWriter/BinaryWriter.EncodingTests.cs +++ b/src/libraries/System.IO/tests/BinaryWriter/BinaryWriter.EncodingTests.cs @@ -190,7 +190,7 @@ public void WriteString_NotUtf8(int stringLengthInChars) private static bool IsUsingFastUtf8(BinaryWriter writer) { - return (bool)writer.GetType().GetField("_useFastUtf8", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(writer); + return (bool)typeof(BinaryWriter).GetField("_useFastUtf8", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(writer); } private static string GenerateLargeUnicodeString(int charCount) diff --git a/src/libraries/System.Memory.Data/src/PACKAGE.md b/src/libraries/System.Memory.Data/src/PACKAGE.md new file mode 100644 index 00000000000000..33996a16b08252 --- /dev/null +++ b/src/libraries/System.Memory.Data/src/PACKAGE.md @@ -0,0 +1,102 @@ +## About + + + +System.Memory.Data introduces the `BinaryData` type, a lightweight abstraction for a byte payload. +It makes it easy to convert between string, bytes, and stream. + +This abstraction can simplify the API surface by exposing a single type instead of numerous overloads or properties. +The `BinaryData` type handles data ownership efficiently, wrapping passed-in bytes when using `byte[]` or `ReadOnlyMemory` constructors or methods, and managing data as bytes when dealing with streams, strings, or rich model types serialized as JSON. + + +## Key Features + + + +* Lightweight abstraction for byte payload via `BinaryData` type. +* Convenient helper methods for common conversions among string, bytes, and stream. +* Efficient data ownership handling. + +## How to Use + + + +To/From String: + +```csharp +var data = new BinaryData("some data"); + +// ToString will decode the bytes using UTF-8 +Console.WriteLine(data.ToString()); // prints "some data" +``` + +To/From Bytes: + +```csharp +byte[] bytes = Encoding.UTF8.GetBytes("some data"); + +// Create BinaryData using a constructor ... +BinaryData data = new BinaryData(bytes); + +// Or using a static factory method. +data = BinaryData.FromBytes(bytes); + +// There is an implicit cast defined for ReadOnlyMemory +ReadOnlyMemory rom = data; + +// There is also an implicit cast defined for ReadOnlySpan +ReadOnlySpan ros = data; + +// there is also a ToMemory method that gives access to the ReadOnlyMemory. +rom = data.ToMemory(); + +// and a ToArray method that converts into a byte array. +byte[] array = data.ToArray(); +``` + +To/From stream: + +```csharp +var bytes = Encoding.UTF8.GetBytes("some data"); +Stream stream = new MemoryStream(bytes); +var data = BinaryData.FromStream(stream); + +// Calling ToStream will give back a stream that is backed by ReadOnlyMemory, so it is not writable. +stream = data.ToStream(); +Console.WriteLine(stream.CanWrite); // prints false +``` + +`BinaryData` also can be used to integrate with `ObjectSerializer`. +By default, the `JsonObjectSerializer` will be used, but any serializer deriving from `ObjectSerializer` can be used. + +```csharp +var model = new CustomModel +{ + A = "some text", + B = 5, + C = true +}; + +var data = BinaryData.FromObjectAsJson(model); +model = data.ToObjectFromJson(); +``` + +## Main Types + + + +The main types provided by this library are: + +* `System.BinaryData` + +## Additional Documentation + + + +* [API documentation](https://learn.microsoft.com/dotnet/api/system.binarydata) + +## Feedback & Contributing + + + +System.Memory.Data is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/HttpClientHandler.AnyMobile.cs b/src/libraries/System.Net.Http/src/System/Net/Http/HttpClientHandler.AnyMobile.cs index dd7d005dbc3c27..4b04787d3acdce 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/HttpClientHandler.AnyMobile.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/HttpClientHandler.AnyMobile.cs @@ -51,10 +51,7 @@ private HttpMessageHandler Handler MetricsHandler metricsHandler = new MetricsHandler(handler, _nativeMeterFactory, out _); // Ensure a single handler is used for all requests. - if (Interlocked.CompareExchange(ref _nativeMetricsHandler, metricsHandler, null) != null) - { - handler.Dispose(); - } + Interlocked.CompareExchange(ref _nativeMetricsHandler, metricsHandler, null); } return _nativeMetricsHandler; @@ -87,7 +84,7 @@ protected override void Dispose(bool disposing) if (IsNativeHandlerEnabled) { - _nativeHandler!.Dispose(); + Handler.Dispose(); } else { diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Metrics/MetricsHandler.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Metrics/MetricsHandler.cs index 1f353f167150e5..4de6df347972fe 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Metrics/MetricsHandler.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Metrics/MetricsHandler.cs @@ -169,7 +169,7 @@ private static bool TryGetErrorType(HttpResponseMessage? response, Exception? ex HttpRequestError.ConfigurationLimitExceeded => "configuration_limit_exceeded", // Fall back to the exception type name in case of HttpRequestError.Unknown or when exception is not an HttpRequestException. - _ => exception.GetType().Name + _ => exception.GetType().FullName! }; return true; } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Metrics/ConnectionMetrics.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Metrics/ConnectionMetrics.cs index a9498cdc948dfb..12dbdface7398f 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Metrics/ConnectionMetrics.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Metrics/ConnectionMetrics.cs @@ -14,10 +14,10 @@ internal sealed class ConnectionMetrics private readonly object _schemeTag; private readonly object _hostTag; private readonly object? _portTag; - private readonly object? _socketAddressTag; + private readonly object? _peerAddressTag; private bool _currentlyIdle; - public ConnectionMetrics(SocketsHttpHandlerMetrics metrics, string protocolVersion, string scheme, string host, int? port, string? socketAddress) + public ConnectionMetrics(SocketsHttpHandlerMetrics metrics, string protocolVersion, string scheme, string host, int? port, string? peerAddress) { _metrics = metrics; _openConnectionsEnabled = _metrics.OpenConnections.Enabled; @@ -25,7 +25,7 @@ public ConnectionMetrics(SocketsHttpHandlerMetrics metrics, string protocolVersi _schemeTag = scheme; _hostTag = host; _portTag = port; - _socketAddressTag = socketAddress; + _peerAddressTag = peerAddress; } // TagList is a huge struct, so we avoid storing it in a field to reduce the amount we allocate on the heap. @@ -42,9 +42,9 @@ private TagList GetTags() tags.Add("server.port", _portTag); } - if (_socketAddressTag is not null) + if (_peerAddressTag is not null) { - tags.Add("server.socket.address", _socketAddressTag); + tags.Add("network.peer.address", _peerAddressTag); } return tags; diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/MetricsTest.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/MetricsTest.cs index 15bb4bcbe2bb97..4bf2638f9d3515 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/MetricsTest.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/MetricsTest.cs @@ -30,7 +30,7 @@ protected static class InstrumentNames public const string ConnectionDuration = "http.client.connection.duration"; public const string TimeInQueue = "http.client.request.time_in_queue"; } - + protected HttpMetricsTestBase(ITestOutputHelper output) : base(output) { } @@ -47,9 +47,9 @@ protected static void VerifyTag(KeyValuePair[] tags, string } } - private static void VerifySocketAddress(KeyValuePair[] tags) + private static void VerifyPeerAddress(KeyValuePair[] tags) { - string ipString = (string)tags.Single(t => t.Key == "server.socket.address").Value; + string ipString = (string)tags.Single(t => t.Key == "network.peer.address").Value; IPAddress ip = IPAddress.Parse(ipString); Assert.True(ip.Equals(IPAddress.Loopback.MapToIPv6()) || ip.Equals(IPAddress.Loopback) || @@ -122,7 +122,7 @@ protected static void VerifyOpenConnections(string actualName, object measuremen VerifySchemeHostPortTags(tags, uri); VerifyTag(tags, "network.protocol.version", GetVersionString(protocolVersion)); VerifyTag(tags, "http.connection.state", state); - VerifySocketAddress(tags); + VerifyPeerAddress(tags); } protected static void VerifyConnectionDuration(string instrumentName, object measurement, KeyValuePair[] tags, Uri uri, Version? protocolVersion) @@ -132,7 +132,7 @@ protected static void VerifyConnectionDuration(string instrumentName, object mea Assert.InRange(value, double.Epsilon, 60); VerifySchemeHostPortTags(tags, uri); VerifyTag(tags, "network.protocol.version", GetVersionString(protocolVersion)); - VerifySocketAddress(tags); + VerifyPeerAddress(tags); } protected static void VerifyTimeInQueue(string instrumentName, object measurement, KeyValuePair[] tags, Uri uri, Version? protocolVersion, string method = "GET") @@ -347,7 +347,7 @@ public Task RequestDuration_CustomTags_Recorded() { ctx.AddCustomTag("route", "/test"); }); - + using HttpResponseMessage response = await SendAsync(client, request); Measurement m = Assert.Single(recorder.GetMeasurements()); @@ -455,6 +455,21 @@ await LoopbackServerFactory.CreateClientAndServerAsync(async uri => using InstrumentRecorder recorder = SetupInstrumentRecorder(InstrumentNames.RequestDuration); using HttpRequestMessage request = new(HttpMethod.Get, uri) { Version = UseVersion }; using HttpResponseMessage response = await client.SendAsync(TestAsync, request, completionOption); + string responseContent = await response.Content.ReadAsStringAsync(); + + if (responseContentType == ResponseContentType.ContentLength) + { + Assert.NotNull(response.Content.Headers.ContentLength); + } + else if (responseContentType == ResponseContentType.TransferEncodingChunked) + { + Assert.NotNull(response.Headers.TransferEncodingChunked); + } + else + { + // Empty + Assert.Empty(responseContent); + } Measurement m = Assert.Single(recorder.GetMeasurements()); VerifyRequestDuration(m, uri, UseVersion, 200); ; @@ -655,8 +670,8 @@ await LoopbackServerFactory.CreateClientAndServerAsync(async uri => _output.WriteLine($"Client exception: {clientException}"); string[] expectedExceptionTypes = TestAsync - ? [nameof(TaskCanceledException)] - : [nameof(TaskCanceledException), nameof(OperationCanceledException)]; + ? [typeof(TaskCanceledException).FullName] + : [typeof(TaskCanceledException).FullName, typeof(OperationCanceledException).FullName]; Measurement m = Assert.Single(recorder.GetMeasurements()); VerifyRequestDuration(m, uri, acceptedErrorTypes: expectedExceptionTypes); @@ -783,7 +798,7 @@ await Assert.ThrowsAsync(async () => using HttpResponseMessage response = await SendAsync(client, request); }); } - + Measurement m = Assert.Single(recorder.GetMeasurements()); VerifyRequestDuration(m, uri, UseVersion, 200); Assert.Equal("before!", m.Tags.ToArray().Single(t => t.Key == "before").Value); @@ -837,7 +852,7 @@ await LoopbackServerFactory.CreateClientAndServerAsync(async uri => Assert.True(ex is HttpRequestException or TaskCanceledException); Measurement m = Assert.Single(recorder.GetMeasurements()); - VerifyRequestDuration(m, uri, acceptedErrorTypes: [nameof(TaskCanceledException), "response_ended"]); + VerifyRequestDuration(m, uri, acceptedErrorTypes: [typeof(TaskCanceledException).FullName, "response_ended"]); }, async server => { try diff --git a/src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs b/src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs index f4b09f346b3d9a..8227d58ebb61b6 100644 --- a/src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs +++ b/src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs @@ -23,13 +23,13 @@ public static string GetHostName() { name = NameResolutionPal.GetHostName(); } - catch when (LogFailure(string.Empty, startingTimestamp)) + catch (Exception ex) when (LogFailure(string.Empty, startingTimestamp, ex)) { Debug.Fail("LogFailure should return false"); throw; } - NameResolutionTelemetry.Log.AfterResolution(string.Empty, startingTimestamp, successful: true); + NameResolutionTelemetry.Log.AfterResolution(string.Empty, startingTimestamp); if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(null, name); return name; @@ -394,13 +394,13 @@ private static object GetHostEntryOrAddressesCore(string hostName, bool justAddr Aliases = aliases }; } - catch when (LogFailure(hostName, startingTimestamp)) + catch (Exception ex) when (LogFailure(hostName, startingTimestamp, ex)) { Debug.Fail("LogFailure should return false"); throw; } - NameResolutionTelemetry.Log.AfterResolution(hostName, startingTimestamp, successful: true); + NameResolutionTelemetry.Log.AfterResolution(hostName, startingTimestamp); return result; } @@ -434,13 +434,13 @@ private static object GetHostEntryOrAddressesCore(IPAddress address, bool justAd } Debug.Assert(name != null); } - catch when (LogFailure(address, startingTimestamp)) + catch (Exception ex) when (LogFailure(address, startingTimestamp, ex)) { Debug.Fail("LogFailure should return false"); throw; } - NameResolutionTelemetry.Log.AfterResolution(address, startingTimestamp, successful: true); + NameResolutionTelemetry.Log.AfterResolution(address, startingTimestamp); // Do the forward lookup to get the IPs for that host name startingTimestamp = NameResolutionTelemetry.Log.BeforeResolution(name); @@ -464,13 +464,13 @@ private static object GetHostEntryOrAddressesCore(IPAddress address, bool justAd AddressList = addresses }; } - catch when (LogFailure(name, startingTimestamp)) + catch (Exception ex) when (LogFailure(name, startingTimestamp, ex)) { Debug.Fail("LogFailure should return false"); throw; } - NameResolutionTelemetry.Log.AfterResolution(name, startingTimestamp, successful: true); + NameResolutionTelemetry.Log.AfterResolution(name, startingTimestamp); // One of three things happened: // 1. Success. @@ -577,7 +577,7 @@ private static Task GetHostEntryOrAddressesCoreAsync(string hostName, bool justR } private static Task? GetAddrInfoWithTelemetryAsync(string hostName, bool justAddresses, AddressFamily addressFamily, CancellationToken cancellationToken) - where T : class + where T : class { long startingTimestamp = Stopwatch.GetTimestamp(); Task? task = NameResolutionPal.GetAddrInfoAsync(hostName, justAddresses, addressFamily, cancellationToken); @@ -594,15 +594,19 @@ private static Task GetHostEntryOrAddressesCoreAsync(string hostName, bool justR static async Task CompleteAsync(Task task, string hostName, long startingTimestamp) { _ = NameResolutionTelemetry.Log.BeforeResolution(hostName); - T? result = null; + Exception? exception = null; try { - result = await ((Task)task).ConfigureAwait(false); - return result; + return await ((Task)task).ConfigureAwait(false); + } + catch (Exception ex) + { + exception = ex; + throw; } finally { - NameResolutionTelemetry.Log.AfterResolution(hostName, startingTimestamp, successful: result is not null); + NameResolutionTelemetry.Log.AfterResolution(hostName, startingTimestamp, exception); } } } @@ -627,9 +631,9 @@ private static void ValidateHostName(string hostName) } } - private static bool LogFailure(object hostNameOrAddress, long? startingTimestamp) + private static bool LogFailure(object hostNameOrAddress, long? startingTimestamp, Exception exception) { - NameResolutionTelemetry.Log.AfterResolution(hostNameOrAddress, startingTimestamp, successful: false); + NameResolutionTelemetry.Log.AfterResolution(hostNameOrAddress, startingTimestamp, exception); return false; } diff --git a/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionMetrics.cs b/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionMetrics.cs index 180f492b3408e2..fe1048e90b22de 100644 --- a/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionMetrics.cs +++ b/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionMetrics.cs @@ -13,15 +13,35 @@ internal static class NameResolutionMetrics private static readonly Meter s_meter = new("System.Net.NameResolution"); private static readonly Histogram s_lookupDuration = s_meter.CreateHistogram( - name: "dns.lookups.duration", + name: "dns.lookup.duration", unit: "s", description: "Measures the time taken to perform a DNS lookup."); public static bool IsEnabled() => s_lookupDuration.Enabled; - public static void AfterResolution(TimeSpan duration, string hostName) + public static void AfterResolution(TimeSpan duration, string hostName, Exception? exception) { - s_lookupDuration.Record(duration.TotalSeconds, KeyValuePair.Create("dns.question.name", (object?)hostName)); + var hostNameTag = KeyValuePair.Create("dns.question.name", (object?)hostName); + + if (exception is null) + { + s_lookupDuration.Record(duration.TotalSeconds, hostNameTag); + } + else + { + var errorTypeTag = KeyValuePair.Create("error.type", (object?)GetErrorType(exception)); + s_lookupDuration.Record(duration.TotalSeconds, hostNameTag, errorTypeTag); + } } + + private static string GetErrorType(Exception exception) => (exception as SocketException)?.SocketErrorCode switch + { + SocketError.HostNotFound => "host_not_found", + SocketError.TryAgain => "try_again", + SocketError.AddressFamilyNotSupported => "address_family_not_supported", + SocketError.NoRecovery => "no_recovery", + + _ => exception.GetType().FullName! + }; } } diff --git a/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionTelemetry.cs b/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionTelemetry.cs index ef43b59d15a139..73ed325712ac52 100644 --- a/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionTelemetry.cs +++ b/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionTelemetry.cs @@ -81,7 +81,7 @@ public long BeforeResolution(object hostNameOrAddress) } [NonEvent] - public void AfterResolution(object hostNameOrAddress, long? startingTimestamp, bool successful) + public void AfterResolution(object hostNameOrAddress, long? startingTimestamp, Exception? exception = null) { Debug.Assert(startingTimestamp.HasValue); if (startingTimestamp == 0) @@ -99,7 +99,7 @@ public void AfterResolution(object hostNameOrAddress, long? startingTimestamp, b if (IsEnabled(EventLevel.Informational, EventKeywords.None)) { - if (!successful) + if (exception is not null) { ResolutionFailed(); } @@ -110,7 +110,7 @@ public void AfterResolution(object hostNameOrAddress, long? startingTimestamp, b if (NameResolutionMetrics.IsEnabled()) { - NameResolutionMetrics.AfterResolution(duration, GetHostnameFromStateObject(hostNameOrAddress)); + NameResolutionMetrics.AfterResolution(duration, GetHostnameFromStateObject(hostNameOrAddress), exception); } } diff --git a/src/libraries/System.Net.NameResolution/tests/FunctionalTests/MetricsTest.cs b/src/libraries/System.Net.NameResolution/tests/FunctionalTests/MetricsTest.cs index d3c0990dbf9c4c..a19d9edc476abc 100644 --- a/src/libraries/System.Net.NameResolution/tests/FunctionalTests/MetricsTest.cs +++ b/src/libraries/System.Net.NameResolution/tests/FunctionalTests/MetricsTest.cs @@ -14,7 +14,7 @@ namespace System.Net.NameResolution.Tests { public class MetricsTest { - private const string DnsLookupDuration = "dns.lookups.duration"; + private const string DnsLookupDuration = "dns.lookup.duration"; [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] public static void ResolveValidHostName_MetricsRecorded() @@ -57,17 +57,26 @@ public static async Task ResolveInvalidHostName_MetricsRecorded() Assert.ThrowsAny(() => Dns.EndGetHostEntry(Dns.BeginGetHostEntry(InvalidHostName, null, null))); Assert.ThrowsAny(() => Dns.EndGetHostAddresses(Dns.BeginGetHostAddresses(InvalidHostName, null, null))); - double[] measurements = GetMeasurementsForHostname(recorder, InvalidHostName); + double[] measurements = GetMeasurementsForHostname(recorder, InvalidHostName, "host_not_found"); Assert.Equal(6, measurements.Length); Assert.All(measurements, m => Assert.True(m > double.Epsilon)); } - private static double[] GetMeasurementsForHostname(InstrumentRecorder recorder, string hostname) + private static double[] GetMeasurementsForHostname(InstrumentRecorder recorder, string hostname, string? expectedErrorType = null) { return recorder .GetMeasurements() - .Where(m => m.Tags.ToArray().Any(t => t.Key == "dns.question.name" && t.Value is string hostnameTag && hostnameTag == hostname)) + .Where(m => + { + KeyValuePair[] tags = m.Tags.ToArray(); + if (!tags.Any(t => t.Key == "dns.question.name" && t.Value is string hostnameTag && hostnameTag == hostname)) + { + return false; + } + string? actualErrorType = tags.FirstOrDefault(t => t.Key == "error.type").Value as string; + return expectedErrorType == actualErrorType; + }) .Select(m => m.Value) .ToArray(); } diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendReceive/SendReceive.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendReceive/SendReceive.cs index d5fbf0b1c83981..32a7bdcfbb6acf 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendReceive/SendReceive.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendReceive/SendReceive.cs @@ -1019,11 +1019,12 @@ public async Task TcpReceiveSendGetsCanceledByDispose(bool receiveOrSend, bool i return; } - // RHEL7 kernel has a bug preventing close(AF_UNKNOWN) to succeed with IPv6 sockets. - // In this case Dispose will trigger a graceful shutdown, which means that receive will succeed on socket2. - // This bug is fixed in kernel 3.10.0-1160.25+. - // TODO: Remove this, once CI machines are updated to a newer kernel. - bool mayShutdownGraceful = UsesSync && PlatformDetection.IsRedHatFamily7 && receiveOrSend && (ipv6Server || dualModeClient); + // .NET uses connect(AF_UNSPEC) to abort on-going operations on Linux. + // Linux 6.4+ introduced a change (4faeee0cf8a5d88d63cdbc3bab124fb0e6aed08c) which disallows + // this operation while operations are on-going. + // When the connect fails, .NET falls back to use shutdown(SHUT_RDWR). + // This causes the receive on socket2 to succeed instead of failing with ConnectionReset. + bool mayShutdownGraceful = UsesSync && PlatformDetection.IsLinux && receiveOrSend; // We try this a couple of times to deal with a timing race: if the Dispose happens // before the operation is started, the peer won't see a ConnectionReset SocketException and we won't diff --git a/src/libraries/System.Numerics.Tensors/System.Numerics.Tensors.sln b/src/libraries/System.Numerics.Tensors/System.Numerics.Tensors.sln index cc3000d60ef88f..015b65250931ac 100644 --- a/src/libraries/System.Numerics.Tensors/System.Numerics.Tensors.sln +++ b/src/libraries/System.Numerics.Tensors/System.Numerics.Tensors.sln @@ -1,18 +1,34 @@ -Microsoft Visual Studio Solution File, Format Version 12.00 + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 17 +VisualStudioVersion = 17.8.34205.153 +MinimumVisualStudioVersion = 10.0.40219.1 Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TestUtilities", "..\Common\tests\TestUtilities\TestUtilities.csproj", "{9F20CEA1-2216-4432-BBBD-F01E05D17F23}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Bcl.Numerics", "..\Microsoft.Bcl.Numerics\ref\Microsoft.Bcl.Numerics.csproj", "{D311ABE4-10A9-4BB1-89CE-6358C55501A8}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Bcl.Numerics", "..\Microsoft.Bcl.Numerics\src\Microsoft.Bcl.Numerics.csproj", "{1578185F-C4FA-4866-936B-E62AAEDD03B7}" +EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "System.Numerics.Tensors", "ref\System.Numerics.Tensors.csproj", "{21CB448A-3882-4337-B416-D1A3E0BCFFC5}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "System.Numerics.Tensors", "src\System.Numerics.Tensors.csproj", "{848DD000-3D22-4A25-A9D9-05AFF857A116}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "System.Numerics.Tensors.Tests", "tests\System.Numerics.Tensors.Tests.csproj", "{4AF6A02D-82C8-4898-9EDF-01F107C25061}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ComInterfaceGenerator", "..\System.Runtime.InteropServices\gen\ComInterfaceGenerator\ComInterfaceGenerator.csproj", "{8CA7C982-3EE4-4BCE-9493-7A63556736D3}" -EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LibraryImportGenerator", "..\System.Runtime.InteropServices\gen\LibraryImportGenerator\LibraryImportGenerator.csproj", "{4588351F-4233-4957-B84C-7F8E22B8888A}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Interop.SourceGeneration", "..\System.Runtime.InteropServices\gen\Microsoft.Interop.SourceGeneration\Microsoft.Interop.SourceGeneration.csproj", "{DB954E01-898A-4FE2-A3AA-180D041AB08F}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ILLink.CodeFixProvider", "..\..\tools\illink\src\ILLink.CodeFix\ILLink.CodeFixProvider.csproj", "{04FC0651-B9D0-448A-A28B-11B1D4A897F4}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ILLink.RoslynAnalyzer", "..\..\tools\illink\src\ILLink.RoslynAnalyzer\ILLink.RoslynAnalyzer.csproj", "{683A7D28-CC55-4375-848D-E659075ECEE4}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ILLink.Tasks", "..\..\tools\illink\src\ILLink.Tasks\ILLink.Tasks.csproj", "{1CBEAEA8-2CA1-4B07-9930-35A785205852}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Mono.Linker", "..\..\tools\illink\src\linker\Mono.Linker.csproj", "{BA7828B1-7953-47A0-AE5A-E22B501C4BD0}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Mono.Linker", "..\..\tools\illink\src\linker\ref\Mono.Linker.csproj", "{57E57290-3A6A-43F8-8764-D4DC8151F89C}" +EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tests", "tests", "{DE94CA7D-BB10-4865-85A6-6B694631247F}" EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "ref", "ref", "{6BC42E6D-848C-4533-B715-F116E7DB3610}" @@ -21,6 +37,14 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{AB415F5A-75E EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "gen", "gen", "{083161E5-6049-4D84-9739-9D7797D7117D}" EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "gen", "gen", "{841A2FA4-A95F-4612-A8B9-AD2EF769BC71}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{DF0561A1-3AB8-4B51-AFB4-392EE1DD6247}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "ref", "ref", "{7AC4B2C7-A55C-4C4F-9B02-77F5CBFFF4AB}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tools", "tools", "{F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -31,6 +55,14 @@ Global {9F20CEA1-2216-4432-BBBD-F01E05D17F23}.Debug|Any CPU.Build.0 = Debug|Any CPU {9F20CEA1-2216-4432-BBBD-F01E05D17F23}.Release|Any CPU.ActiveCfg = Release|Any CPU {9F20CEA1-2216-4432-BBBD-F01E05D17F23}.Release|Any CPU.Build.0 = Release|Any CPU + {D311ABE4-10A9-4BB1-89CE-6358C55501A8}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {D311ABE4-10A9-4BB1-89CE-6358C55501A8}.Debug|Any CPU.Build.0 = Debug|Any CPU + {D311ABE4-10A9-4BB1-89CE-6358C55501A8}.Release|Any CPU.ActiveCfg = Release|Any CPU + {D311ABE4-10A9-4BB1-89CE-6358C55501A8}.Release|Any CPU.Build.0 = Release|Any CPU + {1578185F-C4FA-4866-936B-E62AAEDD03B7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {1578185F-C4FA-4866-936B-E62AAEDD03B7}.Debug|Any CPU.Build.0 = Debug|Any CPU + {1578185F-C4FA-4866-936B-E62AAEDD03B7}.Release|Any CPU.ActiveCfg = Release|Any CPU + {1578185F-C4FA-4866-936B-E62AAEDD03B7}.Release|Any CPU.Build.0 = Release|Any CPU {21CB448A-3882-4337-B416-D1A3E0BCFFC5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {21CB448A-3882-4337-B416-D1A3E0BCFFC5}.Debug|Any CPU.Build.0 = Debug|Any CPU {21CB448A-3882-4337-B416-D1A3E0BCFFC5}.Release|Any CPU.ActiveCfg = Release|Any CPU @@ -43,10 +75,6 @@ Global {4AF6A02D-82C8-4898-9EDF-01F107C25061}.Debug|Any CPU.Build.0 = Debug|Any CPU {4AF6A02D-82C8-4898-9EDF-01F107C25061}.Release|Any CPU.ActiveCfg = Release|Any CPU {4AF6A02D-82C8-4898-9EDF-01F107C25061}.Release|Any CPU.Build.0 = Release|Any CPU - {8CA7C982-3EE4-4BCE-9493-7A63556736D3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {8CA7C982-3EE4-4BCE-9493-7A63556736D3}.Debug|Any CPU.Build.0 = Debug|Any CPU - {8CA7C982-3EE4-4BCE-9493-7A63556736D3}.Release|Any CPU.ActiveCfg = Release|Any CPU - {8CA7C982-3EE4-4BCE-9493-7A63556736D3}.Release|Any CPU.Build.0 = Release|Any CPU {4588351F-4233-4957-B84C-7F8E22B8888A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {4588351F-4233-4957-B84C-7F8E22B8888A}.Debug|Any CPU.Build.0 = Debug|Any CPU {4588351F-4233-4957-B84C-7F8E22B8888A}.Release|Any CPU.ActiveCfg = Release|Any CPU @@ -55,20 +83,53 @@ Global {DB954E01-898A-4FE2-A3AA-180D041AB08F}.Debug|Any CPU.Build.0 = Debug|Any CPU {DB954E01-898A-4FE2-A3AA-180D041AB08F}.Release|Any CPU.ActiveCfg = Release|Any CPU {DB954E01-898A-4FE2-A3AA-180D041AB08F}.Release|Any CPU.Build.0 = Release|Any CPU + {04FC0651-B9D0-448A-A28B-11B1D4A897F4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {04FC0651-B9D0-448A-A28B-11B1D4A897F4}.Debug|Any CPU.Build.0 = Debug|Any CPU + {04FC0651-B9D0-448A-A28B-11B1D4A897F4}.Release|Any CPU.ActiveCfg = Release|Any CPU + {04FC0651-B9D0-448A-A28B-11B1D4A897F4}.Release|Any CPU.Build.0 = Release|Any CPU + {683A7D28-CC55-4375-848D-E659075ECEE4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {683A7D28-CC55-4375-848D-E659075ECEE4}.Debug|Any CPU.Build.0 = Debug|Any CPU + {683A7D28-CC55-4375-848D-E659075ECEE4}.Release|Any CPU.ActiveCfg = Release|Any CPU + {683A7D28-CC55-4375-848D-E659075ECEE4}.Release|Any CPU.Build.0 = Release|Any CPU + {1CBEAEA8-2CA1-4B07-9930-35A785205852}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {1CBEAEA8-2CA1-4B07-9930-35A785205852}.Debug|Any CPU.Build.0 = Debug|Any CPU + {1CBEAEA8-2CA1-4B07-9930-35A785205852}.Release|Any CPU.ActiveCfg = Release|Any CPU + {1CBEAEA8-2CA1-4B07-9930-35A785205852}.Release|Any CPU.Build.0 = Release|Any CPU + {BA7828B1-7953-47A0-AE5A-E22B501C4BD0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {BA7828B1-7953-47A0-AE5A-E22B501C4BD0}.Debug|Any CPU.Build.0 = Debug|Any CPU + {BA7828B1-7953-47A0-AE5A-E22B501C4BD0}.Release|Any CPU.ActiveCfg = Release|Any CPU + {BA7828B1-7953-47A0-AE5A-E22B501C4BD0}.Release|Any CPU.Build.0 = Release|Any CPU + {57E57290-3A6A-43F8-8764-D4DC8151F89C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {57E57290-3A6A-43F8-8764-D4DC8151F89C}.Debug|Any CPU.Build.0 = Debug|Any CPU + {57E57290-3A6A-43F8-8764-D4DC8151F89C}.Release|Any CPU.ActiveCfg = Release|Any CPU + {57E57290-3A6A-43F8-8764-D4DC8151F89C}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE EndGlobalSection GlobalSection(NestedProjects) = preSolution {9F20CEA1-2216-4432-BBBD-F01E05D17F23} = {DE94CA7D-BB10-4865-85A6-6B694631247F} - {4AF6A02D-82C8-4898-9EDF-01F107C25061} = {DE94CA7D-BB10-4865-85A6-6B694631247F} + {D311ABE4-10A9-4BB1-89CE-6358C55501A8} = {6BC42E6D-848C-4533-B715-F116E7DB3610} + {1578185F-C4FA-4866-936B-E62AAEDD03B7} = {AB415F5A-75E5-4E03-8A92-15CEDEC4CD3A} {21CB448A-3882-4337-B416-D1A3E0BCFFC5} = {6BC42E6D-848C-4533-B715-F116E7DB3610} {848DD000-3D22-4A25-A9D9-05AFF857A116} = {AB415F5A-75E5-4E03-8A92-15CEDEC4CD3A} - {8CA7C982-3EE4-4BCE-9493-7A63556736D3} = {083161E5-6049-4D84-9739-9D7797D7117D} + {4AF6A02D-82C8-4898-9EDF-01F107C25061} = {DE94CA7D-BB10-4865-85A6-6B694631247F} {4588351F-4233-4957-B84C-7F8E22B8888A} = {083161E5-6049-4D84-9739-9D7797D7117D} {DB954E01-898A-4FE2-A3AA-180D041AB08F} = {083161E5-6049-4D84-9739-9D7797D7117D} + {04FC0651-B9D0-448A-A28B-11B1D4A897F4} = {841A2FA4-A95F-4612-A8B9-AD2EF769BC71} + {683A7D28-CC55-4375-848D-E659075ECEE4} = {841A2FA4-A95F-4612-A8B9-AD2EF769BC71} + {1CBEAEA8-2CA1-4B07-9930-35A785205852} = {DF0561A1-3AB8-4B51-AFB4-392EE1DD6247} + {BA7828B1-7953-47A0-AE5A-E22B501C4BD0} = {DF0561A1-3AB8-4B51-AFB4-392EE1DD6247} + {57E57290-3A6A-43F8-8764-D4DC8151F89C} = {7AC4B2C7-A55C-4C4F-9B02-77F5CBFFF4AB} + {841A2FA4-A95F-4612-A8B9-AD2EF769BC71} = {F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B} + {DF0561A1-3AB8-4B51-AFB4-392EE1DD6247} = {F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B} + {7AC4B2C7-A55C-4C4F-9B02-77F5CBFFF4AB} = {F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {10A5F2C3-5230-4916-9D4D-BBDB94851037} EndGlobalSection -EndGlobal + GlobalSection(SharedMSBuildProjectFiles) = preSolution + ..\..\tools\illink\src\ILLink.Shared\ILLink.Shared.projitems*{683a7d28-cc55-4375-848d-e659075ecee4}*SharedItemsImports = 5 + ..\..\tools\illink\src\ILLink.Shared\ILLink.Shared.projitems*{ba7828b1-7953-47a0-ae5a-e22b501c4bd0}*SharedItemsImports = 5 + EndGlobalSection +EndGlobal \ No newline at end of file diff --git a/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.cs b/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.cs index 50eaa00160e5cb..99bd4703574e55 100644 --- a/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.cs +++ b/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.cs @@ -8,6 +8,7 @@ namespace System.Numerics.Tensors { public static partial class TensorPrimitives { + public static void Abs(System.ReadOnlySpan x, System.Span destination) { } public static void Add(System.ReadOnlySpan x, System.ReadOnlySpan y, System.Span destination) { } public static void Add(System.ReadOnlySpan x, float y, System.Span destination) { } public static void AddMultiply(System.ReadOnlySpan x, System.ReadOnlySpan y, System.ReadOnlySpan multiplier, System.Span destination) { } @@ -24,12 +25,17 @@ public static void Exp(System.ReadOnlySpan x, System.Span destinat public static int IndexOfMaxMagnitude(System.ReadOnlySpan x) { throw null; } public static int IndexOfMin(System.ReadOnlySpan x) { throw null; } public static int IndexOfMinMagnitude(System.ReadOnlySpan x) { throw null; } - public static float L2Normalize(System.ReadOnlySpan x) { throw null; } + public static float Norm(System.ReadOnlySpan x) { throw null; } public static void Log(System.ReadOnlySpan x, System.Span destination) { } + public static void Log2(System.ReadOnlySpan x, System.Span destination) { } public static float Max(System.ReadOnlySpan x) { throw null; } + public static void Max(System.ReadOnlySpan x, System.ReadOnlySpan y, System.Span destination) { throw null; } public static float MaxMagnitude(System.ReadOnlySpan x) { throw null; } + public static void MaxMagnitude(System.ReadOnlySpan x, System.ReadOnlySpan y, System.Span destination) { throw null; } public static float Min(System.ReadOnlySpan x) { throw null; } + public static void Min(System.ReadOnlySpan x, System.ReadOnlySpan y, System.Span destination) { throw null; } public static float MinMagnitude(System.ReadOnlySpan x) { throw null; } + public static void MinMagnitude(System.ReadOnlySpan x, System.ReadOnlySpan y, System.Span destination) { throw null; } public static void Multiply(System.ReadOnlySpan x, System.ReadOnlySpan y, System.Span destination) { } public static void Multiply(System.ReadOnlySpan x, float y, System.Span destination) { } public static void MultiplyAdd(System.ReadOnlySpan x, System.ReadOnlySpan y, System.ReadOnlySpan addend, System.Span destination) { } diff --git a/src/libraries/System.Numerics.Tensors/src/PACKAGE.md b/src/libraries/System.Numerics.Tensors/src/PACKAGE.md new file mode 100644 index 00000000000000..c5670c1c0f9893 --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/src/PACKAGE.md @@ -0,0 +1,53 @@ +## About + +Provides methods for performing mathematical operations over _tensors_ represented as spans. These methods are accelerated to use SIMD (Single instruction, multiple data) operations supported by the CPU where available. + +## Key Features + +* Numerical operations on tensors represented as `ReadOnlySpan` +* Element-wise arithmetic: Add, Subtract, Multiply, Divide, Exp, Log, Cosh, Tanh, etc. +* Tensor arithmetic: CosineSimilarity, Distance, Dot, Normalize, Softmax, Sigmoid, etc. + +## How to Use + +```C# +using System.Numerics.Tensors; + +var movies = new[] { + new { Title="The Lion King", Embedding= new [] { 0.10022575f, -0.23998135f } }, + new { Title="Inception", Embedding= new [] { 0.10327095f, 0.2563685f } }, + new { Title="Toy Story", Embedding= new [] { 0.095857024f, -0.201278f } }, + new { Title="Pulp Function", Embedding= new [] { 0.106827796f, 0.21676421f } }, + new { Title="Shrek", Embedding= new [] { 0.09568083f, -0.21177962f } } +}; +var queryEmbedding = new[] { 0.12217915f, -0.034832448f }; + +var top3MoviesTensorPrimitives = + movies + .Select(movie => + ( + movie.Title, + Similarity: TensorPrimitives.CosineSimilarity(queryEmbedding, movie.Embedding) + )) + .OrderByDescending(movies => movies.Similarity) + .Take(3); + +foreach (var movie in top3MoviesTensorPrimitives) +{ + Console.WriteLine(movie); +} +``` + +## Main Types + +The main types provided by this library are: + +* `System.Numerics.Tensors.TensorPrimitives` + +## Additional Documentation + +* [API documentation](https://learn.microsoft.com/en-us/dotnet/api/system.numerics.tensors) + +## Feedback & Contributing + +System.Numerics.Tensors is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/System.Numerics.Tensors/src/ReferenceAssemblyExclusions.txt b/src/libraries/System.Numerics.Tensors/src/ReferenceAssemblyExclusions.txt new file mode 100644 index 00000000000000..a8f2d0192cfec9 --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/src/ReferenceAssemblyExclusions.txt @@ -0,0 +1,2 @@ +M:System.Numerics.Tensors.TensorPrimitives.ConvertToHalf(System.ReadOnlySpan{System.Single},System.Span{System.Half}) +M:System.Numerics.Tensors.TensorPrimitives.ConvertToSingle(System.ReadOnlySpan{System.Half},System.Span{System.Single}) \ No newline at end of file diff --git a/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx b/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx index 45f0d8fa17893a..86b9f4d82b1f61 100644 --- a/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx +++ b/src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx @@ -126,4 +126,7 @@ Input span arguments must all have the same length. - \ No newline at end of file + + The destination span may only overlap with an input span if the two spans start at the same memory location. + + diff --git a/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj b/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj index be4a04702af5e1..52c6cb65811e68 100644 --- a/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj +++ b/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj @@ -9,6 +9,7 @@ Once this package has shipped a stable version, the following line should be removed in order to re-enable validation. --> true + ReferenceAssemblyExclusions.txt diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs index a400095b08ad39..03db1abb7f858a 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs @@ -1,824 +1,1097 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + namespace System.Numerics.Tensors { /// Performs primitive tensor operations over spans of memory. public static partial class TensorPrimitives { - /// Computes the element-wise result of: + . + /// Computes the element-wise absolute value of each single-precision floating-point number in the specified tensor. + /// The tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = MathF.Abs([i]). + /// + /// + /// The absolute value of a is its numeric value without its sign. For example, the absolute value of both 1.2e-03 and -1.2e03 is 1.2e03. + /// + /// + /// If a value is equal to or , the result stored into the corresponding destination location is set to . + /// If a value is equal to , the result stored into the corresponding destination location is the original NaN value with the sign bit removed. + /// + /// + public static void Abs(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); + + /// Computes the element-wise addition of single-precision floating-point numbers in the specified tensors. /// The first tensor, represented as a span. /// The second tensor, represented as a span. /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. + /// Length of must be same as length of . /// Destination is too short. - /// This method effectively does [i] = [i] + [i]. - public static unsafe void Add(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] + [i]. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Add(ReadOnlySpan x, ReadOnlySpan y, Span destination) => InvokeSpanSpanIntoSpan(x, y, destination); - /// Computes the element-wise result of: + . + /// Computes the element-wise addition of single-precision floating-point numbers in the specified tensors. /// The first tensor, represented as a span. /// The second tensor, represented as a scalar. /// The destination tensor, represented as a span. /// Destination is too short. - /// This method effectively does [i] = [i] + . + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] + . + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// public static void Add(ReadOnlySpan x, float y, Span destination) => InvokeSpanScalarIntoSpan(x, y, destination); - /// Computes the element-wise result of: - . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = [i] - [i]. - public static void Subtract(ReadOnlySpan x, ReadOnlySpan y, Span destination) => - InvokeSpanSpanIntoSpan(x, y, destination); - - /// Computes the element-wise result of: - . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The destination tensor, represented as a span. - /// Destination is too short. - /// This method effectively does [i] = [i] - . - public static void Subtract(ReadOnlySpan x, float y, Span destination) => - InvokeSpanScalarIntoSpan(x, y, destination); - - /// Computes the element-wise result of: * . + /// Computes the element-wise result of ( + ) * for the specified tensors. /// The first tensor, represented as a span. /// The second tensor, represented as a span. + /// The third tensor, represented as a span. /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = [i] * . - public static void Multiply(ReadOnlySpan x, ReadOnlySpan y, Span destination) => - InvokeSpanSpanIntoSpan(x, y, destination); - - /// Computes the element-wise result of: * . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The destination tensor, represented as a span. + /// Length of must be same as length of and the length of . /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. /// - /// This method effectively does [i] = [i] * . - /// This method corresponds to the scal method defined by BLAS1. + /// + /// This method effectively computes [i] = ([i] + [i]) * [i]. + /// + /// + /// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// /// - public static void Multiply(ReadOnlySpan x, float y, Span destination) => - InvokeSpanScalarIntoSpan(x, y, destination); + public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan multiplier, Span destination) => + InvokeSpanSpanSpanIntoSpan(x, y, multiplier, destination); - /// Computes the element-wise result of: / . + /// Computes the element-wise result of ( + ) * for the specified tensors. /// The first tensor, represented as a span. /// The second tensor, represented as a span. + /// The third tensor, represented as a scalar. /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. + /// Length of must be same as length of . /// Destination is too short. - /// This method effectively does [i] = [i] / . - public static void Divide(ReadOnlySpan x, ReadOnlySpan y, Span destination) => - InvokeSpanSpanIntoSpan(x, y, destination); + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = ([i] + [i]) * . + /// + /// + /// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, float multiplier, Span destination) => + InvokeSpanSpanScalarIntoSpan(x, y, multiplier, destination); - /// Computes the element-wise result of: / . + /// Computes the element-wise result of ( + ) * for the specified tensors. /// The first tensor, represented as a span. /// The second tensor, represented as a scalar. + /// The third tensor, represented as a span. /// The destination tensor, represented as a span. + /// Length of must be same as length of . /// Destination is too short. - /// This method effectively does [i] = [i] / . - public static void Divide(ReadOnlySpan x, float y, Span destination) => - InvokeSpanScalarIntoSpan(x, y, destination); + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = ([i] + ) * [i]. + /// + /// + /// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void AddMultiply(ReadOnlySpan x, float y, ReadOnlySpan multiplier, Span destination) => + InvokeSpanScalarSpanIntoSpan(x, y, multiplier, destination); - /// Computes the element-wise result of: -. + /// Computes the element-wise hyperbolic cosine of each single-precision floating-point radian angle in the specified tensor. /// The tensor, represented as a span. /// The destination tensor, represented as a span. /// Destination is too short. - /// This method effectively does [i] = -[i]. - public static void Negate(ReadOnlySpan x, Span destination) => - InvokeSpanIntoSpan(x, destination); + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = .Cosh([i]). + /// + /// + /// If a value is equal to or , the result stored into the corresponding destination location is set to . + /// If a value is equal to , the result stored into the corresponding destination location is also NaN. + /// + /// + /// The angles in x must be in radians. Use or multiply by /180 to convert degrees to radians. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Cosh(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); - /// Computes the element-wise result of: ( + ) * . + /// Computes the cosine similarity between the two specified non-empty, equal-length tensors of single-precision floating-point numbers. /// The first tensor, represented as a span. /// The second tensor, represented as a span. - /// The third tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = ([i] + [i]) * [i]. - public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan multiplier, Span destination) => - InvokeSpanSpanSpanIntoSpan(x, y, multiplier, destination); + /// The cosine similarity of the two tensors. + /// Length of must be same as length of . + /// and must not be empty. + /// + /// + /// This method effectively computes TensorPrimitives.Dot(x, y) / (MathF.Sqrt(TensorPrimitives.SumOfSquares(x)) * MathF.Sqrt(TensorPrimitives.SumOfSquares(y)). + /// + /// + /// If any element in either input tensor is equal to , , or , + /// NaN is returned. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float CosineSimilarity(ReadOnlySpan x, ReadOnlySpan y) + { + if (x.IsEmpty) + { + ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); + } + + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + return CosineSimilarityCore(x, y); + } - /// Computes the element-wise result of: ( + ) * . + /// Computes the distance between two points, specified as non-empty, equal-length tensors of single-precision floating-point numbers, in Euclidean space. /// The first tensor, represented as a span. /// The second tensor, represented as a span. - /// The third tensor, represented as a scalar. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = ([i] + [i]) * . - public static void AddMultiply(ReadOnlySpan x, ReadOnlySpan y, float multiplier, Span destination) => - InvokeSpanSpanScalarIntoSpan(x, y, multiplier, destination); + /// The Euclidean distance. + /// Length of must be same as length of . + /// and must not be empty. + /// + /// + /// This method effectively computes the equivalent of: + /// + /// Span<float> difference = ...; + /// TensorPrimitives.Subtract(x, y, difference); + /// float result = MathF.Sqrt(TensorPrimitives.SumOfSquares(difference)); + /// + /// but without requiring additional temporary storage for the intermediate differences. + /// + /// + /// If any element in either input tensor is equal to , NaN is returned. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float Distance(ReadOnlySpan x, ReadOnlySpan y) + { + if (x.IsEmpty) + { + ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); + } - /// Computes the element-wise result of: ( + ) * . - /// The first tensor, represented as a span. - /// The second tensor, represented as a scalar. - /// The third tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = ([i] + ) * [i]. - public static void AddMultiply(ReadOnlySpan x, float y, ReadOnlySpan multiplier, Span destination) => - InvokeSpanScalarSpanIntoSpan(x, y, multiplier, destination); + return MathF.Sqrt(Aggregate(x, y)); + } - /// Computes the element-wise result of: ( * ) + . + /// Computes the element-wise division of single-precision floating-point numbers in the specified tensors. /// The first tensor, represented as a span. /// The second tensor, represented as a span. - /// The third tensor, represented as a span. /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Length of '' must be same as length of ''. + /// Length of must be same as length of . /// Destination is too short. - /// This method effectively does [i] = ([i] * [i]) + [i]. - public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan addend, Span destination) => - InvokeSpanSpanSpanIntoSpan(x, y, addend, destination); + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] / [i]. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Divide(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); - /// Computes the element-wise result of: ( * ) + . + /// Computes the element-wise division of single-precision floating-point numbers in the specified tensors. /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The third tensor, represented as a span. + /// The second tensor, represented as a scalar. /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. /// - /// This method effectively does [i] = ([i] * [i]) + . - /// This method corresponds to the axpy method defined by BLAS1. + /// + /// This method effectively computes [i] = [i] / . + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// /// - public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, float addend, Span destination) => - InvokeSpanSpanScalarIntoSpan(x, y, addend, destination); + public static void Divide(ReadOnlySpan x, float y, Span destination) => + InvokeSpanScalarIntoSpan(x, y, destination); - /// Computes the element-wise result of: ( * ) + . + /// Computes the dot product of two tensors containing single-precision floating-point numbers. /// The first tensor, represented as a span. /// The second tensor, represented as a span. - /// The third tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Length of '' must be same as length of ''. - /// Destination is too short. - /// This method effectively does [i] = ([i] * ) + [i]. - public static void MultiplyAdd(ReadOnlySpan x, float y, ReadOnlySpan addend, Span destination) => - InvokeSpanScalarSpanIntoSpan(x, y, addend, destination); + /// The dot product. + /// Length of must be same as length of . + /// + /// + /// This method effectively computes the equivalent of: + /// + /// Span<float> products = ...; + /// TensorPrimitives.Multiply(x, y, products); + /// float result = TensorPrimitives.Sum(products); + /// + /// but without requiring additional temporary storage for the intermediate products. It corresponds to the dot method defined by BLAS1. + /// + /// + /// If any of the input elements is equal to , the resulting value is also NaN. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float Dot(ReadOnlySpan x, ReadOnlySpan y) => + Aggregate(x, y); - /// Computes the element-wise result of: pow(e, ). + /// Computes the element-wise result of raising e to the single-precision floating-point number powers in the specified tensor. /// The tensor, represented as a span. /// The destination tensor, represented as a span. /// Destination is too short. - /// This method effectively does [i] = .Exp([i]). - public static void Exp(ReadOnlySpan x, Span destination) + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = .Exp([i]). + /// + /// + /// If a value equals or , the result stored into the corresponding destination location is set to NaN. + /// If a value equals , the result stored into the corresponding destination location is set to 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Exp(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); + + /// Searches for the index of the largest single-precision floating-point number in the specified tensor. + /// The tensor, represented as a span. + /// The index of the maximum element in , or -1 if is empty. + /// + /// + /// The determination of the maximum element matches the IEEE 754:2019 `maximum` function. If any value equal to + /// is present, the index of the first is returned. Positive 0 is considered greater than negative 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static int IndexOfMax(ReadOnlySpan x) { - if (x.Length > destination.Length) + if (x.IsEmpty) { - ThrowHelper.ThrowArgument_DestinationTooShort(); + return -1; } - for (int i = 0; i < x.Length; i++) - { - destination[i] = MathF.Exp(x[i]); - } + return IndexOfMinMaxCore(x); } - /// Computes the element-wise result of: ln(). + /// Searches for the index of the single-precision floating-point number with the largest magnitude in the specified tensor. /// The tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Destination is too short. - /// This method effectively does [i] = .Log([i]). - public static void Log(ReadOnlySpan x, Span destination) + /// The index of the element in with the largest magnitude (absolute value), or -1 if is empty. + /// + /// + /// The determination of the maximum magnitude matches the IEEE 754:2019 `maximumMagnitude` function. If any value equal to + /// is present, the index of the first is returned. If two values have the same magnitude and one is positive and the other is negative, + /// the positive value is considered to have the larger magnitude. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static int IndexOfMaxMagnitude(ReadOnlySpan x) { - if (x.Length > destination.Length) + if (x.IsEmpty) { - ThrowHelper.ThrowArgument_DestinationTooShort(); + return -1; } - for (int i = 0; i < x.Length; i++) - { - destination[i] = MathF.Log(x[i]); - } + return IndexOfMinMaxCore(x); } - /// Computes the element-wise result of: cosh(). + /// Searches for the index of the smallest single-precision floating-point number in the specified tensor. /// The tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Destination is too short. - /// This method effectively does [i] = .Cosh([i]). - public static void Cosh(ReadOnlySpan x, Span destination) + /// The index of the minimum element in , or -1 if is empty. + /// + /// + /// The determination of the minimum element matches the IEEE 754:2019 `minimum` function. If any value equal to + /// is present, the index of the first is returned. Negative 0 is considered smaller than positive 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static int IndexOfMin(ReadOnlySpan x) { - if (x.Length > destination.Length) + if (x.IsEmpty) { - ThrowHelper.ThrowArgument_DestinationTooShort(); + return -1; } - for (int i = 0; i < x.Length; i++) - { - destination[i] = MathF.Cosh(x[i]); - } + return IndexOfMinMaxCore(x); } - /// Computes the element-wise result of: sinh(). + /// Searches for the index of the single-precision floating-point number with the smallest magnitude in the specified tensor. /// The tensor, represented as a span. - /// The destination tensor, represented as a span. - /// Destination is too short. - /// This method effectively does [i] = .Sinh([i]). - public static void Sinh(ReadOnlySpan x, Span destination) + /// The index of the element in with the smallest magnitude (absolute value), or -1 if is empty. + /// + /// + /// The determination of the minimum magnitude matches the IEEE 754:2019 `minimumMagnitude` function. If any value equal to + /// is present, the index of the first is returned. If two values have the same magnitude and one is positive and the other is negative, + /// the negative value is considered to have the smaller magnitude. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static int IndexOfMinMagnitude(ReadOnlySpan x) { - if (x.Length > destination.Length) + if (x.IsEmpty) { - ThrowHelper.ThrowArgument_DestinationTooShort(); + return -1; } - for (int i = 0; i < x.Length; i++) - { - destination[i] = MathF.Sinh(x[i]); - } + return IndexOfMinMaxCore(x); } - /// Computes the element-wise result of: tanh(). + /// Computes the element-wise natural (base e) logarithm of single-precision floating-point numbers in the specified tensor. /// The tensor, represented as a span. /// The destination tensor, represented as a span. /// Destination is too short. - /// This method effectively does [i] = .Tanh([i]). - public static void Tanh(ReadOnlySpan x, Span destination) - { - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = .Log([i]). + /// + /// + /// If a value equals 0, the result stored into the corresponding destination location is set to . + /// If a value is negative or equal to , the result stored into the corresponding destination location is set to NaN. + /// If a value is positive infinity, the result stored into the corresponding destination location is set to . + /// Otherwise, if a value is positive, its natural logarithm is stored into the corresponding destination location. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Log(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); - for (int i = 0; i < x.Length; i++) - { - destination[i] = MathF.Tanh(x[i]); - } - } + /// Computes the element-wise base 2 logarithm of single-precision floating-point numbers in the specified tensor. + /// The tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = .Log2([i]). + /// + /// + /// If a value equals 0, the result stored into the corresponding destination location is set to . + /// If a value is negative or equal to , the result stored into the corresponding destination location is set to NaN. + /// If a value is positive infinity, the result stored into the corresponding destination location is set to . + /// Otherwise, if a value is positive, its natural logarithm is stored into the corresponding destination location. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Log2(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); + + /// Searches for the largest single-precision floating-point number in the specified tensor. + /// The tensor, represented as a span. + /// The maximum element in . + /// Length of must be greater than zero. + /// + /// + /// The determination of the maximum element matches the IEEE 754:2019 `maximum` function. If any value equal to + /// is present, the first is returned. Positive 0 is considered greater than negative 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float Max(ReadOnlySpan x) => + MinMaxCore(x); - /// Computes the cosine similarity between two non-zero vectors. + /// Computes the element-wise maximum of the single-precision floating-point numbers in the specified tensors. /// The first tensor, represented as a span. /// The second tensor, represented as a span. - /// The cosine similarity between the two vectors. - /// Length of '' must be same as length of ''. - /// '' and '' must not be empty. - public static float CosineSimilarity(ReadOnlySpan x, ReadOnlySpan y) - { - if (x.IsEmpty) - { - ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); - } - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = MathF.Max([i], [i]). + /// + /// + /// The determination of the maximum element matches the IEEE 754:2019 `maximum` function. If either value is equal to , + /// that value is stored as the result. Positive 0 is considered greater than negative 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Max(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); - return CosineSimilarityCore(x, y); - } + /// Searches for the single-precision floating-point number with the largest magnitude in the specified tensor. + /// The tensor, represented as a span. + /// The element in with the largest magnitude (absolute value). + /// Length of must be greater than zero. + /// + /// + /// The determination of the maximum magnitude matches the IEEE 754:2019 `maximumMagnitude` function. If any value equal to + /// is present, the first is returned. If two values have the same magnitude and one is positive and the other is negative, + /// the positive value is considered to have the larger magnitude. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float MaxMagnitude(ReadOnlySpan x) => + MinMaxCore(x); - /// - /// Compute the distance between two points in Euclidean space. - /// + /// Computes the element-wise single-precision floating-point number with the largest magnitude in the specified tensors. /// The first tensor, represented as a span. /// The second tensor, represented as a span. - /// The Euclidean distance. - /// Length of '' must be same as length of ''. - /// '' and '' must not be empty. - public static float Distance(ReadOnlySpan x, ReadOnlySpan y) - { - if (x.IsEmpty) - { - ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); - } - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// This method effectively computes [i] = MathF.MaxMagnitude([i], [i]). + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void MaxMagnitude(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); - return MathF.Sqrt(Aggregate(0f, x, y)); - } + /// Searches for the smallest single-precision floating-point number in the specified tensor. + /// The tensor, represented as a span. + /// The minimum element in . + /// Length of must be greater than zero. + /// + /// + /// The determination of the minimum element matches the IEEE 754:2019 `minimum` function. If any value is equal to + /// is present, the first is returned. Negative 0 is considered smaller than positive 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float Min(ReadOnlySpan x) => + MinMaxCore(x); - /// - /// A mathematical operation that takes two vectors and returns a scalar. - /// + /// Computes the element-wise minimum of the single-precision floating-point numbers in the specified tensors. /// The first tensor, represented as a span. /// The second tensor, represented as a span. - /// The dot product. - /// Length of '' must be same as length of ''. - public static float Dot(ReadOnlySpan x, ReadOnlySpan y) // BLAS1: dot - { - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = MathF.Max([i], [i]). + /// + /// + /// The determination of the maximum element matches the IEEE 754:2019 `maximum` function. If either value is equal to , + /// that value is stored as the result. Positive 0 is considered greater than negative 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Min(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); - return Aggregate(0f, x, y); - } + /// Searches for the single-precision floating-point number with the smallest magnitude in the specified tensor. + /// The tensor, represented as a span. + /// The element in with the smallest magnitude (absolute value). + /// Length of must be greater than zero. + /// + /// + /// The determination of the minimum magnitude matches the IEEE 754:2019 `minimumMagnitude` function. If any value equal to + /// is present, the first is returned. If two values have the same magnitude and one is positive and the other is negative, + /// the negative value is considered to have the smaller magnitude. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float MinMagnitude(ReadOnlySpan x) => + MinMaxCore(x); - /// - /// A mathematical operation that takes a vector and returns the L2 norm. - /// + /// Computes the element-wise single-precision floating-point number with the smallest magnitude in the specified tensors. /// The first tensor, represented as a span. - /// The L2 norm. - public static float L2Normalize(ReadOnlySpan x) // BLAS1: nrm2 - { - return MathF.Sqrt(Aggregate(0f, x)); - } + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// This method effectively computes [i] = MathF.MinMagnitude([i], [i]). + /// + /// + /// The determination of the maximum magnitude matches the IEEE 754:2019 `minimumMagnitude` function. If either value is equal to , + /// that value is stored as the result. If the two values have the same magnitude and one is positive and the other is negative, + /// the negative value is considered to have the smaller magnitude. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void MinMagnitude(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); - /// - /// A function that takes a collection of real numbers and returns a probability distribution. - /// + /// Computes the element-wise product of single-precision floating-point numbers in the specified tensors. /// The first tensor, represented as a span. - /// The destination tensor. + /// The second tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . /// Destination is too short. - /// '' must not be empty. - public static void SoftMax(ReadOnlySpan x, Span destination) - { - if (x.IsEmpty) - { - ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); - } - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] * [i]. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Multiply(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); - float expSum = 0f; + /// Computes the element-wise product of single-precision floating-point numbers in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] * . + /// It corresponds to the scal method defined by BLAS1. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Multiply(ReadOnlySpan x, float y, Span destination) => + InvokeSpanScalarIntoSpan(x, y, destination); - for (int i = 0; i < x.Length; i++) - { - expSum += MathF.Pow((float)Math.E, x[i]); - } + /// Computes the element-wise result of ( * ) * for the specified tensors of single-precision floating-point numbers. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of and length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = ([i] * [i]) + [i]. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan addend, Span destination) => + InvokeSpanSpanSpanIntoSpan(x, y, addend, destination); - for (int i = 0; i < x.Length; i++) - { - destination[i] = MathF.Exp(x[i]) / expSum; - } - } + /// Computes the element-wise result of ( * ) * for the specified tensors of single-precision floating-point numbers. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The third tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = ([i] * [i]) + . + /// It corresponds to the axpy method defined by BLAS1. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void MultiplyAdd(ReadOnlySpan x, ReadOnlySpan y, float addend, Span destination) => + InvokeSpanSpanScalarIntoSpan(x, y, addend, destination); - /// - /// A function that takes a real number and returns a value between 0 and 1. - /// + /// Computes the element-wise result of ( * ) * for the specified tensors of single-precision floating-point numbers. /// The first tensor, represented as a span. - /// The destination tensor. + /// The second tensor, represented as a scalar. + /// The third tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . /// Destination is too short. - /// '' must not be empty. - public static void Sigmoid(ReadOnlySpan x, Span destination) - { - if (x.IsEmpty) - { - ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); - } - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = ([i] * ) + [i]. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void MultiplyAdd(ReadOnlySpan x, float y, ReadOnlySpan addend, Span destination) => + InvokeSpanScalarSpanIntoSpan(x, y, addend, destination); - for (int i = 0; i < x.Length; i++) - { - destination[i] = 1f / (1 + MathF.Exp(-x[i])); - } - } + /// Computes the element-wise negation of each single-precision floating-point number in the specified tensor. + /// The tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = -[i]. + /// + /// + /// If any of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Negate(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); - /// Computes the maximum element in . + /// Computes the Euclidean norm of the specified tensor of single-precision floating-point numbers. + /// The first tensor, represented as a span. + /// The norm. + /// + /// + /// This method effectively computes MathF.Sqrt(TensorPrimitives.SumOfSquares(x)). + /// This is often referred to as the Euclidean norm or L2 norm. + /// It corresponds to the nrm2 method defined by BLAS1. + /// + /// + /// If any of the input values is equal to , the result value is also NaN. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float Norm(ReadOnlySpan x) => + MathF.Sqrt(SumOfSquares(x)); + + /// Computes the product of all elements in the specified non-empty tensor of single-precision floating-point numbers. /// The tensor, represented as a span. - /// The maximum element in . - /// Length of '' must be greater than zero. - public static float Max(ReadOnlySpan x) + /// The result of multiplying all elements in . + /// Length of must be greater than zero. + /// + /// + /// If any of the input values is equal to , the result value is also NaN. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float Product(ReadOnlySpan x) { if (x.IsEmpty) { ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); } - float result = float.NegativeInfinity; - - for (int i = 0; i < x.Length; i++) - { - // This matches the IEEE 754:2019 `maximum` function. - // It propagates NaN inputs back to the caller and - // otherwise returns the greater of the inputs. - // It treats +0 as greater than -0 as per the specification. - - float current = x[i]; - - if (current != result) - { - if (float.IsNaN(current)) - { - return current; - } - - if (result < current) - { - result = current; - } - } - else if (IsNegative(result)) - { - result = current; - } - } - - return result; + return Aggregate(x); } - /// Computes the minimum element in . - /// The tensor, represented as a span. - /// The minimum element in . - /// Length of '' must be greater than zero. - public static float Min(ReadOnlySpan x) + /// Computes the product of the element-wise differences of the single-precision floating-point numbers in the specified non-empty tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The result of multiplying the element-wise subtraction of the elements in the second tensor from the first tensor. + /// Length of both input spans must be greater than zero. + /// and must have the same length. + /// + /// + /// This method effectively computes: + /// + /// Span<float> differences = ...; + /// TensorPrimitives.Subtract(x, y, differences); + /// float result = TensorPrimitives.Product(differences); + /// + /// but without requiring additional temporary storage for the intermediate differences. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float ProductOfDifferences(ReadOnlySpan x, ReadOnlySpan y) { if (x.IsEmpty) { ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); } - float result = float.PositiveInfinity; - - for (int i = 0; i < x.Length; i++) - { - // This matches the IEEE 754:2019 `minimum` function - // It propagates NaN inputs back to the caller and - // otherwise returns the lesser of the inputs. - // It treats +0 as greater than -0 as per the specification. - - float current = x[i]; - - if (current != result) - { - if (float.IsNaN(current)) - { - return current; - } - - if (current < result) - { - result = current; - } - } - else if (IsNegative(current)) - { - result = current; - } - } - - return result; + return Aggregate(x, y); } - /// Computes the maximum magnitude of any element in . - /// The tensor, represented as a span. - /// The maximum magnitude of any element in . - /// Length of '' must be greater than zero. - public static float MaxMagnitude(ReadOnlySpan x) + /// Computes the product of the element-wise sums of the single-precision floating-point numbers in the specified non-empty tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a span. + /// The result of multiplying the element-wise additions of the elements in each tensor. + /// Length of both input spans must be greater than zero. + /// and must have the same length. + /// + /// + /// This method effectively computes: + /// + /// Span<float> sums = ...; + /// TensorPrimitives.Add(x, y, sums); + /// float result = TensorPrimitives.Product(sums); + /// + /// but without requiring additional temporary storage for the intermediate sums. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float ProductOfSums(ReadOnlySpan x, ReadOnlySpan y) { if (x.IsEmpty) { ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); } - float result = float.NegativeInfinity; - float resultMag = float.NegativeInfinity; - - for (int i = 0; i < x.Length; i++) - { - // This matches the IEEE 754:2019 `maximumMagnitude` function. - // It propagates NaN inputs back to the caller and - // otherwise returns the input with a greater magnitude. - // It treats +0 as greater than -0 as per the specification. - - float current = x[i]; - float currentMag = Math.Abs(current); - - if (currentMag != resultMag) - { - if (float.IsNaN(currentMag)) - { - return currentMag; - } - - if (resultMag < currentMag) - { - result = current; - resultMag = currentMag; - } - } - else if (IsNegative(result)) - { - result = current; - resultMag = currentMag; - } - } - - return resultMag; + return Aggregate(x, y); } - /// Computes the minimum magnitude of any element in . + /// Computes the element-wise sigmoid function on the specified non-empty tensor of single-precision floating-point numbers. /// The tensor, represented as a span. - /// The minimum magnitude of any element in . - /// Length of '' must be greater than zero. - public static float MinMagnitude(ReadOnlySpan x) + /// The destination tensor. + /// Destination is too short. + /// must not be empty. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = 1f / (1f + .Exp(-[i])). + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Sigmoid(ReadOnlySpan x, Span destination) { if (x.IsEmpty) { ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); } - float resultMag = float.PositiveInfinity; - - for (int i = 0; i < x.Length; i++) - { - // This matches the IEEE 754:2019 `minimumMagnitude` function. - // It propagates NaN inputs back to the caller and - // otherwise returns the input with a lesser magnitude. - // It treats +0 as greater than -0 as per the specification. - - float current = x[i]; - float currentMag = Math.Abs(current); - - if (currentMag != resultMag) - { - if (float.IsNaN(currentMag)) - { - return currentMag; - } - - if (currentMag < resultMag) - { - resultMag = currentMag; - } - } - else if (IsNegative(current)) - { - resultMag = currentMag; - } - } - - return resultMag; + InvokeSpanIntoSpan(x, destination); } - /// Computes the index of the maximum element in . + /// Computes the element-wise hyperbolic sine of each single-precision floating-point radian angle in the specified tensor. /// The tensor, represented as a span. - /// The index of the maximum element in , or -1 if is empty. - public static unsafe int IndexOfMax(ReadOnlySpan x) - { - int result = -1; - - if (!x.IsEmpty) - { - float max = float.NegativeInfinity; - - for (int i = 0; i < x.Length; i++) - { - // This matches the IEEE 754:2019 `maximum` function. - // It propagates NaN inputs back to the caller and - // otherwise returns the greater of the inputs. - // It treats +0 as greater than -0 as per the specification. - - float current = x[i]; - - if (current != max) - { - if (float.IsNaN(current)) - { - return i; - } - - if (max < current) - { - result = i; - max = current; - } - } - else if (IsNegative(max) && !IsNegative(current)) - { - result = i; - max = current; - } - } - } - - return result; - } + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = .Sinh([i]). + /// + /// + /// If a value is equal to , , or , + /// the corresponding destination location is set to that value. + /// + /// + /// The angles in x must be in radians. Use or multiply by /180 to convert degrees to radians. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Sinh(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); - /// Computes the index of the minimum element in . + /// Computes the softmax function over the specified non-empty tensor of single-precision floating-point numbers. /// The tensor, represented as a span. - /// The index of the minimum element in , or -1 if is empty. - public static unsafe int IndexOfMin(ReadOnlySpan x) + /// The destination tensor. + /// Destination is too short. + /// must not be empty. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes a sum of MathF.Exp(x[i]) for all elements in . + /// It then effectively computes [i] = MathF.Exp([i]) / sum. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void SoftMax(ReadOnlySpan x, Span destination) { - int result = -1; - - if (!x.IsEmpty) + if (x.IsEmpty) { - float min = float.PositiveInfinity; - - for (int i = 0; i < x.Length; i++) - { - // This matches the IEEE 754:2019 `minimum` function. - // It propagates NaN inputs back to the caller and - // otherwise returns the lesser of the inputs. - // It treats +0 as greater than -0 as per the specification. - - float current = x[i]; - - if (current != min) - { - if (float.IsNaN(current)) - { - return i; - } - - if (current < min) - { - result = i; - min = current; - } - } - else if (IsNegative(current) && !IsNegative(min)) - { - result = i; - min = current; - } - } + ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); } - return result; - } - - /// Computes the index of the element in with the maximum magnitude. - /// The tensor, represented as a span. - /// The index of the element with the maximum magnitude, or -1 if is empty. - /// This method corresponds to the iamax method defined by BLAS1. - public static unsafe int IndexOfMaxMagnitude(ReadOnlySpan x) - { - int result = -1; - - if (!x.IsEmpty) + if (x.Length > destination.Length) { - float max = float.NegativeInfinity; - float maxMag = float.NegativeInfinity; - - for (int i = 0; i < x.Length; i++) - { - // This matches the IEEE 754:2019 `maximumMagnitude` function. - // It propagates NaN inputs back to the caller and - // otherwise returns the input with a greater magnitude. - // It treats +0 as greater than -0 as per the specification. - - float current = x[i]; - float currentMag = Math.Abs(current); - - if (currentMag != maxMag) - { - if (float.IsNaN(currentMag)) - { - return i; - } - - if (maxMag < currentMag) - { - result = i; - max = current; - maxMag = currentMag; - } - } - else if (IsNegative(max) && !IsNegative(current)) - { - result = i; - max = current; - maxMag = currentMag; - } - } + ThrowHelper.ThrowArgument_DestinationTooShort(); } - return result; - } + ValidateInputOutputSpanNonOverlapping(x, destination); - /// Computes the index of the element in with the minimum magnitude. - /// The tensor, represented as a span. - /// The index of the element with the minimum magnitude, or -1 if is empty. - public static unsafe int IndexOfMinMagnitude(ReadOnlySpan x) - { - int result = -1; + float expSum = Aggregate(x); - if (!x.IsEmpty) - { - float min = float.PositiveInfinity; - float minMag = float.PositiveInfinity; - - for (int i = 0; i < x.Length; i++) - { - // This matches the IEEE 754:2019 `minimumMagnitude` function - // It propagates NaN inputs back to the caller and - // otherwise returns the input with a lesser magnitude. - // It treats +0 as greater than -0 as per the specification. - - float current = x[i]; - float currentMag = Math.Abs(current); - - if (currentMag != minMag) - { - if (float.IsNaN(currentMag)) - { - return i; - } - - if (currentMag < minMag) - { - result = i; - min = current; - minMag = currentMag; - } - } - else if (IsNegative(current) && !IsNegative(min)) - { - result = i; - min = current; - minMag = currentMag; - } - } - } - - return result; + InvokeSpanScalarIntoSpan(x, expSum, destination); } - /// Computes the sum of all elements in . + /// Computes the element-wise difference between single-precision floating-point numbers in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Length of must be same as length of . + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] - [i]. + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Subtract(ReadOnlySpan x, ReadOnlySpan y, Span destination) => + InvokeSpanSpanIntoSpan(x, y, destination); + + /// Computes the element-wise difference between single-precision floating-point numbers in the specified tensors. + /// The first tensor, represented as a span. + /// The second tensor, represented as a scalar. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = [i] - . + /// + /// + /// If either of the element-wise input values is equal to , the resulting element-wise value is also NaN. + /// + /// + public static void Subtract(ReadOnlySpan x, float y, Span destination) => + InvokeSpanScalarIntoSpan(x, y, destination); + + /// Computes the sum of all elements in the specified tensor of single-precision floating-point numbers. /// The tensor, represented as a span. /// The result of adding all elements in , or zero if is empty. + /// + /// + /// If any of the values in the input is equal to , the result is also NaN. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// public static float Sum(ReadOnlySpan x) => - Aggregate(0f, x); + Aggregate(x); - /// Computes the sum of the squares of every element in . - /// The tensor, represented as a span. - /// The result of adding every element in multiplied by itself, or zero if is empty. - /// This method effectively does .Sum(.Multiply(, )). - public static float SumOfSquares(ReadOnlySpan x) => - Aggregate(0f, x); - - /// Computes the sum of the absolute values of every element in . + /// Computes the sum of the absolute values of every element in the specified tensor of single-precision floating-point numbers. /// The tensor, represented as a span. /// The result of adding the absolute value of every element in , or zero if is empty. /// - /// This method effectively does .Sum(.Abs()). - /// This method corresponds to the asum method defined by BLAS1. + /// + /// This method effectively computes: + /// + /// Span<float> absoluteValues = ...; + /// TensorPrimitives.Abs(x, absoluteValues); + /// float result = TensorPrimitives.Sum(absoluteValues); + /// + /// but without requiring intermediate storage for the absolute values. It corresponds to the asum method defined by BLAS1. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// /// public static float SumOfMagnitudes(ReadOnlySpan x) => - Aggregate(0f, x); + Aggregate(x); - /// Computes the product of all elements in . + /// Computes the sum of the square of every element in the specified tensor of single-precision floating-point numbers. /// The tensor, represented as a span. - /// The result of multiplying all elements in . - /// Length of '' must be greater than zero. - public static float Product(ReadOnlySpan x) - { - if (x.IsEmpty) - { - ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); - } + /// The result of adding the square of every element in , or zero if is empty. + /// + /// + /// This method effectively computes: + /// + /// Span<float> squaredValues = ...; + /// TensorPrimitives.Multiply(x, x, squaredValues); + /// float result = TensorPrimitives.Sum(squaredValues); + /// + /// but without requiring intermediate storage for the squared values. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static float SumOfSquares(ReadOnlySpan x) => + Aggregate(x); - return Aggregate(1.0f, x); - } + /// Computes the element-wise hyperbolic tangent of each single-precision floating-point radian angle in the specified tensor. + /// The tensor, represented as a span. + /// The destination tensor, represented as a span. + /// Destination is too short. + /// and reference overlapping memory locations and do not begin at the same location. + /// + /// + /// This method effectively computes [i] = .Tanh([i]). + /// + /// + /// If a value is equal to , the corresponding destination location is set to -1. + /// If a value is equal to , the corresponding destination location is set to 1. + /// If a value is equal to , the corresponding destination location is set to NaN. + /// + /// + /// The angles in x must be in radians. Use or multiply by /180 to convert degrees to radians. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static void Tanh(ReadOnlySpan x, Span destination) => + InvokeSpanIntoSpan(x, destination); - /// Computes the product of the element-wise result of: + . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The result of multiplying the element-wise additions of the elements in each tensor. - /// Length of both input spans must be greater than zero. - /// and must have the same length. - /// This method effectively does .Product(.Add(, )). - public static float ProductOfSums(ReadOnlySpan x, ReadOnlySpan y) + /// Throws an exception if the and spans overlap and don't begin at the same memory location. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void ValidateInputOutputSpanNonOverlapping(ReadOnlySpan input, Span output) { - if (x.IsEmpty) + if (!Unsafe.AreSame(ref MemoryMarshal.GetReference(input), ref MemoryMarshal.GetReference(output)) && + input.Overlaps(output)) { - ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); - } - - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + ThrowHelper.ThrowArgument_InputAndDestinationSpanMustNotOverlap(); } - - return Aggregate(1.0f, x, y); } - /// Computes the product of the element-wise result of: - . - /// The first tensor, represented as a span. - /// The second tensor, represented as a span. - /// The result of multiplying the element-wise subtraction of the elements in the second tensor from the first tensor. - /// Length of both input spans must be greater than zero. - /// and must have the same length. - /// This method effectively does .Product(.Subtract(, )). - public static float ProductOfDifferences(ReadOnlySpan x, ReadOnlySpan y) - { - if (x.IsEmpty) - { - ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); - } - - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } - - return Aggregate(1.0f, x, y); - } + /// Mask used to handle alignment elements before vectorized handling of the input. + /// + /// Logically 16 rows of 16 uints. The Nth row should be used to handle N alignment elements at the + /// beginning of the input, where elements in the vector after that will be zero'd. + /// + /// There actually exists 17 rows in the table with the last row being a repeat of the first. This is + /// done because it allows the main algorithms to use a simplified algorithm when computing the amount + /// of misalignment where we always skip the first 16 elements, even if already aligned, so we don't + /// double process them. This allows us to avoid an additional branch. + /// + private static ReadOnlySpan AlignmentUInt32Mask_16x16 => + [ + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + ]; + + /// Mask used to handle remaining elements after vectorized handling of the input. + /// + /// Logically 16 rows of 16 uints. The Nth row should be used to handle N remaining elements at the + /// end of the input, where elements in the vector prior to that will be zero'd. + /// + /// Much as with the AlignmentMask table, we actually have 17 rows where the last row is a repeat of + /// the first. Doing this allows us to avoid an additional branch and instead to always process the + /// last 16 elements via a conditional select instead. + /// + private static ReadOnlySpan RemainderUInt32Mask_16x16 => + [ + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, + ]; } } diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs index ae5af404ac1aff..498e4b58da77ca 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs @@ -1,14 +1,33 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.Arm; +using System.Runtime.Intrinsics.X86; namespace System.Numerics.Tensors { - public static partial class TensorPrimitives + public static unsafe partial class TensorPrimitives { + /// Defines the threshold, in bytes, at which non-temporal stores will be used. + /// + /// A non-temporal store is one that allows the CPU to bypass the cache when writing to memory. + /// + /// This can be beneficial when working with large amounts of memory where the writes would otherwise + /// cause large amounts of repeated updates and evictions. The hardware optimization manuals recommend + /// the threshold to be roughly half the size of the last level of on-die cache -- that is, if you have approximately + /// 4MB of L3 cache per core, you'd want this to be approx. 1-2MB, depending on if hyperthreading was enabled. + /// + /// However, actually computing the amount of L3 cache per core can be tricky or error prone. Native memcpy + /// algorithms use a constant threshold that is typically around 256KB and we match that here for simplicity. This + /// threshold accounts for most processors in the last 10-15 years that had approx. 1MB L3 per core and support + /// hyperthreading, giving a per core last level cache of approx. 512KB. + /// + private const nuint NonTemporalByteThreshold = 256 * 1024; + /// /// Copies to , converting each /// value to its nearest representable half-precision floating-point value. @@ -16,6 +35,14 @@ public static partial class TensorPrimitives /// The source span from which to copy values. /// The destination span into which the converted values should be written. /// Destination is too short. + /// + /// + /// This method effectively computes [i] = (Half)[i]. + /// + /// + /// and must not overlap. If they do, behavior is undefined. + /// + /// public static void ConvertToHalf(ReadOnlySpan source, Span destination) { if (source.Length > destination.Length) @@ -23,461 +50,354 @@ public static void ConvertToHalf(ReadOnlySpan source, Span destinat ThrowHelper.ThrowArgument_DestinationTooShort(); } - for (int i = 0; i < source.Length; i++) - { - destination[i] = (Half)source[i]; - } - } - - /// - /// Copies to , converting each half-precision - /// floating-point value to its nearest representable value. - /// - /// The source span from which to copy values. - /// The destination span into which the converted values should be written. - /// Destination is too short. - public static void ConvertToSingle(ReadOnlySpan source, Span destination) - { - if (source.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } + ref float sourceRef = ref MemoryMarshal.GetReference(source); + ref ushort destinationRef = ref Unsafe.As(ref MemoryMarshal.GetReference(destination)); + int i = 0, twoVectorsFromEnd; - for (int i = 0; i < source.Length; i++) +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) { - destination[i] = (float)source[i]; - } - } + twoVectorsFromEnd = source.Length - (Vector512.Count * 2); + if (i <= twoVectorsFromEnd) + { + // Loop handling two input vectors / one output vector at a time. + do + { + Vector512 lower = SingleToHalfAsWidenedUInt32_Vector512(Vector512.LoadUnsafe(ref sourceRef, (uint)i)); + Vector512 upper = SingleToHalfAsWidenedUInt32_Vector512(Vector512.LoadUnsafe(ref sourceRef, (uint)(i + Vector512.Count))); + Vector512.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i); - private static bool IsNegative(float f) => float.IsNegative(f); + i += Vector512.Count * 2; + } + while (i <= twoVectorsFromEnd); - private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan y) - { - // Compute the same as: - // TensorPrimitives.Dot(x, y) / (Math.Sqrt(TensorPrimitives.SumOfSquares(x)) * Math.Sqrt(TensorPrimitives.SumOfSquares(y))) - // but only looping over each span once. + // Handle any remaining elements with final vectors. + if (i != source.Length) + { + i = source.Length - (Vector512.Count * 2); - float dotProduct = 0f; - float xSumOfSquares = 0f; - float ySumOfSquares = 0f; + Vector512 lower = SingleToHalfAsWidenedUInt32_Vector512(Vector512.LoadUnsafe(ref sourceRef, (uint)i)); + Vector512 upper = SingleToHalfAsWidenedUInt32_Vector512(Vector512.LoadUnsafe(ref sourceRef, (uint)(i + Vector512.Count))); + Vector512.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i); + } - int i = 0; + return; + } + } +#endif -#if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count) + if (Vector256.IsHardwareAccelerated) { - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); + twoVectorsFromEnd = source.Length - (Vector256.Count * 2); + if (i <= twoVectorsFromEnd) + { + // Loop handling two input vectors / one output vector at a time. + do + { + Vector256 lower = SingleToHalfAsWidenedUInt32_Vector256(Vector256.LoadUnsafe(ref sourceRef, (uint)i)); + Vector256 upper = SingleToHalfAsWidenedUInt32_Vector256(Vector256.LoadUnsafe(ref sourceRef, (uint)(i + Vector256.Count))); + Vector256 halfs = Vector256.Narrow(lower, upper); + halfs.StoreUnsafe(ref destinationRef, (uint)i); - Vector512 dotProductVector = Vector512.Zero; - Vector512 xSumOfSquaresVector = Vector512.Zero; - Vector512 ySumOfSquaresVector = Vector512.Zero; + i += Vector256.Count * 2; + } + while (i <= twoVectorsFromEnd); - // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining. - int oneVectorFromEnd = x.Length - Vector512.Count; - do - { - Vector512 xVec = Vector512.LoadUnsafe(ref xRef, (uint)i); - Vector512 yVec = Vector512.LoadUnsafe(ref yRef, (uint)i); + // Handle any remaining elements with final vectors. + if (i != source.Length) + { + i = source.Length - (Vector256.Count * 2); - dotProductVector += xVec * yVec; - xSumOfSquaresVector += xVec * xVec; - ySumOfSquaresVector += yVec * yVec; + Vector256 lower = SingleToHalfAsWidenedUInt32_Vector256(Vector256.LoadUnsafe(ref sourceRef, (uint)i)); + Vector256 upper = SingleToHalfAsWidenedUInt32_Vector256(Vector256.LoadUnsafe(ref sourceRef, (uint)(i + Vector256.Count))); + Vector256.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i); + } - i += Vector512.Count; + return; } - while (i <= oneVectorFromEnd); - - // Sum the vector lanes into the scalar result. - dotProduct += Vector512.Sum(dotProductVector); - xSumOfSquares += Vector512.Sum(xSumOfSquaresVector); - ySumOfSquares += Vector512.Sum(ySumOfSquaresVector); } - else -#endif - if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count) + + if (Vector128.IsHardwareAccelerated) { - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); + twoVectorsFromEnd = source.Length - (Vector128.Count * 2); + if (i <= twoVectorsFromEnd) + { + // Loop handling two input vectors / one output vector at a time. + do + { + Vector128 lower = SingleToHalfAsWidenedUInt32_Vector128(Vector128.LoadUnsafe(ref sourceRef, (uint)i)); + Vector128 upper = SingleToHalfAsWidenedUInt32_Vector128(Vector128.LoadUnsafe(ref sourceRef, (uint)(i + Vector128.Count))); + Vector128.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i); - Vector256 dotProductVector = Vector256.Zero; - Vector256 xSumOfSquaresVector = Vector256.Zero; - Vector256 ySumOfSquaresVector = Vector256.Zero; + i += Vector128.Count * 2; + } + while (i <= twoVectorsFromEnd); - // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining. - int oneVectorFromEnd = x.Length - Vector256.Count; - do - { - Vector256 xVec = Vector256.LoadUnsafe(ref xRef, (uint)i); - Vector256 yVec = Vector256.LoadUnsafe(ref yRef, (uint)i); + // Handle any remaining elements with final vectors. + if (i != source.Length) + { + i = source.Length - (Vector128.Count * 2); - dotProductVector += xVec * yVec; - xSumOfSquaresVector += xVec * xVec; - ySumOfSquaresVector += yVec * yVec; + Vector128 lower = SingleToHalfAsWidenedUInt32_Vector128(Vector128.LoadUnsafe(ref sourceRef, (uint)i)); + Vector128 upper = SingleToHalfAsWidenedUInt32_Vector128(Vector128.LoadUnsafe(ref sourceRef, (uint)(i + Vector128.Count))); + Vector128.Narrow(lower, upper).StoreUnsafe(ref destinationRef, (uint)i); + } - i += Vector256.Count; + return; } - while (i <= oneVectorFromEnd); + } - // Sum the vector lanes into the scalar result. - dotProduct += Vector256.Sum(dotProductVector); - xSumOfSquares += Vector256.Sum(xSumOfSquaresVector); - ySumOfSquares += Vector256.Sum(ySumOfSquaresVector); + while (i < source.Length) + { + Unsafe.Add(ref destinationRef, i) = BitConverter.HalfToUInt16Bits((Half)Unsafe.Add(ref sourceRef, i)); + i++; } - else if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count) + + // This implements a vectorized version of the `explicit operator Half(float value) operator`. + // See detailed description of the algorithm used here: + // https://github.com/dotnet/runtime/blob/ca8d6f0420096831766ec11c7d400e4f7ccc7a34/src/libraries/System.Private.CoreLib/src/System/Half.cs#L606-L714 + // The cast operator converts a float to a Half represented as a UInt32, then narrows to a UInt16, and reinterpret casts to Half. + // This does the same, with an input VectorXx and an output VectorXx. + // Loop handling two input vectors at a time; each input float is double the size of each output Half, + // so we need two vectors of floats to produce one vector of Halfs. Half isn't supported in VectorXx, + // so we convert the VectorXx to a VectorXx, and the caller then uses this twice, narrows the combination + // into a VectorXx, and then saves that out to the destination `ref Half` reinterpreted as `ref ushort`. + +#pragma warning disable IDE0059 // https://github.com/dotnet/roslyn/issues/44948 + const uint MinExp = 0x3880_0000u; // Minimum exponent for rounding + const uint Exponent126 = 0x3f00_0000u; // Exponent displacement #1 + const uint SingleBiasedExponentMask = 0x7F80_0000; // float.BiasedExponentMask; // Exponent mask + const uint Exponent13 = 0x0680_0000u; // Exponent displacement #2 + const float MaxHalfValueBelowInfinity = 65520.0f; // Maximum value that is not Infinity in Half + const uint ExponentMask = 0x7C00; // Mask for exponent bits in Half + const uint SingleSignMask = 0x8000_0000u; // float.SignMask; // Mask for sign bit in float +#pragma warning restore IDE0059 + + static Vector128 SingleToHalfAsWidenedUInt32_Vector128(Vector128 value) { - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); + Vector128 bitValue = value.AsUInt32(); - Vector128 dotProductVector = Vector128.Zero; - Vector128 xSumOfSquaresVector = Vector128.Zero; - Vector128 ySumOfSquaresVector = Vector128.Zero; + // Extract sign bit + Vector128 sign = Vector128.ShiftRightLogical(bitValue & Vector128.Create(SingleSignMask), 16); - // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining. - int oneVectorFromEnd = x.Length - Vector128.Count; - do - { - Vector128 xVec = Vector128.LoadUnsafe(ref xRef, (uint)i); - Vector128 yVec = Vector128.LoadUnsafe(ref yRef, (uint)i); + // Detecting NaN (0u if value is NaN; otherwise, ~0u) + Vector128 realMask = Vector128.Equals(value, value).AsUInt32(); - dotProductVector += xVec * yVec; - xSumOfSquaresVector += xVec * xVec; - ySumOfSquaresVector += yVec * yVec; + // Clear sign bit + value = Vector128.Abs(value); - i += Vector128.Count; - } - while (i <= oneVectorFromEnd); + // Rectify values that are Infinity in Half. + value = Vector128.Min(Vector128.Create(MaxHalfValueBelowInfinity), value); - // Sum the vector lanes into the scalar result. - dotProduct += Vector128.Sum(dotProductVector); - xSumOfSquares += Vector128.Sum(xSumOfSquaresVector); - ySumOfSquares += Vector128.Sum(ySumOfSquaresVector); - } + // Rectify lower exponent + Vector128 exponentOffset0 = Vector128.Max(value, Vector128.Create(MinExp).AsSingle()).AsUInt32(); - // Process any remaining elements past the last vector. - for (; (uint)i < (uint)x.Length; i++) - { - dotProduct += x[i] * y[i]; - xSumOfSquares += x[i] * x[i]; - ySumOfSquares += y[i] * y[i]; - } + // Extract exponent + exponentOffset0 &= Vector128.Create(SingleBiasedExponentMask); - // Sum(X * Y) / (|X| * |Y|) - return dotProduct / (MathF.Sqrt(xSumOfSquares) * MathF.Sqrt(ySumOfSquares)); - } + // Add exponent by 13 + exponentOffset0 += Vector128.Create(Exponent13); - private static float Aggregate( - float identityValue, ReadOnlySpan x) - where TLoad : IUnaryOperator - where TAggregate : IBinaryOperator - { - // Initialize the result to the identity value - float result = identityValue; - int i = 0; + // Round Single into Half's precision (NaN also gets modified here, just setting the MSB of fraction) + value += exponentOffset0.AsSingle(); + bitValue = value.AsUInt32(); -#if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count * 2) - { - ref float xRef = ref MemoryMarshal.GetReference(x); + // Only exponent bits will be modified if NaN + Vector128 maskedHalfExponentForNaN = ~realMask & Vector128.Create(ExponentMask); - // Load the first vector as the initial set of results - Vector512 resultVector = TLoad.Invoke(Vector512.LoadUnsafe(ref xRef, 0)); - int oneVectorFromEnd = x.Length - Vector512.Count; + // Subtract exponent by 126 + bitValue -= Vector128.Create(Exponent126); - // Aggregate additional vectors into the result as long as there's at - // least one full vector left to process. - i = Vector512.Count; - do - { - resultVector = TAggregate.Invoke(resultVector, TLoad.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i))); - i += Vector512.Count; - } - while (i <= oneVectorFromEnd); + // Shift bitValue right by 13 bits to match the boundary of exponent part and fraction part. + Vector128 newExponent = Vector128.ShiftRightLogical(bitValue, 13); - // Aggregate the lanes in the vector back into the scalar result - result = TAggregate.Invoke(result, TAggregate.Invoke(resultVector)); - } - else -#endif - if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count * 2) - { - ref float xRef = ref MemoryMarshal.GetReference(x); + // Clear the fraction parts if the value was NaN. + bitValue &= realMask; - // Load the first vector as the initial set of results - Vector256 resultVector = TLoad.Invoke(Vector256.LoadUnsafe(ref xRef, 0)); - int oneVectorFromEnd = x.Length - Vector256.Count; + // Merge the exponent part with fraction part, and add the exponent part and fraction part's overflow. + bitValue += newExponent; - // Aggregate additional vectors into the result as long as there's at - // least one full vector left to process. - i = Vector256.Count; - do - { - resultVector = TAggregate.Invoke(resultVector, TLoad.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i))); - i += Vector256.Count; - } - while (i <= oneVectorFromEnd); + // Clear exponents if value is NaN + bitValue &= ~maskedHalfExponentForNaN; + + // Merge sign bit with possible NaN exponent + Vector128 signAndMaskedExponent = maskedHalfExponentForNaN | sign; - // Aggregate the lanes in the vector back into the scalar result - result = TAggregate.Invoke(result, TAggregate.Invoke(resultVector)); + // Merge sign bit and possible NaN exponent + bitValue |= signAndMaskedExponent; + + // The final result + return bitValue; } - else if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count * 2) + + static Vector256 SingleToHalfAsWidenedUInt32_Vector256(Vector256 value) { - ref float xRef = ref MemoryMarshal.GetReference(x); + Vector256 bitValue = value.AsUInt32(); - // Load the first vector as the initial set of results - Vector128 resultVector = TLoad.Invoke(Vector128.LoadUnsafe(ref xRef, 0)); - int oneVectorFromEnd = x.Length - Vector128.Count; + // Extract sign bit + Vector256 sign = Vector256.ShiftRightLogical(bitValue & Vector256.Create(SingleSignMask), 16); - // Aggregate additional vectors into the result as long as there's at - // least one full vector left to process. - i = Vector128.Count; - do - { - resultVector = TAggregate.Invoke(resultVector, TLoad.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i))); - i += Vector128.Count; - } - while (i <= oneVectorFromEnd); + // Detecting NaN (0u if value is NaN; otherwise, ~0u) + Vector256 realMask = Vector256.Equals(value, value).AsUInt32(); - // Aggregate the lanes in the vector back into the scalar result - result = TAggregate.Invoke(result, TAggregate.Invoke(resultVector)); - } + // Clear sign bit + value = Vector256.Abs(value); - // Aggregate the remaining items in the input span. - for (; (uint)i < (uint)x.Length; i++) - { - result = TAggregate.Invoke(result, TLoad.Invoke(x[i])); - } + // Rectify values that are Infinity in Half. + value = Vector256.Min(Vector256.Create(MaxHalfValueBelowInfinity), value); - return result; - } + // Rectify lower exponent + Vector256 exponentOffset0 = Vector256.Max(value, Vector256.Create(MinExp).AsSingle()).AsUInt32(); - private static float Aggregate( - float identityValue, ReadOnlySpan x, ReadOnlySpan y) - where TBinary : IBinaryOperator - where TAggregate : IBinaryOperator - { - // Initialize the result to the identity value - float result = identityValue; - int i = 0; + // Extract exponent + exponentOffset0 &= Vector256.Create(SingleBiasedExponentMask); -#if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count * 2) - { - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); + // Add exponent by 13 + exponentOffset0 += Vector256.Create(Exponent13); - // Load the first vector as the initial set of results - Vector512 resultVector = TBinary.Invoke(Vector512.LoadUnsafe(ref xRef, 0), Vector512.LoadUnsafe(ref yRef, 0)); - int oneVectorFromEnd = x.Length - Vector512.Count; + // Round Single into Half's precision (NaN also gets modified here, just setting the MSB of fraction) + value += exponentOffset0.AsSingle(); + bitValue = value.AsUInt32(); - // Aggregate additional vectors into the result as long as there's at - // least one full vector left to process. - i = Vector512.Count; - do - { - resultVector = TAggregate.Invoke(resultVector, TBinary.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), Vector512.LoadUnsafe(ref yRef, (uint)i))); - i += Vector512.Count; - } - while (i <= oneVectorFromEnd); + // Only exponent bits will be modified if NaN + Vector256 maskedHalfExponentForNaN = ~realMask & Vector256.Create(ExponentMask); - // Aggregate the lanes in the vector back into the scalar result - result = TAggregate.Invoke(result, TAggregate.Invoke(resultVector)); - } - else -#endif - if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count * 2) - { - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); + // Subtract exponent by 126 + bitValue -= Vector256.Create(Exponent126); - // Load the first vector as the initial set of results - Vector256 resultVector = TBinary.Invoke(Vector256.LoadUnsafe(ref xRef, 0), Vector256.LoadUnsafe(ref yRef, 0)); - int oneVectorFromEnd = x.Length - Vector256.Count; + // Shift bitValue right by 13 bits to match the boundary of exponent part and fraction part. + Vector256 newExponent = Vector256.ShiftRightLogical(bitValue, 13); - // Aggregate additional vectors into the result as long as there's at - // least one full vector left to process. - i = Vector256.Count; - do - { - resultVector = TAggregate.Invoke(resultVector, TBinary.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), Vector256.LoadUnsafe(ref yRef, (uint)i))); - i += Vector256.Count; - } - while (i <= oneVectorFromEnd); + // Clear the fraction parts if the value was NaN. + bitValue &= realMask; - // Aggregate the lanes in the vector back into the scalar result - result = TAggregate.Invoke(result, TAggregate.Invoke(resultVector)); - } - else if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count * 2) - { - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); + // Merge the exponent part with fraction part, and add the exponent part and fraction part's overflow. + bitValue += newExponent; - // Load the first vector as the initial set of results - Vector128 resultVector = TBinary.Invoke(Vector128.LoadUnsafe(ref xRef, 0), Vector128.LoadUnsafe(ref yRef, 0)); - int oneVectorFromEnd = x.Length - Vector128.Count; + // Clear exponents if value is NaN + bitValue &= ~maskedHalfExponentForNaN; - // Aggregate additional vectors into the result as long as there's at - // least one full vector left to process. - i = Vector128.Count; - do - { - resultVector = TAggregate.Invoke(resultVector, TBinary.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), Vector128.LoadUnsafe(ref yRef, (uint)i))); - i += Vector128.Count; - } - while (i <= oneVectorFromEnd); + // Merge sign bit with possible NaN exponent + Vector256 signAndMaskedExponent = maskedHalfExponentForNaN | sign; + + // Merge sign bit and possible NaN exponent + bitValue |= signAndMaskedExponent; - // Aggregate the lanes in the vector back into the scalar result - result = TAggregate.Invoke(result, TAggregate.Invoke(resultVector)); + // The final result + return bitValue; } - // Aggregate the remaining items in the input span. - for (; (uint)i < (uint)x.Length; i++) +#if NET8_0_OR_GREATER + static Vector512 SingleToHalfAsWidenedUInt32_Vector512(Vector512 value) { - result = TAggregate.Invoke(result, TBinary.Invoke(x[i], y[i])); - } + Vector512 bitValue = value.AsUInt32(); - return result; - } + // Extract sign bit + Vector512 sign = Vector512.ShiftRightLogical(bitValue & Vector512.Create(SingleSignMask), 16); - private static unsafe void InvokeSpanIntoSpan( - ReadOnlySpan x, Span destination) - where TUnaryOperator : IUnaryOperator - { - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } + // Detecting NaN (0u if value is NaN; otherwise, ~0u) + Vector512 realMask = Vector512.Equals(value, value).AsUInt32(); - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + // Clear sign bit + value = Vector512.Abs(value); -#if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector512.Count; - if (i <= oneVectorFromEnd) - { - // Loop handling one vector at a time. - do - { - TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + // Rectify values that are Infinity in Half. + value = Vector512.Min(Vector512.Create(MaxHalfValueBelowInfinity), value); - i += Vector512.Count; - } - while (i <= oneVectorFromEnd); + // Rectify lower exponent + Vector512 exponentOffset0 = Vector512.Max(value, Vector512.Create(MinExp).AsSingle()).AsUInt32(); - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); - } + // Extract exponent + exponentOffset0 &= Vector512.Create(SingleBiasedExponentMask); - return; - } - } -#endif + // Add exponent by 13 + exponentOffset0 += Vector512.Create(Exponent13); - if (Vector256.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector256.Count; - if (i <= oneVectorFromEnd) - { - // Loop handling one vector at a time. - do - { - TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + // Round Single into Half's precision (NaN also gets modified here, just setting the MSB of fraction) + value += exponentOffset0.AsSingle(); + bitValue = value.AsUInt32(); - i += Vector256.Count; - } - while (i <= oneVectorFromEnd); + // Only exponent bits will be modified if NaN + Vector512 maskedHalfExponentForNaN = ~realMask & Vector512.Create(ExponentMask); - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); - } + // Subtract exponent by 126 + bitValue -= Vector512.Create(Exponent126); - return; - } - } + // Shift bitValue right by 13 bits to match the boundary of exponent part and fraction part. + Vector512 newExponent = Vector512.ShiftRightLogical(bitValue, 13); - if (Vector128.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector128.Count; - if (i <= oneVectorFromEnd) - { - // Loop handling one vector at a time. - do - { - TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + // Clear the fraction parts if the value was NaN. + bitValue &= realMask; - i += Vector128.Count; - } - while (i <= oneVectorFromEnd); + // Merge the exponent part with fraction part, and add the exponent part and fraction part's overflow. + bitValue += newExponent; - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); - } + // Clear exponents if value is NaN + bitValue &= ~maskedHalfExponentForNaN; - return; - } - } + // Merge sign bit with possible NaN exponent + Vector512 signAndMaskedExponent = maskedHalfExponentForNaN | sign; - while (i < x.Length) - { - Unsafe.Add(ref dRef, i) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, i)); + // Merge sign bit and possible NaN exponent + bitValue |= signAndMaskedExponent; - i++; + // The final result + return bitValue; } +#endif } - private static unsafe void InvokeSpanSpanIntoSpan( - ReadOnlySpan x, ReadOnlySpan y, Span destination) - where TBinaryOperator : IBinaryOperator + /// + /// Copies to , converting each half-precision + /// floating-point value to its nearest representable value. + /// + /// The source span from which to copy values. + /// The destination span into which the converted values should be written. + /// Destination is too short. + /// + /// + /// This method effectively computes [i] = (float)[i]. + /// + /// + /// and must not overlap. If they do, behavior is undefined. + /// + /// + public static void ConvertToSingle(ReadOnlySpan source, Span destination) { - if (x.Length != y.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } - - if (x.Length > destination.Length) + if (source.Length > destination.Length) { ThrowHelper.ThrowArgument_DestinationTooShort(); } - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); - ref float dRef = ref MemoryMarshal.GetReference(destination); + ref short sourceRef = ref Unsafe.As(ref MemoryMarshal.GetReference(source)); + ref float destinationRef = ref MemoryMarshal.GetReference(destination); int i = 0, oneVectorFromEnd; #if NET8_0_OR_GREATER if (Vector512.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector512.Count; + oneVectorFromEnd = source.Length - Vector512.Count; if (i <= oneVectorFromEnd) { - // Loop handling one vector at a time. + // Loop handling one input vector / two output vectors at a time. do { - TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), - Vector512.LoadUnsafe(ref yRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + (Vector512 lower, Vector512 upper) = Vector512.Widen(Vector512.LoadUnsafe(ref sourceRef, (uint)i)); + HalfAsWidenedUInt32ToSingle_Vector512(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i); + HalfAsWidenedUInt32ToSingle_Vector512(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector512.Count)); - i += Vector512.Count; + i += Vector512.Count; } while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) + // Handle any remaining elements with a final input vector. + if (i != source.Length) { - uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - Vector512.LoadUnsafe(ref yRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + i = source.Length - Vector512.Count; + + (Vector512 lower, Vector512 upper) = Vector512.Widen(Vector512.LoadUnsafe(ref sourceRef, (uint)i)); + HalfAsWidenedUInt32ToSingle_Vector512(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i); + HalfAsWidenedUInt32ToSingle_Vector512(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector512.Count)); } return; @@ -487,25 +407,28 @@ private static unsafe void InvokeSpanSpanIntoSpan( if (Vector256.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector256.Count; + oneVectorFromEnd = source.Length - Vector256.Count; if (i <= oneVectorFromEnd) { - // Loop handling one vector at a time. + // Loop handling one input vector / two output vectors at a time. do { - TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), - Vector256.LoadUnsafe(ref yRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + (Vector256 lower, Vector256 upper) = Vector256.Widen(Vector256.LoadUnsafe(ref sourceRef, (uint)i)); + HalfAsWidenedUInt32ToSingle_Vector256(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i); + HalfAsWidenedUInt32ToSingle_Vector256(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector256.Count)); - i += Vector256.Count; + i += Vector256.Count; } while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) + // Handle any remaining elements with a final input vector. + if (i != source.Length) { - uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - Vector256.LoadUnsafe(ref yRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + i = source.Length - Vector256.Count; + + (Vector256 lower, Vector256 upper) = Vector256.Widen(Vector256.LoadUnsafe(ref sourceRef, (uint)i)); + HalfAsWidenedUInt32ToSingle_Vector256(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i); + HalfAsWidenedUInt32ToSingle_Vector256(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector256.Count)); } return; @@ -514,724 +437,11129 @@ private static unsafe void InvokeSpanSpanIntoSpan( if (Vector128.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector128.Count; + oneVectorFromEnd = source.Length - Vector128.Count; if (i <= oneVectorFromEnd) { - // Loop handling one vector at a time. + // Loop handling one input vector / two output vectors at a time. do { - TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), - Vector128.LoadUnsafe(ref yRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + (Vector128 lower, Vector128 upper) = Vector128.Widen(Vector128.LoadUnsafe(ref sourceRef, (uint)i)); + HalfAsWidenedUInt32ToSingle_Vector128(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i); + HalfAsWidenedUInt32ToSingle_Vector128(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector128.Count)); - i += Vector128.Count; + i += Vector128.Count; } while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) + // Handle any remaining elements with a final input vector. + if (i != source.Length) { - uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - Vector128.LoadUnsafe(ref yRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); + i = source.Length - Vector128.Count; + + (Vector128 lower, Vector128 upper) = Vector128.Widen(Vector128.LoadUnsafe(ref sourceRef, (uint)i)); + HalfAsWidenedUInt32ToSingle_Vector128(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i); + HalfAsWidenedUInt32ToSingle_Vector128(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector128.Count)); } return; } } - while (i < x.Length) + while (i < source.Length) { - Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i)); - + Unsafe.Add(ref destinationRef, i) = (float)Unsafe.As(ref Unsafe.Add(ref sourceRef, i)); i++; } - } - private static unsafe void InvokeSpanScalarIntoSpan( - ReadOnlySpan x, float y, Span destination) - where TBinaryOperator : IBinaryOperator - { - if (x.Length > destination.Length) + // This implements a vectorized version of the `explicit operator float(Half value) operator`. + // See detailed description of the algorithm used here: + // https://github.com/dotnet/runtime/blob/3bf40a378f00cb5bf18ff62796bc7097719b974c/src/libraries/System.Private.CoreLib/src/System/Half.cs#L1010-L1040 + // The cast operator converts a Half represented as uint to a float. This does the same, with an input VectorXx and an output VectorXx. + // The VectorXx is created by reading a vector of Halfs as a VectorXx then widened to two VectorXxs and cast to VectorXxs. + // We loop handling one input vector at a time, producing two output float vectors. + +#pragma warning disable IDE0059 // https://github.com/dotnet/roslyn/issues/44948 + const uint ExponentLowerBound = 0x3880_0000u; // The smallest positive normal number in Half, converted to Single + const uint ExponentOffset = 0x3800_0000u; // BitConverter.SingleToUInt32Bits(1.0f) - ((uint)BitConverter.HalfToUInt16Bits((Half)1.0f) << 13) + const uint SingleSignMask = 0x8000_0000; // float.SignMask; // Mask for sign bit in Single + const uint HalfExponentMask = 0x7C00; // Mask for exponent bits in Half + const uint HalfToSingleBitsMask = 0x0FFF_E000; // Mask for bits in Single converted from Half +#pragma warning restore IDE0059 + + static Vector128 HalfAsWidenedUInt32ToSingle_Vector128(Vector128 value) { - ThrowHelper.ThrowArgument_DestinationTooShort(); + // Extract sign bit of value + Vector128 sign = value & Vector128.Create(SingleSignMask); + + // Copy sign bit to upper bits + Vector128 bitValueInProcess = value; + + // Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift) + Vector128 offsetExponent = bitValueInProcess & Vector128.Create(HalfExponentMask); + + // ~0u when value is subnormal, 0 otherwise + Vector128 subnormalMask = Vector128.Equals(offsetExponent, Vector128.Zero); + + // ~0u when value is either Infinity or NaN, 0 otherwise + Vector128 infinityOrNaNMask = Vector128.Equals(offsetExponent, Vector128.Create(HalfExponentMask)); + + // 0x3880_0000u if value is subnormal, 0 otherwise + Vector128 maskedExponentLowerBound = subnormalMask & Vector128.Create(ExponentLowerBound); + + // 0x3880_0000u if value is subnormal, 0x3800_0000u otherwise + Vector128 offsetMaskedExponentLowerBound = Vector128.Create(ExponentOffset) | maskedExponentLowerBound; + + // Match the position of the boundary of exponent bits and fraction bits with IEEE 754 Binary32(Single) + bitValueInProcess = Vector128.ShiftLeft(bitValueInProcess, 13); + + // Double the offsetMaskedExponentLowerBound if value is either Infinity or NaN + offsetMaskedExponentLowerBound = Vector128.ConditionalSelect(Vector128.Equals(infinityOrNaNMask, Vector128.Zero), + offsetMaskedExponentLowerBound, + Vector128.ShiftLeft(offsetMaskedExponentLowerBound, 1)); + + // Extract exponent bits and fraction bits of value + bitValueInProcess &= Vector128.Create(HalfToSingleBitsMask); + + // Adjust exponent to match the range of exponent + bitValueInProcess += offsetMaskedExponentLowerBound; + + // If value is subnormal, remove unnecessary 1 on top of fraction bits. + Vector128 absoluteValue = (bitValueInProcess.AsSingle() - maskedExponentLowerBound.AsSingle()).AsUInt32(); + + // Merge sign bit with rest + return (absoluteValue | sign).AsSingle(); } - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + static Vector256 HalfAsWidenedUInt32ToSingle_Vector256(Vector256 value) + { + // Extract sign bit of value + Vector256 sign = value & Vector256.Create(SingleSignMask); + + // Copy sign bit to upper bits + Vector256 bitValueInProcess = value; + + // Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift) + Vector256 offsetExponent = bitValueInProcess & Vector256.Create(HalfExponentMask); + + // ~0u when value is subnormal, 0 otherwise + Vector256 subnormalMask = Vector256.Equals(offsetExponent, Vector256.Zero); + + // ~0u when value is either Infinity or NaN, 0 otherwise + Vector256 infinityOrNaNMask = Vector256.Equals(offsetExponent, Vector256.Create(HalfExponentMask)); + + // 0x3880_0000u if value is subnormal, 0 otherwise + Vector256 maskedExponentLowerBound = subnormalMask & Vector256.Create(ExponentLowerBound); + + // 0x3880_0000u if value is subnormal, 0x3800_0000u otherwise + Vector256 offsetMaskedExponentLowerBound = Vector256.Create(ExponentOffset) | maskedExponentLowerBound; + + // Match the position of the boundary of exponent bits and fraction bits with IEEE 754 Binary32(Single) + bitValueInProcess = Vector256.ShiftLeft(bitValueInProcess, 13); + + // Double the offsetMaskedExponentLowerBound if value is either Infinity or NaN + offsetMaskedExponentLowerBound = Vector256.ConditionalSelect(Vector256.Equals(infinityOrNaNMask, Vector256.Zero), + offsetMaskedExponentLowerBound, + Vector256.ShiftLeft(offsetMaskedExponentLowerBound, 1)); + + // Extract exponent bits and fraction bits of value + bitValueInProcess &= Vector256.Create(HalfToSingleBitsMask); + + // Adjust exponent to match the range of exponent + bitValueInProcess += offsetMaskedExponentLowerBound; + + // If value is subnormal, remove unnecessary 1 on top of fraction bits. + Vector256 absoluteValue = (bitValueInProcess.AsSingle() - maskedExponentLowerBound.AsSingle()).AsUInt32(); + + // Merge sign bit with rest + return (absoluteValue | sign).AsSingle(); + } #if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated) + static Vector512 HalfAsWidenedUInt32ToSingle_Vector512(Vector512 value) { - oneVectorFromEnd = x.Length - Vector512.Count; - if (i <= oneVectorFromEnd) - { - Vector512 yVec = Vector512.Create(y); + // Extract sign bit of value + Vector512 sign = value & Vector512.Create(SingleSignMask); - // Loop handling one vector at a time. - do - { - TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), - yVec).StoreUnsafe(ref dRef, (uint)i); + // Copy sign bit to upper bits + Vector512 bitValueInProcess = value; - i += Vector512.Count; - } - while (i <= oneVectorFromEnd); + // Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift) + Vector512 offsetExponent = bitValueInProcess & Vector512.Create(HalfExponentMask); - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - yVec).StoreUnsafe(ref dRef, lastVectorIndex); - } + // ~0u when value is subnormal, 0 otherwise + Vector512 subnormalMask = Vector512.Equals(offsetExponent, Vector512.Zero); - return; - } + // ~0u when value is either Infinity or NaN, 0 otherwise + Vector512 infinityOrNaNMask = Vector512.Equals(offsetExponent, Vector512.Create(HalfExponentMask)); + + // 0x3880_0000u if value is subnormal, 0 otherwise + Vector512 maskedExponentLowerBound = subnormalMask & Vector512.Create(ExponentLowerBound); + + // 0x3880_0000u if value is subnormal, 0x3800_0000u otherwise + Vector512 offsetMaskedExponentLowerBound = Vector512.Create(ExponentOffset) | maskedExponentLowerBound; + + // Match the position of the boundary of exponent bits and fraction bits with IEEE 754 Binary32(Single) + bitValueInProcess = Vector512.ShiftLeft(bitValueInProcess, 13); + + // Double the offsetMaskedExponentLowerBound if value is either Infinity or NaN + offsetMaskedExponentLowerBound = Vector512.ConditionalSelect(Vector512.Equals(infinityOrNaNMask, Vector512.Zero), + offsetMaskedExponentLowerBound, + Vector512.ShiftLeft(offsetMaskedExponentLowerBound, 1)); + + // Extract exponent bits and fraction bits of value + bitValueInProcess &= Vector512.Create(HalfToSingleBitsMask); + + // Adjust exponent to match the range of exponent + bitValueInProcess += offsetMaskedExponentLowerBound; + + // If value is subnormal, remove unnecessary 1 on top of fraction bits. + Vector512 absoluteValue = (bitValueInProcess.AsSingle() - maskedExponentLowerBound.AsSingle()).AsUInt32(); + + // Merge sign bit with rest + return (absoluteValue | sign).AsSingle(); } #endif + } - if (Vector256.IsHardwareAccelerated) + /// Computes the cosine similarity between the two specified non-empty, equal-length tensors of single-precision floating-point numbers. + /// Assumes arguments have already been validated to be non-empty and equal length. + private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan y) + { + // Compute the same as: + // TensorPrimitives.Dot(x, y) / (Math.Sqrt(TensorPrimitives.SumOfSquares(x)) * Math.Sqrt(TensorPrimitives.SumOfSquares(y))) + // but only looping over each span once. + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count) { - oneVectorFromEnd = x.Length - Vector256.Count; - if (i <= oneVectorFromEnd) + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + + Vector512 dotProductVector = Vector512.Zero; + Vector512 xSumOfSquaresVector = Vector512.Zero; + Vector512 ySumOfSquaresVector = Vector512.Zero; + + // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining. + int oneVectorFromEnd = x.Length - Vector512.Count; + int i = 0; + do { - Vector256 yVec = Vector256.Create(y); + Vector512 xVec = Vector512.LoadUnsafe(ref xRef, (uint)i); + Vector512 yVec = Vector512.LoadUnsafe(ref yRef, (uint)i); - // Loop handling one vector at a time. - do - { - TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), - yVec).StoreUnsafe(ref dRef, (uint)i); + dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); + xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); + ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); - i += Vector256.Count; - } - while (i <= oneVectorFromEnd); + i += Vector512.Count; + } + while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - yVec).StoreUnsafe(ref dRef, lastVectorIndex); - } + // Process the last vector in the span, masking off elements already processed. + if (i != x.Length) + { + Vector512 xVec = Vector512.LoadUnsafe(ref xRef, (uint)(x.Length - Vector512.Count)); + Vector512 yVec = Vector512.LoadUnsafe(ref yRef, (uint)(x.Length - Vector512.Count)); - return; + Vector512 remainderMask = CreateRemainderMaskSingleVector512(x.Length - i); + xVec &= remainderMask; + yVec &= remainderMask; + + dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); + xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); + ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); } + + // Sum(X * Y) / (|X| * |Y|) + return + Vector512.Sum(dotProductVector) / + (MathF.Sqrt(Vector512.Sum(xSumOfSquaresVector)) * MathF.Sqrt(Vector512.Sum(ySumOfSquaresVector))); } +#endif - if (Vector128.IsHardwareAccelerated) + if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count) { - oneVectorFromEnd = x.Length - Vector128.Count; - if (i <= oneVectorFromEnd) + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + + Vector256 dotProductVector = Vector256.Zero; + Vector256 xSumOfSquaresVector = Vector256.Zero; + Vector256 ySumOfSquaresVector = Vector256.Zero; + + // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining. + int oneVectorFromEnd = x.Length - Vector256.Count; + int i = 0; + do { - Vector128 yVec = Vector128.Create(y); + Vector256 xVec = Vector256.LoadUnsafe(ref xRef, (uint)i); + Vector256 yVec = Vector256.LoadUnsafe(ref yRef, (uint)i); - // Loop handling one vector at a time. - do - { - TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), - yVec).StoreUnsafe(ref dRef, (uint)i); + dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); + xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); + ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); - i += Vector128.Count; - } - while (i <= oneVectorFromEnd); + i += Vector256.Count; + } + while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - yVec).StoreUnsafe(ref dRef, lastVectorIndex); - } + // Process the last vector in the span, masking off elements already processed. + if (i != x.Length) + { + Vector256 xVec = Vector256.LoadUnsafe(ref xRef, (uint)(x.Length - Vector256.Count)); + Vector256 yVec = Vector256.LoadUnsafe(ref yRef, (uint)(x.Length - Vector256.Count)); - return; + Vector256 remainderMask = CreateRemainderMaskSingleVector256(x.Length - i); + xVec &= remainderMask; + yVec &= remainderMask; + + dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); + xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); + ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); } + + // Sum(X * Y) / (|X| * |Y|) + return + Vector256.Sum(dotProductVector) / + (MathF.Sqrt(Vector256.Sum(xSumOfSquaresVector)) * MathF.Sqrt(Vector256.Sum(ySumOfSquaresVector))); } - while (i < x.Length) + if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count) { - Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, i), - y); + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); - i++; - } - } + Vector128 dotProductVector = Vector128.Zero; + Vector128 xSumOfSquaresVector = Vector128.Zero; + Vector128 ySumOfSquaresVector = Vector128.Zero; - private static unsafe void InvokeSpanSpanSpanIntoSpan( - ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination) - where TTernaryOperator : ITernaryOperator - { - if (x.Length != y.Length || x.Length != z.Length) - { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + // Process vectors, summing their dot products and squares, as long as there's a vector's worth remaining. + int oneVectorFromEnd = x.Length - Vector128.Count; + int i = 0; + do + { + Vector128 xVec = Vector128.LoadUnsafe(ref xRef, (uint)i); + Vector128 yVec = Vector128.LoadUnsafe(ref yRef, (uint)i); + + dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); + xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); + ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); + + i += Vector128.Count; + } + while (i <= oneVectorFromEnd); + + // Process the last vector in the span, masking off elements already processed. + if (i != x.Length) + { + Vector128 xVec = Vector128.LoadUnsafe(ref xRef, (uint)(x.Length - Vector128.Count)); + Vector128 yVec = Vector128.LoadUnsafe(ref yRef, (uint)(x.Length - Vector128.Count)); + + Vector128 remainderMask = CreateRemainderMaskSingleVector128(x.Length - i); + xVec &= remainderMask; + yVec &= remainderMask; + + dotProductVector = FusedMultiplyAdd(xVec, yVec, dotProductVector); + xSumOfSquaresVector = FusedMultiplyAdd(xVec, xVec, xSumOfSquaresVector); + ySumOfSquaresVector = FusedMultiplyAdd(yVec, yVec, ySumOfSquaresVector); + } + + // Sum(X * Y) / (|X| * |Y|) + return + Vector128.Sum(dotProductVector) / + (MathF.Sqrt(Vector128.Sum(xSumOfSquaresVector)) * MathF.Sqrt(Vector128.Sum(ySumOfSquaresVector))); } - if (x.Length > destination.Length) + // Vectorization isn't supported or there are too few elements to vectorize. + // Use a scalar implementation. + float dotProduct = 0f, xSumOfSquares = 0f, ySumOfSquares = 0f; + for (int i = 0; i < x.Length; i++) { - ThrowHelper.ThrowArgument_DestinationTooShort(); + dotProduct = MathF.FusedMultiplyAdd(x[i], y[i], dotProduct); + xSumOfSquares = MathF.FusedMultiplyAdd(x[i], x[i], xSumOfSquares); + ySumOfSquares = MathF.FusedMultiplyAdd(y[i], y[i], ySumOfSquares); } + // Sum(X * Y) / (|X| * |Y|) + return + dotProduct / + (MathF.Sqrt(xSumOfSquares) * MathF.Sqrt(ySumOfSquares)); + } + + /// Performs an aggregation over all elements in to produce a single-precision floating-point value. + /// Specifies the transform operation that should be applied to each element loaded from . + /// + /// Specifies the aggregation binary operation that should be applied to multiple values to aggregate them into a single value. + /// The aggregation is applied after the transform is applied to each element. + /// + private static float Aggregate( + ReadOnlySpan x) + where TTransformOperator : struct, IUnaryOperator + where TAggregationOperator : struct, IAggregationOperator + { + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); - ref float zRef = ref MemoryMarshal.GetReference(z); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + + nuint remainder = (uint)(x.Length); #if NET8_0_OR_GREATER if (Vector512.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector512.Count; - if (i <= oneVectorFromEnd) - { - // Loop handling one vector at a time. - do - { - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), - Vector512.LoadUnsafe(ref yRef, (uint)i), - Vector512.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); - - i += Vector512.Count; - } - while (i <= oneVectorFromEnd); + float result; - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - Vector512.LoadUnsafe(ref yRef, lastVectorIndex), - Vector512.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); - } + if (remainder >= (uint)(Vector512.Count)) + { + result = Vectorized512(ref xRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. - return; + result = Vectorized512Small(ref xRef, remainder); } + + return result; } #endif if (Vector256.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector256.Count; - if (i <= oneVectorFromEnd) - { - // Loop handling one vector at a time. - do - { - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), - Vector256.LoadUnsafe(ref yRef, (uint)i), - Vector256.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); - - i += Vector256.Count; - } - while (i <= oneVectorFromEnd); + float result; - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - Vector256.LoadUnsafe(ref yRef, lastVectorIndex), - Vector256.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); - } + if (remainder >= (uint)(Vector256.Count)) + { + result = Vectorized256(ref xRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. - return; + result = Vectorized256Small(ref xRef, remainder); } + + return result; } if (Vector128.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector128.Count; - if (i <= oneVectorFromEnd) - { - // Loop handling one vector at a time. - do - { - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), - Vector128.LoadUnsafe(ref yRef, (uint)i), - Vector128.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); - - i += Vector128.Count; - } - while (i <= oneVectorFromEnd); + float result; - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - Vector128.LoadUnsafe(ref yRef, lastVectorIndex), - Vector128.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); - } + if (remainder >= (uint)(Vector128.Count)) + { + result = Vectorized128(ref xRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. - return; + result = Vectorized128Small(ref xRef, remainder); } + + return result; } - while (i < x.Length) - { - Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i), - Unsafe.Add(ref zRef, i)); + // This is the software fallback when no acceleration is available. + // It requires no branches to hit. - i++; - } - } + return SoftwareFallback(ref xRef, remainder); - private static unsafe void InvokeSpanSpanScalarIntoSpan( - ReadOnlySpan x, ReadOnlySpan y, float z, Span destination) - where TTernaryOperator : ITernaryOperator - { - if (x.Length != y.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float SoftwareFallback(ref float xRef, nuint length) { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + float result = TAggregationOperator.IdentityValue; + + for (nuint i = 0; i < length; i++) + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, i))); + } + + return result; } - if (x.Length > destination.Length) + static float Vectorized128(ref float xRef, nuint remainder) { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + // Preload the beginning and end so that overlapping accesses don't negatively impact the data -#if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector512.Count; - if (i <= oneVectorFromEnd) + Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector128.Count * 8)) { - Vector512 zVec = Vector512.Create(z); + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. - // Loop handling one vector at a time. - do + fixed (float* px = &xRef) { - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), - Vector512.LoadUnsafe(ref yRef, (uint)i), - zVec).StoreUnsafe(ref dRef, (uint)i); + float* xPtr = px; - i += Vector512.Count; - } - while (i <= oneVectorFromEnd); + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - Vector512.LoadUnsafe(ref yRef, lastVectorIndex), - zVec).StoreUnsafe(ref dRef, lastVectorIndex); - } + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; - return; - } - } -#endif + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. - if (Vector256.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector256.Count; - if (i <= oneVectorFromEnd) - { - Vector256 zVec = Vector256.Create(z); + misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(xPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); - // Loop handling one vector at a time. - do - { - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), - Vector256.LoadUnsafe(ref yRef, (uint)i), - zVec).StoreUnsafe(ref dRef, (uint)i); + xPtr += misalignment; - i += Vector256.Count; - } - while (i <= oneVectorFromEnd); + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector128))) == 0); - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - Vector256.LoadUnsafe(ref yRef, lastVectorIndex), - zVec).StoreUnsafe(ref dRef, lastVectorIndex); - } + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))); + vector2 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))); + vector3 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))); + vector4 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))); + vector2 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))); + vector3 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))); + vector4 = TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector128.ConditionalSelect(CreateAlignmentMaskSingleVector128((int)(misalignment)), beg, Vector128.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector128.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector128 vector = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 1))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(trailing)), end, Vector128.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, end); + break; + } + } + + return TAggregationOperator.Invoke(vresult); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float Vectorized128Small(ref float xRef, nuint remainder) + { + float result = TAggregationOperator.IdentityValue; + + switch (remainder) + { + case 3: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2))); + goto case 2; + } + + case 2: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1))); + goto case 1; + } + + case 1: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(xRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } + + static float Vectorized256(ref float xRef, nuint remainder) + { + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 beg = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 end = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + { + float* xPtr = px; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(xPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))); + vector2 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))); + vector3 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))); + vector4 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))); + vector2 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))); + vector3 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))); + vector4 = TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector256.ConditionalSelect(CreateAlignmentMaskSingleVector256((int)(misalignment)), beg, Vector256.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector256.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector256 vector = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 1))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector256.ConditionalSelect(CreateRemainderMaskSingleVector256((int)(trailing)), end, Vector256.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, end); + break; + } + } + + return TAggregationOperator.Invoke(vresult); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float Vectorized256Small(ref float xRef, nuint remainder) + { + float result = TAggregationOperator.IdentityValue; + + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + + end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(remainder % (uint)(Vector128.Count))), end, Vector128.Create(TAggregationOperator.IdentityValue)); + + vresult = TAggregationOperator.Invoke(vresult, beg); + vresult = TAggregationOperator.Invoke(vresult, end); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 3: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2))); + goto case 2; + } + + case 2: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1))); + goto case 1; + } + + case 1: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(xRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } + +#if NET8_0_OR_GREATER + static float Vectorized512(ref float xRef, nuint remainder) + { + Vector512 vresult = Vector512.Create(TAggregationOperator.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 beg = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef)); + Vector512 end = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + { + float* xPtr = px; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(xPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))); + vector2 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))); + vector3 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))); + vector4 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))); + vector2 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))); + vector3 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))); + vector4 = TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector512.ConditionalSelect(CreateAlignmentMaskSingleVector512((int)(misalignment)), beg, Vector512.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector512.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector512 vector = TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 1))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector512.ConditionalSelect(CreateRemainderMaskSingleVector512((int)(trailing)), end, Vector512.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, end); + break; + } + } + + return TAggregationOperator.Invoke(vresult); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float Vectorized512Small(ref float xRef, nuint remainder) + { + float result = TAggregationOperator.IdentityValue; + + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + + Vector256 beg = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 end = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))); + + end = Vector256.ConditionalSelect(CreateRemainderMaskSingleVector256((int)(remainder % (uint)(Vector256.Count))), end, Vector256.Create(TAggregationOperator.IdentityValue)); + + vresult = TAggregationOperator.Invoke(vresult, beg); + vresult = TAggregationOperator.Invoke(vresult, end); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + + Vector256 beg = TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + + end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(remainder % (uint)(Vector128.Count))), end, Vector128.Create(TAggregationOperator.IdentityValue)); + + vresult = TAggregationOperator.Invoke(vresult, beg); + vresult = TAggregationOperator.Invoke(vresult, end); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 3: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2))); + goto case 2; + } + + case 2: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1))); + goto case 1; + } + + case 1: + { + result = TAggregationOperator.Invoke(result, TTransformOperator.Invoke(xRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } +#endif + } + + /// Performs an aggregation over all pair-wise elements in and to produce a single-precision floating-point value. + /// Specifies the binary operation that should be applied to the pair-wise elements loaded from and . + /// + /// Specifies the aggregation binary operation that should be applied to multiple values to aggregate them into a single value. + /// The aggregation is applied to the results of the binary operations on the pair-wise values. + /// + private static float Aggregate( + ReadOnlySpan x, ReadOnlySpan y) + where TBinaryOperator : struct, IBinaryOperator + where TAggregationOperator : struct, IAggregationOperator + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + float result; + + if (remainder >= (uint)(Vector512.Count)) + { + result = Vectorized512(ref xRef, ref yRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + result = Vectorized512Small(ref xRef, ref yRef, remainder); + } + + return result; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + float result; + + if (remainder >= (uint)(Vector256.Count)) + { + result = Vectorized256(ref xRef, ref yRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + result = Vectorized256Small(ref xRef, ref yRef, remainder); + } + + return result; + } + + if (Vector128.IsHardwareAccelerated) + { + float result; + + if (remainder >= (uint)(Vector128.Count)) + { + result = Vectorized128(ref xRef, ref yRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + result = Vectorized128Small(ref xRef, ref yRef, remainder); + } + + return result; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + return SoftwareFallback(ref xRef, ref yRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float SoftwareFallback(ref float xRef, ref float yRef, nuint length) + { + float result = TAggregationOperator.IdentityValue; + + for (nuint i = 0; i < length; i++) + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i))); + } + + return result; + } + + static float Vectorized128(ref float xRef, ref float yRef, nuint remainder) + { + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + { + float* xPtr = px; + float* yPtr = py; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(xPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector128.ConditionalSelect(CreateAlignmentMaskSingleVector128((int)(misalignment)), beg, Vector128.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector128.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 1)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 1))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(trailing)), end, Vector128.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, end); + break; + } + } + + return TAggregationOperator.Invoke(vresult); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float Vectorized128Small(ref float xRef, ref float yRef, nuint remainder) + { + float result = TAggregationOperator.IdentityValue; + + switch (remainder) + { + case 3: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2))); + goto case 2; + } + + case 2: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1))); + goto case 1; + } + + case 1: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(xRef, yRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } + + static float Vectorized256(ref float xRef, ref float yRef, nuint remainder) + { + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + Vector256 end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + { + float* xPtr = px; + float* yPtr = py; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(xPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector256.ConditionalSelect(CreateAlignmentMaskSingleVector256((int)(misalignment)), beg, Vector256.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector256.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 1)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 1))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector256.ConditionalSelect(CreateRemainderMaskSingleVector256((int)(trailing)), end, Vector256.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, end); + break; + } + } + + return TAggregationOperator.Invoke(vresult); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float Vectorized256Small(ref float xRef, ref float yRef, nuint remainder) + { + float result = TAggregationOperator.IdentityValue; + + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + + end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(remainder % (uint)(Vector128.Count))), end, Vector128.Create(TAggregationOperator.IdentityValue)); + + vresult = TAggregationOperator.Invoke(vresult, beg); + vresult = TAggregationOperator.Invoke(vresult, end); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 3: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2))); + goto case 2; + } + + case 2: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1))); + goto case 1; + } + + case 1: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(xRef, yRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } + +#if NET8_0_OR_GREATER + static float Vectorized512(ref float xRef, ref float yRef, nuint remainder) + { + Vector512 vresult = Vector512.Create(TAggregationOperator.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 beg = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + Vector512.LoadUnsafe(ref yRef)); + Vector512 end = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + { + float* xPtr = px; + float* yPtr = py; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(xPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7))); + + vresult = TAggregationOperator.Invoke(vresult, vector1); + vresult = TAggregationOperator.Invoke(vresult, vector2); + vresult = TAggregationOperator.Invoke(vresult, vector3); + vresult = TAggregationOperator.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector512.ConditionalSelect(CreateAlignmentMaskSingleVector512((int)(misalignment)), beg, Vector512.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + (nuint blocks, nuint trailing) = Math.DivRem(remainder, (nuint)(Vector512.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 1)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 1))); + vresult = TAggregationOperator.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector512.ConditionalSelect(CreateRemainderMaskSingleVector512((int)(trailing)), end, Vector512.Create(TAggregationOperator.IdentityValue)); + vresult = TAggregationOperator.Invoke(vresult, end); + break; + } + } + + return TAggregationOperator.Invoke(vresult); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float Vectorized512Small(ref float xRef, ref float yRef, nuint remainder) + { + float result = TAggregationOperator.IdentityValue; + + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + Vector256 end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count))); + + end = Vector256.ConditionalSelect(CreateRemainderMaskSingleVector256((int)(remainder % (uint)(Vector256.Count))), end, Vector256.Create(TAggregationOperator.IdentityValue)); + + vresult = TAggregationOperator.Invoke(vresult, beg); + vresult = TAggregationOperator.Invoke(vresult, end); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + Vector256 vresult = Vector256.Create(TAggregationOperator.IdentityValue); + + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + + end = Vector128.ConditionalSelect(CreateRemainderMaskSingleVector128((int)(remainder % (uint)(Vector128.Count))), end, Vector128.Create(TAggregationOperator.IdentityValue)); + + vresult = TAggregationOperator.Invoke(vresult, beg); + vresult = TAggregationOperator.Invoke(vresult, end); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + Vector128 vresult = Vector128.Create(TAggregationOperator.IdentityValue); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + vresult = TAggregationOperator.Invoke(vresult, beg); + + result = TAggregationOperator.Invoke(vresult); + break; + } + + case 3: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2))); + goto case 2; + } + + case 2: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1))); + goto case 1; + } + + case 1: + { + result = TAggregationOperator.Invoke(result, TBinaryOperator.Invoke(xRef, yRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } +#endif + } + + /// + /// This is the same as + /// with an identity transform, except it early exits on NaN. + /// + private static float MinMaxCore(ReadOnlySpan x) + where TMinMaxOperator : struct, IAggregationOperator + { + if (x.IsEmpty) + { + ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); + } + + // This matches the IEEE 754:2019 `maximum`/`minimum` functions. + // It propagates NaN inputs back to the caller and + // otherwise returns the greater of the inputs. + // It treats +0 as greater than -0 as per the specification. + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector512 result = Vector512.LoadUnsafe(ref xRef, 0), current; + if (!Vector512.EqualsAll(result, result)) + { + return GetFirstNaN(result); + } + + int oneVectorFromEnd = x.Length - Vector512.Count; + int i = Vector512.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = Vector512.LoadUnsafe(ref xRef, (uint)i); + if (!Vector512.EqualsAll(current, current)) + { + return GetFirstNaN(current); + } + + result = TMinMaxOperator.Invoke(result, current); + i += Vector512.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + current = Vector512.LoadUnsafe(ref xRef, (uint)(x.Length - Vector512.Count)); + if (!Vector512.EqualsAll(current, current)) + { + return GetFirstNaN(current); + } + + result = Vector512.ConditionalSelect( + Vector512.Equals(CreateRemainderMaskSingleVector512(x.Length - i), Vector512.Zero), + result, + TMinMaxOperator.Invoke(result, current)); + } + + // Aggregate the lanes in the vector to create the final scalar result. + return TMinMaxOperator.Invoke(result); + } +#endif + + if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector256 result = Vector256.LoadUnsafe(ref xRef, 0), current; + if (!Vector256.EqualsAll(result, result)) + { + return GetFirstNaN(result); + } + + int oneVectorFromEnd = x.Length - Vector256.Count; + int i = Vector256.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = Vector256.LoadUnsafe(ref xRef, (uint)i); + if (!Vector256.EqualsAll(current, current)) + { + return GetFirstNaN(current); + } + + result = TMinMaxOperator.Invoke(result, current); + i += Vector256.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + current = Vector256.LoadUnsafe(ref xRef, (uint)(x.Length - Vector256.Count)); + if (!Vector256.EqualsAll(current, current)) + { + return GetFirstNaN(current); + } + + result = Vector256.ConditionalSelect( + Vector256.Equals(CreateRemainderMaskSingleVector256(x.Length - i), Vector256.Zero), + result, + TMinMaxOperator.Invoke(result, current)); + } + + // Aggregate the lanes in the vector to create the final scalar result. + return TMinMaxOperator.Invoke(result); + } + + if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector128 result = Vector128.LoadUnsafe(ref xRef, 0), current; + if (!Vector128.EqualsAll(result, result)) + { + return GetFirstNaN(result); + } + + int oneVectorFromEnd = x.Length - Vector128.Count; + int i = Vector128.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = Vector128.LoadUnsafe(ref xRef, (uint)i); + if (!Vector128.EqualsAll(current, current)) + { + return GetFirstNaN(current); + } + + result = TMinMaxOperator.Invoke(result, current); + i += Vector128.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + current = Vector128.LoadUnsafe(ref xRef, (uint)(x.Length - Vector128.Count)); + if (!Vector128.EqualsAll(current, current)) + { + return GetFirstNaN(current); + } + + result = Vector128.ConditionalSelect( + Vector128.Equals(CreateRemainderMaskSingleVector128(x.Length - i), Vector128.Zero), + result, + TMinMaxOperator.Invoke(result, current)); + } + + // Aggregate the lanes in the vector to create the final scalar result. + return TMinMaxOperator.Invoke(result); + } + + // Scalar path used when either vectorization is not supported or the input is too small to vectorize. + { + float result = x[0]; + if (float.IsNaN(result)) + { + return result; + } + + for (int i = 1; i < x.Length; i++) + { + float current = x[i]; + if (float.IsNaN(current)) + { + return current; + } + + result = TMinMaxOperator.Invoke(result, current); + } + + return result; + } + } + + private static int IndexOfMinMaxCore(ReadOnlySpan x) where TIndexOfMinMax : struct, IIndexOfOperator + { + // This matches the IEEE 754:2019 `maximum`/`minimum` functions. + // It propagates NaN inputs back to the caller and + // otherwise returns the index of the greater of the inputs. + // It treats +0 as greater than -0 as per the specification. + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + Vector512 resultIndex = Vector512.Create(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + Vector512 curIndex = resultIndex; + Vector512 increment = Vector512.Create(Vector512.Count); + + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector512 result = Vector512.LoadUnsafe(ref xRef); + Vector512 current; + + Vector512 nanMask = ~Vector512.Equals(result, result); + if (nanMask != Vector512.Zero) + { + return IndexOfFirstMatch(nanMask); + } + + int oneVectorFromEnd = x.Length - Vector512.Count; + int i = Vector512.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = Vector512.LoadUnsafe(ref xRef, (uint)i); + curIndex += increment; + + nanMask = ~Vector512.Equals(current, current); + if (nanMask != Vector512.Zero) + { + return i + IndexOfFirstMatch(nanMask); + } + + TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex); + + i += Vector512.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + current = Vector512.LoadUnsafe(ref xRef, (uint)(x.Length - Vector512.Count)); + curIndex += Vector512.Create(x.Length - i); + + nanMask = ~Vector512.Equals(current, current); + if (nanMask != Vector512.Zero) + { + return curIndex[IndexOfFirstMatch(nanMask)]; + } + + TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex); + } + + // Aggregate the lanes in the vector to create the final scalar result. + return TIndexOfMinMax.Invoke(result, resultIndex); + } +#endif + + if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + Vector256 resultIndex = Vector256.Create(0, 1, 2, 3, 4, 5, 6, 7); + Vector256 curIndex = resultIndex; + Vector256 increment = Vector256.Create(Vector256.Count); + + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector256 result = Vector256.LoadUnsafe(ref xRef); + Vector256 current; + + Vector256 nanMask = ~Vector256.Equals(result, result); + if (nanMask != Vector256.Zero) + { + return IndexOfFirstMatch(nanMask); + } + + int oneVectorFromEnd = x.Length - Vector256.Count; + int i = Vector256.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = Vector256.LoadUnsafe(ref xRef, (uint)i); + curIndex += increment; + + nanMask = ~Vector256.Equals(current, current); + if (nanMask != Vector256.Zero) + { + return i + IndexOfFirstMatch(nanMask); + } + + TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex); + + i += Vector256.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + current = Vector256.LoadUnsafe(ref xRef, (uint)(x.Length - Vector256.Count)); + curIndex += Vector256.Create(x.Length - i); + + nanMask = ~Vector256.Equals(current, current); + if (nanMask != Vector256.Zero) + { + return curIndex[IndexOfFirstMatch(nanMask)]; + } + + TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex); + } + + // Aggregate the lanes in the vector to create the final scalar result. + return TIndexOfMinMax.Invoke(result, resultIndex); + } + + if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + Vector128 resultIndex = Vector128.Create(0, 1, 2, 3); + Vector128 curIndex = resultIndex; + Vector128 increment = Vector128.Create(Vector128.Count); + + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector128 result = Vector128.LoadUnsafe(ref xRef); + Vector128 current; + + Vector128 nanMask = ~Vector128.Equals(result, result); + if (nanMask != Vector128.Zero) + { + return IndexOfFirstMatch(nanMask); + } + + int oneVectorFromEnd = x.Length - Vector128.Count; + int i = Vector128.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = Vector128.LoadUnsafe(ref xRef, (uint)i); + curIndex += increment; + + nanMask = ~Vector128.Equals(current, current); + if (nanMask != Vector128.Zero) + { + return i + IndexOfFirstMatch(nanMask); + } + + TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex); + + i += Vector128.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + curIndex += Vector128.Create(x.Length - i); + + current = Vector128.LoadUnsafe(ref xRef, (uint)(x.Length - Vector128.Count)); + + nanMask = ~Vector128.Equals(current, current); + if (nanMask != Vector128.Zero) + { + return curIndex[IndexOfFirstMatch(nanMask)]; + } + + TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex); + } + + // Aggregate the lanes in the vector to create the final scalar result. + return TIndexOfMinMax.Invoke(result, resultIndex); + } + + // Scalar path used when either vectorization is not supported or the input is too small to vectorize. + float curResult = x[0]; + int curIn = 0; + if (float.IsNaN(curResult)) + { + return curIn; + } + + for (int i = 1; i < x.Length; i++) + { + float current = x[i]; + if (float.IsNaN(current)) + { + return i; + } + + curIn = TIndexOfMinMax.Invoke(ref curResult, current, curIn, i); + } + + return curIn; + } + + private static int IndexOfFirstMatch(Vector128 mask) + { + return BitOperations.TrailingZeroCount(mask.ExtractMostSignificantBits()); + } + + private static int IndexOfFirstMatch(Vector256 mask) + { + return BitOperations.TrailingZeroCount(mask.ExtractMostSignificantBits()); + } + +#if NET8_0_OR_GREATER + private static int IndexOfFirstMatch(Vector512 mask) + { + return BitOperations.TrailingZeroCount(mask.ExtractMostSignificantBits()); + } +#endif + + /// Performs an element-wise operation on and writes the results to . + /// Specifies the operation to perform on each element loaded from . + private static void InvokeSpanIntoSpan( + ReadOnlySpan x, Span destination) + where TUnaryOperator : struct, IUnaryOperator + { + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector512.Count)) + { + Vectorized512(ref xRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized512Small(ref xRef, ref dRef, remainder); + } + + return; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector256.Count)) + { + Vectorized256(ref xRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized256Small(ref xRef, ref dRef, remainder); + } + + return; + } + + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float dRef, nuint length) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, i)); + } + } + + static void Vectorized128(ref float xRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; + } + + case 7: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } + + case 6: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TUnaryOperator.Invoke(xRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + + static void Vectorized256(ref float xRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 end = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TUnaryOperator.Invoke(xRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 beg = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef)); + Vector512 end = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count))); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 end = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TUnaryOperator.Invoke(xRef); + goto case 0; + } + + case 0: + { + break; + } + } + } +#endif + } + + /// + /// Performs an element-wise operation on and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise elements loaded from and . + /// + private static void InvokeSpanSpanIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, Span destination) + where TBinaryOperator : struct, IBinaryOperator + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector512.Count)) + { + Vectorized512(ref xRef, ref yRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized512Small(ref xRef, ref yRef, ref dRef, remainder); + } + + return; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector256.Count)) + { + Vectorized256(ref xRef, ref yRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized256Small(ref xRef, ref yRef, ref dRef, remainder); + } + + return; + } + + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, ref yRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, ref yRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, ref float dRef, nuint length) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i)); + } + } + + static void Vectorized128(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; + } + + case 7: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } + + case 6: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(xRef, yRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + + static void Vectorized256(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + Vector256 end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count))); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(xRef, yRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 beg = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + Vector512.LoadUnsafe(ref yRef)); + Vector512 end = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count))); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4))); + vector2 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5))); + vector3 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6))); + vector4 = TBinaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TBinaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, ref float yRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + Vector256 end = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TBinaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + Vector128 end = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(xRef, yRef); + goto case 0; + } + + case 0: + { + break; + } + } + } +#endif + } + + /// + /// Performs an element-wise operation on and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on each element loaded from with . + /// + private static void InvokeSpanScalarIntoSpan( + ReadOnlySpan x, float y, Span destination) + where TBinaryOperator : struct, IBinaryOperator => + InvokeSpanScalarIntoSpan(x, y, destination); + + /// + /// Performs an element-wise operation on and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on each element loaded from . + /// It is not used with . + /// + /// + /// Specifies the operation to perform on the transformed value from with . + /// + private static void InvokeSpanScalarIntoSpan( + ReadOnlySpan x, float y, Span destination) + where TTransformOperator : struct, IUnaryOperator + where TBinaryOperator : struct, IBinaryOperator + { + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector512.Count)) + { + Vectorized512(ref xRef, y, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized512Small(ref xRef, y, ref dRef, remainder); + } + + return; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector256.Count)) + { + Vectorized256(ref xRef, y, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized256Small(ref xRef, y, ref dRef, remainder); + } + + return; + } + + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, y, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, y, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, y, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, float y, ref float dRef, nuint length) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, i)), + y); + } + } + + static void Vectorized128(ref float xRef, float y, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + yVec); + Vector128 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))), + yVec); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))), + yVec); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))), + yVec); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; + } + + case 7: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } + + case 6: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, float y, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2)), + y); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1)), + y); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(TTransformOperator.Invoke(xRef), y); + goto case 0; + } + + case 0: + { + break; + } + } + } + + static void Vectorized256(ref float xRef, float y, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 yVec = Vector256.Create(y); + + Vector256 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), + yVec); + Vector256 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))), + yVec); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))), + yVec); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))), + yVec); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, float y, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + yVec); + Vector128 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))), + yVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + Vector128.Create(y)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2)), + y); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1)), + y); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(TTransformOperator.Invoke(xRef), y); + goto case 0; + } + + case 0: + { + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, float y, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 yVec = Vector512.Create(y); + + Vector512 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef)), + yVec); + Vector512 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count))), + yVec); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))), + yVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))), + yVec); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))), + yVec); + vector2 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))), + yVec); + vector3 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))), + yVec); + vector4 = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))), + yVec); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2))), + yVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, float y, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 yVec = Vector256.Create(y); + + Vector256 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), + yVec); + Vector256 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))), + yVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector256.LoadUnsafe(ref xRef)), + Vector256.Create(y)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + yVec); + Vector128 end = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))), + yVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TBinaryOperator.Invoke(TTransformOperator.Invoke(Vector128.LoadUnsafe(ref xRef)), + Vector128.Create(y)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 2)), + y); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TBinaryOperator.Invoke(TTransformOperator.Invoke(Unsafe.Add(ref xRef, 1)), + y); + goto case 1; + } + + case 1: + { + dRef = TBinaryOperator.Invoke(TTransformOperator.Invoke(xRef), y); + goto case 0; + } + + case 0: + { + break; + } + } + } +#endif + } + + /// + /// Performs an element-wise operation on , , and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise elements loaded from , , + /// and . + /// + private static void InvokeSpanSpanSpanIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination) + where TTernaryOperator : struct, ITernaryOperator + { + if (x.Length != y.Length || x.Length != z.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + ValidateInputOutputSpanNonOverlapping(z, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float zRef = ref MemoryMarshal.GetReference(z); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector512.Count)) + { + Vectorized512(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized512Small(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + + return; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector256.Count)) + { + Vectorized256(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized256Small(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + + return; + } + + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint length) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i), + Unsafe.Add(ref zRef, i)); + } + } + + static void Vectorized128(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + zPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + zPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 8)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; + } + + case 7: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } + + case 6: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, zRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + + static void Vectorized256(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + Vector256.LoadUnsafe(ref zRef)); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count))); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + zPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + zPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 8)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, zRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 beg = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + Vector512.LoadUnsafe(ref yRef), + Vector512.LoadUnsafe(ref zRef)); + Vector512 end = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count))); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + zPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + zPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 8)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + Vector256.LoadUnsafe(ref zRef)); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + Vector256.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, zRef); + goto case 0; + } + + case 0: + { + break; + } + } + } +#endif + } + + /// + /// Performs an element-wise operation on , , and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise elements loaded from and + /// with . + /// + private static void InvokeSpanSpanScalarIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, float z, Span destination) + where TTernaryOperator : struct, ITernaryOperator + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector512.Count)) + { + Vectorized512(ref xRef, ref yRef, z, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized512Small(ref xRef, ref yRef, z, ref dRef, remainder); + } + + return; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector256.Count)) + { + Vectorized256(ref xRef, ref yRef, z, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized256Small(ref xRef, ref yRef, z, ref dRef, remainder); + } + + return; + } + + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, ref yRef, z, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, ref yRef, z, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, z, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, float z, ref float dRef, nuint length) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), + Unsafe.Add(ref yRef, i), + z); + } + } + + static void Vectorized128(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 zVec = Vector128.Create(z); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + zVec); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + zVec); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 3)), + zVec); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + Vector128.Load(yPtr + (uint)(Vector128.Count * 7)), + zVec); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + yPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 8)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; + } + + case 7: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 7)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } + + case 6: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 6)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 5)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 4)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 3)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count * 2)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + z); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + z); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, z); + goto case 0; + } + + case 0: + { + break; + } + } + } + + static void Vectorized256(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 zVec = Vector256.Create(z); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + zVec); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count)), + zVec); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 3)), + zVec); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + Vector256.Load(yPtr + (uint)(Vector256.Count * 7)), + zVec); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + yPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 8)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 7)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 6)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 5)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 4)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 3)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count * 2)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 zVec = Vector128.Create(z); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + zVec); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + zVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.Create(z)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + z); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + z); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, z); + goto case 0; + } + + case 0: + { + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 zVec = Vector512.Create(z); + + Vector512 beg = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + Vector512.LoadUnsafe(ref yRef), + zVec); + Vector512 end = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count)), + zVec); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), + zVec); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 0)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 1)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 2)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 3)), + zVec); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 4)), + zVec); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 5)), + zVec); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 6)), + zVec); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + Vector512.Load(yPtr + (uint)(Vector512.Count * 7)), + zVec); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + yPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 8)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 7)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 6)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 5)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 4)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 3)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + Vector512.LoadUnsafe(ref yRef, remainder - (uint)(Vector512.Count * 2)), + zVec); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 zVec = Vector256.Create(z); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + zVec); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + Vector256.LoadUnsafe(ref yRef, remainder - (uint)(Vector256.Count)), + zVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.LoadUnsafe(ref yRef), + Vector256.Create(z)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 zVec = Vector128.Create(z); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + zVec); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + Vector128.LoadUnsafe(ref yRef, remainder - (uint)(Vector128.Count)), + zVec); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.LoadUnsafe(ref yRef), + Vector128.Create(z)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + z); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + z); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, yRef, z); + goto case 0; + } + + case 0: + { + break; + } + } + } +#endif + } + + /// + /// Performs an element-wise operation on , , and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise element loaded from , with , + /// and the element loaded from . + /// + private static void InvokeSpanScalarSpanIntoSpan( + ReadOnlySpan x, float y, ReadOnlySpan z, Span destination) + where TTernaryOperator : struct, ITernaryOperator + { + if (x.Length != z.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(z, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float zRef = ref MemoryMarshal.GetReference(z); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector512.Count)) + { + Vectorized512(ref xRef, y, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized512Small(ref xRef, y, ref zRef, ref dRef, remainder); + } + + return; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector256.Count)) + { + Vectorized256(ref xRef, y, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized256Small(ref xRef, y, ref zRef, ref dRef, remainder); + } + + return; + } + + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, y, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, y, ref zRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, y, ref zRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, float y, ref float zRef, ref float dRef, nuint length) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), + y, + Unsafe.Add(ref zRef, i)); + } + } + + static void Vectorized128(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + yVec, + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + zPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7)), + yVec, + Vector128.Load(zPtr + (uint)(Vector128.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + zPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; + } + + case 7: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } + + case 6: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + y, + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + y, + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, y, zRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + + static void Vectorized256(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 yVec = Vector256.Create(y); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + yVec, + Vector256.LoadUnsafe(ref zRef)); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count))); + + if (remainder > (uint)(Vector256.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + zPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7)), + yVec, + Vector256.Load(zPtr + (uint)(Vector256.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + zPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; + } + + case 7: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + yVec, + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.Create(y), + Vector128.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + y, + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + y, + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, y, zRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 yVec = Vector512.Create(y); + + Vector512 beg = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef), + yVec, + Vector512.LoadUnsafe(ref zRef)); + Vector512 end = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count))); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if ((remainder > (NonTemporalByteThreshold / sizeof(float))) && canAlign) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + zPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 0))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 1))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 2))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 4))); + vector2 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 5))); + vector3 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 6))); + vector4 = TTernaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7)), + yVec, + Vector512.Load(zPtr + (uint)(Vector512.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + zPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2)), + yVec, + Vector512.LoadUnsafe(ref zRef, remainder - (uint)(Vector512.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 yVec = Vector256.Create(y); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + yVec, + Vector256.LoadUnsafe(ref zRef)); + Vector256 end = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count)), + yVec, + Vector256.LoadUnsafe(ref zRef, remainder - (uint)(Vector256.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef), + Vector256.Create(y), + Vector256.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 yVec = Vector128.Create(y); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + yVec, + Vector128.LoadUnsafe(ref zRef)); + Vector128 end = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count)), + yVec, + Vector128.LoadUnsafe(ref zRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef), + Vector128.Create(y), + Vector128.LoadUnsafe(ref zRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 2), + y, + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, 1), + y, + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TTernaryOperator.Invoke(xRef, y, zRef); + goto case 0; + } + + case 0: + { + break; + } + } + } +#endif + } + + /// Performs (x * y) + z. It will be rounded as one ternary operation if such an operation is accelerated on the current hardware. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 FusedMultiplyAdd(Vector128 x, Vector128 y, Vector128 addend) + { + if (Fma.IsSupported) + { + return Fma.MultiplyAdd(x, y, addend); + } + + if (AdvSimd.IsSupported) + { + return AdvSimd.FusedMultiplyAdd(addend, x, y); + } + + return (x * y) + addend; + } + + /// Performs (x * y) + z. It will be rounded as one ternary operation if such an operation is accelerated on the current hardware. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 FusedMultiplyAdd(Vector256 x, Vector256 y, Vector256 addend) + { + if (Fma.IsSupported) + { + return Fma.MultiplyAdd(x, y, addend); + } + + return (x * y) + addend; + } + +#if NET8_0_OR_GREATER + /// Performs (x * y) + z. It will be rounded as one ternary operation if such an operation is accelerated on the current hardware. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 FusedMultiplyAdd(Vector512 x, Vector512 y, Vector512 addend) + { + if (Avx512F.IsSupported) + { + return Avx512F.FusedMultiplyAdd(x, y, addend); + } + + return (x * y) + addend; + } +#endif + + /// Aggregates all of the elements in the into a single value. + /// Specifies the operation to be performed on each pair of values. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static float HorizontalAggregate(Vector128 x) where TAggregate : struct, IBinaryOperator + { + // We need to do log2(count) operations to compute the total sum + + x = TAggregate.Invoke(x, Vector128.Shuffle(x, Vector128.Create(2, 3, 0, 1))); + x = TAggregate.Invoke(x, Vector128.Shuffle(x, Vector128.Create(1, 0, 3, 2))); + + return x.ToScalar(); + } + + /// Aggregates all of the elements in the into a single value. + /// Specifies the operation to be performed on each pair of values. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static float HorizontalAggregate(Vector256 x) where TAggregate : struct, IBinaryOperator => + HorizontalAggregate(TAggregate.Invoke(x.GetLower(), x.GetUpper())); + +#if NET8_0_OR_GREATER + /// Aggregates all of the elements in the into a single value. + /// Specifies the operation to be performed on each pair of values. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static float HorizontalAggregate(Vector512 x) where TAggregate : struct, IBinaryOperator => + HorizontalAggregate(TAggregate.Invoke(x.GetLower(), x.GetUpper())); +#endif + + /// Gets whether the specified is negative. + private static bool IsNegative(float f) => float.IsNegative(f); + + /// Gets whether each specified is negative. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 IsNegative(Vector128 vector) => + Vector128.LessThan(vector.AsInt32(), Vector128.Zero).AsSingle(); + + /// Gets whether each specified is negative. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 IsNegative(Vector256 vector) => + Vector256.LessThan(vector.AsInt32(), Vector256.Zero).AsSingle(); + +#if NET8_0_OR_GREATER + /// Gets whether each specified is negative. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 IsNegative(Vector512 vector) => + Vector512.LessThan(vector.AsInt32(), Vector512.Zero).AsSingle(); +#endif + + /// Gets whether the specified is positive. + private static bool IsPositive(float f) => float.IsPositive(f); + + /// Gets whether each specified is positive. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 IsPositive(Vector128 vector) => + Vector128.GreaterThan(vector.AsInt32(), Vector128.AllBitsSet).AsSingle(); + + /// Gets whether each specified is positive. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 IsPositive(Vector256 vector) => + Vector256.GreaterThan(vector.AsInt32(), Vector256.AllBitsSet).AsSingle(); + +#if NET8_0_OR_GREATER + /// Gets whether each specified is positive. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 IsPositive(Vector512 vector) => + Vector512.GreaterThan(vector.AsInt32(), Vector512.AllBitsSet).AsSingle(); +#endif + + /// Finds and returns the first NaN value in . + /// The vector must have already been validated to contain a NaN. + private static float GetFirstNaN(Vector128 vector) + { + Debug.Assert(!Vector128.EqualsAll(vector, vector), "Expected vector to contain a NaN"); + return vector.GetElement(BitOperations.TrailingZeroCount((~Vector128.Equals(vector, vector)).ExtractMostSignificantBits())); + } + + /// Finds and returns the first NaN index value in . + /// The vector must have already been validated to contain a NaN. + private static int GetFirstNaNIndex(Vector128 vector, Vector128 index) + { + Debug.Assert(!Vector128.EqualsAll(vector, vector), "Expected vector to contain a NaN"); + return index.GetElement(BitOperations.TrailingZeroCount((~Vector128.Equals(vector, vector)).ExtractMostSignificantBits())); + } + + /// Finds and returns the first NaN value in . + /// The vector must have already been validated to contain a NaN. + private static float GetFirstNaN(Vector256 vector) + { + Debug.Assert(!Vector256.EqualsAll(vector, vector), "Expected vector to contain a NaN"); + return vector.GetElement(BitOperations.TrailingZeroCount((~Vector256.Equals(vector, vector)).ExtractMostSignificantBits())); + } + + /// Finds and returns the first NaN index value in . + /// The vector must have already been validated to contain a NaN. + private static int GetFirstNaNIndex(Vector256 vector, Vector256 index) + { + Debug.Assert(!Vector256.EqualsAll(vector, vector), "Expected vector to contain a NaN"); + return index.GetElement(BitOperations.TrailingZeroCount((~Vector256.Equals(vector, vector)).ExtractMostSignificantBits())); + } + +#if NET8_0_OR_GREATER + /// Finds and returns the first NaN value in . + /// The vector must have already been validated to contain a NaN. + private static float GetFirstNaN(Vector512 vector) + { + Debug.Assert(!Vector512.EqualsAll(vector, vector), "Expected vector to contain a NaN"); + return vector.GetElement(BitOperations.TrailingZeroCount((~Vector512.Equals(vector, vector)).ExtractMostSignificantBits())); + } + + /// Finds and returns the first NaN value in . + /// The vector must have already been validated to contain a NaN. + private static int GetFirstNaNIndex(Vector512 vector, Vector512 index) + { + Debug.Assert(!Vector512.EqualsAll(vector, vector), "Expected vector to contain a NaN"); + return index.GetElement(BitOperations.TrailingZeroCount((~Vector512.Equals(vector, vector)).ExtractMostSignificantBits())); + } +#endif + + /// Gets the base 2 logarithm of . + private static float Log2(float x) => MathF.Log2(x); + + /// + /// Gets a vector mask that will be all-ones-set for the last elements + /// and zero for all other elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 CreateAlignmentMaskSingleVector128(int count) => + Vector128.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x16)), + (uint)(count * 16)); // first four floats in the row + + /// + /// Gets a vector mask that will be all-ones-set for the last elements + /// and zero for all other elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 CreateAlignmentMaskSingleVector256(int count) => + Vector256.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x16)), + (uint)(count * 16)); // first eight floats in the row + +#if NET8_0_OR_GREATER + /// + /// Gets a vector mask that will be all-ones-set for the last elements + /// and zero for all other elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 CreateAlignmentMaskSingleVector512(int count) => + Vector512.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x16)), + (uint)(count * 16)); // all sixteen floats in the row +#endif + + /// + /// Gets a vector mask that will be all-ones-set for the last elements + /// and zero for all other elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 CreateRemainderMaskSingleVector128(int count) => + Vector128.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x16)), + (uint)((count * 16) + 12)); // last four floats in the row + + /// + /// Gets a vector mask that will be all-ones-set for the last elements + /// and zero for all other elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 CreateRemainderMaskSingleVector256(int count) => + Vector256.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x16)), + (uint)((count * 16) + 8)); // last eight floats in the row + +#if NET8_0_OR_GREATER + /// + /// Gets a vector mask that will be all-ones-set for the last elements + /// and zero for all other elements. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 CreateRemainderMaskSingleVector512(int count) => + Vector512.LoadUnsafe( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x16)), + (uint)(count * 16)); // all sixteen floats in the row +#endif + + /// x + y + private readonly struct AddOperator : IAggregationOperator + { + public static float Invoke(float x, float y) => x + y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x + y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x + y; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y) => x + y; +#endif + + public static float Invoke(Vector128 x) => Vector128.Sum(x); + public static float Invoke(Vector256 x) => Vector256.Sum(x); +#if NET8_0_OR_GREATER + public static float Invoke(Vector512 x) => Vector512.Sum(x); +#endif + + public static float IdentityValue => 0; + } + + /// x - y + private readonly struct SubtractOperator : IBinaryOperator + { + public static float Invoke(float x, float y) => x - y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x - y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x - y; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y) => x - y; +#endif + } + + /// (x - y) * (x - y) + private readonly struct SubtractSquaredOperator : IBinaryOperator + { + public static float Invoke(float x, float y) + { + float tmp = x - y; + return tmp * tmp; + } + + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + Vector128 tmp = x - y; + return tmp * tmp; + } + + public static Vector256 Invoke(Vector256 x, Vector256 y) + { + Vector256 tmp = x - y; + return tmp * tmp; + } + +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y) + { + Vector512 tmp = x - y; + return tmp * tmp; + } +#endif + } + + /// x * y + private readonly struct MultiplyOperator : IAggregationOperator + { + public static float Invoke(float x, float y) => x * y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x * y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x * y; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y) => x * y; +#endif + + public static float Invoke(Vector128 x) => HorizontalAggregate(x); + public static float Invoke(Vector256 x) => HorizontalAggregate(x); +#if NET8_0_OR_GREATER + public static float Invoke(Vector512 x) => HorizontalAggregate(x); +#endif + + public static float IdentityValue => 1; + } + + /// x / y + private readonly struct DivideOperator : IBinaryOperator + { + public static float Invoke(float x, float y) => x / y; + public static Vector128 Invoke(Vector128 x, Vector128 y) => x / y; + public static Vector256 Invoke(Vector256 x, Vector256 y) => x / y; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y) => x / y; +#endif + } + + /// MathF.Max(x, y) (but NaNs may not be propagated) + private readonly struct MaxOperator : IAggregationOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float Invoke(float x, float y) => + x == y ? + (IsNegative(x) ? y : x) : + (y > x ? y : x); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + if (AdvSimd.IsSupported) + { + return AdvSimd.Max(x, y); + } + + return + Vector128.ConditionalSelect(Vector128.Equals(x, y), + Vector128.ConditionalSelect(IsNegative(x), y, x), + Vector128.Max(x, y)); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Invoke(Vector256 x, Vector256 y) => + Vector256.ConditionalSelect(Vector256.Equals(x, y), + Vector256.ConditionalSelect(IsNegative(x), y, x), + Vector256.Max(x, y)); + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Invoke(Vector512 x, Vector512 y) => + Vector512.ConditionalSelect(Vector512.Equals(x, y), + Vector512.ConditionalSelect(IsNegative(x), y, x), + Vector512.Max(x, y)); +#endif + + public static float Invoke(Vector128 x) => HorizontalAggregate(x); + public static float Invoke(Vector256 x) => HorizontalAggregate(x); +#if NET8_0_OR_GREATER + public static float Invoke(Vector512 x) => HorizontalAggregate(x); +#endif + } + + private interface IIndexOfOperator + { + static abstract int Invoke(ref float result, float current, int resultIndex, int curIndex); + static abstract int Invoke(Vector128 result, Vector128 resultIndex); + static abstract void Invoke(ref Vector128 result, Vector128 current, ref Vector128 resultIndex, Vector128 curIndex); + static abstract int Invoke(Vector256 result, Vector256 resultIndex); + static abstract void Invoke(ref Vector256 result, Vector256 current, ref Vector256 resultIndex, Vector256 curIndex); +#if NET8_0_OR_GREATER + static abstract int Invoke(Vector512 result, Vector512 resultIndex); + static abstract void Invoke(ref Vector512 result, Vector512 current, ref Vector512 resultIndex, Vector512 curIndex); +#endif + } + + /// Returns the index of MathF.Max(x, y) + private readonly struct IndexOfMaxOperator : IIndexOfOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector128 result, Vector128 maxIndex) + { + Vector128 tmpResult = Vector128.Shuffle(result, Vector128.Create(2, 3, 0, 1)); + Vector128 tmpIndex = Vector128.Shuffle(maxIndex, Vector128.Create(2, 3, 0, 1)); + + Invoke(ref result, tmpResult, ref maxIndex, tmpIndex); + + tmpResult = Vector128.Shuffle(result, Vector128.Create(1, 0, 3, 2)); + tmpIndex = Vector128.Shuffle(maxIndex, Vector128.Create(1, 0, 3, 2)); + + Invoke(ref result, tmpResult, ref maxIndex, tmpIndex); + return maxIndex.ToScalar(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector128 max, Vector128 current, ref Vector128 maxIndex, Vector128 curIndex) + { + Vector128 greaterThanMask = Vector128.GreaterThan(max, current); + + Vector128 equalMask = Vector128.Equals(max, current); + if (equalMask.AsInt32() != Vector128.Zero) + { + Vector128 negativeMask = IsNegative(current); + Vector128 lessThanMask = Vector128.LessThan(maxIndex, curIndex); + + greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle()); + } + + max = ElementWiseSelect(greaterThanMask, max, current); + + maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector256 result, Vector256 maxIndex) + { + // Max the upper/lower halves of the Vector256 + Vector128 resultLower = result.GetLower(); + Vector128 indexLower = maxIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, maxIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector256 max, Vector256 current, ref Vector256 maxIndex, Vector256 curIndex) + { + Vector256 greaterThanMask = Vector256.GreaterThan(max, current); + + Vector256 equalMask = Vector256.Equals(max, current); + if (equalMask.AsInt32() != Vector256.Zero) + { + Vector256 negativeMask = IsNegative(current); + Vector256 lessThanMask = Vector256.LessThan(maxIndex, curIndex); + + greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle()); + } + + max = ElementWiseSelect(greaterThanMask, max, current); + + maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector512 result, Vector512 resultIndex) + { + // Min the upper/lower halves of the Vector512 + Vector256 resultLower = result.GetLower(); + Vector256 indexLower = resultIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector512 max, Vector512 current, ref Vector512 maxIndex, Vector512 curIndex) + { + Vector512 greaterThanMask = Vector512.GreaterThan(max, current); + + Vector512 equalMask = Vector512.Equals(max, current); + if (equalMask.AsInt32() != Vector512.Zero) + { + Vector512 negativeMask = IsNegative(current); + Vector512 lessThanMask = Vector512.LessThan(maxIndex, curIndex); + + greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle()); + } + + max = ElementWiseSelect(greaterThanMask, max, current); + + maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex); + } +#endif + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(ref float result, float current, int resultIndex, int curIndex) + { + if (result == current) + { + if (IsNegative(result) && !IsNegative(current)) + { + result = current; + return curIndex; + } + } + else if (current > result) + { + result = current; + return curIndex; + } + + return resultIndex; + } + } + + private readonly struct IndexOfMaxMagnitudeOperator : IIndexOfOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector128 result, Vector128 maxIndex) + { + Vector128 tmpResult = Vector128.Shuffle(result, Vector128.Create(2, 3, 0, 1)); + Vector128 tmpIndex = Vector128.Shuffle(maxIndex, Vector128.Create(2, 3, 0, 1)); + + Invoke(ref result, tmpResult, ref maxIndex, tmpIndex); + + tmpResult = Vector128.Shuffle(result, Vector128.Create(1, 0, 3, 2)); + tmpIndex = Vector128.Shuffle(maxIndex, Vector128.Create(1, 0, 3, 2)); + + Invoke(ref result, tmpResult, ref maxIndex, tmpIndex); + return maxIndex.ToScalar(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector128 max, Vector128 current, ref Vector128 maxIndex, Vector128 curIndex) + { + Vector128 maxMag = Vector128.Abs(max), currentMag = Vector128.Abs(current); + + Vector128 greaterThanMask = Vector128.GreaterThan(maxMag, currentMag); + + Vector128 equalMask = Vector128.Equals(max, current); + if (equalMask.AsInt32() != Vector128.Zero) + { + Vector128 negativeMask = IsNegative(current); + Vector128 lessThanMask = Vector128.LessThan(maxIndex, curIndex); + + greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle()); + } + + max = ElementWiseSelect(greaterThanMask, max, current); + + maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector256 result, Vector256 maxIndex) + { + // Max the upper/lower halves of the Vector256 + Vector128 resultLower = result.GetLower(); + Vector128 indexLower = maxIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, maxIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector256 max, Vector256 current, ref Vector256 maxIndex, Vector256 curIndex) + { + Vector256 maxMag = Vector256.Abs(max), currentMag = Vector256.Abs(current); + + Vector256 greaterThanMask = Vector256.GreaterThan(maxMag, currentMag); + + Vector256 equalMask = Vector256.Equals(max, current); + if (equalMask.AsInt32() != Vector256.Zero) + { + Vector256 negativeMask = IsNegative(current); + Vector256 lessThanMask = Vector256.LessThan(maxIndex, curIndex); + + greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle()); + } + + max = ElementWiseSelect(greaterThanMask, max, current); + + maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector512 result, Vector512 resultIndex) + { + // Min the upper/lower halves of the Vector512 + Vector256 resultLower = result.GetLower(); + Vector256 indexLower = resultIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector512 max, Vector512 current, ref Vector512 maxIndex, Vector512 curIndex) + { + Vector512 maxMag = Vector512.Abs(max), currentMag = Vector512.Abs(current); + Vector512 greaterThanMask = Vector512.GreaterThan(maxMag, currentMag); + + Vector512 equalMask = Vector512.Equals(max, current); + if (equalMask.AsInt32() != Vector512.Zero) + { + Vector512 negativeMask = IsNegative(current); + Vector512 lessThanMask = Vector512.LessThan(maxIndex, curIndex); + + greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle()); + } + + max = ElementWiseSelect(greaterThanMask, max, current); + + maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex); + } +#endif + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(ref float result, float current, int resultIndex, int curIndex) + { + float curMaxAbs = MathF.Abs(result); + float currentAbs = MathF.Abs(current); + + if (curMaxAbs == currentAbs) + { + if (IsNegative(result) && !IsNegative(current)) + { + result = current; + return curIndex; + } + } + else if (currentAbs > curMaxAbs) + { + result = current; + return curIndex; + } + + return resultIndex; + } + } + + /// Returns the index of MathF.Min(x, y) + private readonly struct IndexOfMinOperator : IIndexOfOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector128 result, Vector128 resultIndex) + { + Vector128 tmpResult = Vector128.Shuffle(result, Vector128.Create(2, 3, 0, 1)); + Vector128 tmpIndex = Vector128.Shuffle(resultIndex, Vector128.Create(2, 3, 0, 1)); + + Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); + + tmpResult = Vector128.Shuffle(result, Vector128.Create(1, 0, 3, 2)); + tmpIndex = Vector128.Shuffle(resultIndex, Vector128.Create(1, 0, 3, 2)); + + Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); + return resultIndex.ToScalar(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector128 result, Vector128 current, ref Vector128 resultIndex, Vector128 curIndex) + { + Vector128 lessThanMask = Vector128.LessThan(result, current); + + Vector128 equalMask = Vector128.Equals(result, current); + if (equalMask.AsInt32() != Vector128.Zero) + { + Vector128 negativeMask = IsNegative(current); + Vector128 lessThanIndexMask = Vector128.LessThan(resultIndex, curIndex); + + lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle()); + } + + result = ElementWiseSelect(lessThanMask, result, current); + + resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex); + } + + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector256 result, Vector256 resultIndex) + { + // Min the upper/lower halves of the Vector256 + Vector128 resultLower = result.GetLower(); + Vector128 indexLower = resultIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector256 result, Vector256 current, ref Vector256 resultIndex, Vector256 curIndex) + { + Vector256 lessThanMask = Vector256.LessThan(result, current); + + Vector256 equalMask = Vector256.Equals(result, current); + if (equalMask.AsInt32() != Vector256.Zero) + { + Vector256 negativeMask = IsNegative(current); + Vector256 lessThanIndexMask = Vector256.LessThan(resultIndex, curIndex); + + lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle()); + } + + result = ElementWiseSelect(lessThanMask, result, current); + + resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector512 result, Vector512 resultIndex) + { + // Min the upper/lower halves of the Vector512 + Vector256 resultLower = result.GetLower(); + Vector256 indexLower = resultIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector512 result, Vector512 current, ref Vector512 resultIndex, Vector512 curIndex) + { + Vector512 lessThanMask = Vector512.LessThan(result, current); + + Vector512 equalMask = Vector512.Equals(result, current); + if (equalMask.AsInt32() != Vector512.Zero) + { + Vector512 negativeMask = IsNegative(current); + Vector512 lessThanIndexMask = Vector512.LessThan(resultIndex, curIndex); + + lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle()); + } + + result = ElementWiseSelect(lessThanMask, result, current); + + resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex); + } +#endif + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(ref float result, float current, int resultIndex, int curIndex) + { + if (result == current) + { + if (IsPositive(result) && !IsPositive(current)) + { + result = current; + return curIndex; + } + } + else if (current < result) + { + result = current; + return curIndex; + } + + return resultIndex; + } + } + + private readonly struct IndexOfMinMagnitudeOperator : IIndexOfOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector128 result, Vector128 resultIndex) + { + Vector128 tmpResult = Vector128.Shuffle(result, Vector128.Create(2, 3, 0, 1)); + Vector128 tmpIndex = Vector128.Shuffle(resultIndex, Vector128.Create(2, 3, 0, 1)); + + Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); + + tmpResult = Vector128.Shuffle(result, Vector128.Create(1, 0, 3, 2)); + tmpIndex = Vector128.Shuffle(resultIndex, Vector128.Create(1, 0, 3, 2)); + + Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); + return resultIndex.ToScalar(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector128 result, Vector128 current, ref Vector128 resultIndex, Vector128 curIndex) + { + Vector128 minMag = Vector128.Abs(result), currentMag = Vector128.Abs(current); + + Vector128 lessThanMask = Vector128.LessThan(minMag, currentMag); + + Vector128 equalMask = Vector128.Equals(result, current); + if (equalMask.AsInt32() != Vector128.Zero) + { + Vector128 negativeMask = IsNegative(current); + Vector128 lessThanIndexMask = Vector128.LessThan(resultIndex, curIndex); + + lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle()); + } + + result = ElementWiseSelect(lessThanMask, result, current); + + resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector256 result, Vector256 resultIndex) + { + // Min the upper/lower halves of the Vector256 + Vector128 resultLower = result.GetLower(); + Vector128 indexLower = resultIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector256 result, Vector256 current, ref Vector256 resultIndex, Vector256 curIndex) + { + Vector256 minMag = Vector256.Abs(result), currentMag = Vector256.Abs(current); + + Vector256 lessThanMask = Vector256.LessThan(minMag, currentMag); + + Vector256 equalMask = Vector256.Equals(result, current); + if (equalMask.AsInt32() != Vector256.Zero) + { + Vector256 negativeMask = IsNegative(current); + Vector256 lessThanIndexMask = Vector256.LessThan(resultIndex, curIndex); + + lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle()); + } + + result = ElementWiseSelect(lessThanMask, result, current); + + resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector512 result, Vector512 resultIndex) + { + // Min the upper/lower halves of the Vector512 + Vector256 resultLower = result.GetLower(); + Vector256 indexLower = resultIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector512 result, Vector512 current, ref Vector512 resultIndex, Vector512 curIndex) + { + Vector512 minMag = Vector512.Abs(result), currentMag = Vector512.Abs(current); + + Vector512 lessThanMask = Vector512.LessThan(minMag, currentMag); + + Vector512 equalMask = Vector512.Equals(result, current); + if (equalMask.AsInt32() != Vector512.Zero) + { + Vector512 negativeMask = IsNegative(current); + Vector512 lessThanIndexMask = Vector512.LessThan(resultIndex, curIndex); + + lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle()); + } + + result = ElementWiseSelect(lessThanMask, result, current); + + resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex); + } +#endif + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(ref float result, float current, int resultIndex, int curIndex) + { + float curMinAbs = MathF.Abs(result); + float currentAbs = MathF.Abs(current); + if (curMinAbs == currentAbs) + { + if (IsPositive(result) && !IsPositive(current)) + { + result = current; + return curIndex; + } + } + else if (currentAbs < curMinAbs) + { + result = current; + return curIndex; + } + + return resultIndex; + } + } + + /// MathF.Max(x, y) + private readonly struct MaxPropagateNaNOperator : IBinaryOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float Invoke(float x, float y) => MathF.Max(x, y); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + if (AdvSimd.IsSupported) + { + return AdvSimd.Max(x, y); + } + + return + Vector128.ConditionalSelect(Vector128.Equals(x, x), + Vector128.ConditionalSelect(Vector128.Equals(y, y), + Vector128.ConditionalSelect(Vector128.Equals(x, y), + Vector128.ConditionalSelect(IsNegative(x), y, x), + Vector128.Max(x, y)), + y), + x); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Invoke(Vector256 x, Vector256 y) => + Vector256.ConditionalSelect(Vector256.Equals(x, x), + Vector256.ConditionalSelect(Vector256.Equals(y, y), + Vector256.ConditionalSelect(Vector256.Equals(x, y), + Vector256.ConditionalSelect(IsNegative(x), y, x), + Vector256.Max(x, y)), + y), + x); + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Invoke(Vector512 x, Vector512 y) => + Vector512.ConditionalSelect(Vector512.Equals(x, x), + Vector512.ConditionalSelect(Vector512.Equals(y, y), + Vector512.ConditionalSelect(Vector512.Equals(x, y), + Vector512.ConditionalSelect(IsNegative(x), y, x), + Vector512.Max(x, y)), + y), + x); +#endif + } + + /// Operator to get x or y based on which has the larger MathF.Abs (but NaNs may not be propagated) + private readonly struct MaxMagnitudeOperator : IAggregationOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float Invoke(float x, float y) + { + float xMag = MathF.Abs(x), yMag = MathF.Abs(y); + return + xMag == yMag ? + (IsNegative(x) ? y : x) : + (xMag > yMag ? x : y); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + Vector128 xMag = Vector128.Abs(x), yMag = Vector128.Abs(y); + return + Vector128.ConditionalSelect(Vector128.Equals(xMag, yMag), + Vector128.ConditionalSelect(IsNegative(x), y, x), + Vector128.ConditionalSelect(Vector128.GreaterThan(xMag, yMag), x, y)); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Invoke(Vector256 x, Vector256 y) + { + Vector256 xMag = Vector256.Abs(x), yMag = Vector256.Abs(y); + return + Vector256.ConditionalSelect(Vector256.Equals(xMag, yMag), + Vector256.ConditionalSelect(IsNegative(x), y, x), + Vector256.ConditionalSelect(Vector256.GreaterThan(xMag, yMag), x, y)); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Invoke(Vector512 x, Vector512 y) + { + Vector512 xMag = Vector512.Abs(x), yMag = Vector512.Abs(y); + return + Vector512.ConditionalSelect(Vector512.Equals(xMag, yMag), + Vector512.ConditionalSelect(IsNegative(x), y, x), + Vector512.ConditionalSelect(Vector512.GreaterThan(xMag, yMag), x, y)); + } +#endif + + public static float Invoke(Vector128 x) => HorizontalAggregate(x); + public static float Invoke(Vector256 x) => HorizontalAggregate(x); +#if NET8_0_OR_GREATER + public static float Invoke(Vector512 x) => HorizontalAggregate(x); +#endif + } + + /// Operator to get x or y based on which has the larger MathF.Abs + private readonly struct MaxMagnitudePropagateNaNOperator : IBinaryOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float Invoke(float x, float y) => MathF.MaxMagnitude(x, y); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + Vector128 xMag = Vector128.Abs(x), yMag = Vector128.Abs(y); + return + Vector128.ConditionalSelect(Vector128.Equals(x, x), + Vector128.ConditionalSelect(Vector128.Equals(y, y), + Vector128.ConditionalSelect(Vector128.Equals(yMag, xMag), + Vector128.ConditionalSelect(IsNegative(x), y, x), + Vector128.ConditionalSelect(Vector128.GreaterThan(yMag, xMag), y, x)), + y), + x); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Invoke(Vector256 x, Vector256 y) + { + Vector256 xMag = Vector256.Abs(x), yMag = Vector256.Abs(y); + return + Vector256.ConditionalSelect(Vector256.Equals(x, x), + Vector256.ConditionalSelect(Vector256.Equals(y, y), + Vector256.ConditionalSelect(Vector256.Equals(xMag, yMag), + Vector256.ConditionalSelect(IsNegative(x), y, x), + Vector256.ConditionalSelect(Vector256.GreaterThan(xMag, yMag), x, y)), + y), + x); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Invoke(Vector512 x, Vector512 y) + { + Vector512 xMag = Vector512.Abs(x), yMag = Vector512.Abs(y); + return + Vector512.ConditionalSelect(Vector512.Equals(x, x), + Vector512.ConditionalSelect(Vector512.Equals(y, y), + Vector512.ConditionalSelect(Vector512.Equals(xMag, yMag), + Vector512.ConditionalSelect(IsNegative(x), y, x), + Vector512.ConditionalSelect(Vector512.GreaterThan(xMag, yMag), x, y)), + y), + x); + } +#endif + } + + /// MathF.Min(x, y) (but NaNs may not be propagated) + private readonly struct MinOperator : IAggregationOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float Invoke(float x, float y) => + x == y ? + (IsNegative(y) ? y : x) : + (y < x ? y : x); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + if (AdvSimd.IsSupported) + { + return AdvSimd.Min(x, y); + } + + return + Vector128.ConditionalSelect(Vector128.Equals(x, y), + Vector128.ConditionalSelect(IsNegative(y), y, x), + Vector128.Min(x, y)); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Invoke(Vector256 x, Vector256 y) => + Vector256.ConditionalSelect(Vector256.Equals(x, y), + Vector256.ConditionalSelect(IsNegative(y), y, x), + Vector256.Min(x, y)); + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Invoke(Vector512 x, Vector512 y) => + Vector512.ConditionalSelect(Vector512.Equals(x, y), + Vector512.ConditionalSelect(IsNegative(y), y, x), + Vector512.Min(x, y)); +#endif + + public static float Invoke(Vector128 x) => HorizontalAggregate(x); + public static float Invoke(Vector256 x) => HorizontalAggregate(x); +#if NET8_0_OR_GREATER + public static float Invoke(Vector512 x) => HorizontalAggregate(x); +#endif + } + + /// MathF.Min(x, y) + private readonly struct MinPropagateNaNOperator : IBinaryOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float Invoke(float x, float y) => MathF.Min(x, y); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + if (AdvSimd.IsSupported) + { + return AdvSimd.Min(x, y); + } + + return + Vector128.ConditionalSelect(Vector128.Equals(x, x), + Vector128.ConditionalSelect(Vector128.Equals(y, y), + Vector128.ConditionalSelect(Vector128.Equals(x, y), + Vector128.ConditionalSelect(IsNegative(x), x, y), + Vector128.Min(x, y)), + y), + x); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Invoke(Vector256 x, Vector256 y) => + Vector256.ConditionalSelect(Vector256.Equals(x, x), + Vector256.ConditionalSelect(Vector256.Equals(y, y), + Vector256.ConditionalSelect(Vector256.Equals(x, y), + Vector256.ConditionalSelect(IsNegative(x), x, y), + Vector256.Min(x, y)), + y), + x); + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Invoke(Vector512 x, Vector512 y) => + Vector512.ConditionalSelect(Vector512.Equals(x, x), + Vector512.ConditionalSelect(Vector512.Equals(y, y), + Vector512.ConditionalSelect(Vector512.Equals(x, y), + Vector512.ConditionalSelect(IsNegative(x), x, y), + Vector512.Min(x, y)), + y), + x); +#endif + } + + /// Operator to get x or y based on which has the smaller MathF.Abs (but NaNs may not be propagated) + private readonly struct MinMagnitudeOperator : IAggregationOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float Invoke(float x, float y) + { + float xMag = MathF.Abs(x), yMag = MathF.Abs(y); + return xMag == yMag ? + (IsNegative(y) ? y : x) : + (yMag < xMag ? y : x); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + Vector128 xMag = Vector128.Abs(x), yMag = Vector128.Abs(y); + return + Vector128.ConditionalSelect(Vector128.Equals(yMag, xMag), + Vector128.ConditionalSelect(IsNegative(y), y, x), + Vector128.ConditionalSelect(Vector128.LessThan(yMag, xMag), y, x)); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Invoke(Vector256 x, Vector256 y) + { + Vector256 xMag = Vector256.Abs(x), yMag = Vector256.Abs(y); + return + Vector256.ConditionalSelect(Vector256.Equals(yMag, xMag), + Vector256.ConditionalSelect(IsNegative(y), y, x), + Vector256.ConditionalSelect(Vector256.LessThan(yMag, xMag), y, x)); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Invoke(Vector512 x, Vector512 y) + { + Vector512 xMag = Vector512.Abs(x), yMag = Vector512.Abs(y); + return + Vector512.ConditionalSelect(Vector512.Equals(yMag, xMag), + Vector512.ConditionalSelect(IsNegative(y), y, x), + Vector512.ConditionalSelect(Vector512.LessThan(yMag, xMag), y, x)); + } +#endif + + public static float Invoke(Vector128 x) => HorizontalAggregate(x); + public static float Invoke(Vector256 x) => HorizontalAggregate(x); +#if NET8_0_OR_GREATER + public static float Invoke(Vector512 x) => HorizontalAggregate(x); +#endif + } + + /// Operator to get x or y based on which has the smaller MathF.Abs + private readonly struct MinMagnitudePropagateNaNOperator : IBinaryOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float Invoke(float x, float y) => MathF.MinMagnitude(x, y); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Invoke(Vector128 x, Vector128 y) + { + Vector128 xMag = Vector128.Abs(x), yMag = Vector128.Abs(y); + return + Vector128.ConditionalSelect(Vector128.Equals(x, x), + Vector128.ConditionalSelect(Vector128.Equals(y, y), + Vector128.ConditionalSelect(Vector128.Equals(yMag, xMag), + Vector128.ConditionalSelect(IsNegative(x), x, y), + Vector128.ConditionalSelect(Vector128.LessThan(xMag, yMag), x, y)), + y), + x); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Invoke(Vector256 x, Vector256 y) + { + Vector256 xMag = Vector256.Abs(x), yMag = Vector256.Abs(y); + return + Vector256.ConditionalSelect(Vector256.Equals(x, x), + Vector256.ConditionalSelect(Vector256.Equals(y, y), + Vector256.ConditionalSelect(Vector256.Equals(yMag, xMag), + Vector256.ConditionalSelect(IsNegative(x), x, y), + Vector256.ConditionalSelect(Vector256.LessThan(xMag, yMag), x, y)), + y), + x); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Invoke(Vector512 x, Vector512 y) + { + Vector512 xMag = Vector512.Abs(x), yMag = Vector512.Abs(y); + return + Vector512.ConditionalSelect(Vector512.Equals(x, x), + Vector512.ConditionalSelect(Vector512.Equals(y, y), + Vector512.ConditionalSelect(Vector512.Equals(yMag, xMag), + Vector512.ConditionalSelect(IsNegative(x), x, y), + Vector512.ConditionalSelect(Vector512.LessThan(xMag, yMag), x, y)), + y), + x); + } +#endif + } + + /// -x + private readonly struct NegateOperator : IUnaryOperator + { + public static float Invoke(float x) => -x; + public static Vector128 Invoke(Vector128 x) => -x; + public static Vector256 Invoke(Vector256 x) => -x; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x) => -x; +#endif + } + + /// (x + y) * z + private readonly struct AddMultiplyOperator : ITernaryOperator + { + public static float Invoke(float x, float y, float z) => (x + y) * z; + public static Vector128 Invoke(Vector128 x, Vector128 y, Vector128 z) => (x + y) * z; + public static Vector256 Invoke(Vector256 x, Vector256 y, Vector256 z) => (x + y) * z; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y, Vector512 z) => (x + y) * z; +#endif + } + + /// (x * y) + z + private readonly struct MultiplyAddOperator : ITernaryOperator + { + public static float Invoke(float x, float y, float z) => (x * y) + z; + public static Vector128 Invoke(Vector128 x, Vector128 y, Vector128 z) => (x * y) + z; + public static Vector256 Invoke(Vector256 x, Vector256 y, Vector256 z) => (x * y) + z; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x, Vector512 y, Vector512 z) => (x * y) + z; +#endif + } + + /// x + private readonly struct IdentityOperator : IUnaryOperator + { + public static float Invoke(float x) => x; + public static Vector128 Invoke(Vector128 x) => x; + public static Vector256 Invoke(Vector256 x) => x; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x) => x; +#endif + } + + /// x * x + private readonly struct SquaredOperator : IUnaryOperator + { + public static float Invoke(float x) => x * x; + public static Vector128 Invoke(Vector128 x) => x * x; + public static Vector256 Invoke(Vector256 x) => x * x; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x) => x * x; +#endif + } + + /// MathF.Abs(x) + private readonly struct AbsoluteOperator : IUnaryOperator + { + public static float Invoke(float x) => MathF.Abs(x); + public static Vector128 Invoke(Vector128 x) => Vector128.Abs(x); + public static Vector256 Invoke(Vector256 x) => Vector256.Abs(x); +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x) => Vector512.Abs(x); +#endif + } + + /// MathF.Exp(x) + private readonly struct ExpOperator : IUnaryOperator + { + // This code is based on `vrs4_expf` from amd/aocl-libm-ose + // Copyright (C) 2019-2022 Advanced Micro Devices, Inc. All rights reserved. + // + // Licensed under the BSD 3-Clause "New" or "Revised" License + // See THIRD-PARTY-NOTICES.TXT for the full license text + + // Implementation Notes: + // 1. Argument Reduction: + // e^x = 2^(x/ln2) --- (1) + // + // Let x/ln(2) = z --- (2) + // + // Let z = n + r , where n is an integer --- (3) + // |r| <= 1/2 + // + // From (1), (2) and (3), + // e^x = 2^z + // = 2^(N+r) + // = (2^N)*(2^r) --- (4) + // + // 2. Polynomial Evaluation + // From (4), + // r = z - N + // 2^r = C1 + C2*r + C3*r^2 + C4*r^3 + C5 *r^4 + C6*r^5 + // + // 4. Reconstruction + // Thus, + // e^x = (2^N) * (2^r) + + private const uint V_ARG_MAX = 0x42AE0000; + private const uint V_MASK = 0x7FFFFFFF; + + private const float V_EXPF_MIN = -103.97208f; + private const float V_EXPF_MAX = 88.72284f; + + private const double V_EXPF_HUGE = 6755399441055744; + private const double V_TBL_LN2 = 1.4426950408889634; + + private const double C1 = 1.0000000754895704; + private const double C2 = 0.6931472254087585; + private const double C3 = 0.2402210737432219; + private const double C4 = 0.05550297297702539; + private const double C5 = 0.009676036358193323; + private const double C6 = 0.001341000536524434; + + public static float Invoke(float x) => MathF.Exp(x); + + public static Vector128 Invoke(Vector128 x) + { + // Convert x to double precision + (Vector128 xl, Vector128 xu) = Vector128.Widen(x); + + // x * (64.0 / ln(2)) + Vector128 v_tbl_ln2 = Vector128.Create(V_TBL_LN2); + + Vector128 zl = xl * v_tbl_ln2; + Vector128 zu = xu * v_tbl_ln2; + + Vector128 v_expf_huge = Vector128.Create(V_EXPF_HUGE); + + Vector128 dnl = zl + v_expf_huge; + Vector128 dnu = zu + v_expf_huge; + + // n = int (z) + Vector128 nl = dnl.AsUInt64(); + Vector128 nu = dnu.AsUInt64(); + + // dn = double(n) + dnl -= v_expf_huge; + dnu -= v_expf_huge; + + // r = z - dn + Vector128 c1 = Vector128.Create(C1); + Vector128 c2 = Vector128.Create(C2); + Vector128 c3 = Vector128.Create(C3); + Vector128 c4 = Vector128.Create(C4); + Vector128 c5 = Vector128.Create(C5); + Vector128 c6 = Vector128.Create(C6); + + Vector128 rl = zl - dnl; + + Vector128 rl2 = rl * rl; + Vector128 rl4 = rl2 * rl2; + + Vector128 polyl = (c4 * rl + c3) * rl2 + + ((c6 * rl + c5) * rl4 + + (c2 * rl + c1)); + + + Vector128 ru = zu - dnu; + + Vector128 ru2 = ru * ru; + Vector128 ru4 = ru2 * ru2; + + Vector128 polyu = (c4 * ru + c3) * ru2 + + ((c6 * ru + c5) * ru4 + + (c2 * ru + c1)); + + // result = (float)[poly + (n << 52)] + Vector128 ret = Vector128.Narrow( + (polyl.AsUInt64() + Vector128.ShiftLeft(nl, 52)).AsDouble(), + (polyu.AsUInt64() + Vector128.ShiftLeft(nu, 52)).AsDouble() + ); + + // Check if -103 < |x| < 88 + if (Vector128.GreaterThanAny(x.AsUInt32() & Vector128.Create(V_MASK), Vector128.Create(V_ARG_MAX))) + { + // (x > V_EXPF_MAX) ? float.PositiveInfinity : x + Vector128 infinityMask = Vector128.GreaterThan(x, Vector128.Create(V_EXPF_MAX)); + + ret = Vector128.ConditionalSelect( + infinityMask, + Vector128.Create(float.PositiveInfinity), + ret + ); + + // (x < V_EXPF_MIN) ? 0 : x + ret = Vector128.AndNot(ret, Vector128.LessThan(x, Vector128.Create(V_EXPF_MIN))); + } + + return ret; + } + + public static Vector256 Invoke(Vector256 x) + { + // Convert x to double precision + (Vector256 xl, Vector256 xu) = Vector256.Widen(x); + + // x * (64.0 / ln(2)) + Vector256 v_tbl_ln2 = Vector256.Create(V_TBL_LN2); + + Vector256 zl = xl * v_tbl_ln2; + Vector256 zu = xu * v_tbl_ln2; + + Vector256 v_expf_huge = Vector256.Create(V_EXPF_HUGE); + + Vector256 dnl = zl + v_expf_huge; + Vector256 dnu = zu + v_expf_huge; + + // n = int (z) + Vector256 nl = dnl.AsUInt64(); + Vector256 nu = dnu.AsUInt64(); + + // dn = double(n) + dnl -= v_expf_huge; + dnu -= v_expf_huge; + + // r = z - dn + Vector256 c1 = Vector256.Create(C1); + Vector256 c2 = Vector256.Create(C2); + Vector256 c3 = Vector256.Create(C3); + Vector256 c4 = Vector256.Create(C4); + Vector256 c5 = Vector256.Create(C5); + Vector256 c6 = Vector256.Create(C6); + + Vector256 rl = zl - dnl; + + Vector256 rl2 = rl * rl; + Vector256 rl4 = rl2 * rl2; + + Vector256 polyl = (c4 * rl + c3) * rl2 + + ((c6 * rl + c5) * rl4 + + (c2 * rl + c1)); + + + Vector256 ru = zu - dnu; + + Vector256 ru2 = ru * ru; + Vector256 ru4 = ru2 * ru2; + + Vector256 polyu = (c4 * ru + c3) * ru2 + + ((c6 * ru + c5) * ru4 + + (c2 * ru + c1)); + + // result = (float)[poly + (n << 52)] + Vector256 ret = Vector256.Narrow( + (polyl.AsUInt64() + Vector256.ShiftLeft(nl, 52)).AsDouble(), + (polyu.AsUInt64() + Vector256.ShiftLeft(nu, 52)).AsDouble() + ); + + // Check if -103 < |x| < 88 + if (Vector256.GreaterThanAny(x.AsUInt32() & Vector256.Create(V_MASK), Vector256.Create(V_ARG_MAX))) + { + // (x > V_EXPF_MAX) ? float.PositiveInfinity : x + Vector256 infinityMask = Vector256.GreaterThan(x, Vector256.Create(V_EXPF_MAX)); + + ret = Vector256.ConditionalSelect( + infinityMask, + Vector256.Create(float.PositiveInfinity), + ret + ); + + // (x < V_EXPF_MIN) ? 0 : x + ret = Vector256.AndNot(ret, Vector256.LessThan(x, Vector256.Create(V_EXPF_MIN))); + } + + return ret; + } + +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x) + { + // Convert x to double precision + (Vector512 xl, Vector512 xu) = Vector512.Widen(x); + + // x * (64.0 / ln(2)) + Vector512 v_tbl_ln2 = Vector512.Create(V_TBL_LN2); + + Vector512 zl = xl * v_tbl_ln2; + Vector512 zu = xu * v_tbl_ln2; + + Vector512 v_expf_huge = Vector512.Create(V_EXPF_HUGE); + + Vector512 dnl = zl + v_expf_huge; + Vector512 dnu = zu + v_expf_huge; + + // n = int (z) + Vector512 nl = dnl.AsUInt64(); + Vector512 nu = dnu.AsUInt64(); + + // dn = double(n) + dnl -= v_expf_huge; + dnu -= v_expf_huge; + + // r = z - dn + Vector512 c1 = Vector512.Create(C1); + Vector512 c2 = Vector512.Create(C2); + Vector512 c3 = Vector512.Create(C3); + Vector512 c4 = Vector512.Create(C4); + Vector512 c5 = Vector512.Create(C5); + Vector512 c6 = Vector512.Create(C6); + + Vector512 rl = zl - dnl; + + Vector512 rl2 = rl * rl; + Vector512 rl4 = rl2 * rl2; + + Vector512 polyl = (c4 * rl + c3) * rl2 + + ((c6 * rl + c5) * rl4 + + (c2 * rl + c1)); + + + Vector512 ru = zu - dnu; + + Vector512 ru2 = ru * ru; + Vector512 ru4 = ru2 * ru2; + + Vector512 polyu = (c4 * ru + c3) * ru2 + + ((c6 * ru + c5) * ru4 + + (c2 * ru + c1)); + + // result = (float)[poly + (n << 52)] + Vector512 ret = Vector512.Narrow( + (polyl.AsUInt64() + Vector512.ShiftLeft(nl, 52)).AsDouble(), + (polyu.AsUInt64() + Vector512.ShiftLeft(nu, 52)).AsDouble() + ); + + // Check if -103 < |x| < 88 + if (Vector512.GreaterThanAny(x.AsUInt32() & Vector512.Create(V_MASK), Vector512.Create(V_ARG_MAX))) + { + // (x > V_EXPF_MAX) ? float.PositiveInfinity : x + Vector512 infinityMask = Vector512.GreaterThan(x, Vector512.Create(V_EXPF_MAX)); + + ret = Vector512.ConditionalSelect( + infinityMask, + Vector512.Create(float.PositiveInfinity), + ret + ); + + // (x < V_EXPF_MIN) ? 0 : x + ret = Vector512.AndNot(ret, Vector512.LessThan(x, Vector512.Create(V_EXPF_MIN))); + } + + return ret; + } +#endif + } + + /// MathF.Cosh(x) + private readonly struct CoshOperator : IUnaryOperator + { + // This code is based on `vrs4_coshf` from amd/aocl-libm-ose + // Copyright (C) 2008-2022 Advanced Micro Devices, Inc. All rights reserved. + // + // Licensed under the BSD 3-Clause "New" or "Revised" License + // See THIRD-PARTY-NOTICES.TXT for the full license text + + // Spec: + // coshf(|x| > 89.415985107421875) = Infinity + // coshf(Infinity) = infinity + // coshf(-Infinity) = infinity + // + // cosh(x) = (exp(x) + exp(-x))/2 + // cosh(-x) = +cosh(x) + // + // checks for special cases + // if ( asint(x) > infinity) return x with overflow exception and + // return x. + // if x is NaN then raise invalid FP operation exception and return x. + // + // coshf = v/2 * exp(x - log(v)) where v = 0x1.0000e8p-1 + + private const float LOGV = 0.693161f; + private const float HALFV = 1.0000138f; + private const float INVV2 = 0.24999309f; + + public static float Invoke(float x) => MathF.Cosh(x); + + public static Vector128 Invoke(Vector128 x) + { + Vector128 y = Vector128.Abs(x); + Vector128 z = ExpOperator.Invoke(y - Vector128.Create(LOGV)); + return Vector128.Create(HALFV) * (z + (Vector128.Create(INVV2) / z)); + } - return; - } + public static Vector256 Invoke(Vector256 x) + { + Vector256 y = Vector256.Abs(x); + Vector256 z = ExpOperator.Invoke(y - Vector256.Create(LOGV)); + return Vector256.Create(HALFV) * (z + (Vector256.Create(INVV2) / z)); } - if (Vector128.IsHardwareAccelerated) +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x) { - oneVectorFromEnd = x.Length - Vector128.Count; - if (i <= oneVectorFromEnd) - { - Vector128 zVec = Vector128.Create(z); + Vector512 y = Vector512.Abs(x); + Vector512 z = ExpOperator.Invoke(y - Vector512.Create(LOGV)); + return Vector512.Create(HALFV) * (z + (Vector512.Create(INVV2) / z)); + } +#endif + } - // Loop handling one vector at a time. - do - { - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), - Vector128.LoadUnsafe(ref yRef, (uint)i), - zVec).StoreUnsafe(ref dRef, (uint)i); + /// MathF.Sinh(x) + private readonly struct SinhOperator : IUnaryOperator + { + // Same as cosh, but with `z -` rather than `z +`, and with the sign + // flipped on the result based on the sign of the input. - i += Vector128.Count; - } - while (i <= oneVectorFromEnd); + private const uint SIGN_MASK = 0x7FFFFFFF; + private const float LOGV = 0.693161f; + private const float HALFV = 1.0000138f; + private const float INVV2 = 0.24999309f; - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - Vector128.LoadUnsafe(ref yRef, lastVectorIndex), - zVec).StoreUnsafe(ref dRef, lastVectorIndex); - } + public static float Invoke(float x) => MathF.Sinh(x); - return; - } + public static Vector128 Invoke(Vector128 x) + { + Vector128 y = Vector128.Abs(x); + Vector128 z = ExpOperator.Invoke(y - Vector128.Create(LOGV)); + Vector128 result = Vector128.Create(HALFV) * (z - (Vector128.Create(INVV2) / z)); + Vector128 sign = x.AsUInt32() & Vector128.Create(~SIGN_MASK); + return (sign ^ result.AsUInt32()).AsSingle(); } - while (i < x.Length) + public static Vector256 Invoke(Vector256 x) { - Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i), - z); + Vector256 y = Vector256.Abs(x); + Vector256 z = ExpOperator.Invoke(y - Vector256.Create(LOGV)); + Vector256 result = Vector256.Create(HALFV) * (z - (Vector256.Create(INVV2) / z)); + Vector256 sign = x.AsUInt32() & Vector256.Create(~SIGN_MASK); + return (sign ^ result.AsUInt32()).AsSingle(); + } - i++; +#if NET8_0_OR_GREATER + public static Vector512 Invoke(Vector512 x) + { + Vector512 y = Vector512.Abs(x); + Vector512 z = ExpOperator.Invoke(y - Vector512.Create(LOGV)); + Vector512 result = Vector512.Create(HALFV) * (z - (Vector512.Create(INVV2) / z)); + Vector512 sign = x.AsUInt32() & Vector512.Create(~SIGN_MASK); + return (sign ^ result.AsUInt32()).AsSingle(); } +#endif } - private static unsafe void InvokeSpanScalarSpanIntoSpan( - ReadOnlySpan x, float y, ReadOnlySpan z, Span destination) - where TTernaryOperator : ITernaryOperator + /// MathF.Tanh(x) + private readonly struct TanhOperator : IUnaryOperator { - if (x.Length != z.Length) + // This code is based on `vrs4_tanhf` from amd/aocl-libm-ose + // Copyright (C) 2008-2022 Advanced Micro Devices, Inc. All rights reserved. + // + // Licensed under the BSD 3-Clause "New" or "Revised" License + // See THIRD-PARTY-NOTICES.TXT for the full license text + + // To compute vrs4_tanhf(v_f32x4_t x) + // Let y = |x| + // If 0 <= y < 0x1.154246p3 + // Let z = e^(-2.0 * y) - 1 -(1) + // + // Using (1), tanhf(y) can be calculated as, + // tanhf(y) = -z / (z + 2.0) + // + // For other cases, call scalar tanhf() + // + // If x < 0, then we use the identity + // tanhf(-x) = -tanhf(x) + + private const uint SIGN_MASK = 0x7FFFFFFF; + + public static float Invoke(float x) => MathF.Tanh(x); + + public static Vector128 Invoke(Vector128 x) { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + Vector128 y = Vector128.Abs(x); + Vector128 z = ExpOperator.Invoke(Vector128.Create(-2f) * y) - Vector128.Create(1f); + Vector128 sign = x.AsUInt32() & Vector128.Create(~SIGN_MASK); + return (sign ^ (-z / (z + Vector128.Create(2f))).AsUInt32()).AsSingle(); } - if (x.Length > destination.Length) + public static Vector256 Invoke(Vector256 x) { - ThrowHelper.ThrowArgument_DestinationTooShort(); + Vector256 y = Vector256.Abs(x); + Vector256 z = ExpOperator.Invoke(Vector256.Create(-2f) * y) - Vector256.Create(1f); + Vector256 sign = x.AsUInt32() & Vector256.Create(~SIGN_MASK); + return (sign ^ (-z / (z + Vector256.Create(2f))).AsUInt32()).AsSingle(); } - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float zRef = ref MemoryMarshal.GetReference(z); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; - #if NET8_0_OR_GREATER - if (Vector512.IsHardwareAccelerated) + public static Vector512 Invoke(Vector512 x) { - oneVectorFromEnd = x.Length - Vector512.Count; - if (i <= oneVectorFromEnd) - { - Vector512 yVec = Vector512.Create(y); + Vector512 y = Vector512.Abs(x); + Vector512 z = ExpOperator.Invoke(Vector512.Create(-2f) * y) - Vector512.Create(1f); + Vector512 sign = x.AsUInt32() & Vector512.Create(~SIGN_MASK); + return (sign ^ (-z / (z + Vector512.Create(2f))).AsUInt32()).AsSingle(); + } +#endif + } - // Loop handling one vector at a time. - do - { - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i), - yVec, - Vector512.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + /// MathF.Log(x) + private readonly struct LogOperator : IUnaryOperator + { + // This code is based on `vrs4_logf` from amd/aocl-libm-ose + // Copyright (C) 2018-2019 Advanced Micro Devices, Inc. All rights reserved. + // + // Licensed under the BSD 3-Clause "New" or "Revised" License + // See THIRD-PARTY-NOTICES.TXT for the full license text + + // Spec: + // logf(x) + // = logf(x) if x ∈ F and x > 0 + // = x if x = qNaN + // = 0 if x = 1 + // = -inf if x = (-0, 0} + // = NaN otherwise + // + // Assumptions/Expectations + // - ULP is derived to be << 4 (always) + // - Some FPU Exceptions may not be available + // - Performance is at least 3x + // + // Implementation Notes: + // 1. Range Reduction: + // x = 2^n*(1+f) .... (1) + // where n is exponent and is an integer + // (1+f) is mantissa ∈ [1,2). i.e., 1 ≤ 1+f < 2 .... (2) + // + // From (1), taking log on both sides + // log(x) = log(2^n * (1+f)) + // = log(2^n) + log(1+f) + // = n*log(2) + log(1+f) .... (3) + // + // let z = 1 + f + // log(z) = log(k) + log(z) - log(k) + // log(z) = log(kz) - log(k) + // + // From (2), range of z is [1, 2) + // by simply dividing range by 'k', z is in [1/k, 2/k) .... (4) + // Best choice of k is the one which gives equal and opposite values + // at extrema +- -+ + // 1 | 2 | + // --- - 1 = - |--- - 1 | + // k | k | .... (5) + // +- -+ + // + // Solving for k, k = 3/2, + // From (4), using 'k' value, range is therefore [-0.3333, 0.3333] + // + // 2. Polynomial Approximation: + // More information refer to tools/sollya/vrs4_logf.sollya + // + // 7th Deg - Error abs: 0x1.04c4ac98p-22 rel: 0x1.2216e6f8p-19 + // 6th Deg - Error abs: 0x1.179e97d8p-19 rel: 0x1.db676c1p-17 + + private const uint V_MIN = 0x00800000; + private const uint V_MAX = 0x7F800000; + private const uint V_MASK = 0x007FFFFF; + private const uint V_OFF = 0x3F2AAAAB; + + private const float V_LN2 = 0.6931472f; + + private const float C0 = 0.0f; + private const float C1 = 1.0f; + private const float C2 = -0.5000001f; + private const float C3 = 0.33332965f; + private const float C4 = -0.24999046f; + private const float C5 = 0.20018855f; + private const float C6 = -0.16700386f; + private const float C7 = 0.13902695f; + private const float C8 = -0.1197452f; + private const float C9 = 0.14401625f; + private const float C10 = -0.13657966f; + + public static float Invoke(float x) => MathF.Log(x); - i += Vector512.Count; - } - while (i <= oneVectorFromEnd); + public static Vector128 Invoke(Vector128 x) + { + Vector128 specialResult = x; - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - TTernaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex), - yVec, - Vector512.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); - } + // x is subnormal or infinity or NaN + Vector128 specialMask = Vector128.GreaterThanOrEqual(x.AsUInt32() - Vector128.Create(V_MIN), Vector128.Create(V_MAX - V_MIN)); - return; + if (specialMask != Vector128.Zero) + { + // float.IsZero(x) ? float.NegativeInfinity : x + Vector128 zeroMask = Vector128.Equals(x, Vector128.Zero); + + specialResult = Vector128.ConditionalSelect( + zeroMask, + Vector128.Create(float.NegativeInfinity), + specialResult + ); + + // (x < 0) ? float.NaN : x + Vector128 lessThanZeroMask = Vector128.LessThan(x, Vector128.Zero); + + specialResult = Vector128.ConditionalSelect( + lessThanZeroMask, + Vector128.Create(float.NaN), + specialResult + ); + + // float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x) + Vector128 temp = zeroMask + | lessThanZeroMask + | ~Vector128.Equals(x, x) + | Vector128.Equals(x, Vector128.Create(float.PositiveInfinity)); + + // subnormal + Vector128 subnormalMask = Vector128.AndNot(specialMask.AsSingle(), temp); + + x = Vector128.ConditionalSelect( + subnormalMask, + ((x * 8388608.0f).AsUInt32() - Vector128.Create(23u << 23)).AsSingle(), + x + ); + + specialMask = temp.AsUInt32(); } - } -#endif - if (Vector256.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector256.Count; - if (i <= oneVectorFromEnd) - { - Vector256 yVec = Vector256.Create(y); + Vector128 vx = x.AsUInt32() - Vector128.Create(V_OFF); + Vector128 n = Vector128.ConvertToSingle(Vector128.ShiftRightArithmetic(vx.AsInt32(), 23)); - // Loop handling one vector at a time. - do - { - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i), - yVec, - Vector256.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + vx = (vx & Vector128.Create(V_MASK)) + Vector128.Create(V_OFF); - i += Vector256.Count; - } - while (i <= oneVectorFromEnd); + Vector128 r = vx.AsSingle() - Vector128.Create(1.0f); - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - TTernaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex), - yVec, - Vector256.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); - } + Vector128 r2 = r * r; + Vector128 r4 = r2 * r2; + Vector128 r8 = r4 * r4; - return; - } + Vector128 q = (Vector128.Create(C10) * r2 + (Vector128.Create(C9) * r + Vector128.Create(C8))) + * r8 + (((Vector128.Create(C7) * r + Vector128.Create(C6)) + * r2 + (Vector128.Create(C5) * r + Vector128.Create(C4))) + * r4 + ((Vector128.Create(C3) * r + Vector128.Create(C2)) + * r2 + (Vector128.Create(C1) * r + Vector128.Create(C0)))); + + return Vector128.ConditionalSelect( + specialMask.AsSingle(), + specialResult, + n * Vector128.Create(V_LN2) + q + ); } - if (Vector128.IsHardwareAccelerated) + public static Vector256 Invoke(Vector256 x) { - oneVectorFromEnd = x.Length - Vector128.Count; - if (i <= oneVectorFromEnd) + Vector256 specialResult = x; + + // x is subnormal or infinity or NaN + Vector256 specialMask = Vector256.GreaterThanOrEqual(x.AsUInt32() - Vector256.Create(V_MIN), Vector256.Create(V_MAX - V_MIN)); + + if (specialMask != Vector256.Zero) { - Vector128 yVec = Vector128.Create(y); + // float.IsZero(x) ? float.NegativeInfinity : x + Vector256 zeroMask = Vector256.Equals(x, Vector256.Zero); + + specialResult = Vector256.ConditionalSelect( + zeroMask, + Vector256.Create(float.NegativeInfinity), + specialResult + ); + + // (x < 0) ? float.NaN : x + Vector256 lessThanZeroMask = Vector256.LessThan(x, Vector256.Zero); + + specialResult = Vector256.ConditionalSelect( + lessThanZeroMask, + Vector256.Create(float.NaN), + specialResult + ); + + // float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x) + Vector256 temp = zeroMask + | lessThanZeroMask + | ~Vector256.Equals(x, x) + | Vector256.Equals(x, Vector256.Create(float.PositiveInfinity)); + + // subnormal + Vector256 subnormalMask = Vector256.AndNot(specialMask.AsSingle(), temp); + + x = Vector256.ConditionalSelect( + subnormalMask, + ((x * 8388608.0f).AsUInt32() - Vector256.Create(23u << 23)).AsSingle(), + x + ); + + specialMask = temp.AsUInt32(); + } - // Loop handling one vector at a time. - do - { - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i), - yVec, - Vector128.LoadUnsafe(ref zRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + Vector256 vx = x.AsUInt32() - Vector256.Create(V_OFF); + Vector256 n = Vector256.ConvertToSingle(Vector256.ShiftRightArithmetic(vx.AsInt32(), 23)); - i += Vector128.Count; - } - while (i <= oneVectorFromEnd); + vx = (vx & Vector256.Create(V_MASK)) + Vector256.Create(V_OFF); - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - TTernaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex), - yVec, - Vector128.LoadUnsafe(ref zRef, lastVectorIndex)).StoreUnsafe(ref dRef, lastVectorIndex); - } + Vector256 r = vx.AsSingle() - Vector256.Create(1.0f); - return; - } - } + Vector256 r2 = r * r; + Vector256 r4 = r2 * r2; + Vector256 r8 = r4 * r4; - while (i < x.Length) - { - Unsafe.Add(ref dRef, i) = TTernaryOperator.Invoke(Unsafe.Add(ref xRef, i), - y, - Unsafe.Add(ref zRef, i)); + Vector256 q = (Vector256.Create(C10) * r2 + (Vector256.Create(C9) * r + Vector256.Create(C8))) + * r8 + (((Vector256.Create(C7) * r + Vector256.Create(C6)) + * r2 + (Vector256.Create(C5) * r + Vector256.Create(C4))) + * r4 + ((Vector256.Create(C3) * r + Vector256.Create(C2)) + * r2 + (Vector256.Create(C1) * r + Vector256.Create(C0)))); - i++; + return Vector256.ConditionalSelect( + specialMask.AsSingle(), + specialResult, + n * Vector256.Create(V_LN2) + q + ); } - } - private readonly struct AddOperator : IBinaryOperator - { - public static float Invoke(float x, float y) => x + y; - public static Vector128 Invoke(Vector128 x, Vector128 y) => x + y; - public static Vector256 Invoke(Vector256 x, Vector256 y) => x + y; #if NET8_0_OR_GREATER - public static Vector512 Invoke(Vector512 x, Vector512 y) => x + y; -#endif + public static Vector512 Invoke(Vector512 x) + { + Vector512 specialResult = x; - public static float Invoke(Vector128 x) => Vector128.Sum(x); - public static float Invoke(Vector256 x) => Vector256.Sum(x); -#if NET8_0_OR_GREATER - public static float Invoke(Vector512 x) => Vector512.Sum(x); -#endif - } + // x is subnormal or infinity or NaN + Vector512 specialMask = Vector512.GreaterThanOrEqual(x.AsUInt32() - Vector512.Create(V_MIN), Vector512.Create(V_MAX - V_MIN)); - private readonly struct SubtractOperator : IBinaryOperator - { - public static float Invoke(float x, float y) => x - y; - public static Vector128 Invoke(Vector128 x, Vector128 y) => x - y; - public static Vector256 Invoke(Vector256 x, Vector256 y) => x - y; -#if NET8_0_OR_GREATER - public static Vector512 Invoke(Vector512 x, Vector512 y) => x - y; -#endif + if (specialMask != Vector512.Zero) + { + // float.IsZero(x) ? float.NegativeInfinity : x + Vector512 zeroMask = Vector512.Equals(x, Vector512.Zero); + + specialResult = Vector512.ConditionalSelect( + zeroMask, + Vector512.Create(float.NegativeInfinity), + specialResult + ); + + // (x < 0) ? float.NaN : x + Vector512 lessThanZeroMask = Vector512.LessThan(x, Vector512.Zero); + + specialResult = Vector512.ConditionalSelect( + lessThanZeroMask, + Vector512.Create(float.NaN), + specialResult + ); + + // float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x) + Vector512 temp = zeroMask + | lessThanZeroMask + | ~Vector512.Equals(x, x) + | Vector512.Equals(x, Vector512.Create(float.PositiveInfinity)); + + // subnormal + Vector512 subnormalMask = Vector512.AndNot(specialMask.AsSingle(), temp); + + x = Vector512.ConditionalSelect( + subnormalMask, + ((x * 8388608.0f).AsUInt32() - Vector512.Create(23u << 23)).AsSingle(), + x + ); + + specialMask = temp.AsUInt32(); + } - public static float Invoke(Vector128 x) => throw new NotSupportedException(); - public static float Invoke(Vector256 x) => throw new NotSupportedException(); -#if NET8_0_OR_GREATER - public static float Invoke(Vector512 x) => throw new NotSupportedException(); -#endif - } + Vector512 vx = x.AsUInt32() - Vector512.Create(V_OFF); + Vector512 n = Vector512.ConvertToSingle(Vector512.ShiftRightArithmetic(vx.AsInt32(), 23)); - private readonly struct SubtractSquaredOperator : IBinaryOperator - { - public static float Invoke(float x, float y) - { - float tmp = x - y; - return tmp * tmp; - } + vx = (vx & Vector512.Create(V_MASK)) + Vector512.Create(V_OFF); - public static Vector128 Invoke(Vector128 x, Vector128 y) - { - Vector128 tmp = x - y; - return tmp * tmp; - } + Vector512 r = vx.AsSingle() - Vector512.Create(1.0f); - public static Vector256 Invoke(Vector256 x, Vector256 y) - { - Vector256 tmp = x - y; - return tmp * tmp; - } + Vector512 r2 = r * r; + Vector512 r4 = r2 * r2; + Vector512 r8 = r4 * r4; -#if NET8_0_OR_GREATER - public static Vector512 Invoke(Vector512 x, Vector512 y) - { - Vector512 tmp = x - y; - return tmp * tmp; - } -#endif + Vector512 q = (Vector512.Create(C10) * r2 + (Vector512.Create(C9) * r + Vector512.Create(C8))) + * r8 + (((Vector512.Create(C7) * r + Vector512.Create(C6)) + * r2 + (Vector512.Create(C5) * r + Vector512.Create(C4))) + * r4 + ((Vector512.Create(C3) * r + Vector512.Create(C2)) + * r2 + (Vector512.Create(C1) * r + Vector512.Create(C0)))); - public static float Invoke(Vector128 x) => throw new NotSupportedException(); - public static float Invoke(Vector256 x) => throw new NotSupportedException(); -#if NET8_0_OR_GREATER - public static float Invoke(Vector512 x) => throw new NotSupportedException(); + return Vector512.ConditionalSelect( + specialMask.AsSingle(), + specialResult, + n * Vector512.Create(V_LN2) + q + ); + } #endif } - private readonly struct MultiplyOperator : IBinaryOperator + /// MathF.Log2(x) + private readonly struct Log2Operator : IUnaryOperator { - public static float Invoke(float x, float y) => x * y; - public static Vector128 Invoke(Vector128 x, Vector128 y) => x * y; - public static Vector256 Invoke(Vector256 x, Vector256 y) => x * y; -#if NET8_0_OR_GREATER - public static Vector512 Invoke(Vector512 x, Vector512 y) => x * y; -#endif + // This code is based on `vrs4_log2f` from amd/aocl-libm-ose + // Copyright (C) 2021-2022 Advanced Micro Devices, Inc. All rights reserved. + // + // Licensed under the BSD 3-Clause "New" or "Revised" License + // See THIRD-PARTY-NOTICES.TXT for the full license text + + // Spec: + // log2f(x) + // = log2f(x) if x ∈ F and x > 0 + // = x if x = qNaN + // = 0 if x = 1 + // = -inf if x = (-0, 0} + // = NaN otherwise + // + // Assumptions/Expectations + // - Maximum ULP is observed to be at 4 + // - Some FPU Exceptions may not be available + // - Performance is at least 3x + // + // Implementation Notes: + // 1. Range Reduction: + // x = 2^n*(1+f) .... (1) + // where n is exponent and is an integer + // (1+f) is mantissa ∈ [1,2). i.e., 1 ≤ 1+f < 2 .... (2) + // + // From (1), taking log on both sides + // log2(x) = log2(2^n * (1+f)) + // = n + log2(1+f) .... (3) + // + // let z = 1 + f + // log2(z) = log2(k) + log2(z) - log2(k) + // log2(z) = log2(kz) - log2(k) + // + // From (2), range of z is [1, 2) + // by simply dividing range by 'k', z is in [1/k, 2/k) .... (4) + // Best choice of k is the one which gives equal and opposite values + // at extrema +- -+ + // 1 | 2 | + // --- - 1 = - |--- - 1 | + // k | k | .... (5) + // +- -+ + // + // Solving for k, k = 3/2, + // From (4), using 'k' value, range is therefore [-0.3333, 0.3333] + // + // 2. Polynomial Approximation: + // More information refer to tools/sollya/vrs4_logf.sollya + // + // 7th Deg - Error abs: 0x1.04c4ac98p-22 rel: 0x1.2216e6f8p-19 + + private const uint V_MIN = 0x00800000; + private const uint V_MAX = 0x7F800000; + private const uint V_MASK = 0x007FFFFF; + private const uint V_OFF = 0x3F2AAAAB; + + private const float C0 = 0.0f; + private const float C1 = 1.4426951f; + private const float C2 = -0.72134554f; + private const float C3 = 0.48089063f; + private const float C4 = -0.36084408f; + private const float C5 = 0.2888971f; + private const float C6 = -0.23594281f; + private const float C7 = 0.19948183f; + private const float C8 = -0.22616665f; + private const float C9 = 0.21228963f; + + public static float Invoke(float x) => MathF.Log2(x); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static float Invoke(Vector128 x) + public static Vector128 Invoke(Vector128 x) { - float f = x[0]; - for (int i = 1; i < Vector128.Count; i++) + Vector128 specialResult = x; + + // x is subnormal or infinity or NaN + Vector128 specialMask = Vector128.GreaterThanOrEqual(x.AsUInt32() - Vector128.Create(V_MIN), Vector128.Create(V_MAX - V_MIN)); + + if (specialMask != Vector128.Zero) { - f *= x[i]; + // float.IsZero(x) ? float.NegativeInfinity : x + Vector128 zeroMask = Vector128.Equals(x, Vector128.Zero); + + specialResult = Vector128.ConditionalSelect( + zeroMask, + Vector128.Create(float.NegativeInfinity), + specialResult + ); + + // (x < 0) ? float.NaN : x + Vector128 lessThanZeroMask = Vector128.LessThan(x, Vector128.Zero); + + specialResult = Vector128.ConditionalSelect( + lessThanZeroMask, + Vector128.Create(float.NaN), + specialResult + ); + + // float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x) + Vector128 temp = zeroMask + | lessThanZeroMask + | ~Vector128.Equals(x, x) + | Vector128.Equals(x, Vector128.Create(float.PositiveInfinity)); + + // subnormal + Vector128 subnormalMask = Vector128.AndNot(specialMask.AsSingle(), temp); + + x = Vector128.ConditionalSelect( + subnormalMask, + ((x * 8388608.0f).AsUInt32() - Vector128.Create(23u << 23)).AsSingle(), + x + ); + + specialMask = temp.AsUInt32(); } - return f; + + Vector128 vx = x.AsUInt32() - Vector128.Create(V_OFF); + Vector128 n = Vector128.ConvertToSingle(Vector128.ShiftRightArithmetic(vx.AsInt32(), 23)); + + vx = (vx & Vector128.Create(V_MASK)) + Vector128.Create(V_OFF); + + Vector128 r = vx.AsSingle() - Vector128.Create(1.0f); + + Vector128 r2 = r * r; + Vector128 r4 = r2 * r2; + Vector128 r8 = r4 * r4; + + Vector128 poly = (Vector128.Create(C9) * r + Vector128.Create(C8)) * r8 + + (((Vector128.Create(C7) * r + Vector128.Create(C6)) * r2 + + (Vector128.Create(C5) * r + Vector128.Create(C4))) * r4 + + ((Vector128.Create(C3) * r + Vector128.Create(C2)) * r2 + + (Vector128.Create(C1) * r + Vector128.Create(C0)))); + + return Vector128.ConditionalSelect( + specialMask.AsSingle(), + specialResult, + n + poly + ); } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static float Invoke(Vector256 x) + public static Vector256 Invoke(Vector256 x) { - float f = x[0]; - for (int i = 1; i < Vector256.Count; i++) + Vector256 specialResult = x; + + // x is subnormal or infinity or NaN + Vector256 specialMask = Vector256.GreaterThanOrEqual(x.AsUInt32() - Vector256.Create(V_MIN), Vector256.Create(V_MAX - V_MIN)); + + if (specialMask != Vector256.Zero) { - f *= x[i]; + // float.IsZero(x) ? float.NegativeInfinity : x + Vector256 zeroMask = Vector256.Equals(x, Vector256.Zero); + + specialResult = Vector256.ConditionalSelect( + zeroMask, + Vector256.Create(float.NegativeInfinity), + specialResult + ); + + // (x < 0) ? float.NaN : x + Vector256 lessThanZeroMask = Vector256.LessThan(x, Vector256.Zero); + + specialResult = Vector256.ConditionalSelect( + lessThanZeroMask, + Vector256.Create(float.NaN), + specialResult + ); + + // float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x) + Vector256 temp = zeroMask + | lessThanZeroMask + | ~Vector256.Equals(x, x) + | Vector256.Equals(x, Vector256.Create(float.PositiveInfinity)); + + // subnormal + Vector256 subnormalMask = Vector256.AndNot(specialMask.AsSingle(), temp); + + x = Vector256.ConditionalSelect( + subnormalMask, + ((x * 8388608.0f).AsUInt32() - Vector256.Create(23u << 23)).AsSingle(), + x + ); + + specialMask = temp.AsUInt32(); } - return f; + + Vector256 vx = x.AsUInt32() - Vector256.Create(V_OFF); + Vector256 n = Vector256.ConvertToSingle(Vector256.ShiftRightArithmetic(vx.AsInt32(), 23)); + + vx = (vx & Vector256.Create(V_MASK)) + Vector256.Create(V_OFF); + + Vector256 r = vx.AsSingle() - Vector256.Create(1.0f); + + Vector256 r2 = r * r; + Vector256 r4 = r2 * r2; + Vector256 r8 = r4 * r4; + + Vector256 poly = (Vector256.Create(C9) * r + Vector256.Create(C8)) * r8 + + (((Vector256.Create(C7) * r + Vector256.Create(C6)) * r2 + + (Vector256.Create(C5) * r + Vector256.Create(C4))) * r4 + + ((Vector256.Create(C3) * r + Vector256.Create(C2)) * r2 + + (Vector256.Create(C1) * r + Vector256.Create(C0)))); + + return Vector256.ConditionalSelect( + specialMask.AsSingle(), + specialResult, + n + poly + ); } #if NET8_0_OR_GREATER - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static float Invoke(Vector512 x) + public static Vector512 Invoke(Vector512 x) { - float f = x[0]; - for (int i = 1; i < Vector512.Count; i++) + Vector512 specialResult = x; + + // x is subnormal or infinity or NaN + Vector512 specialMask = Vector512.GreaterThanOrEqual(x.AsUInt32() - Vector512.Create(V_MIN), Vector512.Create(V_MAX - V_MIN)); + + if (specialMask != Vector512.Zero) { - f *= x[i]; + // float.IsZero(x) ? float.NegativeInfinity : x + Vector512 zeroMask = Vector512.Equals(x, Vector512.Zero); + + specialResult = Vector512.ConditionalSelect( + zeroMask, + Vector512.Create(float.NegativeInfinity), + specialResult + ); + + // (x < 0) ? float.NaN : x + Vector512 lessThanZeroMask = Vector512.LessThan(x, Vector512.Zero); + + specialResult = Vector512.ConditionalSelect( + lessThanZeroMask, + Vector512.Create(float.NaN), + specialResult + ); + + // float.IsZero(x) | (x < 0) | float.IsNaN(x) | float.IsPositiveInfinity(x) + Vector512 temp = zeroMask + | lessThanZeroMask + | ~Vector512.Equals(x, x) + | Vector512.Equals(x, Vector512.Create(float.PositiveInfinity)); + + // subnormal + Vector512 subnormalMask = Vector512.AndNot(specialMask.AsSingle(), temp); + + x = Vector512.ConditionalSelect( + subnormalMask, + ((x * 8388608.0f).AsUInt32() - Vector512.Create(23u << 23)).AsSingle(), + x + ); + + specialMask = temp.AsUInt32(); } - return f; + + Vector512 vx = x.AsUInt32() - Vector512.Create(V_OFF); + Vector512 n = Vector512.ConvertToSingle(Vector512.ShiftRightArithmetic(vx.AsInt32(), 23)); + + vx = (vx & Vector512.Create(V_MASK)) + Vector512.Create(V_OFF); + + Vector512 r = vx.AsSingle() - Vector512.Create(1.0f); + + Vector512 r2 = r * r; + Vector512 r4 = r2 * r2; + Vector512 r8 = r4 * r4; + + Vector512 poly = (Vector512.Create(C9) * r + Vector512.Create(C8)) * r8 + + (((Vector512.Create(C7) * r + Vector512.Create(C6)) * r2 + + (Vector512.Create(C5) * r + Vector512.Create(C4))) * r4 + + ((Vector512.Create(C3) * r + Vector512.Create(C2)) * r2 + + (Vector512.Create(C1) * r + Vector512.Create(C0)))); + + return Vector512.ConditionalSelect( + specialMask.AsSingle(), + specialResult, + n + poly + ); } #endif } - private readonly struct DivideOperator : IBinaryOperator + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 ElementWiseSelect(Vector128 mask, Vector128 left, Vector128 right) { - public static float Invoke(float x, float y) => x / y; - public static Vector128 Invoke(Vector128 x, Vector128 y) => x / y; - public static Vector256 Invoke(Vector256 x, Vector256 y) => x / y; -#if NET8_0_OR_GREATER - public static Vector512 Invoke(Vector512 x, Vector512 y) => x / y; -#endif + if (Sse41.IsSupported) + return Sse41.BlendVariable(left, right, ~mask); - public static float Invoke(Vector128 x) => throw new NotSupportedException(); - public static float Invoke(Vector256 x) => throw new NotSupportedException(); -#if NET8_0_OR_GREATER - public static float Invoke(Vector512 x) => throw new NotSupportedException(); -#endif + return Vector128.ConditionalSelect(mask, left, right); } - private readonly struct NegateOperator : IUnaryOperator + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 ElementWiseSelect(Vector128 mask, Vector128 left, Vector128 right) { - public static float Invoke(float x) => -x; - public static Vector128 Invoke(Vector128 x) => -x; - public static Vector256 Invoke(Vector256 x) => -x; -#if NET8_0_OR_GREATER - public static Vector512 Invoke(Vector512 x) => -x; -#endif - } + if (Sse41.IsSupported) + return Sse41.BlendVariable(left, right, ~mask); - private readonly struct AddMultiplyOperator : ITernaryOperator - { - public static float Invoke(float x, float y, float z) => (x + y) * z; - public static Vector128 Invoke(Vector128 x, Vector128 y, Vector128 z) => (x + y) * z; - public static Vector256 Invoke(Vector256 x, Vector256 y, Vector256 z) => (x + y) * z; -#if NET8_0_OR_GREATER - public static Vector512 Invoke(Vector512 x, Vector512 y, Vector512 z) => (x + y) * z; -#endif + return Vector128.ConditionalSelect(mask, left, right); } - private readonly struct MultiplyAddOperator : ITernaryOperator + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 ElementWiseSelect(Vector256 mask, Vector256 left, Vector256 right) { - public static float Invoke(float x, float y, float z) => (x * y) + z; - public static Vector128 Invoke(Vector128 x, Vector128 y, Vector128 z) => (x * y) + z; - public static Vector256 Invoke(Vector256 x, Vector256 y, Vector256 z) => (x * y) + z; -#if NET8_0_OR_GREATER - public static Vector512 Invoke(Vector512 x, Vector512 y, Vector512 z) => (x * y) + z; -#endif + if (Avx2.IsSupported) + return Avx2.BlendVariable(left, right, ~mask); + + return Vector256.ConditionalSelect(mask, left, right); } - private readonly struct LoadIdentity : IUnaryOperator + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 ElementWiseSelect(Vector256 mask, Vector256 left, Vector256 right) { - public static float Invoke(float x) => x; - public static Vector128 Invoke(Vector128 x) => x; - public static Vector256 Invoke(Vector256 x) => x; -#if NET8_0_OR_GREATER - public static Vector512 Invoke(Vector512 x) => x; -#endif + if (Avx2.IsSupported) + return Avx2.BlendVariable(left, right, ~mask); + + return Vector256.ConditionalSelect(mask, left, right); } - private readonly struct LoadSquared : IUnaryOperator - { - public static float Invoke(float x) => x * x; - public static Vector128 Invoke(Vector128 x) => x * x; - public static Vector256 Invoke(Vector256 x) => x * x; #if NET8_0_OR_GREATER - public static Vector512 Invoke(Vector512 x) => x * x; -#endif + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 ElementWiseSelect(Vector512 mask, Vector512 left, Vector512 right) + { + if (Avx512F.IsSupported) + return Avx512F.BlendVariable(left, right, ~mask); + + return Vector512.ConditionalSelect(mask, left, right); } - private readonly struct LoadAbsolute : IUnaryOperator + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 ElementWiseSelect(Vector512 mask, Vector512 left, Vector512 right) { - public static float Invoke(float x) => MathF.Abs(x); + if (Avx512F.IsSupported) + return Avx512F.BlendVariable(left, right, ~mask); - public static Vector128 Invoke(Vector128 x) - { - Vector128 raw = x.AsUInt32(); - Vector128 mask = Vector128.Create((uint)0x7FFFFFFF); - return (raw & mask).AsSingle(); - } - - public static Vector256 Invoke(Vector256 x) - { - Vector256 raw = x.AsUInt32(); - Vector256 mask = Vector256.Create((uint)0x7FFFFFFF); - return (raw & mask).AsSingle(); - } + return Vector512.ConditionalSelect(mask, left, right); + } +#endif + /// 1f / (1f + MathF.Exp(-x)) + private readonly struct SigmoidOperator : IUnaryOperator + { + public static float Invoke(float x) => 1.0f / (1.0f + MathF.Exp(-x)); + public static Vector128 Invoke(Vector128 x) => Vector128.Create(1f) / (Vector128.Create(1f) + ExpOperator.Invoke(-x)); + public static Vector256 Invoke(Vector256 x) => Vector256.Create(1f) / (Vector256.Create(1f) + ExpOperator.Invoke(-x)); #if NET8_0_OR_GREATER - public static Vector512 Invoke(Vector512 x) - { - Vector512 raw = x.AsUInt32(); - Vector512 mask = Vector512.Create((uint)0x7FFFFFFF); - return (raw & mask).AsSingle(); - } + public static Vector512 Invoke(Vector512 x) => Vector512.Create(1f) / (Vector512.Create(1f) + ExpOperator.Invoke(-x)); #endif } + /// Operator that takes one input value and returns a single value. private interface IUnaryOperator { static abstract float Invoke(float x); @@ -1242,20 +11570,30 @@ private interface IUnaryOperator #endif } + /// Operator that takes two input values and returns a single value. private interface IBinaryOperator { static abstract float Invoke(float x, float y); - static abstract Vector128 Invoke(Vector128 x, Vector128 y); - static abstract float Invoke(Vector128 x); static abstract Vector256 Invoke(Vector256 x, Vector256 y); - static abstract float Invoke(Vector256 x); #if NET8_0_OR_GREATER static abstract Vector512 Invoke(Vector512 x, Vector512 y); +#endif + } + + /// that specializes horizontal aggregation of all elements in a vector. + private interface IAggregationOperator : IBinaryOperator + { + static abstract float Invoke(Vector128 x); + static abstract float Invoke(Vector256 x); +#if NET8_0_OR_GREATER static abstract float Invoke(Vector512 x); #endif + + static virtual float IdentityValue => throw new NotSupportedException(); } + /// Operator that takes three input values and returns a single value. private interface ITernaryOperator { static abstract float Invoke(float x, float y, float z); diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs index ba3fc69bab527f..c0039be0a08e2b 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; @@ -8,8 +9,8 @@ namespace System.Numerics.Tensors { public static partial class TensorPrimitives { - private static unsafe bool IsNegative(float f) => *(int*)&f < 0; - + /// Computes the cosine similarity between the two specified non-empty, equal-length tensors of single-precision floating-point numbers. + /// Assumes arguments have already been validated to be non-empty and equal length. private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan y) { // Compute the same as: @@ -20,9 +21,9 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan= Vector.Count) + if (Vector.IsHardwareAccelerated && + Vector.Count <= 16 && // currently never greater than 8, but 16 would occur if/when AVX512 is supported, and logic in remainder handling assumes that maximum + x.Length >= Vector.Count) { ref float xRef = ref MemoryMarshal.GetReference(x); ref float yRef = ref MemoryMarshal.GetReference(y); @@ -33,6 +34,7 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan.Count; + int i = 0; do { Vector xVec = AsVector(ref xRef, i); @@ -46,6 +48,21 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan xVec = AsVector(ref xRef, x.Length - Vector.Count); + Vector yVec = AsVector(ref yRef, x.Length - Vector.Count); + + Vector remainderMask = CreateRemainderMaskSingleVector(x.Length - i); + xVec &= remainderMask; + yVec &= remainderMask; + + dotProductVector += xVec * yVec; + xSumOfSquaresVector += xVec * xVec; + ySumOfSquaresVector += yVec * yVec; + } + // Sum the vector lanes into the scalar result. for (int e = 0; e < Vector.Count; e++) { @@ -54,539 +71,3464 @@ private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan( - float identityValue, ReadOnlySpan x, TLoad load = default, TAggregate aggregate = default) - where TLoad : struct, IUnaryOperator - where TAggregate : struct, IBinaryOperator - { - // Initialize the result to the identity value - float result = identityValue; - int i = 0; + /// Performs an aggregation over all elements in to produce a single-precision floating-point value. + /// Specifies the transform operation that should be applied to each element loaded from . + /// + /// Specifies the aggregation binary operation that should be applied to multiple values to aggregate them into a single value. + /// The aggregation is applied after the transform is applied to each element. + /// + private static unsafe float Aggregate( + ReadOnlySpan x, TTransformOperator transformOp = default, TAggregationOperator aggregationOp = default) + where TTransformOperator : struct, IUnaryOperator + where TAggregationOperator : struct, IAggregationOperator + { + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated && transformOp.CanVectorize) + { + float result; + + if (remainder >= (uint)(Vector.Count)) + { + result = Vectorized(ref xRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + result = VectorizedSmall(ref xRef, remainder); + } + + return result; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + return SoftwareFallback(ref xRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float SoftwareFallback(ref float xRef, nuint length, TTransformOperator transformOp = default, TAggregationOperator aggregationOp = default) + { + float result = aggregationOp.IdentityValue; + + for (nuint i = 0; i < length; i++) + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, (nint)(i)))); + } + + return result; + } + + static float Vectorized(ref float xRef, nuint remainder, TTransformOperator transformOp = default, TAggregationOperator aggregationOp = default) + { + Vector vresult = new Vector(aggregationOp.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector beg = transformOp.Invoke(AsVector(ref xRef)); + Vector end = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + { + float* xPtr = px; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector)) - ((nuint)(xPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0))); + vector2 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1))); + vector3 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2))); + vector4 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3))); + + vresult = aggregationOp.Invoke(vresult, vector1); + vresult = aggregationOp.Invoke(vresult, vector2); + vresult = aggregationOp.Invoke(vresult, vector3); + vresult = aggregationOp.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4))); + vector2 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5))); + vector3 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6))); + vector4 = transformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7))); + + vresult = aggregationOp.Invoke(vresult, vector1); + vresult = aggregationOp.Invoke(vresult, vector2); + vresult = aggregationOp.Invoke(vresult, vector3); + vresult = aggregationOp.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector.ConditionalSelect(CreateAlignmentMaskSingleVector((int)(misalignment)), beg, new Vector(aggregationOp.IdentityValue)); + vresult = aggregationOp.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + nuint blocks = remainder / (nuint)(Vector.Count); + nuint trailing = remainder - (blocks * (nuint)(Vector.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector vector = transformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 1))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector.ConditionalSelect(CreateRemainderMaskSingleVector((int)(trailing)), end, new Vector(aggregationOp.IdentityValue)); + vresult = aggregationOp.Invoke(vresult, end); + break; + } + } + + float result = aggregationOp.IdentityValue; + + for (int i = 0; i < Vector.Count; i++) + { + result = aggregationOp.Invoke(result, vresult[i]); + } + + return result; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float VectorizedSmall(ref float xRef, nuint remainder, TTransformOperator transformOp = default, TAggregationOperator aggregationOp = default) + { + float result = aggregationOp.IdentityValue; + + switch (remainder) + { + case 7: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, 6))); + goto case 6; + } + + case 6: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, 5))); + goto case 5; + } + + case 5: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, 4))); + goto case 4; + } + + case 4: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, 3))); + goto case 3; + } + + case 3: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, 2))); + goto case 2; + } + + case 2: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(Unsafe.Add(ref xRef, 1))); + goto case 1; + } + + case 1: + { + result = aggregationOp.Invoke(result, transformOp.Invoke(xRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } + } + + /// Performs an aggregation over all pair-wise elements in and to produce a single-precision floating-point value. + /// Specifies the binary operation that should be applied to the pair-wise elements loaded from and . + /// + /// Specifies the aggregation binary operation that should be applied to multiple values to aggregate them into a single value. + /// The aggregation is applied to the results of the binary operations on the pair-wise values. + /// + private static unsafe float Aggregate( + ReadOnlySpan x, ReadOnlySpan y, TBinaryOperator binaryOp = default, TAggregationOperator aggregationOp = default) + where TBinaryOperator : struct, IBinaryOperator + where TAggregationOperator : struct, IAggregationOperator + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated) + { + float result; + + if (remainder >= (uint)(Vector.Count)) + { + result = Vectorized(ref xRef, ref yRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + result = VectorizedSmall(ref xRef, ref yRef, remainder); + } + + return result; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + return SoftwareFallback(ref xRef, ref yRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float SoftwareFallback(ref float xRef, ref float yRef, nuint length, TBinaryOperator binaryOp = default, TAggregationOperator aggregationOp = default) + { + float result = aggregationOp.IdentityValue; + + for (nuint i = 0; i < length; i++) + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, (nint)(i)), + Unsafe.Add(ref yRef, (nint)(i)))); + } + + return result; + } + + static float Vectorized(ref float xRef, ref float yRef, nuint remainder, TBinaryOperator binaryOp = default, TAggregationOperator aggregationOp = default) + { + Vector vresult = new Vector(aggregationOp.IdentityValue); + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector beg = binaryOp.Invoke(AsVector(ref xRef), + AsVector(ref yRef)); + Vector end = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count)), + AsVector(ref yRef, remainder - (uint)(Vector.Count))); + + nuint misalignment = 0; + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + { + float* xPtr = px; + float* yPtr = py; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(xPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. This is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + misalignment = ((uint)(sizeof(Vector)) - ((nuint)(xPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + + Debug.Assert(((nuint)(xPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + // We only need to load, so there isn't a lot of benefit to doing non-temporal operations + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0)), + *(Vector*)(yPtr + (uint)(Vector.Count * 0))); + vector2 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1)), + *(Vector*)(yPtr + (uint)(Vector.Count * 1))); + vector3 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2)), + *(Vector*)(yPtr + (uint)(Vector.Count * 2))); + vector4 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3)), + *(Vector*)(yPtr + (uint)(Vector.Count * 3))); + + vresult = aggregationOp.Invoke(vresult, vector1); + vresult = aggregationOp.Invoke(vresult, vector2); + vresult = aggregationOp.Invoke(vresult, vector3); + vresult = aggregationOp.Invoke(vresult, vector4); + + // We load, process, and store the next four vectors + + vector1 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4)), + *(Vector*)(yPtr + (uint)(Vector.Count * 4))); + vector2 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5)), + *(Vector*)(yPtr + (uint)(Vector.Count * 5))); + vector3 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6)), + *(Vector*)(yPtr + (uint)(Vector.Count * 6))); + vector4 = binaryOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7)), + *(Vector*)(yPtr + (uint)(Vector.Count * 7))); + + vresult = aggregationOp.Invoke(vresult, vector1); + vresult = aggregationOp.Invoke(vresult, vector2); + vresult = aggregationOp.Invoke(vresult, vector3); + vresult = aggregationOp.Invoke(vresult, vector4); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + yPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + } + } + + // Store the first block. Handling this separately simplifies the latter code as we know + // they come after and so we can relegate it to full blocks or the trailing elements + + beg = Vector.ConditionalSelect(CreateAlignmentMaskSingleVector((int)(misalignment)), beg, new Vector(aggregationOp.IdentityValue)); + vresult = aggregationOp.Invoke(vresult, beg); + + // Process the remaining [0, Count * 7] elements via a jump table + // + // We end up handling any trailing elements in case 0 and in the + // worst case end up just doing the identity operation here if there + // were no trailing elements. + + nuint blocks = remainder / (nuint)(Vector.Count); + nuint trailing = remainder - (blocks * (nuint)(Vector.Count)); + blocks -= (misalignment == 0) ? 1u : 0u; + remainder -= trailing; + + switch (blocks) + { + case 7: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 7))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 6; + } + + case 6: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 6))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 5; + } + + case 5: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 5))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 4; + } + + case 4: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 4))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 3; + } + + case 3: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 3))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 2; + } + + case 2: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 2))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 1; + } + + case 1: + { + Vector vector = binaryOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 1)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 1))); + vresult = aggregationOp.Invoke(vresult, vector); + goto case 0; + } + + case 0: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end = Vector.ConditionalSelect(CreateRemainderMaskSingleVector((int)(trailing)), end, new Vector(aggregationOp.IdentityValue)); + vresult = aggregationOp.Invoke(vresult, end); + break; + } + } + + float result = aggregationOp.IdentityValue; + + for (int i = 0; i < Vector.Count; i++) + { + result = aggregationOp.Invoke(result, vresult[i]); + } + + return result; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static float VectorizedSmall(ref float xRef, ref float yRef, nuint remainder, TBinaryOperator binaryOp = default, TAggregationOperator aggregationOp = default) + { + float result = aggregationOp.IdentityValue; + + switch (remainder) + { + case 7: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, 6), + Unsafe.Add(ref yRef, 6))); + goto case 6; + } + + case 6: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, 5), + Unsafe.Add(ref yRef, 5))); + goto case 5; + } + + case 5: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, 4), + Unsafe.Add(ref yRef, 4))); + goto case 4; + } + + case 4: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, 3), + Unsafe.Add(ref yRef, 3))); + goto case 3; + } + + case 3: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2))); + goto case 2; + } + + case 2: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1))); + goto case 1; + } + + case 1: + { + result = aggregationOp.Invoke(result, binaryOp.Invoke(xRef, yRef)); + goto case 0; + } + + case 0: + { + break; + } + } + + return result; + } + } + + /// + /// This is the same as + /// with an identity transform, except it early exits on NaN. + /// + private static float MinMaxCore(ReadOnlySpan x, TMinMaxOperator op = default) + where TMinMaxOperator : struct, IBinaryOperator + { + if (x.IsEmpty) + { + ThrowHelper.ThrowArgument_SpansMustBeNonEmpty(); + } + + // This matches the IEEE 754:2019 `maximum`/`minimum` functions. + // It propagates NaN inputs back to the caller and + // otherwise returns the greater of the inputs. + // It treats +0 as greater than -0 as per the specification. + + float result = x[0]; + int i = 0; + + if (Vector.IsHardwareAccelerated && x.Length >= Vector.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector resultVector = AsVector(ref xRef, 0), current; + if (Vector.EqualsAll(resultVector, resultVector)) + { + int oneVectorFromEnd = x.Length - Vector.Count; + i = Vector.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = AsVector(ref xRef, i); + if (!Vector.EqualsAll(current, current)) + { + goto Scalar; + } + + resultVector = op.Invoke(resultVector, current); + i += Vector.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + current = AsVector(ref xRef, x.Length - Vector.Count); + if (!Vector.EqualsAll(current, current)) + { + goto Scalar; + } + + resultVector = op.Invoke(resultVector, current); + } + + // Aggregate the lanes in the vector to create the final scalar result. + for (int f = 0; f < Vector.Count; f++) + { + result = op.Invoke(result, resultVector[f]); + } + + return result; + } + } + + // Scalar path used when either vectorization is not supported, the input is too small to vectorize, + // or a NaN is encountered. + Scalar: + for (; (uint)i < (uint)x.Length; i++) + { + float current = x[i]; + + if (float.IsNaN(current)) + { + return current; + } + + result = op.Invoke(result, current); + } + + return result; + } + + private static readonly int[] s_0through7 = [0, 1, 2, 3, 4, 5, 6, 7]; + + private static int IndexOfMinMaxCore(ReadOnlySpan x, TIndexOfMinMaxOperator op = default) + where TIndexOfMinMaxOperator : struct, IIndexOfOperator + { + // This matches the IEEE 754:2019 `maximum`/`minimum` functions. + // It propagates NaN inputs back to the caller and + // otherwise returns the index of the greater of the inputs. + // It treats +0 as greater than -0 as per the specification. + + int result; + int i = 0; + + if (Vector.IsHardwareAccelerated && Vector.Count <= 8 && x.Length >= Vector.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + + Vector resultIndex = new Vector(s_0through7); + Vector curIndex = resultIndex; + Vector increment = new Vector(Vector.Count); + + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector resultVector = AsVector(ref xRef, 0), current; + if (Vector.EqualsAll(resultVector, resultVector)) + { + int oneVectorFromEnd = x.Length - Vector.Count; + i = Vector.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = AsVector(ref xRef, i); + curIndex = Vector.Add(curIndex, increment); + + if (!Vector.EqualsAll(current, current)) + { + goto Scalar; + } + + op.Invoke(ref resultVector, current, ref resultIndex, curIndex); + i += Vector.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + curIndex = Vector.Add(curIndex, new Vector(x.Length - i)); + + current = AsVector(ref xRef, x.Length - Vector.Count); + if (!Vector.EqualsAll(current, current)) + { + goto Scalar; + } + + op.Invoke(ref resultVector, current, ref resultIndex, curIndex); + } + + result = op.Invoke(resultVector, resultIndex); + + return result; + } + } + + // Scalar path used when either vectorization is not supported, the input is too small to vectorize, + // or a NaN is encountered. + Scalar: + float curResult = x[i]; + int curIn = i; + if (float.IsNaN(curResult)) + { + return curIn; + } + + for (; i < x.Length; i++) + { + float current = x[i]; + if (float.IsNaN(current)) + { + return i; + } + + curIn = op.Invoke(ref curResult, current, curIn, i); + } + + return curIn; + } + + /// Performs an element-wise operation on and writes the results to . + /// Specifies the operation to perform on each element loaded from . + private static unsafe void InvokeSpanIntoSpan( + ReadOnlySpan x, Span destination, TUnaryOperator op = default) + where TUnaryOperator : struct, IUnaryOperator + { + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated && op.CanVectorize) + { + if (remainder >= (uint)(Vector.Count)) + { + Vectorized(ref xRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float dRef, nuint length, TUnaryOperator op = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = op.Invoke(Unsafe.Add(ref xRef, (nint)(i))); + } + } + + static void Vectorized(ref float xRef, ref float dRef, nuint remainder, TUnaryOperator op = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector beg = op.Invoke(AsVector(ref xRef)); + Vector end = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count))); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; + } + + case 7: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, ref float dRef, nuint remainder, TUnaryOperator op = default) + { + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = op.Invoke(Unsafe.Add(ref xRef, 6)); + goto case 6; + } + + case 6: + { + Unsafe.Add(ref dRef, 5) = op.Invoke(Unsafe.Add(ref xRef, 5)); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = op.Invoke(Unsafe.Add(ref xRef, 4)); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = op.Invoke(Unsafe.Add(ref xRef, 3)); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = op.Invoke(Unsafe.Add(ref xRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = op.Invoke(Unsafe.Add(ref xRef, 1)); + goto case 1; + } + + case 1: + { + dRef = op.Invoke(xRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + } + + /// + /// Performs an element-wise operation on and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise elements loaded from and . + /// + private static unsafe void InvokeSpanSpanIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, Span destination, TBinaryOperator op = default) + where TBinaryOperator : struct, IBinaryOperator + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector.Count)) + { + Vectorized(ref xRef, ref yRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xRef, ref yRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, ref float dRef, nuint length, TBinaryOperator op = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = op.Invoke(Unsafe.Add(ref xRef, (nint)(i)), + Unsafe.Add(ref yRef, (nint)(i))); + } + } + + static void Vectorized(ref float xRef, ref float yRef, ref float dRef, nuint remainder, TBinaryOperator op = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector beg = op.Invoke(AsVector(ref xRef), + AsVector(ref yRef)); + Vector end = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count)), + AsVector(ref yRef, remainder - (uint)(Vector.Count))); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0)), + *(Vector*)(yPtr + (uint)(Vector.Count * 0))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1)), + *(Vector*)(yPtr + (uint)(Vector.Count * 1))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2)), + *(Vector*)(yPtr + (uint)(Vector.Count * 2))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3)), + *(Vector*)(yPtr + (uint)(Vector.Count * 3))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4)), + *(Vector*)(yPtr + (uint)(Vector.Count * 4))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5)), + *(Vector*)(yPtr + (uint)(Vector.Count * 5))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6)), + *(Vector*)(yPtr + (uint)(Vector.Count * 6))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7)), + *(Vector*)(yPtr + (uint)(Vector.Count * 7))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + yPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 8))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; + } + + case 7: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 7))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 6))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 5))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 4))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 3))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 2))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, ref float yRef, ref float dRef, nuint remainder, TBinaryOperator op = default) + { + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = op.Invoke(Unsafe.Add(ref xRef, 6), + Unsafe.Add(ref yRef, 6)); + goto case 6; + } + + case 6: + { + Unsafe.Add(ref dRef, 5) = op.Invoke(Unsafe.Add(ref xRef, 5), + Unsafe.Add(ref yRef, 5)); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = op.Invoke(Unsafe.Add(ref xRef, 4), + Unsafe.Add(ref yRef, 4)); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = op.Invoke(Unsafe.Add(ref xRef, 3), + Unsafe.Add(ref yRef, 3)); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = op.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = op.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1)); + goto case 1; + } + + case 1: + { + dRef = op.Invoke(xRef, yRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + } + + /// + /// Performs an element-wise operation on and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on each element loaded from with . + /// + private static void InvokeSpanScalarIntoSpan( + ReadOnlySpan x, float y, Span destination, TBinaryOperator op = default) + where TBinaryOperator : struct, IBinaryOperator => + InvokeSpanScalarIntoSpan(x, y, destination, default, op); + + /// + /// Performs an element-wise operation on and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on each element loaded from . + /// It is not used with . + /// + /// + /// Specifies the operation to perform on the transformed value from with . + /// + private static unsafe void InvokeSpanScalarIntoSpan( + ReadOnlySpan x, float y, Span destination, TTransformOperator xTransformOp = default, TBinaryOperator binaryOp = default) + where TTransformOperator : struct, IUnaryOperator + where TBinaryOperator : struct, IBinaryOperator + { + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated && xTransformOp.CanVectorize) + { + if (remainder >= (uint)(Vector.Count)) + { + Vectorized(ref xRef, y, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xRef, y, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, y, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, float y, ref float dRef, nuint length, TTransformOperator xTransformOp = default, TBinaryOperator binaryOp = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, (nint)(i))), + y); + } + } + + static void Vectorized(ref float xRef, float y, ref float dRef, nuint remainder, TTransformOperator xTransformOp = default, TBinaryOperator binaryOp = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector yVec = new Vector(y); + + Vector beg = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef)), + yVec); + Vector end = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count))), + yVec); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0))), + yVec); + vector2 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1))), + yVec); + vector3 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2))), + yVec); + vector4 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3))), + yVec); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4))), + yVec); + vector2 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5))), + yVec); + vector3 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6))), + yVec); + vector4 = binaryOp.Invoke(xTransformOp.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7))), + yVec); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; + } + + case 7: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = binaryOp.Invoke(xTransformOp.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2))), + yVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, float y, ref float dRef, nuint remainder, TTransformOperator xTransformOp = default, TBinaryOperator binaryOp = default) + { + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 6)), + y); + goto case 6; + } + + case 6: + { + Unsafe.Add(ref dRef, 5) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 5)), + y); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 4)), + y); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 3)), + y); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 2)), + y); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = binaryOp.Invoke(xTransformOp.Invoke(Unsafe.Add(ref xRef, 1)), + y); + goto case 1; + } + + case 1: + { + dRef = binaryOp.Invoke(xTransformOp.Invoke(xRef), y); + goto case 0; + } + + case 0: + { + break; + } + } + } + } + + /// + /// Performs an element-wise operation on , , and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise elements loaded from , , + /// and . + /// + private static unsafe void InvokeSpanSpanSpanIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination, TTernaryOperator op = default) + where TTernaryOperator : struct, ITernaryOperator + { + if (x.Length != y.Length || x.Length != z.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + ValidateInputOutputSpanNonOverlapping(z, destination); + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float zRef = ref MemoryMarshal.GetReference(z); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector.Count)) + { + Vectorized(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, ref zRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint length, TTernaryOperator op = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = op.Invoke(Unsafe.Add(ref xRef, (nint)(i)), + Unsafe.Add(ref yRef, (nint)(i)), + Unsafe.Add(ref zRef, (nint)(i))); + } + } + + static void Vectorized(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder, TTernaryOperator op = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector beg = op.Invoke(AsVector(ref xRef), + AsVector(ref yRef), + AsVector(ref zRef)); + Vector end = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count)), + AsVector(ref yRef, remainder - (uint)(Vector.Count)), + AsVector(ref zRef, remainder - (uint)(Vector.Count))); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0)), + *(Vector*)(yPtr + (uint)(Vector.Count * 0)), + *(Vector*)(zPtr + (uint)(Vector.Count * 0))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1)), + *(Vector*)(yPtr + (uint)(Vector.Count * 1)), + *(Vector*)(zPtr + (uint)(Vector.Count * 1))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2)), + *(Vector*)(yPtr + (uint)(Vector.Count * 2)), + *(Vector*)(zPtr + (uint)(Vector.Count * 2))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3)), + *(Vector*)(yPtr + (uint)(Vector.Count * 3)), + *(Vector*)(zPtr + (uint)(Vector.Count * 3))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4)), + *(Vector*)(yPtr + (uint)(Vector.Count * 4)), + *(Vector*)(zPtr + (uint)(Vector.Count * 4))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5)), + *(Vector*)(yPtr + (uint)(Vector.Count * 5)), + *(Vector*)(zPtr + (uint)(Vector.Count * 5))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6)), + *(Vector*)(yPtr + (uint)(Vector.Count * 6)), + *(Vector*)(zPtr + (uint)(Vector.Count * 6))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7)), + *(Vector*)(yPtr + (uint)(Vector.Count * 7)), + *(Vector*)(zPtr + (uint)(Vector.Count * 7))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + yPtr += (uint)(Vector.Count * 8); + zPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 8)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 8))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; + } + + case 7: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 7)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 7))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 6)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 6))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 5)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 5))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 4)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 4))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 3)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 3))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 2)), + AsVector(ref zRef, remainder - (uint)(Vector.Count * 2))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, ref float yRef, ref float zRef, ref float dRef, nuint remainder, TTernaryOperator op = default) + { + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = op.Invoke(Unsafe.Add(ref xRef, 6), + Unsafe.Add(ref yRef, 6), + Unsafe.Add(ref zRef, 6)); + goto case 6; + } + + case 6: + { + Unsafe.Add(ref dRef, 5) = op.Invoke(Unsafe.Add(ref xRef, 5), + Unsafe.Add(ref yRef, 5), + Unsafe.Add(ref zRef, 5)); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = op.Invoke(Unsafe.Add(ref xRef, 4), + Unsafe.Add(ref yRef, 4), + Unsafe.Add(ref zRef, 4)); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = op.Invoke(Unsafe.Add(ref xRef, 3), + Unsafe.Add(ref yRef, 3), + Unsafe.Add(ref zRef, 3)); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = op.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = op.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = op.Invoke(xRef, yRef, zRef); + break; + } + + case 0: + { + break; + } + } + } + } + + /// + /// Performs an element-wise operation on , , and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise elements loaded from and + /// with . + /// + private static unsafe void InvokeSpanSpanScalarIntoSpan( + ReadOnlySpan x, ReadOnlySpan y, float z, Span destination, TTernaryOperator op = default) + where TTernaryOperator : struct, ITernaryOperator + { + if (x.Length != y.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(y, destination); + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float yRef = ref MemoryMarshal.GetReference(y); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector.Count)) + { + Vectorized(ref xRef, ref yRef, z, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xRef, ref yRef, z, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref yRef, z, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float yRef, float z, ref float dRef, nuint length, TTernaryOperator op = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = op.Invoke(Unsafe.Add(ref xRef, (nint)(i)), + Unsafe.Add(ref yRef, (nint)(i)), + z); + } + } + + static void Vectorized(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder, TTernaryOperator op = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector zVec = new Vector(z); + + Vector beg = op.Invoke(AsVector(ref xRef), + AsVector(ref yRef), + zVec); + Vector end = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count)), + AsVector(ref yRef, remainder - (uint)(Vector.Count)), + zVec); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* py = &yRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* yPtr = py; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + yPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0)), + *(Vector*)(yPtr + (uint)(Vector.Count * 0)), + zVec); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1)), + *(Vector*)(yPtr + (uint)(Vector.Count * 1)), + zVec); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2)), + *(Vector*)(yPtr + (uint)(Vector.Count * 2)), + zVec); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3)), + *(Vector*)(yPtr + (uint)(Vector.Count * 3)), + zVec); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4)), + *(Vector*)(yPtr + (uint)(Vector.Count * 4)), + zVec); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5)), + *(Vector*)(yPtr + (uint)(Vector.Count * 5)), + zVec); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6)), + *(Vector*)(yPtr + (uint)(Vector.Count * 6)), + zVec); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7)), + *(Vector*)(yPtr + (uint)(Vector.Count * 7)), + zVec); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + yPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + yRef = ref *yPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 8)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; + } + + case 7: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 7)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 6)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 5)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 4)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 3)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2)), + AsVector(ref yRef, remainder - (uint)(Vector.Count * 2)), + zVec); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, ref float yRef, float z, ref float dRef, nuint remainder, TTernaryOperator op = default) + { + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = op.Invoke(Unsafe.Add(ref xRef, 6), + Unsafe.Add(ref yRef, 6), + z); + goto case 6; + } + + case 6: + { + Unsafe.Add(ref dRef, 5) = op.Invoke(Unsafe.Add(ref xRef, 5), + Unsafe.Add(ref yRef, 5), + z); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = op.Invoke(Unsafe.Add(ref xRef, 4), + Unsafe.Add(ref yRef, 4), + z); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = op.Invoke(Unsafe.Add(ref xRef, 3), + Unsafe.Add(ref yRef, 3), + z); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = op.Invoke(Unsafe.Add(ref xRef, 2), + Unsafe.Add(ref yRef, 2), + z); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = op.Invoke(Unsafe.Add(ref xRef, 1), + Unsafe.Add(ref yRef, 1), + z); + goto case 1; + } + + case 1: + { + dRef = op.Invoke(xRef, yRef, z); + goto case 0; + } + + case 0: + { + break; + } + } + } + } + + /// + /// Performs an element-wise operation on , , and , + /// and writes the results to . + /// + /// + /// Specifies the operation to perform on the pair-wise element loaded from , with , + /// and the element loaded from . + /// + private static unsafe void InvokeSpanScalarSpanIntoSpan( + ReadOnlySpan x, float y, ReadOnlySpan z, Span destination, TTernaryOperator op = default) + where TTernaryOperator : struct, ITernaryOperator + { + if (x.Length != z.Length) + { + ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + } + + if (x.Length > destination.Length) + { + ThrowHelper.ThrowArgument_DestinationTooShort(); + } + + ValidateInputOutputSpanNonOverlapping(x, destination); + ValidateInputOutputSpanNonOverlapping(z, destination); + + ref float xRef = ref MemoryMarshal.GetReference(x); + ref float zRef = ref MemoryMarshal.GetReference(z); + ref float dRef = ref MemoryMarshal.GetReference(destination); + + nuint remainder = (uint)(x.Length); + + if (Vector.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector.Count)) + { + Vectorized(ref xRef, y, ref zRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xRef, y, ref zRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, y, ref zRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, float y, ref float zRef, ref float dRef, nuint length, TTernaryOperator op = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = op.Invoke(Unsafe.Add(ref xRef, (nint)(i)), + y, + Unsafe.Add(ref zRef, (nint)(i))); + } + } + + static void Vectorized(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder, TTernaryOperator op = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector yVec = new Vector(y); + + Vector beg = op.Invoke(AsVector(ref xRef), + yVec, + AsVector(ref zRef)); + Vector end = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count))); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pz = &zRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* zPtr = pz; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + zPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 0))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 1))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 2))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 3))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 4))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 5))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 6))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7)), + yVec, + *(Vector*)(zPtr + (uint)(Vector.Count * 7))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + zPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + zRef = ref *zPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 8))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; + } + + case 7: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 7))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 6))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 5))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 4))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 3))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2)), + yVec, + AsVector(ref zRef, remainder - (uint)(Vector.Count * 2))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto case 0; + } + + case 0: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, float y, ref float zRef, ref float dRef, nuint remainder, TTernaryOperator op = default) + { + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = op.Invoke(Unsafe.Add(ref xRef, 6), + y, + Unsafe.Add(ref zRef, 6)); + goto case 6; + } + + case 6: + { + Unsafe.Add(ref dRef, 5) = op.Invoke(Unsafe.Add(ref xRef, 5), + y, + Unsafe.Add(ref zRef, 5)); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = op.Invoke(Unsafe.Add(ref xRef, 4), + y, + Unsafe.Add(ref zRef, 4)); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = op.Invoke(Unsafe.Add(ref xRef, 3), + y, + Unsafe.Add(ref zRef, 3)); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = op.Invoke(Unsafe.Add(ref xRef, 2), + y, + Unsafe.Add(ref zRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = op.Invoke(Unsafe.Add(ref xRef, 1), + y, + Unsafe.Add(ref zRef, 1)); + goto case 1; + } + + case 1: + { + dRef = op.Invoke(xRef, y, zRef); + goto case 0; + } + + case 0: + { + break; + } + } + } + } + + /// Loads a from . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static ref Vector AsVector(ref float start) => + ref Unsafe.As>(ref start); + + /// Loads a that begins at the specified from . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static ref Vector AsVector(ref float start, int offset) => + ref Unsafe.As>( + ref Unsafe.Add(ref start, offset)); + + /// Loads a that begins at the specified from . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static ref Vector AsVector(ref float start, nuint offset) => + ref Unsafe.As>( + ref Unsafe.Add(ref start, (nint)(offset))); + + /// Loads a that begins at the specified from . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static ref Vector AsVector(ref int start, int offset) => + ref Unsafe.As>( + ref Unsafe.Add(ref start, offset)); + + /// Gets whether the specified is positive. + private static bool IsPositive(float f) => !IsNegative(f); + + /// Gets whether each specified is positive. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector IsPositive(Vector vector) => + ((Vector)Vector.GreaterThan(((Vector)vector), Vector.Zero)); + + /// Gets whether the specified is negative. + private static unsafe bool IsNegative(float f) => *(int*)&f < 0; + + /// Gets whether each specified is negative. + private static Vector IsNegative(Vector f) => + (Vector)Vector.LessThan((Vector)f, Vector.Zero); + + /// Gets the base 2 logarithm of . + private static float Log2(float x) => MathF.Log(x, 2); + + /// + /// Gets a vector mask that will be all-ones-set for the first elements + /// and zero for all other elements. + /// + private static Vector CreateAlignmentMaskSingleVector(int count) + { + Debug.Assert(Vector.Count is 4 or 8 or 16); + + return AsVector( + ref Unsafe.As(ref MemoryMarshal.GetReference(AlignmentUInt32Mask_16x16)), + (count * 16)); + } + + /// + /// Gets a vector mask that will be all-ones-set for the last elements + /// and zero for all other elements. + /// + private static Vector CreateRemainderMaskSingleVector(int count) + { + Debug.Assert(Vector.Count is 4 or 8 or 16); + + return AsVector( + ref Unsafe.As(ref MemoryMarshal.GetReference(RemainderUInt32Mask_16x16)), + (count * 16) + (16 - Vector.Count)); + } + + /// x + y + private readonly struct AddOperator : IAggregationOperator + { + public float Invoke(float x, float y) => x + y; + public Vector Invoke(Vector x, Vector y) => x + y; + public float IdentityValue => 0; + } + + /// x - y + private readonly struct SubtractOperator : IBinaryOperator + { + public float Invoke(float x, float y) => x - y; + public Vector Invoke(Vector x, Vector y) => x - y; + } + + /// (x - y) * (x - y) + private readonly struct SubtractSquaredOperator : IBinaryOperator + { + public float Invoke(float x, float y) + { + float tmp = x - y; + return tmp * tmp; + } + + public Vector Invoke(Vector x, Vector y) + { + Vector tmp = x - y; + return tmp * tmp; + } + } + + /// x * y + private readonly struct MultiplyOperator : IAggregationOperator + { + public float Invoke(float x, float y) => x * y; + public Vector Invoke(Vector x, Vector y) => x * y; + public float IdentityValue => 1; + } + + /// x / y + private readonly struct DivideOperator : IBinaryOperator + { + public float Invoke(float x, float y) => x / y; + public Vector Invoke(Vector x, Vector y) => x / y; + } + + private interface IIndexOfOperator + { + int Invoke(ref float result, float current, int resultIndex, int curIndex); + int Invoke(Vector result, Vector resultIndex); + void Invoke(ref Vector result, Vector current, ref Vector resultIndex, Vector curIndex); + } + + /// Returns the index of MathF.Max(x, y) + private readonly struct IndexOfMaxOperator : IIndexOfOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(Vector result, Vector resultIndex) + { + float curMax = result[0]; + int curIn = resultIndex[0]; + for (int i = 1; i < Vector.Count; i++) + { + if (result[i] == curMax && IsNegative(curMax) && !IsNegative(result[i])) + { + curMax = result[i]; + curIn = resultIndex[i]; + } + else if (result[i] > curMax) + { + curMax = result[i]; + curIn = resultIndex[i]; + } + } + + return curIn; + } - if (Vector.IsHardwareAccelerated && x.Length >= Vector.Count * 2) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Invoke(ref Vector result, Vector current, ref Vector resultIndex, Vector curIndex) { - ref float xRef = ref MemoryMarshal.GetReference(x); + Vector lessThanMask = Vector.GreaterThan(result, current); - // Load the first vector as the initial set of results - Vector resultVector = load.Invoke(AsVector(ref xRef, 0)); - int oneVectorFromEnd = x.Length - Vector.Count; + Vector equalMask = Vector.Equals(result, current); - // Aggregate additional vectors into the result as long as there's at - // least one full vector left to process. - i = Vector.Count; - do + if (equalMask != Vector.Zero) { - resultVector = aggregate.Invoke(resultVector, load.Invoke(AsVector(ref xRef, i))); - i += Vector.Count; - } - while (i <= oneVectorFromEnd); + Vector negativeMask = IsNegative(current); + Vector lessThanIndexMask = Vector.LessThan(resultIndex, curIndex); - // Aggregate the lanes in the vector back into the scalar result - for (int f = 0; f < Vector.Count; f++) - { - result = aggregate.Invoke(result, resultVector[f]); + lessThanMask |= ((Vector)~negativeMask & equalMask) | ((Vector)IsNegative(result) & equalMask & lessThanIndexMask); } - } - - // Aggregate the remaining items in the input span. - for (; (uint)i < (uint)x.Length; i++) - { - result = aggregate.Invoke(result, load.Invoke(x[i])); - } - return result; - } + result = Vector.ConditionalSelect(lessThanMask, result, current); - private static float Aggregate( - float identityValue, ReadOnlySpan x, ReadOnlySpan y, TBinary binary = default, TAggregate aggregate = default) - where TBinary : struct, IBinaryOperator - where TAggregate : struct, IBinaryOperator - { - // Initialize the result to the identity value - float result = identityValue; - int i = 0; + resultIndex = Vector.ConditionalSelect(lessThanMask, resultIndex, curIndex); + } - if (Vector.IsHardwareAccelerated && x.Length >= Vector.Count * 2) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(ref float result, float current, int resultIndex, int curIndex) { - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); - - // Load the first vector as the initial set of results - Vector resultVector = binary.Invoke(AsVector(ref xRef, 0), AsVector(ref yRef, 0)); - int oneVectorFromEnd = x.Length - Vector.Count; - - // Aggregate additional vectors into the result as long as there's at - // least one full vector left to process. - i = Vector.Count; - do + if (result == current) { - resultVector = aggregate.Invoke(resultVector, binary.Invoke(AsVector(ref xRef, i), AsVector(ref yRef, i))); - i += Vector.Count; + if (IsNegative(result) && !IsNegative(current)) + { + result = current; + return curIndex; + } } - while (i <= oneVectorFromEnd); - - // Aggregate the lanes in the vector back into the scalar result - for (int f = 0; f < Vector.Count; f++) + else if (current > result) { - result = aggregate.Invoke(result, resultVector[f]); + result = current; + return curIndex; } - } - // Aggregate the remaining items in the input span. - for (; (uint)i < (uint)x.Length; i++) - { - result = aggregate.Invoke(result, binary.Invoke(x[i], y[i])); + return resultIndex; } - - return result; } - private static void InvokeSpanIntoSpan( - ReadOnlySpan x, Span destination, TUnaryOperator op = default) - where TUnaryOperator : struct, IUnaryOperator + private readonly struct IndexOfMaxMagnitudeOperator : IIndexOfOperator { - if (x.Length > destination.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(ref float result, float current, int resultIndex, int curIndex) { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } + float curMaxAbs = MathF.Abs(result); + float currentAbs = MathF.Abs(current); - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + if (curMaxAbs == currentAbs) + { + if (IsNegative(result) && !IsNegative(current)) + { + result = current; + return curIndex; + } + } + else if (currentAbs > curMaxAbs) + { + result = current; + return curIndex; + } - if (Vector.IsHardwareAccelerated) + return resultIndex; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(Vector result, Vector maxIndex) { - oneVectorFromEnd = x.Length - Vector.Count; - if (oneVectorFromEnd >= 0) + float curMax = result[0]; + int curIn = maxIndex[0]; + for (int i = 1; i < Vector.Count; i++) { - // Loop handling one vector at a time. - do + if (MathF.Abs(result[i]) == MathF.Abs(curMax) && IsNegative(curMax) && !IsNegative(result[i])) { - AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i)); - - i += Vector.Count; + curMax = result[i]; + curIn = maxIndex[i]; } - while (i <= oneVectorFromEnd); - - // Handle any remaining elements with a final vector. - if (i != x.Length) + else if (MathF.Abs(result[i]) > MathF.Abs(curMax)) { - int lastVectorIndex = x.Length - Vector.Count; - AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex)); + curMax = result[i]; + curIn = maxIndex[i]; } - - return; } - } - - // Loop handling one element at a time. - while (i < x.Length) - { - Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i)); - i++; + return curIn; } - } - private static void InvokeSpanSpanIntoSpan( - ReadOnlySpan x, ReadOnlySpan y, Span destination, TBinaryOperator op = default) - where TBinaryOperator : struct, IBinaryOperator - { - if (x.Length != y.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Invoke(ref Vector result, Vector current, ref Vector resultIndex, Vector curIndex) { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } + Vector maxMag = Vector.Abs(result), currentMag = Vector.Abs(current); - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } + Vector lessThanMask = Vector.GreaterThan(maxMag, currentMag); - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + Vector equalMask = Vector.Equals(result, current); - if (Vector.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector.Count; - if (oneVectorFromEnd >= 0) + if (equalMask != Vector.Zero) { - // Loop handling one vector at a time. - do - { - AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), - AsVector(ref yRef, i)); - - i += Vector.Count; - } - while (i <= oneVectorFromEnd); + Vector negativeMask = IsNegative(current); + Vector lessThanIndexMask = Vector.LessThan(resultIndex, curIndex); - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - int lastVectorIndex = x.Length - Vector.Count; - AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex), - AsVector(ref yRef, lastVectorIndex)); - } - - return; + lessThanMask |= ((Vector)~negativeMask & equalMask) | ((Vector)IsNegative(result) & equalMask & lessThanIndexMask); } - } - while (i < x.Length) - { - Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i)); + result = Vector.ConditionalSelect(lessThanMask, result, current); - i++; + resultIndex = Vector.ConditionalSelect(lessThanMask, resultIndex, curIndex); } } - private static void InvokeSpanScalarIntoSpan( - ReadOnlySpan x, float y, Span destination, TBinaryOperator op = default) - where TBinaryOperator : struct, IBinaryOperator + private readonly struct IndexOfMinOperator : IIndexOfOperator { - if (x.Length > destination.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(ref float result, float current, int resultIndex, int curIndex) { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } + if (result == current) + { + if (IsPositive(result) && !IsPositive(current)) + { + result = current; + return curIndex; + } + } + else if (current < result) + { + result = current; + return curIndex; + } - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + return resultIndex; + } - if (Vector.IsHardwareAccelerated) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(Vector result, Vector resultIndex) { - oneVectorFromEnd = x.Length - Vector.Count; - if (oneVectorFromEnd >= 0) + float curMin = result[0]; + int curIn = resultIndex[0]; + for (int i = 1; i < Vector.Count; i++) { - // Loop handling one vector at a time. - Vector yVec = new(y); - do + if (result[i] == curMin && IsPositive(curMin) && !IsPositive(result[i])) { - AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), - yVec); - - i += Vector.Count; + curMin = result[i]; + curIn = resultIndex[i]; } - while (i <= oneVectorFromEnd); - - // Handle any remaining elements with a final vector. - if (i != x.Length) + else if (result[i] < curMin) { - int lastVectorIndex = x.Length - Vector.Count; - AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex), - yVec); + curMin = result[i]; + curIn = resultIndex[i]; } - - return; } + + return curIn; } - // Loop handling one element at a time. - while (i < x.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Invoke(ref Vector result, Vector current, ref Vector resultIndex, Vector curIndex) { - Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), - y); + Vector lessThanMask = Vector.LessThan(result, current); + + Vector equalMask = Vector.Equals(result, current); + + if (equalMask != Vector.Zero) + { + Vector negativeMask = IsNegative(current); + Vector lessThanIndexMask = Vector.LessThan(resultIndex, curIndex); + + lessThanMask |= ((Vector)negativeMask & equalMask) | (~(Vector)IsNegative(result) & equalMask & lessThanIndexMask); + } + + result = Vector.ConditionalSelect(lessThanMask, result, current); - i++; + resultIndex = Vector.ConditionalSelect(lessThanMask, resultIndex, curIndex); } } - private static void InvokeSpanSpanSpanIntoSpan( - ReadOnlySpan x, ReadOnlySpan y, ReadOnlySpan z, Span destination, TTernaryOperator op = default) - where TTernaryOperator : struct, ITernaryOperator + private readonly struct IndexOfMinMagnitudeOperator : IIndexOfOperator { - if (x.Length != y.Length || x.Length != z.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(ref float result, float current, int resultIndex, int curIndex) { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } + float curMinAbs = MathF.Abs(result); + float currentAbs = MathF.Abs(current); + if (curMinAbs == currentAbs) + { + if (IsPositive(result) && !IsPositive(current)) + { + result = current; + return curIndex; + } + } + else if (currentAbs < curMinAbs) + { + result = current; + return curIndex; + } - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); + return resultIndex; } - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); - ref float zRef = ref MemoryMarshal.GetReference(z); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; - - if (Vector.IsHardwareAccelerated) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(Vector result, Vector resultIndex) { - oneVectorFromEnd = x.Length - Vector.Count; - if (oneVectorFromEnd >= 0) + float curMin = result[0]; + int curIn = resultIndex[0]; + for (int i = 1; i < Vector.Count; i++) { - // Loop handling one vector at a time. - do + if (MathF.Abs(result[i]) == MathF.Abs(curMin) && IsPositive(curMin) && !IsPositive(result[i])) { - AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), - AsVector(ref yRef, i), - AsVector(ref zRef, i)); - - i += Vector.Count; + curMin = result[i]; + curIn = resultIndex[i]; } - while (i <= oneVectorFromEnd); - - // Handle any remaining elements with a final vector. - if (i != x.Length) + else if (MathF.Abs(result[i]) < MathF.Abs(curMin)) { - int lastVectorIndex = x.Length - Vector.Count; - AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex), - AsVector(ref yRef, lastVectorIndex), - AsVector(ref zRef, lastVectorIndex)); + curMin = result[i]; + curIn = resultIndex[i]; } - - return; } - } - // Loop handling one element at a time. - while (i < x.Length) - { - Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i), - Unsafe.Add(ref zRef, i)); - - i++; + return curIn; } - } - private static void InvokeSpanSpanScalarIntoSpan( - ReadOnlySpan x, ReadOnlySpan y, float z, Span destination, TTernaryOperator op = default) - where TTernaryOperator : struct, ITernaryOperator - { - if (x.Length != y.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Invoke(ref Vector result, Vector current, ref Vector resultIndex, Vector curIndex) { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); - } + Vector minMag = Vector.Abs(result), currentMag = Vector.Abs(current); - if (x.Length > destination.Length) - { - ThrowHelper.ThrowArgument_DestinationTooShort(); - } + Vector lessThanMask = Vector.LessThan(minMag, currentMag); - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float yRef = ref MemoryMarshal.GetReference(y); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + Vector equalMask = Vector.Equals(result, current); - if (Vector.IsHardwareAccelerated) - { - oneVectorFromEnd = x.Length - Vector.Count; - if (oneVectorFromEnd >= 0) + if (equalMask != Vector.Zero) { - Vector zVec = new(z); - - // Loop handling one vector at a time. - do - { - AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), - AsVector(ref yRef, i), - zVec); + Vector negativeMask = IsNegative(current); + Vector lessThanIndexMask = Vector.LessThan(resultIndex, curIndex); - i += Vector.Count; - } - while (i <= oneVectorFromEnd); + lessThanMask |= ((Vector)negativeMask & equalMask) | (~(Vector)IsNegative(result) & equalMask & lessThanIndexMask); + } - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - int lastVectorIndex = x.Length - Vector.Count; - AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex), - AsVector(ref yRef, lastVectorIndex), - zVec); - } + result = Vector.ConditionalSelect(lessThanMask, result, current); - return; - } + resultIndex = Vector.ConditionalSelect(lessThanMask, resultIndex, curIndex); } + } - // Loop handling one element at a time. - while (i < x.Length) - { - Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), - Unsafe.Add(ref yRef, i), - z); + /// MathF.Max(x, y) (but without guaranteed NaN propagation) + private readonly struct MaxOperator : IBinaryOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public float Invoke(float x, float y) => + x == y ? + (IsNegative(x) ? y : x) : + (y > x ? y : x); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Vector Invoke(Vector x, Vector y) => + Vector.ConditionalSelect(Vector.Equals(x, y), + Vector.ConditionalSelect(IsNegative(x), y, x), + Vector.Max(x, y)); + } - i++; - } + /// MathF.Max(x, y) + private readonly struct MaxPropagateNaNOperator : IBinaryOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public float Invoke(float x, float y) => MathF.Max(x, y); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Vector Invoke(Vector x, Vector y) => + Vector.ConditionalSelect(Vector.Equals(x, x), + Vector.ConditionalSelect(Vector.Equals(y, y), + Vector.ConditionalSelect(Vector.Equals(x, y), + Vector.ConditionalSelect(IsNegative(x), y, x), + Vector.Max(x, y)), + y), + x); } - private static void InvokeSpanScalarSpanIntoSpan( - ReadOnlySpan x, float y, ReadOnlySpan z, Span destination, TTernaryOperator op = default) - where TTernaryOperator : struct, ITernaryOperator + /// Operator to get x or y based on which has the larger MathF.Abs (but NaNs may not be propagated) + private readonly struct MaxMagnitudeOperator : IBinaryOperator { - if (x.Length != z.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public float Invoke(float x, float y) { - ThrowHelper.ThrowArgument_SpansMustHaveSameLength(); + float xMag = MathF.Abs(x), yMag = MathF.Abs(y); + return + yMag == xMag ? + (IsNegative(x) ? y : x) : + (xMag > yMag ? x : y); } - if (x.Length > destination.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Vector Invoke(Vector x, Vector y) { - ThrowHelper.ThrowArgument_DestinationTooShort(); + Vector xMag = Vector.Abs(x), yMag = Vector.Abs(y); + return + Vector.ConditionalSelect(Vector.Equals(xMag, yMag), + Vector.ConditionalSelect(IsNegative(x), y, x), + Vector.ConditionalSelect(Vector.GreaterThan(xMag, yMag), x, y)); } + } - ref float xRef = ref MemoryMarshal.GetReference(x); - ref float zRef = ref MemoryMarshal.GetReference(z); - ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; - - if (Vector.IsHardwareAccelerated) + /// Operator to get x or y based on which has the larger MathF.Abs + private readonly struct MaxMagnitudePropagateNaNOperator : IBinaryOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public float Invoke(float x, float y) { - oneVectorFromEnd = x.Length - Vector.Count; - if (oneVectorFromEnd >= 0) - { - Vector yVec = new(y); - - // Loop handling one vector at a time. - do - { - AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i), - yVec, - AsVector(ref zRef, i)); - - i += Vector.Count; - } - while (i <= oneVectorFromEnd); - - // Handle any remaining elements with a final vector. - if (i != x.Length) - { - int lastVectorIndex = x.Length - Vector.Count; - AsVector(ref dRef, lastVectorIndex) = op.Invoke(AsVector(ref xRef, lastVectorIndex), - yVec, - AsVector(ref zRef, lastVectorIndex)); - } - - return; - } + float xMag = MathF.Abs(x), yMag = MathF.Abs(y); + return xMag > yMag || float.IsNaN(xMag) || (xMag == yMag && !IsNegative(x)) ? x : y; } - // Loop handling one element at a time. - while (i < x.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Vector Invoke(Vector x, Vector y) { - Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i), - y, - Unsafe.Add(ref zRef, i)); - - i++; + Vector xMag = Vector.Abs(x), yMag = Vector.Abs(y); + return + Vector.ConditionalSelect(Vector.Equals(x, x), + Vector.ConditionalSelect(Vector.Equals(y, y), + Vector.ConditionalSelect(Vector.Equals(xMag, yMag), + Vector.ConditionalSelect(IsNegative(x), y, x), + Vector.ConditionalSelect(Vector.GreaterThan(xMag, yMag), x, y)), + y), + x); } } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static ref Vector AsVector(ref float start, int offset) => - ref Unsafe.As>( - ref Unsafe.Add(ref start, offset)); - - private readonly struct AddOperator : IBinaryOperator + /// MathF.Min(x, y) (but NaNs may not be propagated) + private readonly struct MinOperator : IBinaryOperator { - public float Invoke(float x, float y) => x + y; - public Vector Invoke(Vector x, Vector y) => x + y; + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public float Invoke(float x, float y) => + x == y ? + (IsNegative(y) ? y : x) : + (y < x ? y : x); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Vector Invoke(Vector x, Vector y) => + Vector.ConditionalSelect(Vector.Equals(x, y), + Vector.ConditionalSelect(IsNegative(y), y, x), + Vector.Min(x, y)); } - private readonly struct SubtractOperator : IBinaryOperator + /// MathF.Min(x, y) + private readonly struct MinPropagateNaNOperator : IBinaryOperator { - public float Invoke(float x, float y) => x - y; - public Vector Invoke(Vector x, Vector y) => x - y; + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public float Invoke(float x, float y) => MathF.Min(x, y); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Vector Invoke(Vector x, Vector y) => + Vector.ConditionalSelect(Vector.Equals(x, x), + Vector.ConditionalSelect(Vector.Equals(y, y), + Vector.ConditionalSelect(Vector.Equals(x, y), + Vector.ConditionalSelect(IsNegative(x), x, y), + Vector.Min(x, y)), + y), + x); } - private readonly struct SubtractSquaredOperator : IBinaryOperator + /// Operator to get x or y based on which has the smaller MathF.Abs (but NaNs may not be propagated) + private readonly struct MinMagnitudeOperator : IBinaryOperator { + [MethodImpl(MethodImplOptions.AggressiveInlining)] public float Invoke(float x, float y) { - float tmp = x - y; - return tmp * tmp; + float xMag = MathF.Abs(x), yMag = MathF.Abs(y); + return + yMag == xMag ? + (IsNegative(y) ? y : x) : + (yMag < xMag ? y : x); } + [MethodImpl(MethodImplOptions.AggressiveInlining)] public Vector Invoke(Vector x, Vector y) { - Vector tmp = x - y; - return tmp * tmp; + Vector xMag = Vector.Abs(x), yMag = Vector.Abs(y); + return + Vector.ConditionalSelect(Vector.Equals(yMag, xMag), + Vector.ConditionalSelect(IsNegative(y), y, x), + Vector.ConditionalSelect(Vector.LessThan(yMag, xMag), y, x)); } } - private readonly struct MultiplyOperator : IBinaryOperator + /// Operator to get x or y based on which has the smaller MathF.Abs + private readonly struct MinMagnitudePropagateNaNOperator : IBinaryOperator { - public float Invoke(float x, float y) => x * y; - public Vector Invoke(Vector x, Vector y) => x * y; - } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public float Invoke(float x, float y) + { + float xMag = MathF.Abs(x), yMag = MathF.Abs(y); + return xMag < yMag || float.IsNaN(xMag) || (xMag == yMag && IsNegative(x)) ? x : y; + } - private readonly struct DivideOperator : IBinaryOperator - { - public float Invoke(float x, float y) => x / y; - public Vector Invoke(Vector x, Vector y) => x / y; + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Vector Invoke(Vector x, Vector y) + { + Vector xMag = Vector.Abs(x), yMag = Vector.Abs(y); + + return + Vector.ConditionalSelect(Vector.Equals(x, x), + Vector.ConditionalSelect(Vector.Equals(y, y), + Vector.ConditionalSelect(Vector.Equals(yMag, xMag), + Vector.ConditionalSelect(IsNegative(x), x, y), + Vector.ConditionalSelect(Vector.LessThan(xMag, yMag), x, y)), + y), + x); + } } + /// -x private readonly struct NegateOperator : IUnaryOperator { + public bool CanVectorize => true; public float Invoke(float x) => -x; public Vector Invoke(Vector x) => -x; } + /// (x + y) * z private readonly struct AddMultiplyOperator : ITernaryOperator { public float Invoke(float x, float y, float z) => (x + y) * z; public Vector Invoke(Vector x, Vector y, Vector z) => (x + y) * z; } + /// (x * y) + z private readonly struct MultiplyAddOperator : ITernaryOperator { public float Invoke(float x, float y, float z) => (x * y) + z; public Vector Invoke(Vector x, Vector y, Vector z) => (x * y) + z; } - private readonly struct LoadIdentity : IUnaryOperator + /// x + private readonly struct IdentityOperator : IUnaryOperator { + public bool CanVectorize => true; public float Invoke(float x) => x; public Vector Invoke(Vector x) => x; } - private readonly struct LoadSquared : IUnaryOperator + /// x * x + private readonly struct SquaredOperator : IUnaryOperator { + public bool CanVectorize => true; public float Invoke(float x) => x * x; public Vector Invoke(Vector x) => x * x; } - private readonly struct LoadAbsolute : IUnaryOperator + /// MathF.Abs(x) + private readonly struct AbsoluteOperator : IUnaryOperator { + public bool CanVectorize => true; public float Invoke(float x) => MathF.Abs(x); + public Vector Invoke(Vector x) => Vector.Abs(x); + } - public Vector Invoke(Vector x) - { - Vector raw = Vector.AsVectorUInt32(x); - Vector mask = new Vector(0x7FFFFFFF); - return Vector.AsVectorSingle(raw & mask); - } + /// MathF.Exp(x) + private readonly struct ExpOperator : IUnaryOperator + { + public bool CanVectorize => false; + public float Invoke(float x) => MathF.Exp(x); + public Vector Invoke(Vector x) => + // requires ShiftLeft (.NET 7+) + throw new NotImplementedException(); + } + + /// MathF.Sinh(x) + private readonly struct SinhOperator : IUnaryOperator + { + public bool CanVectorize => false; + public float Invoke(float x) => MathF.Sinh(x); + public Vector Invoke(Vector x) => + // requires ShiftLeft (.NET 7+) + throw new NotImplementedException(); + } + + /// MathF.Cosh(x) + private readonly struct CoshOperator : IUnaryOperator + { + public bool CanVectorize => false; + public float Invoke(float x) => MathF.Cosh(x); + public Vector Invoke(Vector x) => + // requires ShiftLeft (.NET 7+) + throw new NotImplementedException(); + } + + /// MathF.Tanh(x) + private readonly struct TanhOperator : IUnaryOperator + { + public bool CanVectorize => false; + public float Invoke(float x) => MathF.Tanh(x); + public Vector Invoke(Vector x) => + // requires ShiftLeft (.NET 7+) + throw new NotImplementedException(); + } + + /// MathF.Log(x) + private readonly struct LogOperator : IUnaryOperator + { + public bool CanVectorize => false; + public float Invoke(float x) => MathF.Log(x); + public Vector Invoke(Vector x) => + // requires ShiftRightArithmetic (.NET 7+) + throw new NotImplementedException(); + } + + /// MathF.Log2(x) + private readonly struct Log2Operator : IUnaryOperator + { + public bool CanVectorize => false; + public float Invoke(float x) => Log2(x); + public Vector Invoke(Vector x) => + // requires ShiftRightArithmetic (.NET 7+) + throw new NotImplementedException(); + } + + /// 1f / (1f + MathF.Exp(-x)) + private readonly struct SigmoidOperator : IUnaryOperator + { + public bool CanVectorize => false; + public float Invoke(float x) => 1.0f / (1.0f + MathF.Exp(-x)); + public Vector Invoke(Vector x) => + // requires ShiftRightArithmetic (.NET 7+) + throw new NotImplementedException(); } + /// Operator that takes one input value and returns a single value. private interface IUnaryOperator { + bool CanVectorize { get; } float Invoke(float x); Vector Invoke(Vector x); } + /// Operator that takes two input values and returns a single value. private interface IBinaryOperator { float Invoke(float x, float y); Vector Invoke(Vector x, Vector y); } + /// that specializes horizontal aggregation of all elements in a vector. + private interface IAggregationOperator : IBinaryOperator + { + float IdentityValue { get; } + } + + /// Operator that takes three input values and returns a single value. private interface ITernaryOperator { float Invoke(float x, float y, float z); diff --git a/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs b/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs index 902b27787e856c..272991aed44ab8 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/ThrowHelper.cs @@ -18,5 +18,9 @@ public static void ThrowArgument_SpansMustHaveSameLength() => [DoesNotReturn] public static void ThrowArgument_SpansMustBeNonEmpty() => throw new ArgumentException(SR.Argument_SpansMustBeNonEmpty); + + [DoesNotReturn] + public static void ThrowArgument_InputAndDestinationSpanMustNotOverlap() => + throw new ArgumentException(SR.Argument_InputAndDestinationSpanMustNotOverlap, "destination"); } } diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs index 181d152e6ae979..09aa13ae35800f 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Runtime.InteropServices; using Xunit; +using Xunit.Sdk; #pragma warning disable xUnit1025 // reporting duplicate test cases due to not distinguishing 0.0 from -0.0 @@ -13,12 +14,25 @@ namespace System.Numerics.Tensors.Tests { public static partial class TensorPrimitivesTests { - private const double Tolerance = 0.0001; + #region Test Utilities + public static IEnumerable TensorLengthsIncluding0 => + TensorLengths.Concat(new object[][] { [0] }); public static IEnumerable TensorLengths => - from length in Enumerable.Range(1, 128) + from length in Enumerable.Range(1, 256) select new object[] { length }; + public static IEnumerable VectorLengthAndIteratedRange(float min, float max, float increment) + { + foreach (int length in new[] { 4, 8, 16 }) + { + for (float f = min; f <= max; f += increment) + { + yield return new object[] { length, f }; + } + } + } + private static readonly Random s_random = new Random(20230828); private static BoundedMemory CreateTensor(int size) => BoundedMemory.Allocate(size); @@ -38,300 +52,300 @@ private static void FillTensor(Span tensor) } } - private static float NextSingle() - { + private static float NextSingle() => // For testing purposes, get a mix of negative and positive values. - return (float)((s_random.NextDouble() * 2) - 1); - } + (float)((s_random.NextDouble() * 2) - 1); - [Theory] - [MemberData(nameof(TensorLengths))] - public static void AddTwoTensors(int tensorLength) + private static void AssertEqualTolerance(double expected, double actual, double tolerance = 0.00001f) { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); - - TensorPrimitives.Add(x, y, destination); - - for (int i = 0; i < tensorLength; i++) + double diff = Math.Abs(expected - actual); + if (diff > tolerance && + diff > Math.Max(Math.Abs(expected), Math.Abs(actual)) * tolerance) { - Assert.Equal(x[i] + y[i], destination[i]); + throw new EqualException(expected, actual); } } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void AddTwoTensors_ThrowsForMismatchedLengths(int tensorLength) + private static unsafe float MathFMaxMagnitude(float x, float y) { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - using BoundedMemory destination = CreateTensor(tensorLength); - - Assert.Throws(() => TensorPrimitives.Add(x, y, destination)); + float ax = MathF.Abs(x), ay = MathF.Abs(y); + return (ax > ay) || float.IsNaN(ax) || (ax == ay && *(int*)&x >= 0) ? x : y; } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void AddTwoTensors_ThrowsForTooShortDestination(int tensorLength) + private static unsafe float MathFMinMagnitude(float x, float y) { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); - - AssertExtensions.Throws("destination", () => TensorPrimitives.Add(x, y, destination)); + float ax = MathF.Abs(x), ay = MathF.Abs(y); + return (ax < ay) || float.IsNaN(ax) || (ax == ay && *(int*)&x < 0) ? x : y; } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void AddTensorAndScalar(int tensorLength) + private static unsafe float UInt32ToSingle(uint i) => *(float*)&i; + + private static unsafe float SingleToUInt32(float f) => *(uint*)&f; + + /// Gets a variety of special values (e.g. NaN). + private static IEnumerable GetSpecialValues() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); - using BoundedMemory destination = CreateTensor(tensorLength); + // NaN + yield return UInt32ToSingle(0xFFC0_0000); // -qNaN / float.NaN + yield return UInt32ToSingle(0xFFFF_FFFF); // -qNaN / all-bits-set + yield return UInt32ToSingle(0x7FC0_0000); // +qNaN + yield return UInt32ToSingle(0xFFA0_0000); // -sNaN + yield return UInt32ToSingle(0x7FA0_0000); // +sNaN - TensorPrimitives.Add(x, y, destination); + // +Infinity, -Infinity + yield return float.PositiveInfinity; + yield return float.NegativeInfinity; - for (int i = 0; i < tensorLength; i++) - { - Assert.Equal(x[i] + y, destination[i]); - } - } + // +Zero, -Zero + yield return +0.0f; + yield return -0.0f; - [Theory] - [MemberData(nameof(TensorLengths))] - public static void AddTensorAndScalar_ThrowsForTooShortDestination(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + // Subnormals + yield return +float.Epsilon; + yield return -float.Epsilon; + yield return UInt32ToSingle(0x007F_FFFF); + yield return UInt32ToSingle(0x807F_FFFF); - AssertExtensions.Throws("destination", () => TensorPrimitives.Add(x, y, destination)); + // Normals + yield return UInt32ToSingle(0x0080_0000); + yield return UInt32ToSingle(0x8080_0000); + yield return UInt32ToSingle(0x7F7F_FFFF); // MaxValue + yield return UInt32ToSingle(0xFF7F_FFFF); // MinValue } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void SubtractTwoTensors(int tensorLength) + /// + /// Runs the specified action for each special value. Before the action is invoked, + /// the value is stored into a random position in , and the original + /// value is subsequently restored. + /// + private static void RunForEachSpecialValue(Action action, BoundedMemory x) { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + foreach (float value in GetSpecialValues()) + { + int pos = s_random.Next(x.Length); + float orig = x[pos]; + x[pos] = value; - TensorPrimitives.Subtract(x, y, destination); + action(); - for (int i = 0; i < tensorLength; i++) - { - Assert.Equal(x[i] - y[i], destination[i]); + x[pos] = orig; } } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void SubtractTwoTensors_ThrowsForMismatchedLengths(int tensorLength) + /// + /// Loads a variety of special values (e.g. NaN) into random positions in + /// and related values into the corresponding positions in . + /// + private static void SetSpecialValues(Span x, Span y) { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - using BoundedMemory destination = CreateTensor(tensorLength); + int pos; - Assert.Throws(() => TensorPrimitives.Subtract(x, y, destination)); - } + // NaNs + pos = s_random.Next(x.Length); + x[pos] = float.NaN; + y[pos] = UInt32ToSingle(0x7FC0_0000); - [Theory] - [MemberData(nameof(TensorLengths))] - public static void SubtractTwoTensors_ThrowsForTooShortDestination(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + // +Infinity, -Infinity + pos = s_random.Next(x.Length); + x[pos] = float.PositiveInfinity; + y[pos] = float.NegativeInfinity; - AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(x, y, destination)); + // +Zero, -Zero + pos = s_random.Next(x.Length); + x[pos] = +0.0f; + y[pos] = -0.0f; + + // +Epsilon, -Epsilon + pos = s_random.Next(x.Length); + x[pos] = +float.Epsilon; + y[pos] = -float.Epsilon; + + // Same magnitude, opposite sign + pos = s_random.Next(x.Length); + x[pos] = +5.0f; + y[pos] = -5.0f; } + #endregion + #region Abs [Theory] - [MemberData(nameof(TensorLengths))] - public static void SubtractTensorAndScalar(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Abs(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.Subtract(x, y, destination); + TensorPrimitives.Abs(x, destination); - for (int i = 0; i < tensorLength; i++) + for (int i = 0; i < x.Length; i++) { - Assert.Equal(x[i] - y, destination[i]); + AssertEqualTolerance(MathF.Abs(x[i]), destination[i]); } } [Theory] - [MemberData(nameof(TensorLengths))] - public static void SubtractTensorAndScalar_ThrowsForTooShortDestination(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); - using BoundedMemory destination = CreateTensor(tensorLength - 1); - - AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(x, y, destination)); - } - - [Theory] - [MemberData(nameof(TensorLengths))] - public static void MultiplyTwoTensors(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Abs_InPlace(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); - TensorPrimitives.Multiply(x, y, destination); + TensorPrimitives.Abs(x, x); - for (int i = 0; i < tensorLength; i++) + for (int i = 0; i < x.Length; i++) { - Assert.Equal(x[i] * y[i], destination[i]); + AssertEqualTolerance(MathF.Abs(xOrig[i]), x[i]); } } [Theory] [MemberData(nameof(TensorLengths))] - public static void MultiplyTwoTensors_ThrowsForMismatchedLengths(int tensorLength) + public static void Abs_ThrowsForTooShortDestination(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - using BoundedMemory destination = CreateTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.Multiply(x, y, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Abs(x, destination)); } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void MultiplyTwoTensors_ThrowsForTooShortDestination(int tensorLength) + [Fact] + public static void Abs_ThrowsForOverlapppingInputsWithOutputs() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); - - AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(x, y, destination)); + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Abs(array.AsSpan(1, 5), array.AsSpan(0, 5))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Abs(array.AsSpan(1, 5), array.AsSpan(2, 5))); } + #endregion + #region Add [Theory] - [MemberData(nameof(TensorLengths))] - public static void MultiplyTensorAndScalar(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Add_TwoTensors(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); + using BoundedMemory y = CreateAndFillTensor(tensorLength); using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.Multiply(x, y, destination); - + TensorPrimitives.Add(x, y, destination); for (int i = 0; i < tensorLength; i++) { - Assert.Equal(x[i] * y, destination[i]); + AssertEqualTolerance(x[i] + y[i], destination[i]); } - } - [Theory] - [MemberData(nameof(TensorLengths))] - public static void MultiplyTensorAndScalar_ThrowsForTooShortDestination(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + float[] xOrig = x.Span.ToArray(); - AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(x, y, destination)); + // Validate that the destination can be the same as an input. + TensorPrimitives.Add(x, x, x); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(xOrig[i] + xOrig[i], x[i]); + } } [Theory] - [MemberData(nameof(TensorLengths))] - public static void DivideTwoTensors(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Add_TwoTensors_InPlace(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); - TensorPrimitives.Divide(x, y, destination); + TensorPrimitives.Add(x, x, x); for (int i = 0; i < tensorLength; i++) { - Assert.Equal(x[i] / y[i], destination[i]); + AssertEqualTolerance(xOrig[i] + xOrig[i], x[i]); } } [Theory] [MemberData(nameof(TensorLengths))] - public static void DivideTwoTensors_ThrowsForMismatchedLengths(int tensorLength) + public static void Add_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); using BoundedMemory destination = CreateTensor(tensorLength); - Assert.Throws(() => TensorPrimitives.Divide(x, y, destination)); + Assert.Throws(() => TensorPrimitives.Add(x, y, destination)); + Assert.Throws(() => TensorPrimitives.Add(y, x, destination)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void DivideTwoTensors_ThrowsForTooShortDestination(int tensorLength) + public static void Add_TwoTensors_ThrowsForTooShortDestination(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); using BoundedMemory y = CreateAndFillTensor(tensorLength); using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(x, y, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(x, y, destination)); + } + + [Fact] + public static void Add_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), array.AsSpan(5, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), array.AsSpan(5, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), array.AsSpan(5, 2), array.AsSpan(4, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), array.AsSpan(5, 2), array.AsSpan(6, 2))); } [Theory] - [MemberData(nameof(TensorLengths))] - public static void DivideTensorAndScalar(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Add_TensorScalar(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); float y = NextSingle(); using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.Divide(x, y, destination); + TensorPrimitives.Add(x, y, destination); for (int i = 0; i < tensorLength; i++) { - Assert.Equal(x[i] / y, destination[i]); + AssertEqualTolerance(x[i] + y, destination[i]); } } [Theory] - [MemberData(nameof(TensorLengths))] - public static void DivideTensorAndScalar_ThrowsForTooShortDestination(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Add_TensorScalar_InPlace(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); float y = NextSingle(); - using BoundedMemory destination = CreateTensor(tensorLength - 1); - - AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(x, y, destination)); - } - - [Theory] - [MemberData(nameof(TensorLengths))] - public static void NegateTensor(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.Negate(x, destination); + TensorPrimitives.Add(x, y, x); for (int i = 0; i < tensorLength; i++) { - Assert.Equal(-x[i], destination[i]); + AssertEqualTolerance(xOrig[i] + y, x[i]); } } [Theory] [MemberData(nameof(TensorLengths))] - public static void NegateTensor_ThrowsForTooShortDestination(int tensorLength) + public static void Add_TensorScalar_ThrowsForTooShortDestination(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.Negate(x, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(x, y, destination)); + } + + [Fact] + public static void Add_TensorScalar_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), 42, array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Add(array.AsSpan(1, 2), 42, array.AsSpan(2, 2))); } + #endregion + #region AddMultiply [Theory] - [MemberData(nameof(TensorLengths))] - public static void AddTwoTensorsAndMultiplyWithThirdTensor(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void AddMultiply_ThreeTensors(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); using BoundedMemory y = CreateAndFillTensor(tensorLength); @@ -342,37 +356,42 @@ public static void AddTwoTensorsAndMultiplyWithThirdTensor(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal((x[i] + y[i]) * multiplier[i], destination[i]); + AssertEqualTolerance((x[i] + y[i]) * multiplier[i], destination[i]); } } [Theory] - [MemberData(nameof(TensorLengths))] - public static void AddTwoTensorsAndMultiplyWithThirdTensor_ThrowsForMismatchedLengths_x_y(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void AddMultiply_ThreeTensors_InPlace(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - using BoundedMemory multiplier = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); - Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); + TensorPrimitives.AddMultiply(x, x, x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((xOrig[i] + xOrig[i]) * xOrig[i], x[i]); + } } [Theory] [MemberData(nameof(TensorLengths))] - public static void AddTwoTensorsAndMultiplyWithThirdTensor_ThrowsForMismatchedLengths_x_multiplier(int tensorLength) + public static void AddMultiply_ThreeTensors_ThrowsForMismatchedLengths(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory multiplier = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory z = CreateAndFillTensor(tensorLength - 1); using BoundedMemory destination = CreateTensor(tensorLength); - Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); + Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, z, destination)); + Assert.Throws(() => TensorPrimitives.AddMultiply(x, z, y, destination)); + Assert.Throws(() => TensorPrimitives.AddMultiply(z, x, y, destination)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void AddTwoTensorsAndMultiplyWithThirdTensor_ThrowsForTooShortDestination(int tensorLength) + public static void AddMultiply_ThreeTensors_ThrowsForTooShortDestination(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); using BoundedMemory y = CreateAndFillTensor(tensorLength); @@ -382,9 +401,21 @@ public static void AddTwoTensorsAndMultiplyWithThirdTensor_ThrowsForTooShortDest AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); } + [Fact] + public static void AddMultiply_ThreeTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(5, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(6, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(8, 2))); + } + [Theory] - [MemberData(nameof(TensorLengths))] - public static void AddTwoTensorsAndMultiplyWithScalar(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void AddMultiply_TensorTensorScalar(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); using BoundedMemory y = CreateAndFillTensor(tensorLength); @@ -395,13 +426,29 @@ public static void AddTwoTensorsAndMultiplyWithScalar(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal((x[i] + y[i]) * multiplier, destination[i]); + AssertEqualTolerance((x[i] + y[i]) * multiplier, destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void AddMultiply_TensorTensorScalar_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float multiplier = NextSingle(); + + TensorPrimitives.AddMultiply(x, x, multiplier, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((xOrig[i] + xOrig[i]) * multiplier, x[i]); } } [Theory] [MemberData(nameof(TensorLengths))] - public static void AddTwoTensorsAndMultiplyWithScalar_ThrowsForMismatchedLengths_x_y(int tensorLength) + public static void AddMultiply_TensorTensorScalar_ThrowsForMismatchedLengths_x_y(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); @@ -409,11 +456,12 @@ public static void AddTwoTensorsAndMultiplyWithScalar_ThrowsForMismatchedLengths using BoundedMemory destination = CreateTensor(tensorLength); Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); + Assert.Throws(() => TensorPrimitives.AddMultiply(y, x, multiplier, destination)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void AddTwoTensorsAndMultiplyWithScalar_ThrowsForTooShortDestination(int tensorLength) + public static void AddMultiply_TensorTensorScalar_ThrowsForTooShortDestination(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); using BoundedMemory y = CreateAndFillTensor(tensorLength); @@ -423,9 +471,19 @@ public static void AddTwoTensorsAndMultiplyWithScalar_ThrowsForTooShortDestinati AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); } + [Fact] + public static void AddMultiply_TensorTensorScalar_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(5, 2))); + } + [Theory] - [MemberData(nameof(TensorLengths))] - public static void AddTensorAndScalarAndMultiplyWithTensor(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void AddMultiply_TensorScalarTensor(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); float y = NextSingle(); @@ -436,25 +494,42 @@ public static void AddTensorAndScalarAndMultiplyWithTensor(int tensorLength) for (int i = 0; i < tensorLength; i++) { - Assert.Equal((x[i] + y) * multiplier[i], destination[i]); + AssertEqualTolerance((x[i] + y) * multiplier[i], destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void AddMultiply_TensorScalarTensor_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float y = NextSingle(); + + TensorPrimitives.AddMultiply(x, y, x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((xOrig[i] + y) * xOrig[i], x[i]); } } [Theory] [MemberData(nameof(TensorLengths))] - public static void AddTensorAndScalarAndMultiplyWithTensor_ThrowsForMismatchedLengths_x_z(int tensorLength) + public static void AddMultiply_TensorScalarTensor_ThrowsForMismatchedLengths_x_z(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); float y = NextSingle(); - using BoundedMemory multiplier = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory z = CreateAndFillTensor(tensorLength - 1); using BoundedMemory destination = CreateTensor(tensorLength); - Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); + Assert.Throws(() => TensorPrimitives.AddMultiply(x, y, z, destination)); + Assert.Throws(() => TensorPrimitives.AddMultiply(z, y, x, destination)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void AddTensorAndScalarAndMultiplyWithTensor_ThrowsForTooShortDestination(int tensorLength) + public static void AddMultiply_TensorScalarTensor_ThrowsForTooShortDestination(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); float y = NextSingle(); @@ -464,330 +539,315 @@ public static void AddTensorAndScalarAndMultiplyWithTensor_ThrowsForTooShortDest AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(x, y, multiplier, destination)); } + [Fact] + public static void AddMultiply_TensorScalarTensor_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.AddMultiply(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + #endregion + + #region Cosh [Theory] - [MemberData(nameof(TensorLengths))] - public static void MultiplyTwoTensorsAndAddWithThirdTensor(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Cosh(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory addend = CreateAndFillTensor(tensorLength); using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.MultiplyAdd(x, y, addend, destination); + TensorPrimitives.Cosh(x, destination); for (int i = 0; i < tensorLength; i++) { - Assert.Equal((x[i] * y[i]) + addend[i], destination[i]); + AssertEqualTolerance(MathF.Cosh(x[i]), destination[i]); } } [Theory] - [MemberData(nameof(TensorLengths))] - public static void MultiplyTwoTensorsAndAddWithThirdTensor_ThrowsForMismatchedLengths_x_y(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Cosh_InPlace(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - using BoundedMemory addend = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); - - Assert.Throws(() => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); - } + float[] xOrig = x.Span.ToArray(); - [Theory] - [MemberData(nameof(TensorLengths))] - public static void MultiplyTwoTensorsAndAddWithThirdTensor_ThrowsForMismatchedLengths_x_multiplier(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory addend = CreateAndFillTensor(tensorLength - 1); - using BoundedMemory destination = CreateTensor(tensorLength); + TensorPrimitives.Cosh(x, x); - Assert.Throws(() => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Cosh(xOrig[i]), x[i]); + } } [Theory] [MemberData(nameof(TensorLengths))] - public static void MultiplyTwoTensorsAndAddWithThirdTensor_ThrowsForTooShortDestination(int tensorLength) + public static void Cosh_SpecialValues(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - using BoundedMemory addend = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); - AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); + RunForEachSpecialValue(() => + { + TensorPrimitives.Cosh(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Cosh(x[i]), destination[i]); + } + }, x); } [Theory] - [MemberData(nameof(TensorLengths))] - public static void MultiplyTwoTensorsAndAddWithScalar(int tensorLength) + [MemberData(nameof(VectorLengthAndIteratedRange), new object[] { -100f, 100f, 3f })] + public static void Cosh_ValueRange(int vectorLength, float element) { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - float addend = NextSingle(); - using BoundedMemory destination = CreateTensor(tensorLength); + float[] x = new float[vectorLength]; + float[] dest = new float[vectorLength]; - TensorPrimitives.MultiplyAdd(x, y, addend, destination); + x.AsSpan().Fill(element); + TensorPrimitives.Cosh(x, dest); - for (int i = 0; i < tensorLength; i++) + float expected = MathF.Cosh(element); + foreach (float actual in dest) { - Assert.Equal((x[i] * y[i]) + addend, destination[i]); + AssertEqualTolerance(expected, actual); } } [Theory] [MemberData(nameof(TensorLengths))] - public static void MultiplyTwoTensorsAndAddWithScalar_ThrowsForTooShortDestination(int tensorLength) + public static void Cosh_ThrowsForTooShortDestination(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); - float addend = NextSingle(); using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Cosh(x, destination)); + } + + [Fact] + public static void Cosh_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Cosh(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Cosh(array.AsSpan(1, 2), array.AsSpan(2, 2))); } + #endregion + #region CosineSimilarity [Theory] [MemberData(nameof(TensorLengths))] - public static void MultiplyTensorAndScalarAndAddWithTensor(int tensorLength) + public static void CosineSimilarity_ThrowsForMismatchedLengths(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); - using BoundedMemory addend = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - TensorPrimitives.MultiplyAdd(x, y, addend, destination); + Assert.Throws(() => TensorPrimitives.CosineSimilarity(x, y)); + Assert.Throws(() => TensorPrimitives.CosineSimilarity(y, x)); + } - for (int i = 0; i < tensorLength; i++) - { - Assert.Equal((x[i] * y) + addend[i], destination[i]); - } + [Fact] + public static void CosineSimilarity_ThrowsForEmpty() + { + Assert.Throws(() => TensorPrimitives.CosineSimilarity(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); + Assert.Throws(() => TensorPrimitives.CosineSimilarity(ReadOnlySpan.Empty, CreateTensor(1))); + Assert.Throws(() => TensorPrimitives.CosineSimilarity(CreateTensor(1), ReadOnlySpan.Empty)); } [Theory] - [MemberData(nameof(TensorLengths))] - public static void MultiplyTensorAndScalarAndAddWithTensor_ThrowsForTooShortDestination(int tensorLength) + [InlineData(new float[] { 3, 2, 0, 5 }, new float[] { 1, 0, 0, 0 }, 0.48666f)] + [InlineData(new float[] { 1, 1, 1, 1, 1, 0 }, new float[] { 1, 1, 1, 1, 0, 1 }, 0.80f)] + public static void CosineSimilarity_KnownValues(float[] x, float[] y, float expectedResult) { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - float y = NextSingle(); - using BoundedMemory addend = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); - - AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); + AssertEqualTolerance(expectedResult, TensorPrimitives.CosineSimilarity(x, y)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void ExpTensor(int tensorLength) + public static void CosineSimilarity(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); - - TensorPrimitives.Exp(x, destination); + using BoundedMemory y = CreateAndFillTensor(tensorLength); - for (int i = 0; i < tensorLength; i++) + float dot = 0f, squareX = 0f, squareY = 0f; + for (int i = 0; i < x.Length; i++) { - Assert.Equal(MathF.Exp(x[i]), destination[i]); + dot += x[i] * y[i]; + squareX += x[i] * x[i]; + squareY += y[i] * y[i]; } + + AssertEqualTolerance(dot / (MathF.Sqrt(squareX) * MathF.Sqrt(squareY)), TensorPrimitives.CosineSimilarity(x, y)); } + #endregion - [Theory] - [MemberData(nameof(TensorLengths))] - public static void ExpTensor_ThrowsForTooShortDestination(int tensorLength) + #region Distance + [Fact] + public static void Distance_ThrowsForEmpty() { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); - - AssertExtensions.Throws("destination", () => TensorPrimitives.Exp(x, destination)); + Assert.Throws(() => TensorPrimitives.Distance(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); + Assert.Throws(() => TensorPrimitives.Distance(ReadOnlySpan.Empty, CreateTensor(1))); + Assert.Throws(() => TensorPrimitives.Distance(CreateTensor(1), ReadOnlySpan.Empty)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void LogTensor(int tensorLength) + public static void Distance_ThrowsForMismatchedLengths(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); - - TensorPrimitives.Log(x, destination); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); - for (int i = 0; i < tensorLength; i++) - { - Assert.Equal(MathF.Log(x[i]), destination[i]); - } + Assert.Throws(() => TensorPrimitives.Distance(x, y)); + Assert.Throws(() => TensorPrimitives.Distance(y, x)); } [Theory] - [MemberData(nameof(TensorLengths))] - public static void LogTensor_ThrowsForTooShortDestination(int tensorLength) + [InlineData(new float[] { 3, 2 }, new float[] { 4, 1 }, 1.4142f)] + [InlineData(new float[] { 0, 4 }, new float[] { 6, 2 }, 6.3245f)] + [InlineData(new float[] { 1, 2, 3 }, new float[] { 4, 5, 6 }, 5.19615f)] + [InlineData(new float[] { 5, 1, 6, 10 }, new float[] { 7, 2, 8, 4 }, 6.7082f)] + public static void Distance_KnownValues(float[] x, float[] y, float expectedResult) { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); - - AssertExtensions.Throws("destination", () => TensorPrimitives.Log(x, destination)); + AssertEqualTolerance(expectedResult, TensorPrimitives.Distance(x, y)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void CoshTensor(int tensorLength) + public static void Distance(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); - - TensorPrimitives.Cosh(x, destination); + using BoundedMemory y = CreateAndFillTensor(tensorLength); - for (int i = 0; i < tensorLength; i++) + float distance = 0f; + for (int i = 0; i < x.Length; i++) { - Assert.Equal(MathF.Cosh(x[i]), destination[i]); + distance += (x[i] - y[i]) * (x[i] - y[i]); } - } - - [Theory] - [MemberData(nameof(TensorLengths))] - public static void CoshTensor_ThrowsForTooShortDestination(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.Cosh(x, destination)); + AssertEqualTolerance(MathF.Sqrt(distance), TensorPrimitives.Distance(x, y)); } + #endregion + #region Divide [Theory] - [MemberData(nameof(TensorLengths))] - public static void SinhTensor(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Divide_TwoTensors(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); using BoundedMemory destination = CreateTensor(tensorLength); - TensorPrimitives.Sinh(x, destination); + TensorPrimitives.Divide(x, y, destination); for (int i = 0; i < tensorLength; i++) { - Assert.Equal(MathF.Sinh(x[i]), destination[i]); + AssertEqualTolerance(x[i] / y[i], destination[i]); } } [Theory] - [MemberData(nameof(TensorLengths))] - public static void SinhTensor_ThrowsForTooShortDestination(int tensorLength) - { - using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); - - AssertExtensions.Throws("destination", () => TensorPrimitives.Sinh(x, destination)); - } - - [Theory] - [MemberData(nameof(TensorLengths))] - public static void TanhTensor(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Divide_TwoTensors_InPlace(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); - TensorPrimitives.Tanh(x, destination); + TensorPrimitives.Divide(x, x, x); for (int i = 0; i < tensorLength; i++) { - Assert.Equal(MathF.Tanh(x[i]), destination[i]); + AssertEqualTolerance(xOrig[i] / xOrig[i], x[i]); } } [Theory] [MemberData(nameof(TensorLengths))] - public static void TanhTensor_ThrowsForTooShortDestination(int tensorLength) + public static void Divide_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); - AssertExtensions.Throws("destination", () => TensorPrimitives.Tanh(x, destination)); + Assert.Throws(() => TensorPrimitives.Divide(x, y, destination)); + Assert.Throws(() => TensorPrimitives.Divide(y, x, destination)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void CosineSimilarity_ThrowsForMismatchedLengths_x_y(int tensorLength) + public static void Divide_TwoTensors_ThrowsForTooShortDestination(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - Assert.Throws(() => TensorPrimitives.CosineSimilarity(x, y)); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(x, y, destination)); } [Fact] - public static void CosineSimilarity_ThrowsForEmpty_x_y() - { - Assert.Throws(() => TensorPrimitives.CosineSimilarity(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); - Assert.Throws(() => TensorPrimitives.CosineSimilarity(ReadOnlySpan.Empty, CreateTensor(1))); - Assert.Throws(() => TensorPrimitives.CosineSimilarity(CreateTensor(1), ReadOnlySpan.Empty)); - } - - [Theory] - [InlineData(new float[] { 3, 2, 0, 5 }, new float[] { 1, 0, 0, 0 }, 0.48666f)] - [InlineData(new float[] { 1, 1, 1, 1, 1, 0 }, new float[] { 1, 1, 1, 1, 0, 1 }, 0.80f)] - public static void CosineSimilarity_KnownValues(float[] x, float[] y, float expectedResult) + public static void Divide_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() { - Assert.Equal(expectedResult, TensorPrimitives.CosineSimilarity(x, y), Tolerance); + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); } [Theory] - [MemberData(nameof(TensorLengths))] - public static void CosineSimilarity(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Divide_TensorScalar(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength); - float dot = 0f, squareX = 0f, squareY = 0f; - for (int i = 0; i < x.Length; i++) + TensorPrimitives.Divide(x, y, destination); + + for (int i = 0; i < tensorLength; i++) { - dot += x[i] * y[i]; - squareX += x[i] * x[i]; - squareY += y[i] * y[i]; + AssertEqualTolerance(x[i] / y, destination[i]); } - - Assert.Equal(dot / (Math.Sqrt(squareX) * Math.Sqrt(squareY)), TensorPrimitives.CosineSimilarity(x, y), Tolerance); - } - - [Fact] - public static void Distance_ThrowsForEmpty_x_y() - { - Assert.Throws(() => TensorPrimitives.Distance(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); - Assert.Throws(() => TensorPrimitives.Distance(ReadOnlySpan.Empty, CreateTensor(1))); - Assert.Throws(() => TensorPrimitives.Distance(CreateTensor(1), ReadOnlySpan.Empty)); } [Theory] - [MemberData(nameof(TensorLengths))] - public static void Distance_ThrowsForMismatchedLengths_x_y(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Divide_TensorScalar_InPlace(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + float[] xOrig = x.Span.ToArray(); + float y = NextSingle(); - Assert.Throws(() => TensorPrimitives.Distance(x, y)); - } + TensorPrimitives.Divide(x, y, x); - [Theory] - [InlineData(new float[] { 3, 2 }, new float[] { 4, 1 }, 1.4142f)] - [InlineData(new float[] { 0, 4 }, new float[] { 6, 2 }, 6.3245f)] - [InlineData(new float[] { 1, 2, 3 }, new float[] { 4, 5, 6 }, 5.1961f)] - [InlineData(new float[] { 5, 1, 6, 10 }, new float[] { 7, 2, 8, 4 }, 6.7082f)] - public static void Distance_KnownValues(float[] x, float[] y, float expectedResult) - { - Assert.Equal(expectedResult, TensorPrimitives.Distance(x, y), Tolerance); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(xOrig[i] / y, x[i]); + } } [Theory] [MemberData(nameof(TensorLengths))] - public static void Distance(int tensorLength) + public static void Divide_TensorScalar_ThrowsForTooShortDestination(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory y = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - float distance = 0f; - for (int i = 0; i < x.Length; i++) - { - distance += (x[i] - y[i]) * (x[i] - y[i]); - } + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(x, y, destination)); + } - Assert.Equal(Math.Sqrt(distance), TensorPrimitives.Distance(x, y), Tolerance); + [Fact] + public static void Divide_TensorScalar_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Divide(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); } + #endregion + #region Dot [Theory] [MemberData(nameof(TensorLengths))] public static void Dot_ThrowsForMismatchedLengths_x_y(int tensorLength) @@ -796,6 +856,7 @@ public static void Dot_ThrowsForMismatchedLengths_x_y(int tensorLength) using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); Assert.Throws(() => TensorPrimitives.Dot(x, y)); + Assert.Throws(() => TensorPrimitives.Dot(y, x)); } [Theory] @@ -805,11 +866,11 @@ public static void Dot_ThrowsForMismatchedLengths_x_y(int tensorLength) [InlineData(new float[] { }, new float[] { }, 0)] public static void Dot_KnownValues(float[] x, float[] y, float expectedResult) { - Assert.Equal(expectedResult, TensorPrimitives.Dot(x, y), Tolerance); + AssertEqualTolerance(expectedResult, TensorPrimitives.Dot(x, y)); } [Theory] - [MemberData(nameof(TensorLengths))] + [MemberData(nameof(TensorLengthsIncluding0))] public static void Dot(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); @@ -821,129 +882,78 @@ public static void Dot(int tensorLength) dot += x[i] * y[i]; } - Assert.Equal(dot, TensorPrimitives.Dot(x, y), Tolerance); - } - - [Theory] - [InlineData(new float[] { 1, 2, 3 }, 3.7416575f)] - [InlineData(new float[] { 3, 4 }, 5)] - [InlineData(new float[] { 3 }, 3)] - [InlineData(new float[] { 3, 4, 1, 2 }, 5.477226)] - [InlineData(new float[] { }, 0f)] - public static void L2Normalize_KnownValues(float[] x, float expectedResult) - { - Assert.Equal(expectedResult, TensorPrimitives.L2Normalize(x), Tolerance); + AssertEqualTolerance(dot, TensorPrimitives.Dot(x, y)); } + #endregion + #region Exp [Theory] - [MemberData(nameof(TensorLengths))] - public static void L2Normalize(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Exp(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - float sumOfSquares = 0f; - for (int i = 0; i < x.Length; i++) + TensorPrimitives.Exp(x, destination); + + for (int i = 0; i < tensorLength; i++) { - sumOfSquares += x[i] * x[i]; + AssertEqualTolerance(MathF.Exp(x[i]), destination[i]); } - - Assert.Equal(Math.Sqrt(sumOfSquares), TensorPrimitives.L2Normalize(x), Tolerance); } [Theory] - [MemberData(nameof(TensorLengths))] - public static void SoftMax_ThrowsForTooShortDestination(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Exp_InPlace(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - using BoundedMemory destination = CreateTensor(tensorLength - 1); - - AssertExtensions.Throws("destination", () => TensorPrimitives.SoftMax(x, destination)); - } + float[] xOrig = x.Span.ToArray(); - [Theory] - [InlineData(new float[] { 3, 1, .2f }, new float[] { 0.8360188f, 0.11314284f, 0.05083836f })] - [InlineData(new float[] { 3, 4, 1 }, new float[] { 0.2594f, 0.705384f, 0.0351f })] - [InlineData(new float[] { 5, 3 }, new float[] { 0.8807f, 0.1192f })] - [InlineData(new float[] { 4, 2, 1, 9 }, new float[] { 0.0066f, 9.04658e-4f, 3.32805e-4f, 0.9920f})] - public static void SoftMax(float[] x, float[] expectedResult) - { - using BoundedMemory dest = CreateTensor(x.Length); - TensorPrimitives.SoftMax(x, dest); + TensorPrimitives.Exp(x, x); - for (int i = 0; i < x.Length; i++) + for (int i = 0; i < tensorLength; i++) { - Assert.Equal(expectedResult[i], dest[i], Tolerance); + AssertEqualTolerance(MathF.Exp(xOrig[i]), x[i]); } } - [Fact] - public static void SoftMax_DestinationLongerThanSource() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Exp_SpecialValues(int tensorLength) { - float[] x = [3, 1, .2f]; - float[] expectedResult = [0.8360188f, 0.11314284f, 0.05083836f]; - using BoundedMemory dest = CreateTensor(x.Length + 1); - TensorPrimitives.SoftMax(x, dest); + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - for (int i = 0; i < x.Length; i++) + RunForEachSpecialValue(() => { - Assert.Equal(expectedResult[i], dest[i], Tolerance); - } - } - - [Fact] - public static void SoftMax_ThrowsForEmptyInput() - { - AssertExtensions.Throws(() => TensorPrimitives.SoftMax(ReadOnlySpan.Empty, CreateTensor(1))); + TensorPrimitives.Exp(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Exp(x[i]), destination[i]); + } + }, x); } [Theory] [MemberData(nameof(TensorLengths))] - public static void Sigmoid_ThrowsForTooShortDestination(int tensorLength) + public static void Exp_ThrowsForTooShortDestination(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); using BoundedMemory destination = CreateTensor(tensorLength - 1); - AssertExtensions.Throws("destination", () => TensorPrimitives.Sigmoid(x, destination)); - } - - [Theory] - [InlineData(new float[] { -5, -4.5f, -4 }, new float[] { 0.0066f, 0.0109f, 0.0179f })] - [InlineData(new float[] { 4.5f, 5 }, new float[] { 0.9890f, 0.9933f })] - [InlineData(new float[] { 0, -3, 3, .5f }, new float[] { 0.5f, 0.0474f, 0.9525f, 0.6224f })] - public static void Sigmoid(float[] x, float[] expectedResult) - { - using BoundedMemory dest = CreateTensor(x.Length); - TensorPrimitives.Sigmoid(x, dest); - - for (int i = 0; i < x.Length; i++) - { - Assert.Equal(expectedResult[i], dest[i], Tolerance); - } - } - - [Fact] - public static void Sigmoid_DestinationLongerThanSource() - { - float[] x = [-5, -4.5f, -4]; - float[] expectedResult = [0.0066f, 0.0109f, 0.0179f]; - using BoundedMemory dest = CreateTensor(x.Length + 1); - - TensorPrimitives.Sigmoid(x, dest); - - float originalLast = dest[dest.Length - 1]; - for (int i = 0; i < x.Length; i++) - { - Assert.Equal(expectedResult[i], dest[i], Tolerance); - } - Assert.Equal(originalLast, dest[dest.Length - 1]); + AssertExtensions.Throws("destination", () => TensorPrimitives.Exp(x, destination)); } [Fact] - public static void Sigmoid_ThrowsForEmptyInput() + public static void Exp_ThrowsForOverlapppingInputsWithOutputs() { - AssertExtensions.Throws(() => TensorPrimitives.Sigmoid(ReadOnlySpan.Empty, CreateTensor(1))); + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Exp(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Exp(array.AsSpan(1, 2), array.AsSpan(2, 2))); } + #endregion + #region IndexOfMax [Fact] public static void IndexOfMax_ReturnsNegative1OnEmpty() { @@ -985,90 +995,96 @@ public static void IndexOfMax_Negative0LesserThanPositive0() Assert.Equal(1, TensorPrimitives.IndexOfMax([-1, -0f])); Assert.Equal(2, TensorPrimitives.IndexOfMax([-1, -0f, 1])); } + #endregion + #region IndexOfMaxMagnitude [Fact] - public static void IndexOfMin_ReturnsNegative1OnEmpty() + public static void IndexOfMaxMagnitude_ReturnsNegative1OnEmpty() { - Assert.Equal(-1, TensorPrimitives.IndexOfMin(ReadOnlySpan.Empty)); + Assert.Equal(-1, TensorPrimitives.IndexOfMaxMagnitude(ReadOnlySpan.Empty)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void IndexOfMin(int tensorLength) + public static void IndexOfMaxMagnitude(int tensorLength) { foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - x[expected] = Enumerable.Min(MemoryMarshal.ToEnumerable(x.Memory)) - 1; - Assert.Equal(expected, TensorPrimitives.IndexOfMin(x)); + x[expected] = Enumerable.Max(MemoryMarshal.ToEnumerable(x.Memory), Math.Abs) + 1; + Assert.Equal(expected, TensorPrimitives.IndexOfMaxMagnitude(x)); } } [Theory] [MemberData(nameof(TensorLengths))] - public static void IndexOfMin_FirstNaNReturned(int tensorLength) + public static void IndexOfMaxMagnitude_FirstNaNReturned(int tensorLength) { foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) { using BoundedMemory x = CreateAndFillTensor(tensorLength); x[expected] = float.NaN; x[tensorLength - 1] = float.NaN; - Assert.Equal(expected, TensorPrimitives.IndexOfMin(x)); + Assert.Equal(expected, TensorPrimitives.IndexOfMaxMagnitude(x)); } } [Fact] - public static void IndexOfMin_Negative0LesserThanPositive0() + public static void IndexOfMaxMagnitude_Negative0LesserThanPositive0() { - Assert.Equal(0, TensorPrimitives.IndexOfMin([-0f, +0f])); - Assert.Equal(1, TensorPrimitives.IndexOfMin([+0f, -0f])); - Assert.Equal(1, TensorPrimitives.IndexOfMin([+0f, -0f, -0f, -0f, -0f])); - Assert.Equal(0, TensorPrimitives.IndexOfMin([-1, -0f])); - Assert.Equal(0, TensorPrimitives.IndexOfMin([-1, -0f, 1])); + Assert.Equal(0, TensorPrimitives.IndexOfMaxMagnitude([-0f, -0f, -0f, -0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMaxMagnitude([-0f, +0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMaxMagnitude([-0f, +0f, +0f, +0f])); + Assert.Equal(0, TensorPrimitives.IndexOfMaxMagnitude([+0f, -0f])); + Assert.Equal(0, TensorPrimitives.IndexOfMaxMagnitude([-1, -0f])); + Assert.Equal(2, TensorPrimitives.IndexOfMaxMagnitude([-1, -0f, 1])); } + #endregion + #region IndexOfMin [Fact] - public static void IndexOfMaxMagnitude_ReturnsNegative1OnEmpty() + public static void IndexOfMin_ReturnsNegative1OnEmpty() { - Assert.Equal(-1, TensorPrimitives.IndexOfMaxMagnitude(ReadOnlySpan.Empty)); + Assert.Equal(-1, TensorPrimitives.IndexOfMin(ReadOnlySpan.Empty)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void IndexOfMaxMagnitude(int tensorLength) + public static void IndexOfMin(int tensorLength) { foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - x[expected] = Enumerable.Max(MemoryMarshal.ToEnumerable(x.Memory), Math.Abs) + 1; - Assert.Equal(expected, TensorPrimitives.IndexOfMaxMagnitude(x)); + x[expected] = Enumerable.Min(MemoryMarshal.ToEnumerable(x.Memory)) - 1; + Assert.Equal(expected, TensorPrimitives.IndexOfMin(x)); } } [Theory] [MemberData(nameof(TensorLengths))] - public static void IndexOfMaxMagnitude_FirstNaNReturned(int tensorLength) + public static void IndexOfMin_FirstNaNReturned(int tensorLength) { foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) { using BoundedMemory x = CreateAndFillTensor(tensorLength); x[expected] = float.NaN; x[tensorLength - 1] = float.NaN; - Assert.Equal(expected, TensorPrimitives.IndexOfMaxMagnitude(x)); + Assert.Equal(expected, TensorPrimitives.IndexOfMin(x)); } } [Fact] - public static void IndexOfMaxMagnitude_Negative0LesserThanPositive0() + public static void IndexOfMin_Negative0LesserThanPositive0() { - Assert.Equal(0, TensorPrimitives.IndexOfMaxMagnitude([-0f, -0f, -0f, -0f])); - Assert.Equal(1, TensorPrimitives.IndexOfMaxMagnitude([-0f, +0f])); - Assert.Equal(1, TensorPrimitives.IndexOfMaxMagnitude([-0f, +0f, +0f, +0f])); - Assert.Equal(0, TensorPrimitives.IndexOfMaxMagnitude([+0f, -0f])); - Assert.Equal(0, TensorPrimitives.IndexOfMaxMagnitude([-1, -0f])); - Assert.Equal(2, TensorPrimitives.IndexOfMaxMagnitude([-1, -0f, 1])); + Assert.Equal(0, TensorPrimitives.IndexOfMin([-0f, +0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMin([+0f, -0f])); + Assert.Equal(1, TensorPrimitives.IndexOfMin([+0f, -0f, -0f, -0f, -0f])); + Assert.Equal(0, TensorPrimitives.IndexOfMin([-1, -0f])); + Assert.Equal(0, TensorPrimitives.IndexOfMin([-1, -0f, 1])); } + #endregion + #region IndexOfMinMagnitude [Fact] public static void IndexOfMinMagnitude_ReturnsNegative1OnEmpty() { @@ -1116,369 +1132,1861 @@ public static void IndexOfMinMagnitude_Negative0LesserThanPositive0() Assert.Equal(1, TensorPrimitives.IndexOfMinMagnitude([-1, -0f])); Assert.Equal(1, TensorPrimitives.IndexOfMinMagnitude([-1, -0f, 1])); } + #endregion - [Fact] - public static void Max_ThrowsForEmpty() + #region Log + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Log(int tensorLength) { - Assert.Throws(() => TensorPrimitives.Max(ReadOnlySpan.Empty)); + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Log(x, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Log(x[i]), destination[i]); + } } [Theory] - [MemberData(nameof(TensorLengths))] - public static void Max(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Log_InPlace(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); - Assert.Equal(Enumerable.Max(MemoryMarshal.ToEnumerable(x.Memory)), TensorPrimitives.Max(x)); + TensorPrimitives.Log(x, x); - float max = float.NegativeInfinity; - foreach (float f in x.Span) + for (int i = 0; i < tensorLength; i++) { - max = Math.Max(max, f); + AssertEqualTolerance(MathF.Log(xOrig[i]), x[i]); } - Assert.Equal(max, TensorPrimitives.Max(x)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void Max_NanReturned(int tensorLength) + public static void Log_SpecialValues(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + using BoundedMemory destination = CreateTensor(tensorLength); + + RunForEachSpecialValue(() => { - x[expected] = float.NaN; - Assert.Equal(float.NaN, TensorPrimitives.Max(x)); - } + TensorPrimitives.Log(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Log(x[i]), destination[i]); + } + }, x); } - [Fact] - public static void Max_Negative0LesserThanPositive0() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Log_ThrowsForTooShortDestination(int tensorLength) { - Assert.Equal(+0f, TensorPrimitives.Max([-0f, +0f])); - Assert.Equal(+0f, TensorPrimitives.Max([+0f, -0f])); - Assert.Equal(-0f, TensorPrimitives.Max([-1, -0f])); - Assert.Equal(1, TensorPrimitives.Max([-1, -0f, 1])); + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Log(x, destination)); } [Fact] - public static void MaxMagnitude_ThrowsForEmpty() + public static void Log_ThrowsForOverlapppingInputsWithOutputs() { - Assert.Throws(() => TensorPrimitives.MaxMagnitude(ReadOnlySpan.Empty)); + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Log(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Log(array.AsSpan(1, 2), array.AsSpan(2, 2))); } + #endregion + #region Log2 [Theory] - [MemberData(nameof(TensorLengths))] - public static void MaxMagnitude(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Log2(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - Assert.Equal(Enumerable.Max(MemoryMarshal.ToEnumerable(x.Memory), MathF.Abs), TensorPrimitives.MaxMagnitude(x)); + TensorPrimitives.Log2(x, destination); - float max = 0; - foreach (float f in x.Span) + for (int i = 0; i < tensorLength; i++) { - max = Math.Max(max, MathF.Abs(f)); + AssertEqualTolerance(MathF.Log(x[i], 2), destination[i]); } - Assert.Equal(max, TensorPrimitives.MaxMagnitude(x)); } [Theory] - [MemberData(nameof(TensorLengths))] - public static void MaxMagnitude_NanReturned(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Log2_InPlace(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) - { - x[expected] = float.NaN; - Assert.Equal(float.NaN, TensorPrimitives.MaxMagnitude(x)); - } - } + float[] xOrig = x.Span.ToArray(); - [Fact] - public static void MaxMagnitude_Negative0LesserThanPositive0() - { - Assert.Equal(+0f, TensorPrimitives.MaxMagnitude([-0f, +0f])); - Assert.Equal(+0f, TensorPrimitives.MaxMagnitude([+0f, -0f])); - Assert.Equal(1, TensorPrimitives.MaxMagnitude([-1, -0f])); - Assert.Equal(1, TensorPrimitives.MaxMagnitude([-1, -0f, 1])); - Assert.Equal(0f, TensorPrimitives.MaxMagnitude([-0f, -0f, -0f, -0f, -0f, 0f])); - Assert.Equal(1, TensorPrimitives.MaxMagnitude([-0f, -0f, -0f, -0f, -1, -0f, 0f, 1])); - } + TensorPrimitives.Log2(x, x); - [Fact] - public static void Min_ThrowsForEmpty() - { - Assert.Throws(() => TensorPrimitives.Min(ReadOnlySpan.Empty)); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Log(xOrig[i], 2), x[i]); + } } [Theory] [MemberData(nameof(TensorLengths))] - public static void Min(int tensorLength) + public static void Log2_SpecialValues(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - Assert.Equal(Enumerable.Min(MemoryMarshal.ToEnumerable(x.Memory)), TensorPrimitives.Min(x)); - - float min = float.PositiveInfinity; - foreach (float f in x.Span) + RunForEachSpecialValue(() => { - min = Math.Min(min, f); - } - Assert.Equal(min, TensorPrimitives.Min(x)); + TensorPrimitives.Log2(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Log(x[i], 2), destination[i]); + } + }, x); } [Theory] [MemberData(nameof(TensorLengths))] - public static void Min_NanReturned(int tensorLength) + public static void Log2_ThrowsForTooShortDestination(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) - { - x[expected] = float.NaN; - Assert.Equal(float.NaN, TensorPrimitives.Min(x)); - } + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Log2(x, destination)); } [Fact] - public static void Min_Negative0LesserThanPositive0() + public static void Log2_ThrowsForOverlapppingInputsWithOutputs() { - Assert.Equal(-0f, TensorPrimitives.Min([-0f, +0f])); - Assert.Equal(-0f, TensorPrimitives.Min([+0f, -0f])); - Assert.Equal(-1, TensorPrimitives.Min([-1, -0f])); - Assert.Equal(-1, TensorPrimitives.Min([-1, -0f, 1])); + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Log2(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Log2(array.AsSpan(1, 2), array.AsSpan(2, 2))); } + #endregion + #region Max [Fact] - public static void MinMagnitude_ThrowsForEmpty() + public static void Max_Tensor_ThrowsForEmpty() { - Assert.Throws(() => TensorPrimitives.MinMagnitude(ReadOnlySpan.Empty)); + Assert.Throws(() => TensorPrimitives.Max(ReadOnlySpan.Empty)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void MinMagnitude(int tensorLength) + public static void Max_Tensor(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - Assert.Equal(Enumerable.Min(MemoryMarshal.ToEnumerable(x.Memory), MathF.Abs), TensorPrimitives.MinMagnitude(x)); + Assert.Equal(Enumerable.Max(MemoryMarshal.ToEnumerable(x.Memory)), TensorPrimitives.Max(x)); - float min = float.PositiveInfinity; + float max = float.NegativeInfinity; foreach (float f in x.Span) { - min = Math.Min(min, MathF.Abs(f)); + max = Math.Max(max, f); } - Assert.Equal(min, TensorPrimitives.MinMagnitude(x)); + + Assert.Equal(max, TensorPrimitives.Max(x)); + Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMax(x)]), SingleToUInt32(TensorPrimitives.Max(x))); } [Theory] [MemberData(nameof(TensorLengths))] - public static void MinMagnitude_NanReturned(int tensorLength) + public static void Max_Tensor_SpecialValues(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); + + RunForEachSpecialValue(() => + { + float max = float.NegativeInfinity; + foreach (float f in x.Span) + { + max = Math.Max(max, f); + } + + Assert.Equal(max, TensorPrimitives.Max(x)); + Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMax(x)]), SingleToUInt32(TensorPrimitives.Max(x))); + }, x); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Max_Tensor_NanReturned(int tensorLength) + { + using BoundedMemory x = CreateTensor(tensorLength); foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) { + FillTensor(x); x[expected] = float.NaN; - Assert.Equal(float.NaN, TensorPrimitives.MinMagnitude(x)); + Assert.Equal(float.NaN, TensorPrimitives.Max(x)); } } [Fact] - public static void MinMagnitude_Negative0LesserThanPositive0() + public static void Max_Tensor_Negative0LesserThanPositive0() { - Assert.Equal(0, TensorPrimitives.MinMagnitude([-0f, +0f])); - Assert.Equal(0, TensorPrimitives.MinMagnitude([+0f, -0f])); - Assert.Equal(0, TensorPrimitives.MinMagnitude([-1, -0f])); - Assert.Equal(0, TensorPrimitives.MinMagnitude([-1, -0f, 1])); + Assert.Equal(+0f, TensorPrimitives.Max([-0f, +0f])); + Assert.Equal(+0f, TensorPrimitives.Max([+0f, -0f])); + Assert.Equal(-0f, TensorPrimitives.Max([-1, -0f])); + Assert.Equal(1, TensorPrimitives.Max([-1, -0f, 1])); } - [Fact] - public static void Product_ThrowsForEmpty() + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Max_TwoTensors(int tensorLength) { - Assert.Throws(() => TensorPrimitives.Product(ReadOnlySpan.Empty)); + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Max(x, y, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Max(x[i], y[i]), destination[i]); + } } [Theory] - [MemberData(nameof(TensorLengths))] - public static void Product(int tensorLength) + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Max_TwoTensors_InPlace(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray(); - float f = x[0]; - for (int i = 1; i < x.Length; i++) + TensorPrimitives.Max(x, y, x); + + for (int i = 0; i < tensorLength; i++) { - f *= x[i]; + AssertEqualTolerance(MathF.Max(xOrig[i], y[i]), x[i]); } - Assert.Equal(f, TensorPrimitives.Product(x), Tolerance); - } + xOrig.AsSpan().CopyTo(x.Span); + yOrig.AsSpan().CopyTo(y.Span); - [Fact] - public static void Product_KnownValues() - { - Assert.Equal(1, TensorPrimitives.Product([1])); - Assert.Equal(-2, TensorPrimitives.Product([1, -2])); - Assert.Equal(-6, TensorPrimitives.Product([1, -2, 3])); - Assert.Equal(24, TensorPrimitives.Product([1, -2, 3, -4])); - Assert.Equal(120, TensorPrimitives.Product([1, -2, 3, -4, 5])); - Assert.Equal(-720, TensorPrimitives.Product([1, -2, 3, -4, 5, -6])); - Assert.Equal(0, TensorPrimitives.Product([1, -2, 3, -4, 5, -6, 0])); - Assert.Equal(0, TensorPrimitives.Product([0, 1, -2, 3, -4, 5, -6])); - Assert.Equal(0, TensorPrimitives.Product([1, -2, 3, 0, -4, 5, -6])); - Assert.Equal(float.NaN, TensorPrimitives.Product([1, -2, 3, float.NaN, -4, 5, -6])); - } + TensorPrimitives.Max(x, y, y); - [Fact] - public static void ProductOfDifferences_ThrowsForEmptyAndMismatchedLengths() - { - Assert.Throws(() => TensorPrimitives.ProductOfDifferences(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); - Assert.Throws(() => TensorPrimitives.ProductOfDifferences(ReadOnlySpan.Empty, CreateTensor(1))); - Assert.Throws(() => TensorPrimitives.ProductOfDifferences(CreateTensor(1), ReadOnlySpan.Empty)); - Assert.Throws(() => TensorPrimitives.ProductOfDifferences(CreateTensor(44), CreateTensor(43))); - Assert.Throws(() => TensorPrimitives.ProductOfDifferences(CreateTensor(43), CreateTensor(44))); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Max(x[i], yOrig[i]), y[i]); + } } [Theory] [MemberData(nameof(TensorLengths))] - public static void ProductOfDifferences(int tensorLength) + public static void Max_TwoTensors_SpecialValues(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); - float f = x[0] - y[0]; - for (int i = 1; i < x.Length; i++) + SetSpecialValues(x, y); + + TensorPrimitives.Max(x, y, destination); + for (int i = 0; i < tensorLength; i++) { - f *= x[i] - y[i]; + AssertEqualTolerance(MathF.Max(x[i], y[i]), destination[i]); } - Assert.Equal(f, TensorPrimitives.ProductOfDifferences(x, y), Tolerance); - } - [Fact] - public static void ProductOfDifferences_KnownValues() - { - Assert.Equal(0, TensorPrimitives.ProductOfDifferences([0], [0])); - Assert.Equal(0, TensorPrimitives.ProductOfDifferences([1], [1])); - Assert.Equal(1, TensorPrimitives.ProductOfDifferences([1], [0])); - Assert.Equal(-1, TensorPrimitives.ProductOfDifferences([0], [1])); - Assert.Equal(-1, TensorPrimitives.ProductOfDifferences([1, 2, 3, 4, 5], [2, 3, 4, 5, 6])); - Assert.Equal(120, TensorPrimitives.ProductOfDifferences([1, 2, 3, 4, 5], [0, 0, 0, 0, 0])); - Assert.Equal(-120, TensorPrimitives.ProductOfDifferences([0, 0, 0, 0, 0], [1, 2, 3, 4, 5])); - Assert.Equal(float.NaN, TensorPrimitives.ProductOfDifferences([1, 2, float.NaN, 4, 5], [0, 0, 0, 0, 0])); + TensorPrimitives.Max(y, x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Max(y[i], x[i]), destination[i]); + } } - [Fact] - public static void ProductOfSums_ThrowsForEmptyAndMismatchedLengths() + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Max_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) { - Assert.Throws(() => TensorPrimitives.ProductOfSums(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); - Assert.Throws(() => TensorPrimitives.ProductOfSums(ReadOnlySpan.Empty, CreateTensor(1))); - Assert.Throws(() => TensorPrimitives.ProductOfSums(CreateTensor(1), ReadOnlySpan.Empty)); - Assert.Throws(() => TensorPrimitives.ProductOfSums(CreateTensor(44), CreateTensor(43))); - Assert.Throws(() => TensorPrimitives.ProductOfSums(CreateTensor(43), CreateTensor(44))); + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); + + Assert.Throws(() => TensorPrimitives.Max(x, y, destination)); + Assert.Throws(() => TensorPrimitives.Max(y, x, destination)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void ProductOfSums(int tensorLength) + public static void Max_TwoTensors_ThrowsForTooShortDestination(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); - float f = x[0] + y[0]; - for (int i = 1; i < x.Length; i++) - { - f *= x[i] + y[i]; - } - Assert.Equal(f, TensorPrimitives.ProductOfSums(x, y), Tolerance); + AssertExtensions.Throws("destination", () => TensorPrimitives.Max(x, y, destination)); + } + + [Fact] + public static void Max_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Max(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Max(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Max(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Max(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); } + #endregion + #region MaxMagnitude [Fact] - public static void ProductOfSums_KnownValues() + public static void MaxMagnitude_Tensor_ThrowsForEmpty() { - Assert.Equal(0, TensorPrimitives.ProductOfSums([0], [0])); - Assert.Equal(1, TensorPrimitives.ProductOfSums([0], [1])); - Assert.Equal(1, TensorPrimitives.ProductOfSums([1], [0])); - Assert.Equal(2, TensorPrimitives.ProductOfSums([1], [1])); - Assert.Equal(10395, TensorPrimitives.ProductOfSums([1, 2, 3, 4, 5], [2, 3, 4, 5, 6])); - Assert.Equal(120, TensorPrimitives.ProductOfSums([1, 2, 3, 4, 5], [0, 0, 0, 0, 0])); - Assert.Equal(120, TensorPrimitives.ProductOfSums([0, 0, 0, 0, 0], [1, 2, 3, 4, 5])); - Assert.Equal(float.NaN, TensorPrimitives.ProductOfSums([1, 2, float.NaN, 4, 5], [0, 0, 0, 0, 0])); + Assert.Throws(() => TensorPrimitives.MaxMagnitude(ReadOnlySpan.Empty)); } [Theory] [MemberData(nameof(TensorLengths))] - public static void Sum(int tensorLength) + public static void MaxMagnitude_Tensor(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - Assert.Equal(Enumerable.Sum(MemoryMarshal.ToEnumerable(x.Memory)), TensorPrimitives.Sum(x), Tolerance); - - float sum = 0; + float maxMagnitude = x[0]; foreach (float f in x.Span) { - sum += f; + maxMagnitude = MathFMaxMagnitude(maxMagnitude, f); } - Assert.Equal(sum, TensorPrimitives.Sum(x), Tolerance); - } - [Fact] - public static void Sum_KnownValues() - { - Assert.Equal(0, TensorPrimitives.Sum([0])); - Assert.Equal(1, TensorPrimitives.Sum([0, 1])); - Assert.Equal(6, TensorPrimitives.Sum([1, 2, 3])); - Assert.Equal(0, TensorPrimitives.Sum([-3, 0, 3])); - Assert.Equal(float.NaN, TensorPrimitives.Sum([-3, float.NaN, 3])); + Assert.Equal(maxMagnitude, TensorPrimitives.MaxMagnitude(x)); + Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMaxMagnitude(x)]), SingleToUInt32(TensorPrimitives.MaxMagnitude(x))); } [Theory] [MemberData(nameof(TensorLengths))] - public static void SumOfSquares(int tensorLength) + public static void MaxMagnitude_Tensor_SpecialValues(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - Assert.Equal(Enumerable.Sum(MemoryMarshal.ToEnumerable(x.Memory), v => v * v), TensorPrimitives.SumOfSquares(x), Tolerance); + RunForEachSpecialValue(() => + { + float maxMagnitude = x[0]; + foreach (float f in x.Span) + { + maxMagnitude = MathFMaxMagnitude(maxMagnitude, f); + } - float sum = 0; - foreach (float f in x.Span) + Assert.Equal(maxMagnitude, TensorPrimitives.MaxMagnitude(x)); + Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMaxMagnitude(x)]), SingleToUInt32(TensorPrimitives.MaxMagnitude(x))); + }, x); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MaxMagnitude_Tensor_NanReturned(int tensorLength) + { + using BoundedMemory x = CreateTensor(tensorLength); + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) { - sum += f * f; + FillTensor(x); + x[expected] = float.NaN; + Assert.Equal(float.NaN, TensorPrimitives.MaxMagnitude(x)); } - Assert.Equal(sum, TensorPrimitives.SumOfSquares(x), Tolerance); } [Fact] - public static void SumOfSquares_KnownValues() + public static void MaxMagnitude_Tensor_Negative0LesserThanPositive0() { - Assert.Equal(0, TensorPrimitives.SumOfSquares([0])); - Assert.Equal(1, TensorPrimitives.SumOfSquares([0, 1])); - Assert.Equal(14, TensorPrimitives.SumOfSquares([1, 2, 3])); - Assert.Equal(18, TensorPrimitives.SumOfSquares([-3, 0, 3])); - Assert.Equal(float.NaN, TensorPrimitives.SumOfSquares([-3, float.NaN, 3])); + Assert.Equal(+0f, TensorPrimitives.MaxMagnitude([-0f, +0f])); + Assert.Equal(+0f, TensorPrimitives.MaxMagnitude([+0f, -0f])); + Assert.Equal(-1, TensorPrimitives.MaxMagnitude([-1, -0f])); + Assert.Equal(1, TensorPrimitives.MaxMagnitude([-1, -0f, 1])); + Assert.Equal(0f, TensorPrimitives.MaxMagnitude([-0f, -0f, -0f, -0f, -0f, 0f])); + Assert.Equal(1, TensorPrimitives.MaxMagnitude([-0f, -0f, -0f, -0f, -1, -0f, 0f, 1])); } [Theory] - [MemberData(nameof(TensorLengths))] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MaxMagnitude_TwoTensors(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.MaxMagnitude(x, y, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMaxMagnitude(x[i], y[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MaxMagnitude_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray(); + + TensorPrimitives.MaxMagnitude(x, y, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMaxMagnitude(xOrig[i], y[i]), x[i]); + } + + xOrig.AsSpan().CopyTo(x.Span); + yOrig.AsSpan().CopyTo(y.Span); + + TensorPrimitives.MaxMagnitude(x, y, y); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMaxMagnitude(x[i], yOrig[i]), y[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MaxMagnitude_TwoTensors_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + SetSpecialValues(x, y); + + TensorPrimitives.MaxMagnitude(x, y, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMaxMagnitude(x[i], y[i]), destination[i]); + } + + TensorPrimitives.MaxMagnitude(y, x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMaxMagnitude(y[i], x[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MaxMagnitude_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); + + Assert.Throws(() => TensorPrimitives.MaxMagnitude(x, y, destination)); + Assert.Throws(() => TensorPrimitives.MaxMagnitude(y, x, destination)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MaxMagnitude_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.MaxMagnitude(x, y, destination)); + } + + [Fact] + public static void MaxMagnitude_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.MaxMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MaxMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MaxMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MaxMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + #endregion + + #region Min + [Fact] + public static void Min_Tensor_ThrowsForEmpty() + { + Assert.Throws(() => TensorPrimitives.Min(ReadOnlySpan.Empty)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Min_Tensor(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + Assert.Equal(Enumerable.Min(MemoryMarshal.ToEnumerable(x.Memory)), TensorPrimitives.Min(x)); + + float min = float.PositiveInfinity; + foreach (float f in x.Span) + { + min = Math.Min(min, f); + } + + Assert.Equal(min, TensorPrimitives.Min(x)); + Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMin(x)]), SingleToUInt32(TensorPrimitives.Min(x))); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Min_Tensor_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + RunForEachSpecialValue(() => + { + float min = float.PositiveInfinity; + foreach (float f in x.Span) + { + min = Math.Min(min, f); + } + + Assert.Equal(min, TensorPrimitives.Min(x)); + Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMin(x)]), SingleToUInt32(TensorPrimitives.Min(x))); + }, x); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Min_Tensor_NanReturned(int tensorLength) + { + using BoundedMemory x = CreateTensor(tensorLength); + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + FillTensor(x); + x[expected] = float.NaN; + Assert.Equal(float.NaN, TensorPrimitives.Min(x)); + } + } + + [Fact] + public static void Min_Tensor_Negative0LesserThanPositive0() + { + Assert.Equal(-0f, TensorPrimitives.Min([-0f, +0f])); + Assert.Equal(-0f, TensorPrimitives.Min([+0f, -0f])); + Assert.Equal(-1, TensorPrimitives.Min([-1, -0f])); + Assert.Equal(-1, TensorPrimitives.Min([-1, -0f, 1])); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Min_TwoTensors(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Min(x, y, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Min(x[i], y[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Min_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray(); + + TensorPrimitives.Min(x, y, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Min(xOrig[i], y[i]), x[i]); + } + + xOrig.AsSpan().CopyTo(x.Span); + yOrig.AsSpan().CopyTo(y.Span); + + TensorPrimitives.Min(x, y, y); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Min(x[i], yOrig[i]), y[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Min_TwoTensors_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + SetSpecialValues(x, y); + + TensorPrimitives.Min(x, y, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Min(x[i], y[i]), destination[i]); + } + + TensorPrimitives.Min(y, x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Min(y[i], x[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Min_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); + + Assert.Throws(() => TensorPrimitives.Min(x, y, destination)); + Assert.Throws(() => TensorPrimitives.Min(y, x, destination)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Min_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Min(x, y, destination)); + } + + [Fact] + public static void Min_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Min(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Min(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Min(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Min(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + #endregion + + #region MinMagnitude + [Fact] + public static void MinMagnitude_Tensor_ThrowsForEmpty() + { + Assert.Throws(() => TensorPrimitives.MinMagnitude(ReadOnlySpan.Empty)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MinMagnitude_Tensor(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + float minMagnitude = x[0]; + foreach (float f in x.Span) + { + minMagnitude = MathFMinMagnitude(minMagnitude, f); + } + + Assert.Equal(minMagnitude, TensorPrimitives.MinMagnitude(x)); + Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMinMagnitude(x)]), SingleToUInt32(TensorPrimitives.MinMagnitude(x))); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MinMagnitude_Tensor_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + RunForEachSpecialValue(() => + { + float minMagnitude = x[0]; + foreach (float f in x.Span) + { + minMagnitude = MathFMinMagnitude(minMagnitude, f); + } + + Assert.Equal(minMagnitude, TensorPrimitives.MinMagnitude(x)); + Assert.Equal(SingleToUInt32(x[TensorPrimitives.IndexOfMinMagnitude(x)]), SingleToUInt32(TensorPrimitives.MinMagnitude(x))); + }, x); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MinMagnitude_Tensor_NanReturned(int tensorLength) + { + using BoundedMemory x = CreateTensor(tensorLength); + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + FillTensor(x); + x[expected] = float.NaN; + Assert.Equal(float.NaN, TensorPrimitives.MinMagnitude(x)); + } + } + + [Fact] + public static void MinMagnitude_Tensor_Negative0LesserThanPositive0() + { + Assert.Equal(0, TensorPrimitives.MinMagnitude([-0f, +0f])); + Assert.Equal(0, TensorPrimitives.MinMagnitude([+0f, -0f])); + Assert.Equal(0, TensorPrimitives.MinMagnitude([-1, -0f])); + Assert.Equal(0, TensorPrimitives.MinMagnitude([-1, -0f, 1])); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MinMagnitude_TwoTensors(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.MinMagnitude(x, y, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMinMagnitude(x[i], y[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MinMagnitude_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(), yOrig = y.Span.ToArray(); + + TensorPrimitives.MinMagnitude(x, y, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMinMagnitude(xOrig[i], y[i]), x[i]); + } + + xOrig.AsSpan().CopyTo(x.Span); + yOrig.AsSpan().CopyTo(y.Span); + + TensorPrimitives.MinMagnitude(x, y, y); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMinMagnitude(x[i], yOrig[i]), y[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MinMagnitude_TwoTensors_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + SetSpecialValues(x, y); + + TensorPrimitives.MinMagnitude(x, y, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMinMagnitude(x[i], y[i]), destination[i]); + } + + TensorPrimitives.MinMagnitude(y, x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathFMinMagnitude(y[i], x[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MinMagnitude_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); + + Assert.Throws(() => TensorPrimitives.MinMagnitude(x, y, destination)); + Assert.Throws(() => TensorPrimitives.MinMagnitude(y, x, destination)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MinMagnitude_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.MinMagnitude(x, y, destination)); + } + + [Fact] + public static void MinMagnitude_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.MinMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MinMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MinMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MinMagnitude(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + #endregion + + #region Multiply + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Multiply_TwoTensors(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Multiply(x, y, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(x[i] * y[i], destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Multiply_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Multiply(x, x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(xOrig[i] * xOrig[i], x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Multiply_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); + + Assert.Throws(() => TensorPrimitives.Multiply(x, y, destination)); + Assert.Throws(() => TensorPrimitives.Multiply(y, x, destination)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Multiply_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(x, y, destination)); + } + + [Fact] + public static void Multiply_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Multiply_TensorScalar(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Multiply(x, y, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(x[i] * y, destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Multiply_TensorScalar_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float y = NextSingle(); + + TensorPrimitives.Multiply(x, y, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(xOrig[i] * y, x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Multiply_TensorScalar_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(x, y, destination)); + } + + [Fact] + public static void Multiply_TensorScalar_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), 42, array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Multiply(array.AsSpan(1, 2), 42, array.AsSpan(2, 2))); + } + #endregion + + #region MultiplyAdd + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MultiplyAdd_ThreeTensors(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory addend = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.MultiplyAdd(x, y, addend, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((x[i] * y[i]) + addend[i], destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MultiplyAdd_ThreeTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.MultiplyAdd(x, x, x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((xOrig[i] * xOrig[i]) + xOrig[i], x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyAdd_ThreeTensors_ThrowsForMismatchedLengths_x_y(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory z = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); + + Assert.Throws(() => TensorPrimitives.MultiplyAdd(x, y, z, destination)); + Assert.Throws(() => TensorPrimitives.MultiplyAdd(x, z, y, destination)); + Assert.Throws(() => TensorPrimitives.MultiplyAdd(z, x, y, destination)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyAdd_ThreeTensors_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory addend = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); + } + + [Fact] + public static void MultiplyAdd_ThreeTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(5, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(6, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(7, 2), array.AsSpan(8, 2))); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MultiplyAdd_TensorTensorScalar(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + float addend = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.MultiplyAdd(x, y, addend, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((x[i] * y[i]) + addend, destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MultiplyAdd_TensorTensorScalar_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float addend = NextSingle(); + + TensorPrimitives.MultiplyAdd(x, x, addend, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((xOrig[i] * xOrig[i]) + addend, x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyAdd_TensorTensorScalar_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + float addend = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); + } + + [Fact] + public static void MultiplyAdd_TensorTensorScalar_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), array.AsSpan(4, 2), 42, array.AsSpan(5, 2))); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MultiplyAdd_TensorScalarTensor(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory addend = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.MultiplyAdd(x, y, addend, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((x[i] * y) + addend[i], destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void MultiplyAdd_TensorScalarTensor_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float y = NextSingle(); + + TensorPrimitives.MultiplyAdd(x, y, x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance((xOrig[i] * y) + xOrig[i], x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void MultiplyAdd_TensorScalarTensor_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory addend = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(x, y, addend, destination)); + } + + [Fact] + public static void MultiplyAdd_TensorScalarTensor_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.MultiplyAdd(array.AsSpan(1, 2), 42, array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + #endregion + + #region Negate + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Negate(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Negate(x, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(-x[i], destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Negate_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Negate(x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(-xOrig[i], x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Negate_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Negate(x, destination)); + } + + [Fact] + public static void Negate_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Negate(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Negate(array.AsSpan(1, 2), array.AsSpan(2, 2))); + } + #endregion + + #region Norm + [Theory] + [InlineData(new float[] { 1, 2, 3 }, 3.7416575f)] + [InlineData(new float[] { 3, 4 }, 5)] + [InlineData(new float[] { 3 }, 3)] + [InlineData(new float[] { 3, 4, 1, 2 }, 5.477226)] + [InlineData(new float[] { }, 0f)] + public static void Norm_KnownValues(float[] x, float expectedResult) + { + AssertEqualTolerance(expectedResult, TensorPrimitives.Norm(x)); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Norm(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + float sumOfSquares = 0f; + for (int i = 0; i < x.Length; i++) + { + sumOfSquares += x[i] * x[i]; + } + + AssertEqualTolerance(MathF.Sqrt(sumOfSquares), TensorPrimitives.Norm(x)); + } + #endregion + + #region Product + [Fact] + public static void Product_ThrowsForEmpty() + { + Assert.Throws(() => TensorPrimitives.Product(ReadOnlySpan.Empty)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Product(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + float f = x[0]; + for (int i = 1; i < x.Length; i++) + { + f *= x[i]; + } + + AssertEqualTolerance(f, TensorPrimitives.Product(x)); + } + + [Theory] + [InlineData(1, new float[] { 1 })] + [InlineData(-2, new float[] { 1, -2 })] + [InlineData(-6, new float[] { 1, -2, 3 })] + [InlineData(24, new float[] { 1, -2, 3, -4 })] + [InlineData(120, new float[] { 1, -2, 3, -4, 5 })] + [InlineData(-720, new float[] { 1, -2, 3, -4, 5, -6 })] + [InlineData(0, new float[] { 1, -2, 3, -4, 5, -6, 0 })] + [InlineData(0, new float[] { 0, 1, -2, 3, -4, 5, -6 })] + [InlineData(0, new float[] { 1, -2, 3, 0, -4, 5, -6 })] + [InlineData(float.NaN, new float[] { 1, -2, 3, float.NaN, -4, 5, -6 })] + public static void Product_KnownValues(float expected, float[] input) + { + Assert.Equal(expected, TensorPrimitives.Product(input)); + } + #endregion + + #region ProductOfDifferences + [Fact] + public static void ProductOfDifferences_ThrowsForEmptyAndMismatchedLengths() + { + Assert.Throws(() => TensorPrimitives.ProductOfDifferences(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); + Assert.Throws(() => TensorPrimitives.ProductOfDifferences(ReadOnlySpan.Empty, CreateTensor(1))); + Assert.Throws(() => TensorPrimitives.ProductOfDifferences(CreateTensor(1), ReadOnlySpan.Empty)); + Assert.Throws(() => TensorPrimitives.ProductOfDifferences(CreateTensor(44), CreateTensor(43))); + Assert.Throws(() => TensorPrimitives.ProductOfDifferences(CreateTensor(43), CreateTensor(44))); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void ProductOfDifferences(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + + float f = x[0] - y[0]; + for (int i = 1; i < x.Length; i++) + { + f *= x[i] - y[i]; + } + AssertEqualTolerance(f, TensorPrimitives.ProductOfDifferences(x, y)); + } + + [Theory] + [InlineData(0, new float[] {0 }, new float[] {0})] + [InlineData(0, new float[] {1 }, new float[] {1})] + [InlineData(1, new float[] {1 }, new float[] {0})] + [InlineData(-1, new float[] {0 }, new float[] {1})] + [InlineData(-1, new float[] {1, 2, 3, 4, 5 }, new float[] {2, 3, 4, 5, 6})] + [InlineData(120, new float[] {1, 2, 3, 4, 5 }, new float[] {0, 0, 0, 0, 0})] + [InlineData(-120, new float[] {0, 0, 0, 0, 0 }, new float[] {1, 2, 3, 4, 5})] + [InlineData(float.NaN, new float[] {1, 2, float.NaN, 4, 5 }, new float[] {0, 0, 0, 0, 0})] + public static void ProductOfDifferences_KnownValues(float expected, float[] x, float[] y) + { + Assert.Equal(expected, TensorPrimitives.ProductOfDifferences(x, y)); + + } + #endregion + + #region ProductOfSums + [Fact] + public static void ProductOfSums_ThrowsForEmptyAndMismatchedLengths() + { + Assert.Throws(() => TensorPrimitives.ProductOfSums(ReadOnlySpan.Empty, ReadOnlySpan.Empty)); + Assert.Throws(() => TensorPrimitives.ProductOfSums(ReadOnlySpan.Empty, CreateTensor(1))); + Assert.Throws(() => TensorPrimitives.ProductOfSums(CreateTensor(1), ReadOnlySpan.Empty)); + Assert.Throws(() => TensorPrimitives.ProductOfSums(CreateTensor(44), CreateTensor(43))); + Assert.Throws(() => TensorPrimitives.ProductOfSums(CreateTensor(43), CreateTensor(44))); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void ProductOfSums(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + + float f = x[0] + y[0]; + for (int i = 1; i < x.Length; i++) + { + f *= x[i] + y[i]; + } + AssertEqualTolerance(f, TensorPrimitives.ProductOfSums(x, y)); + } + + [Theory] + [InlineData(0, new float[] {0 }, new float[] { 0 })] + [InlineData(1, new float[] {0 }, new float[] { 1 })] + [InlineData(1, new float[] {1 }, new float[] { 0 })] + [InlineData(2, new float[] {1 }, new float[] { 1 })] + [InlineData(10395, new float[] {1, 2, 3, 4, 5 }, new float[] { 2, 3, 4, 5, 6 })] + [InlineData(120, new float[] {1, 2, 3, 4, 5 }, new float[] { 0, 0, 0, 0, 0 })] + [InlineData(120, new float[] {0, 0, 0, 0, 0 }, new float[] { 1, 2, 3, 4, 5 })] + [InlineData(float.NaN, new float[] {1, 2, float.NaN, 4, 5 }, new float[] { 0, 0, 0, 0, 0 })] + public static void ProductOfSums_KnownValues(float expected, float[] x, float[] y) + { + Assert.Equal(expected, TensorPrimitives.ProductOfSums(x, y)); + } + #endregion + + #region Sigmoid + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Sigmoid(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Sigmoid(x, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(1f / (1f + MathF.Exp(-x[i])), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Sigmoid_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Sigmoid(x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(1f / (1f + MathF.Exp(-xOrig[i])), x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Sigmoid_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + RunForEachSpecialValue(() => + { + TensorPrimitives.Sigmoid(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(1f / (1f + MathF.Exp(-x[i])), destination[i]); + } + }, x); + } + + [Theory] + [InlineData(new float[] { -5, -4.5f, -4 }, new float[] { 0.0066f, 0.0109f, 0.0179f })] + [InlineData(new float[] { 4.5f, 5 }, new float[] { 0.9890f, 0.9933f })] + [InlineData(new float[] { 0, -3, 3, .5f }, new float[] { 0.5f, 0.0474f, 0.9525f, 0.6224f })] + public static void Sigmoid_KnownValues(float[] x, float[] expectedResult) + { + using BoundedMemory dest = CreateTensor(x.Length); + TensorPrimitives.Sigmoid(x, dest); + + for (int i = 0; i < x.Length; i++) + { + AssertEqualTolerance(expectedResult[i], dest[i], 0.0001f); + } + } + + [Theory] + [InlineData(new float[] { -5, -4.5f, -4 }, new float[] { 0.0066f, 0.0109f, 0.0179f })] + public static void Sigmoid_DestinationLongerThanSource(float[] x, float[] expectedResult) + { + using BoundedMemory dest = CreateTensor(x.Length + 1); + + TensorPrimitives.Sigmoid(x, dest); + + float originalLast = dest[dest.Length - 1]; + for (int i = 0; i < x.Length; i++) + { + AssertEqualTolerance(expectedResult[i], dest[i], 0.0001f); + } + Assert.Equal(originalLast, dest[dest.Length - 1]); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Sigmoid_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Sigmoid(x, destination)); + } + + [Fact] + public static void Sigmoid_ThrowsForEmptyInput() + { + AssertExtensions.Throws(() => TensorPrimitives.Sigmoid(ReadOnlySpan.Empty, CreateTensor(1))); + } + + [Fact] + public static void Sigmoid_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Sigmoid(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Sigmoid(array.AsSpan(1, 2), array.AsSpan(2, 2))); + } + #endregion + + #region Sinh + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Sinh(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Sinh(x, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Sinh(x[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Sinh_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Sinh(x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Sinh(xOrig[i]), x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Sinh_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + RunForEachSpecialValue(() => + { + TensorPrimitives.Sinh(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Sinh(x[i]), destination[i]); + } + }, x); + } + + [Theory] + [MemberData(nameof(VectorLengthAndIteratedRange), new object[] { -100f, 100f, 3f })] + public static void Sinh_ValueRange(int vectorLengths, float element) + { + float[] x = new float[vectorLengths]; + float[] dest = new float[vectorLengths]; + + x.AsSpan().Fill(element); + TensorPrimitives.Sinh(x, dest); + + float expected = MathF.Sinh(element); + foreach (float actual in dest) + { + AssertEqualTolerance(expected, actual); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Sinh_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Sinh(x, destination)); + } + + [Fact] + public static void Sinh_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Sinh(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Sinh(array.AsSpan(1, 2), array.AsSpan(2, 2))); + } + #endregion + + #region SoftMax + [Theory] + [MemberData(nameof(TensorLengths))] + public static void SoftMax(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.SoftMax(x, destination); + + float expSum = MemoryMarshal.ToEnumerable(x.Memory).Sum(MathF.Exp); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Exp(x[i]) / expSum, destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void SoftMax_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.SoftMax(x, x); + + float expSum = xOrig.Sum(MathF.Exp); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Exp(xOrig[i]) / expSum, x[i]); + } + } + + [Theory] + [InlineData(new float[] { 3, 1, .2f }, new float[] { 0.8360188f, 0.11314284f, 0.05083836f })] + [InlineData(new float[] { 3, 4, 1 }, new float[] { 0.2594f, 0.705384f, 0.0351f })] + [InlineData(new float[] { 5, 3 }, new float[] { 0.8807f, 0.1192f })] + [InlineData(new float[] { 4, 2, 1, 9 }, new float[] { 0.0066f, 9.04658e-4f, 3.32805e-4f, 0.9920f })] + public static void SoftMax_KnownValues(float[] x, float[] expectedResult) + { + using BoundedMemory dest = CreateTensor(x.Length); + TensorPrimitives.SoftMax(x, dest); + + for (int i = 0; i < x.Length; i++) + { + AssertEqualTolerance(expectedResult[i], dest[i], 0.0001f); + } + } + + [Fact] + public static void SoftMax_DestinationLongerThanSource() + { + float[] x = [3, 1, .2f]; + float[] expectedResult = [0.8360188f, 0.11314284f, 0.05083836f]; + using BoundedMemory dest = CreateTensor(x.Length + 1); + TensorPrimitives.SoftMax(x, dest); + + for (int i = 0; i < x.Length; i++) + { + AssertEqualTolerance(expectedResult[i], dest[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void SoftMax_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.SoftMax(x, destination)); + } + + [Fact] + public static void SoftMax_ThrowsForEmptyInput() + { + AssertExtensions.Throws(() => TensorPrimitives.SoftMax(ReadOnlySpan.Empty, CreateTensor(1))); + } + + [Fact] + public static void SoftMax_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.SoftMax(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.SoftMax(array.AsSpan(1, 2), array.AsSpan(2, 2))); + } + #endregion + + #region Subtract + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Subtract_TwoTensors(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Subtract(x, y, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(x[i] - y[i], destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Subtract_TwoTensors_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Subtract(x, x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(xOrig[i] - xOrig[i], x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Subtract_TwoTensors_ThrowsForMismatchedLengths(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength - 1); + using BoundedMemory destination = CreateTensor(tensorLength); + + Assert.Throws(() => TensorPrimitives.Subtract(x, y, destination)); + Assert.Throws(() => TensorPrimitives.Subtract(y, x, destination)); + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Subtract_TwoTensors_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory y = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(x, y, destination)); + } + + [Fact] + public static void Subtract_TwoTensors_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(2, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(3, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), array.AsSpan(4, 2), array.AsSpan(5, 2))); + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Subtract_TensorScalar(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Subtract(x, y, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(x[i] - y, destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Subtract_TensorScalar_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + float y = NextSingle(); + + TensorPrimitives.Subtract(x, y, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(xOrig[i] - y, x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Subtract_TensorScalar_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float y = NextSingle(); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(x, y, destination)); + } + + [Fact] + public static void Subtract_TensorScalar_ThrowsForOverlapppingInputsWithOutputs() + { + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), 42, array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Subtract(array.AsSpan(1, 2), 42, array.AsSpan(2, 2))); + } + #endregion + + #region Sum + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Sum(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + AssertEqualTolerance(MemoryMarshal.ToEnumerable(x.Memory).Sum(), TensorPrimitives.Sum(x)); + + float sum = 0; + foreach (float f in x.Span) + { + sum += f; + } + AssertEqualTolerance(sum, TensorPrimitives.Sum(x)); + } + + [Theory] + [InlineData(0, new float[] { 0 })] + [InlineData(1, new float[] { 0, 1 })] + [InlineData(6, new float[] { 1, 2, 3 })] + [InlineData(0, new float[] { -3, 0, 3 })] + [InlineData(float.NaN, new float[] { -3, float.NaN, 3 })] + public static void Sum_KnownValues(float expected, float[] x) + { + Assert.Equal(expected, TensorPrimitives.Sum(x)); + } + #endregion + + #region SumOfMagnitudes + [Theory] + [MemberData(nameof(TensorLengths))] public static void SumOfMagnitudes(int tensorLength) { using BoundedMemory x = CreateAndFillTensor(tensorLength); - Assert.Equal(Enumerable.Sum(MemoryMarshal.ToEnumerable(x.Memory), MathF.Abs), TensorPrimitives.SumOfMagnitudes(x), Tolerance); + AssertEqualTolerance(Enumerable.Sum(MemoryMarshal.ToEnumerable(x.Memory), MathF.Abs), TensorPrimitives.SumOfMagnitudes(x)); float sum = 0; foreach (float f in x.Span) { sum += MathF.Abs(f); } - Assert.Equal(sum, TensorPrimitives.SumOfMagnitudes(x), Tolerance); + AssertEqualTolerance(sum, TensorPrimitives.SumOfMagnitudes(x)); + } + + [Theory] + [InlineData(0, new float[] { 0 })] + [InlineData(1, new float[] { 0, 1 })] + [InlineData(6, new float[] { 1, 2, 3 })] + [InlineData(6, new float[] { -3, 0, 3 })] + [InlineData(float.NaN, new float[] { -3, float.NaN, 3 })] + public static void SumOfMagnitudes_KnownValues(float expected, float[] x) + { + Assert.Equal(expected, TensorPrimitives.SumOfMagnitudes(x)); + } + #endregion + + #region SumOfSquares + [Theory] + [MemberData(nameof(TensorLengths))] + public static void SumOfSquares(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + AssertEqualTolerance(Enumerable.Sum(MemoryMarshal.ToEnumerable(x.Memory), v => v * v), TensorPrimitives.SumOfSquares(x)); + + float sum = 0; + foreach (float f in x.Span) + { + sum += f * f; + } + AssertEqualTolerance(sum, TensorPrimitives.SumOfSquares(x)); + } + + [Theory] + [InlineData(0, new float[] { 0 })] + [InlineData(1, new float[] { 0, 1 })] + [InlineData(14, new float[] { 1, 2, 3 })] + [InlineData(18, new float[] { -3, 0, 3 })] + [InlineData(float.NaN, new float[] { -3, float.NaN, 3 })] + public static void SumOfSquares_KnownValues(float expected, float[] x) + { + Assert.Equal(expected, TensorPrimitives.SumOfSquares(x)); + } + #endregion + + #region Tanh + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Tanh(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + TensorPrimitives.Tanh(x, destination); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Tanh(x[i]), destination[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengthsIncluding0))] + public static void Tanh_InPlace(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + float[] xOrig = x.Span.ToArray(); + + TensorPrimitives.Tanh(x, x); + + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Tanh(xOrig[i]), x[i]); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Tanh_SpecialValues(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength); + + RunForEachSpecialValue(() => + { + TensorPrimitives.Tanh(x, destination); + for (int i = 0; i < tensorLength; i++) + { + AssertEqualTolerance(MathF.Tanh(x[i]), destination[i]); + } + }, x); + } + + [Theory] + [MemberData(nameof(VectorLengthAndIteratedRange), new object[] { -11f, 11f, 0.2f })] + public static void Tanh_ValueRange(int vectorLengths, float element) + { + float[] x = new float[vectorLengths]; + float[] dest = new float[vectorLengths]; + + x.AsSpan().Fill(element); + TensorPrimitives.Tanh(x, dest); + + float expected = MathF.Tanh(element); + foreach (float actual in dest) + { + AssertEqualTolerance(expected, actual); + } + } + + [Theory] + [MemberData(nameof(TensorLengths))] + public static void Tanh_ThrowsForTooShortDestination(int tensorLength) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = CreateTensor(tensorLength - 1); + + AssertExtensions.Throws("destination", () => TensorPrimitives.Tanh(x, destination)); } [Fact] - public static void SumOfMagnitudes_KnownValues() + public static void Tanh_ThrowsForOverlapppingInputsWithOutputs() { - Assert.Equal(0, TensorPrimitives.SumOfMagnitudes([0])); - Assert.Equal(1, TensorPrimitives.SumOfMagnitudes([0, 1])); - Assert.Equal(6, TensorPrimitives.SumOfMagnitudes([1, 2, 3])); - Assert.Equal(6, TensorPrimitives.SumOfMagnitudes([-3, 0, 3])); - Assert.Equal(float.NaN, TensorPrimitives.SumOfMagnitudes([-3, float.NaN, 3])); + float[] array = new float[10]; + AssertExtensions.Throws("destination", () => TensorPrimitives.Tanh(array.AsSpan(1, 2), array.AsSpan(0, 2))); + AssertExtensions.Throws("destination", () => TensorPrimitives.Tanh(array.AsSpan(1, 2), array.AsSpan(2, 2))); } + #endregion } } diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.netcore.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.netcore.cs index 113f26048d352c..06ab341db16242 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.netcore.cs @@ -8,15 +8,16 @@ namespace System.Numerics.Tensors.Tests { public static partial class TensorPrimitivesTests { + #region ConvertToHalf [Theory] - [InlineData(0)] - [MemberData(nameof(TensorLengths))] + [MemberData(nameof(TensorLengthsIncluding0))] public static void ConvertToHalf(int tensorLength) { using BoundedMemory source = CreateAndFillTensor(tensorLength); foreach (int destLength in new[] { source.Length, source.Length + 1 }) { - Half[] destination = new Half[destLength]; + using BoundedMemory destination = BoundedMemory.Allocate(destLength); + destination.Span.Fill(Half.Zero); TensorPrimitives.ConvertToHalf(source, destination); @@ -35,6 +36,28 @@ public static void ConvertToHalf(int tensorLength) } } + [Theory] + [MemberData(nameof(TensorLengths))] + public static void ConvertToHalf_SpecialValues(int tensorLength) + { + using BoundedMemory source = CreateAndFillTensor(tensorLength); + using BoundedMemory destination = BoundedMemory.Allocate(tensorLength); + + // NaN, infinities, and 0s + source[s_random.Next(source.Length)] = float.NaN; + source[s_random.Next(source.Length)] = float.PositiveInfinity; + source[s_random.Next(source.Length)] = float.NegativeInfinity; + source[s_random.Next(source.Length)] = 0; + source[s_random.Next(source.Length)] = float.NegativeZero; + + TensorPrimitives.ConvertToHalf(source, destination); + + for (int i = 0; i < source.Length; i++) + { + Assert.Equal((Half)source[i], destination[i]); + } + } + [Theory] [MemberData(nameof(TensorLengths))] public static void ConvertToHalf_ThrowsForTooShortDestination(int tensorLength) @@ -44,13 +67,14 @@ public static void ConvertToHalf_ThrowsForTooShortDestination(int tensorLength) AssertExtensions.Throws("destination", () => TensorPrimitives.ConvertToHalf(source, destination)); } + #endregion + #region ConvertToSingle [Theory] - [InlineData(0)] - [MemberData(nameof(TensorLengths))] + [MemberData(nameof(TensorLengthsIncluding0))] public static void ConvertToSingle(int tensorLength) { - Half[] source = new Half[tensorLength]; + using BoundedMemory source = BoundedMemory.Allocate(tensorLength); for (int i = 0; i < source.Length; i++) { source[i] = (Half)s_random.NextSingle(); @@ -77,6 +101,32 @@ public static void ConvertToSingle(int tensorLength) } } } + [Theory] + [MemberData(nameof(TensorLengths))] + public static void ConvertToSingle_SpecialValues(int tensorLength) + { + using BoundedMemory source = BoundedMemory.Allocate(tensorLength); + for (int i = 0; i < source.Length; i++) + { + source[i] = (Half)s_random.NextSingle(); + } + + using BoundedMemory destination = CreateTensor(tensorLength); + + // NaN, infinities, and 0s + source[s_random.Next(source.Length)] = Half.NaN; + source[s_random.Next(source.Length)] = Half.PositiveInfinity; + source[s_random.Next(source.Length)] = Half.NegativeInfinity; + source[s_random.Next(source.Length)] = Half.Zero; + source[s_random.Next(source.Length)] = Half.NegativeZero; + + TensorPrimitives.ConvertToSingle(source, destination); + + for (int i = 0; i < source.Length; i++) + { + Assert.Equal((float)source[i], destination[i]); + } + } [Theory] [MemberData(nameof(TensorLengths))] @@ -87,5 +137,6 @@ public static void ConvertToSingle_ThrowsForTooShortDestination(int tensorLength AssertExtensions.Throws("destination", () => TensorPrimitives.ConvertToSingle(source, destination)); } + #endregion } } diff --git a/src/libraries/System.Private.CoreLib/src/System/AppContextConfigHelper.cs b/src/libraries/System.Private.CoreLib/src/System/AppContextConfigHelper.cs index 3a8c3e8e7ea72d..efbe9e691625d7 100644 --- a/src/libraries/System.Private.CoreLib/src/System/AppContextConfigHelper.cs +++ b/src/libraries/System.Private.CoreLib/src/System/AppContextConfigHelper.cs @@ -65,6 +65,45 @@ internal static int GetInt32Config(string configName, int defaultValue, bool all } } + internal static int GetInt32Config(string configName, string envVariable, int defaultValue, bool allowNegative = true) + { + string? str = Environment.GetEnvironmentVariable(envVariable); + if (str != null) + { + try + { + int result; + if (str.StartsWith('0')) + { + if (str.Length >= 2 && str[1] == 'x') + { + result = Convert.ToInt32(str, 16); + } + else + { + result = Convert.ToInt32(str, 8); + } + } + else + { + result = int.Parse(str, NumberStyles.AllowLeadingSign, NumberFormatInfo.InvariantInfo); + } + + if (allowNegative || result >= 0) + { + return result; + } + } + catch (FormatException) + { + } + catch (OverflowException) + { + } + } + + return GetInt32Config(configName, defaultValue, allowNegative); + } internal static short GetInt16Config(string configName, short defaultValue, bool allowNegative = true) { @@ -112,5 +151,45 @@ internal static short GetInt16Config(string configName, short defaultValue, bool return defaultValue; } } + + internal static short GetInt16Config(string configName, string envVariable, short defaultValue, bool allowNegative = true) + { + string? str = Environment.GetEnvironmentVariable(envVariable); + if (str != null) + { + try + { + short result; + if (str.StartsWith('0')) + { + if (str.Length >= 2 && str[1] == 'x') + { + result = Convert.ToInt16(str, 16); + } + else + { + result = Convert.ToInt16(str, 8); + } + } + else + { + result = short.Parse(str, NumberStyles.AllowLeadingSign, NumberFormatInfo.InvariantInfo); + } + + if (allowNegative || result >= 0) + { + return result; + } + } + catch (FormatException) + { + } + catch (OverflowException) + { + } + } + + return GetInt16Config(configName, defaultValue, allowNegative); + } } } diff --git a/src/libraries/System.Private.CoreLib/src/System/Diagnostics/Tracing/EventPipeEventDispatcher.cs b/src/libraries/System.Private.CoreLib/src/System/Diagnostics/Tracing/EventPipeEventDispatcher.cs index efd4b8cfb656dd..030560b2002145 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Diagnostics/Tracing/EventPipeEventDispatcher.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Diagnostics/Tracing/EventPipeEventDispatcher.cs @@ -103,8 +103,8 @@ private void CommitDispatchConfiguration() new EventPipeProviderConfiguration(NativeRuntimeEventSource.EventSourceName, (ulong)aggregatedKeywords, (uint)enableLevel, null) }; - m_sessionID = EventPipeInternal.Enable(null, EventPipeSerializationFormat.NetTrace, DefaultEventListenerCircularMBSize, providerConfiguration); - if (m_sessionID == 0) + ulong sessionID = EventPipeInternal.Enable(null, EventPipeSerializationFormat.NetTrace, DefaultEventListenerCircularMBSize, providerConfiguration); + if (sessionID == 0) { throw new EventSourceException(SR.EventSource_CouldNotEnableEventPipe); } @@ -113,7 +113,7 @@ private void CommitDispatchConfiguration() EventPipeSessionInfo sessionInfo; unsafe { - if (!EventPipeInternal.GetSessionInfo(m_sessionID, &sessionInfo)) + if (!EventPipeInternal.GetSessionInfo(sessionID, &sessionInfo)) { Debug.Fail("GetSessionInfo returned false."); } @@ -124,8 +124,11 @@ private void CommitDispatchConfiguration() long syncTimeQPC = sessionInfo.StartTimeStamp; long timeQPCFrequency = sessionInfo.TimeStampFrequency; + Debug.Assert(Volatile.Read(ref m_sessionID) == 0); + Volatile.Write(ref m_sessionID, sessionID); + // Start the dispatch task. - StartDispatchTask(m_sessionID, syncTimeUtc, syncTimeQPC, timeQPCFrequency); + StartDispatchTask(sessionID, syncTimeUtc, syncTimeQPC, timeQPCFrequency); } private void StartDispatchTask(ulong sessionID, DateTime syncTimeUtc, long syncTimeQPC, long timeQPCFrequency) @@ -142,12 +145,16 @@ private void SetStopDispatchTask() { Debug.Assert(Monitor.IsEntered(m_dispatchControlLock)); - if (m_dispatchTask != null) + if (m_dispatchTaskCancellationSource?.IsCancellationRequested ?? true) { - Debug.Assert(m_dispatchTaskCancellationSource != null); - m_dispatchTaskCancellationSource?.Cancel(); - EventPipeInternal.SignalSession(m_sessionID); + return; } + + ulong sessionID = Volatile.Read(ref m_sessionID); + Debug.Assert(sessionID != 0); + m_dispatchTaskCancellationSource.Cancel(); + EventPipeInternal.SignalSession(sessionID); + Volatile.Write(ref m_sessionID, 0); } private unsafe void DispatchEventsToEventListeners(ulong sessionID, DateTime syncTimeUtc, long syncTimeQPC, long timeQPCFrequency, Task? previousDispatchTask, CancellationToken token) @@ -187,7 +194,16 @@ private unsafe void DispatchEventsToEventListeners(ulong sessionID, DateTime syn } } - // Disable the old session. This can happen asynchronously since we aren't using the old session anymore + // Wait for SignalSession() to be called before we call disable, otherwise + // the SignalSession() call could be on a disabled session. + SpinWait sw = default; + while (Volatile.Read(ref m_sessionID) == sessionID) + { + sw.SpinOnce(); + } + + // Disable the old session. This can happen asynchronously since we aren't using the old session + // anymore. EventPipeInternal.Disable(sessionID); } diff --git a/src/libraries/System.Private.CoreLib/src/System/Globalization/CultureData.Icu.cs b/src/libraries/System.Private.CoreLib/src/System/Globalization/CultureData.Icu.cs index 0d4dad112249ca..12af791d392e24 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Globalization/CultureData.Icu.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Globalization/CultureData.Icu.cs @@ -420,7 +420,19 @@ private static CultureInfo[] IcuEnumCultures(CultureTypes types) return Array.Empty(); } - int bufferLength = Interop.Globalization.GetLocales(null, 0); + int bufferLength; +#if TARGET_MACCATALYST || TARGET_IOS || TARGET_TVOS + if (GlobalizationMode.Hybrid) + { + bufferLength = Interop.Globalization.GetLocalesNative(null, 0); + } + else + { + bufferLength = Interop.Globalization.GetLocales(null, 0); + } +#else + bufferLength = Interop.Globalization.GetLocales(null, 0); +#endif if (bufferLength <= 0) { return Array.Empty(); @@ -428,7 +440,18 @@ private static CultureInfo[] IcuEnumCultures(CultureTypes types) char [] chars = new char[bufferLength]; +#if TARGET_MACCATALYST || TARGET_IOS || TARGET_TVOS + if (GlobalizationMode.Hybrid) + { + bufferLength = Interop.Globalization.GetLocalesNative(chars, bufferLength); + } + else + { + bufferLength = Interop.Globalization.GetLocales(chars, bufferLength); + } +#else bufferLength = Interop.Globalization.GetLocales(chars, bufferLength); +#endif if (bufferLength <= 0) { return Array.Empty(); diff --git a/src/libraries/System.Private.CoreLib/src/System/Half.cs b/src/libraries/System.Private.CoreLib/src/System/Half.cs index 8daa37bbab576b..cd3e6ab3ed73c3 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Half.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Half.cs @@ -1044,7 +1044,7 @@ public static explicit operator float(Half value) // BitConverter.SingleToUInt32Bits(1.0f) - ((uint)BitConverter.HalfToUInt16Bits((Half)1.0f) << 13) const uint ExponentOffset = 0x3800_0000u; // Mask for sign bit in Single - const uint FloatSignMask = float.SignMask; + const uint SingleSignMask = float.SignMask; // Mask for exponent bits in Half const uint HalfExponentMask = BiasedExponentMask; // Mask for bits in Single converted from Half @@ -1052,7 +1052,7 @@ public static explicit operator float(Half value) // Extract the internal representation of value short valueInInt16Bits = BitConverter.HalfToInt16Bits(value); // Extract sign bit of value - uint sign = (uint)(int)valueInInt16Bits & FloatSignMask; + uint sign = (uint)(int)valueInInt16Bits & SingleSignMask; // Copy sign bit to upper bits uint bitValueInProcess = (uint)valueInInt16Bits; // Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift) diff --git a/src/libraries/System.Private.CoreLib/src/System/IO/UnmanagedMemoryStream.cs b/src/libraries/System.Private.CoreLib/src/System/IO/UnmanagedMemoryStream.cs index 5f049e69445381..b2a3134ae7501a 100644 --- a/src/libraries/System.Private.CoreLib/src/System/IO/UnmanagedMemoryStream.cs +++ b/src/libraries/System.Private.CoreLib/src/System/IO/UnmanagedMemoryStream.cs @@ -23,14 +23,9 @@ namespace System.IO * of the UnmanagedMemoryStream. * 3) You clean up the memory when appropriate. The UnmanagedMemoryStream * currently will do NOTHING to free this memory. - * 4) All calls to Write and WriteByte may not be threadsafe currently. - * - * It may become necessary to add in some sort of - * DeallocationMode enum, specifying whether we unmap a section of memory, - * call free, run a user-provided delegate to free the memory, etc. - * We'll suggest user write a subclass of UnmanagedMemoryStream that uses - * a SafeHandle subclass to hold onto the memory. - * + * 4) This type is not thread safe. However, the implementation should prevent buffer + * overruns or returning uninitialized memory when Reads and Writes are called + * concurrently in thread unsafe manner. */ /// @@ -40,10 +35,10 @@ public class UnmanagedMemoryStream : Stream { private SafeBuffer? _buffer; private unsafe byte* _mem; - private long _length; - private long _capacity; - private long _position; - private long _offset; + private nuint _capacity; + private nuint _offset; + private nuint _length; // nuint to guarantee atomic access on 32-bit platforms + private long _position; // long to allow seeking to any location beyond the length of the stream. private FileAccess _access; private bool _isOpen; private CachedCompletedInt32Task _lastReadTask; // The last successful task returned from ReadAsync @@ -123,10 +118,10 @@ protected void Initialize(SafeBuffer buffer, long offset, long length, FileAcces } } - _offset = offset; + _offset = (nuint)offset; _buffer = buffer; - _length = length; - _capacity = length; + _length = (nuint)length; + _capacity = (nuint)length; _access = access; _isOpen = true; } @@ -171,8 +166,8 @@ protected unsafe void Initialize(byte* pointer, long length, long capacity, File _mem = pointer; _offset = 0; - _length = length; - _capacity = capacity; + _length = (nuint)length; + _capacity = (nuint)capacity; _access = access; _isOpen = true; } @@ -259,7 +254,7 @@ public override long Length get { EnsureNotClosed(); - return Interlocked.Read(ref _length); + return (long)_length; } } @@ -271,7 +266,7 @@ public long Capacity get { EnsureNotClosed(); - return _capacity; + return (long)_capacity; } } @@ -283,14 +278,14 @@ public override long Position get { if (!CanSeek) ThrowHelper.ThrowObjectDisposedException_StreamClosed(null); - return Interlocked.Read(ref _position); + return _position; } set { ArgumentOutOfRangeException.ThrowIfNegative(value); if (!CanSeek) ThrowHelper.ThrowObjectDisposedException_StreamClosed(null); - Interlocked.Exchange(ref _position, value); + _position = value; } } @@ -308,11 +303,10 @@ public unsafe byte* PositionPointer EnsureNotClosed(); // Use a temp to avoid a race - long pos = Interlocked.Read(ref _position); - if (pos > _capacity) + long pos = _position; + if (pos > (long)_capacity) throw new IndexOutOfRangeException(SR.IndexOutOfRange_UMSPosition); - byte* ptr = _mem + pos; - return ptr; + return _mem + pos; } set { @@ -327,7 +321,7 @@ public unsafe byte* PositionPointer if (newPosition < 0) throw new ArgumentOutOfRangeException(nameof(value), SR.ArgumentOutOfRange_UnmanagedMemStreamLength); - Interlocked.Exchange(ref _position, newPosition); + _position = newPosition; } } @@ -367,8 +361,13 @@ internal int ReadCore(Span buffer) // Use a local variable to avoid a race where another thread // changes our position after we decide we can read some bytes. - long pos = Interlocked.Read(ref _position); - long len = Interlocked.Read(ref _length); + long pos = _position; + + // Use a volatile read to prevent reading of the uninitialized memory. This volatile read + // and matching volatile write that set _length avoids reordering of NativeMemory.Clear + // operations with reading of the buffer below. + long len = (long)Volatile.Read(ref _length); + long n = Math.Min(len - pos, buffer.Length); if (n <= 0) { @@ -407,7 +406,7 @@ internal int ReadCore(Span buffer) } } - Interlocked.Exchange(ref _position, pos + n); + _position = pos + n; return nInt; } @@ -484,11 +483,16 @@ public override int ReadByte() EnsureNotClosed(); EnsureReadable(); - long pos = Interlocked.Read(ref _position); // Use a local to avoid a race condition - long len = Interlocked.Read(ref _length); + long pos = _position; // Use a local to avoid a race condition + + // Use a volatile read to prevent reading of the uninitialized memory. This volatile read + // and matching volatile write that set _length avoids reordering of NativeMemory.Clear + // operations with reading of the buffer below. + long len = (long)Volatile.Read(ref _length); + if (pos >= len) return -1; - Interlocked.Exchange(ref _position, pos + 1); + _position = pos + 1; int result; if (_buffer != null) { @@ -529,35 +533,33 @@ public override long Seek(long offset, SeekOrigin loc) { EnsureNotClosed(); + long newPosition; switch (loc) { case SeekOrigin.Begin: - if (offset < 0) + newPosition = offset; + if (newPosition < 0) throw new IOException(SR.IO_SeekBeforeBegin); - Interlocked.Exchange(ref _position, offset); break; case SeekOrigin.Current: - long pos = Interlocked.Read(ref _position); - if (offset + pos < 0) + newPosition = _position + offset; + if (newPosition < 0) throw new IOException(SR.IO_SeekBeforeBegin); - Interlocked.Exchange(ref _position, offset + pos); break; case SeekOrigin.End: - long len = Interlocked.Read(ref _length); - if (len + offset < 0) + newPosition = (long)_length + offset; + if (newPosition < 0) throw new IOException(SR.IO_SeekBeforeBegin); - Interlocked.Exchange(ref _position, len + offset); break; default: throw new ArgumentException(SR.Argument_InvalidSeekOrigin); } - long finalPos = Interlocked.Read(ref _position); - Debug.Assert(finalPos >= 0, "_position >= 0"); - return finalPos; + _position = newPosition; + return newPosition; } /// @@ -573,11 +575,10 @@ public override void SetLength(long value) EnsureNotClosed(); EnsureWriteable(); - if (value > _capacity) + if (value > (long)_capacity) throw new IOException(SR.IO_FixedCapacity); - long pos = Interlocked.Read(ref _position); - long len = Interlocked.Read(ref _length); + long len = (long)_length; if (value > len) { unsafe @@ -585,10 +586,11 @@ public override void SetLength(long value) NativeMemory.Clear(_mem + len, (nuint)(value - len)); } } - Interlocked.Exchange(ref _length, value); - if (pos > value) + Volatile.Write(ref _length, (nuint)value); // volatile to prevent reading of uninitialized memory + + if (_position > value) { - Interlocked.Exchange(ref _position, value); + _position = value; } } @@ -625,8 +627,8 @@ internal unsafe void WriteCore(ReadOnlySpan buffer) EnsureNotClosed(); EnsureWriteable(); - long pos = Interlocked.Read(ref _position); // Use a local to avoid a race condition - long len = Interlocked.Read(ref _length); + long pos = _position; // Use a local to avoid a race condition + long len = (long)_length; long n = pos + buffer.Length; // Check for overflow if (n < 0) @@ -634,7 +636,7 @@ internal unsafe void WriteCore(ReadOnlySpan buffer) throw new IOException(SR.IO_StreamTooLong); } - if (n > _capacity) + if (n > (long)_capacity) { throw new NotSupportedException(SR.IO_FixedCapacity); } @@ -648,16 +650,16 @@ internal unsafe void WriteCore(ReadOnlySpan buffer) NativeMemory.Clear(_mem + len, (nuint)(pos - len)); } - // set length after zeroing memory to avoid race condition of accessing unzeroed memory + // set length after zeroing memory to avoid race condition of accessing uninitialized memory if (n > len) { - Interlocked.Exchange(ref _length, n); + Volatile.Write(ref _length, (nuint)n); // volatile to prevent reading of uninitialized memory } } if (_buffer != null) { - long bytesLeft = _capacity - pos; + long bytesLeft = (long)_capacity - pos; if (bytesLeft < buffer.Length) { throw new ArgumentException(SR.Arg_BufferTooSmall); @@ -682,8 +684,7 @@ internal unsafe void WriteCore(ReadOnlySpan buffer) Buffer.Memmove(ref *(_mem + pos), ref MemoryMarshal.GetReference(buffer), (nuint)buffer.Length); } - Interlocked.Exchange(ref _position, n); - return; + _position = n; } /// @@ -754,8 +755,8 @@ public override void WriteByte(byte value) EnsureNotClosed(); EnsureWriteable(); - long pos = Interlocked.Read(ref _position); // Use a local to avoid a race condition - long len = Interlocked.Read(ref _length); + long pos = _position; // Use a local to avoid a race condition + long len = (long)_length; long n = pos + 1; if (pos >= len) { @@ -763,7 +764,7 @@ public override void WriteByte(byte value) if (n < 0) throw new IOException(SR.IO_StreamTooLong); - if (n > _capacity) + if (n > (long)_capacity) throw new NotSupportedException(SR.IO_FixedCapacity); // Check to see whether we are now expanding the stream and must @@ -779,8 +780,7 @@ public override void WriteByte(byte value) } } - // set length after zeroing memory to avoid race condition of accessing unzeroed memory - Interlocked.Exchange(ref _length, n); + Volatile.Write(ref _length, (nuint)n); // volatile to prevent reading of uninitialized memory } } @@ -810,7 +810,7 @@ public override void WriteByte(byte value) _mem[pos] = value; } } - Interlocked.Exchange(ref _position, n); + _position = n; } } } diff --git a/src/libraries/System.Private.CoreLib/src/System/Numerics/INumberBase.cs b/src/libraries/System.Private.CoreLib/src/System/Numerics/INumberBase.cs index 13397687e2a1dc..4368605183d2e9 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Numerics/INumberBase.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Numerics/INumberBase.cs @@ -27,7 +27,7 @@ public interface INumberBase ISubtractionOperators, IUnaryPlusOperators, IUnaryNegationOperators, - // IUtf8SpanFormattable, + IUtf8SpanFormattable, IUtf8SpanParsable where TSelf : INumberBase? { @@ -457,9 +457,7 @@ static virtual bool TryParse(ReadOnlySpan utf8Text, NumberStyles style, IF return succeeded; } - // Workaround devdiv/#1851707: C++/CLI fails to compile when encountering a Default Interface Method implemented in a derived interface - // bool IUtf8SpanFormattable.TryFormat(Span utf8Destination, out int bytesWritten, ReadOnlySpan format, IFormatProvider? provider) - bool TryFormat(Span utf8Destination, out int bytesWritten, ReadOnlySpan format, IFormatProvider? provider) + bool IUtf8SpanFormattable.TryFormat(Span utf8Destination, out int bytesWritten, ReadOnlySpan format, IFormatProvider? provider) { char[]? utf16DestinationArray; scoped Span utf16Destination; diff --git a/src/libraries/System.Private.CoreLib/src/System/Threading/PortableThreadPool.WorkerThread.NonBrowser.cs b/src/libraries/System.Private.CoreLib/src/System/Threading/PortableThreadPool.WorkerThread.NonBrowser.cs index 2e634fd469d9d8..c3b278019f6d92 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Threading/PortableThreadPool.WorkerThread.NonBrowser.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Threading/PortableThreadPool.WorkerThread.NonBrowser.cs @@ -7,11 +7,29 @@ namespace System.Threading { internal sealed partial class PortableThreadPool { + private int _numThreadsBeingKeptAlive; + /// /// The worker thread infastructure for the CLR thread pool. /// private static partial class WorkerThread { + private static readonly short ThreadsToKeepAlive = DetermineThreadsToKeepAlive(); + + private static short DetermineThreadsToKeepAlive() + { + const short DefaultThreadsToKeepAlive = 0; + + // The number of worker threads to keep alive after they are created. Set to -1 to keep all created worker + // threads alive. When the ThreadTimeoutMs config value is also set, for worker threads the timeout applies to + // worker threads that are in excess of the number configured for ThreadsToKeepAlive. + short threadsToKeepAlive = + AppContextConfigHelper.GetInt16Config( + "System.Threading.ThreadPool.ThreadsToKeepAlive", + "DOTNET_ThreadPool_ThreadsToKeepAlive", + DefaultThreadsToKeepAlive); + return threadsToKeepAlive >= -1 ? threadsToKeepAlive : DefaultThreadsToKeepAlive; + } /// /// Semaphore for controlling how many threads are currently working. @@ -50,10 +68,36 @@ private static void WorkerThreadStart() LowLevelLock threadAdjustmentLock = threadPoolInstance._threadAdjustmentLock; LowLevelLifoSemaphore semaphore = s_semaphore; + // Determine the idle timeout to use for this thread. Some threads may always be kept alive based on config. + int timeoutMs = ThreadPoolThreadTimeoutMs; + if (ThreadsToKeepAlive != 0) + { + if (ThreadsToKeepAlive < 0) + { + timeoutMs = Timeout.Infinite; + } + else + { + int count = threadPoolInstance._numThreadsBeingKeptAlive; + while (count < ThreadsToKeepAlive) + { + int countBeforeUpdate = + Interlocked.CompareExchange(ref threadPoolInstance._numThreadsBeingKeptAlive, count + 1, count); + if (countBeforeUpdate == count) + { + timeoutMs = Timeout.Infinite; + break; + } + + count = countBeforeUpdate; + } + } + } + while (true) { bool spinWait = true; - while (semaphore.Wait(ThreadPoolThreadTimeoutMs, spinWait)) + while (semaphore.Wait(timeoutMs, spinWait)) { WorkerDoWork(threadPoolInstance, ref spinWait); } @@ -65,7 +109,6 @@ private static void WorkerThreadStart() } } - private static void CreateWorkerThread() { // Thread pool threads must start in the default execution context without transferring the context, so diff --git a/src/libraries/System.Private.CoreLib/src/System/Threading/PortableThreadPool.cs b/src/libraries/System.Private.CoreLib/src/System/Threading/PortableThreadPool.cs index 9ada201e134c5a..db49ea2c5a5a92 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Threading/PortableThreadPool.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Threading/PortableThreadPool.cs @@ -13,7 +13,6 @@ namespace System.Threading /// internal sealed partial class PortableThreadPool { - private const int ThreadPoolThreadTimeoutMs = 20 * 1000; // If you change this make sure to change the timeout times in the tests. private const int SmallStackSizeBytes = 256 * 1024; private const short MaxPossibleThreadCount = short.MaxValue; @@ -40,6 +39,23 @@ internal sealed partial class PortableThreadPool private static readonly short ForcedMaxWorkerThreads = AppContextConfigHelper.GetInt16Config("System.Threading.ThreadPool.MaxThreads", 0, false); + private static readonly int ThreadPoolThreadTimeoutMs = DetermineThreadPoolThreadTimeoutMs(); + + private static int DetermineThreadPoolThreadTimeoutMs() + { + const int DefaultThreadPoolThreadTimeoutMs = 20 * 1000; // If you change this make sure to change the timeout times in the tests. + + // The amount of time in milliseconds a thread pool thread waits without having done any work before timing out and + // exiting. Set to -1 to disable the timeout. Applies to worker threads and wait threads. Also see the + // ThreadsToKeepAlive config value for relevant information. + int timeoutMs = + AppContextConfigHelper.GetInt32Config( + "System.Threading.ThreadPool.ThreadTimeoutMs", + "DOTNET_ThreadPool_ThreadTimeoutMs", + DefaultThreadPoolThreadTimeoutMs); + return timeoutMs >= -1 ? timeoutMs : DefaultThreadPoolThreadTimeoutMs; + } + [ThreadStatic] private static object? t_completionCountObject; diff --git a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSExportCodeGenerator.cs b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSExportCodeGenerator.cs index 0a18241ef7c767..a00fc9e4024f76 100644 --- a/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSExportCodeGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices.JavaScript/gen/JSImportGenerator/JSExportCodeGenerator.cs @@ -59,7 +59,7 @@ public JSExportCodeGenerator( public BlockSyntax GenerateJSExportBody() { - StatementSyntax invoke = InvokeSyntax(); + List invoke = InvokeSyntax(); GeneratedStatements statements = GeneratedStatements.Create(_marshallers, _context); bool shouldInitializeVariables = !statements.GuaranteedUnmarshal.IsEmpty || !statements.CleanupCallerAllocated.IsEmpty || !statements.CleanupCalleeAllocated.IsEmpty; VariableDeclarations declarations = VariableDeclarations.GenerateDeclarationsForUnmanagedToManaged(_marshallers, _context, shouldInitializeVariables); @@ -79,7 +79,7 @@ public BlockSyntax GenerateJSExportBody() var tryStatements = new List(); tryStatements.AddRange(statements.Unmarshal); - tryStatements.Add(invoke); + tryStatements.AddRange(invoke); if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty)) { @@ -93,6 +93,18 @@ public BlockSyntax GenerateJSExportBody() tryStatements.AddRange(statements.Marshal); List allStatements = setupStatements; + + // Wrap unmarshall, invocation and return value marshalling in try-catch. + // In case of exception, marshal exception instead of return value. + var tryInvokeAndMarshal = TryStatement(SingletonList(CatchClause() + .WithDeclaration(CatchDeclaration(IdentifierName(Constants.ExceptionGlobal)).WithIdentifier(Identifier("ex"))) + .WithBlock(Block(SingletonList( + ExpressionStatement(InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(Constants.ArgumentException), IdentifierName(Constants.ToJSMethod))) + .WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(IdentifierName("ex"))))))))))) + .WithBlock(Block(tryStatements)); + List finallyStatements = new List(); if (!(statements.GuaranteedUnmarshal.IsEmpty && statements.CleanupCalleeAllocated.IsEmpty)) { @@ -100,16 +112,14 @@ public BlockSyntax GenerateJSExportBody() } finallyStatements.AddRange(statements.CleanupCallerAllocated); + if (finallyStatements.Count > 0) { - allStatements.Add( - TryStatement(Block(tryStatements), default, FinallyClause(Block(finallyStatements)))); - } - else - { - allStatements.AddRange(tryStatements); + tryInvokeAndMarshal = TryStatement(Block(tryInvokeAndMarshal), default, FinallyClause(Block(finallyStatements))); } + allStatements.Add(tryInvokeAndMarshal); + return Block(allStatements); } @@ -175,7 +185,7 @@ private void SetupSyntax(List statementsToUpdate) Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(1))))))))))))); } - private TryStatementSyntax InvokeSyntax() + private List InvokeSyntax() { var statements = new List(); var arguments = new List(); @@ -205,16 +215,8 @@ private TryStatementSyntax InvokeSyntax() IdentifierName(nativeIdentifier), invocation)); statements.Add(statement); - statements.AddRange(_marshallers.ManagedReturnMarshaller.Generator.Generate(_marshallers.ManagedReturnMarshaller.TypeInfo, _context with { CurrentStage = StubCodeContext.Stage.Marshal })); } - return TryStatement(SingletonList(CatchClause() - .WithDeclaration(CatchDeclaration(IdentifierName(Constants.ExceptionGlobal)).WithIdentifier(Identifier("ex"))) - .WithBlock(Block(SingletonList( - ExpressionStatement(InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(Constants.ArgumentException), IdentifierName(Constants.ToJSMethod))) - .WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(IdentifierName("ex"))))))))))) - .WithBlock(Block(statements)); + return statements; } diff --git a/src/libraries/System.Runtime/ref/System.Runtime.cs b/src/libraries/System.Runtime/ref/System.Runtime.cs index c284894678b3bb..bc7e04a54f96b3 100644 --- a/src/libraries/System.Runtime/ref/System.Runtime.cs +++ b/src/libraries/System.Runtime/ref/System.Runtime.cs @@ -10801,7 +10801,7 @@ public partial interface IMultiplyOperators where TSelf static virtual TResult operator checked *(TSelf left, TOther right) { throw null; } static abstract TResult operator *(TSelf left, TOther right); } - public partial interface INumberBase : System.IEquatable, System.IFormattable, System.IParsable, System.ISpanFormattable, System.ISpanParsable, System.Numerics.IAdditionOperators, System.Numerics.IAdditiveIdentity, System.Numerics.IDecrementOperators, System.Numerics.IDivisionOperators, System.Numerics.IEqualityOperators, System.Numerics.IIncrementOperators, System.Numerics.IMultiplicativeIdentity, System.Numerics.IMultiplyOperators, System.Numerics.ISubtractionOperators, System.Numerics.IUnaryNegationOperators, System.Numerics.IUnaryPlusOperators, /* System.IUtf8SpanFormattable, */ System.IUtf8SpanParsable where TSelf : System.Numerics.INumberBase? + public partial interface INumberBase : System.IEquatable, System.IFormattable, System.IParsable, System.ISpanFormattable, System.ISpanParsable, System.Numerics.IAdditionOperators, System.Numerics.IAdditiveIdentity, System.Numerics.IDecrementOperators, System.Numerics.IDivisionOperators, System.Numerics.IEqualityOperators, System.Numerics.IIncrementOperators, System.Numerics.IMultiplicativeIdentity, System.Numerics.IMultiplyOperators, System.Numerics.ISubtractionOperators, System.Numerics.IUnaryNegationOperators, System.Numerics.IUnaryPlusOperators, System.IUtf8SpanFormattable, System.IUtf8SpanParsable where TSelf : System.Numerics.INumberBase? { static abstract TSelf One { get; } static abstract int Radix { get; } @@ -10843,9 +10843,7 @@ static virtual TSelf CreateTruncating(TOther value) static virtual TSelf Parse(System.ReadOnlySpan utf8Text, System.Globalization.NumberStyles style, System.IFormatProvider? provider) { throw null; } static abstract TSelf Parse(System.ReadOnlySpan s, System.Globalization.NumberStyles style, System.IFormatProvider? provider); static abstract TSelf Parse(string s, System.Globalization.NumberStyles style, System.IFormatProvider? provider); - // Workaround devdiv/#1851707: C++/CLI fails to compile when encountering a Default Interface Method implemented in a derived interface - // bool System.IUtf8SpanFormattable.TryFormat(System.Span utf8Destination, out int bytesWritten, System.ReadOnlySpan format, System.IFormatProvider? provider) { throw null; } - bool TryFormat(System.Span utf8Destination, out int bytesWritten, System.ReadOnlySpan format, System.IFormatProvider? provider) { throw null; } + bool System.IUtf8SpanFormattable.TryFormat(System.Span utf8Destination, out int bytesWritten, System.ReadOnlySpan format, System.IFormatProvider? provider) { throw null; } static TSelf System.IUtf8SpanParsable.Parse(System.ReadOnlySpan utf8Text, System.IFormatProvider? provider) { throw null; } static bool System.IUtf8SpanParsable.TryParse(System.ReadOnlySpan utf8Text, System.IFormatProvider? provider, [System.Diagnostics.CodeAnalysis.MaybeNullWhen(false)] out TSelf result) { throw null; } protected static abstract bool TryConvertFromChecked(TOther value, [System.Diagnostics.CodeAnalysis.MaybeNullWhen(false)] out TSelf result) diff --git a/src/libraries/System.Security.Cryptography.ProtectedData/src/PACKAGE.md b/src/libraries/System.Security.Cryptography.ProtectedData/src/PACKAGE.md new file mode 100644 index 00000000000000..7a0f751326f550 --- /dev/null +++ b/src/libraries/System.Security.Cryptography.ProtectedData/src/PACKAGE.md @@ -0,0 +1,72 @@ +## About + + + +System.Security.Cryptography.ProtectedData offers a simplified interface for utilizing Microsoft Windows DPAPI's [CryptProtectData](https://learn.microsoft.com/windows/win32/api/dpapi/nf-dpapi-cryptprotectdata) and [CryptUnprotectData](https://learn.microsoft.com/windows/win32/api/dpapi/nf-dpapi-cryptunprotectdata) functions. + +**Note**: Since it relies on Windows DPAPI, this package is only supported on Windows platforms. +For more complex cryptographic operations or cross-platform support, consider the [System.Security.Cryptography](https://learn.microsoft.com/dotnet/api/system.security.cryptography) namespace. + +## Key Features + + + +* Built upon the robust and secure Windows Data Protection API (DPAPI). +* Data can be protected either for current process or for any process on the machine. +* Scope of protection can be defined either to the current user or the local machine. + +## How to Use + + + +Utilizing this package is quite simple, and it mainly revolves around two methods: `Protect` and `Unprotect`. + +Here, `originalData` is the data you want to protect, `optionalEntropy` is an additional byte array used to increase encryption complexity, and `DataProtectionScope` specifies whether the data protection should apply to the current user or the machine. + +```csharp +using System.Security.Cryptography; +using System.Text; + +byte[] originalData = Encoding.UTF8.GetBytes("This is a secret"); +byte[] optionalEntropy = new byte[64]; +Random.Shared.NextBytes(optionalEntropy); + +// To protect: +byte[] encryptedData = ProtectedData.Protect( + originalData, + optionalEntropy, + DataProtectionScope.CurrentUser); + +// To unprotect: +byte[] decryptedData = ProtectedData.Unprotect( + encryptedData, + optionalEntropy, + DataProtectionScope.CurrentUser); +``` + +## Main Types + + + +The main type provided by this library is: + +* `System.Security.Cryptography.ProtectedData` + +## Additional Documentation + + + +* [Conceptual documentation](https://learn.microsoft.com/dotnet/standard/security/how-to-use-data-protection) +* [API documentation](https://learn.microsoft.com/dotnet/api/system.security.cryptography.protecteddata) + +## Related Packages + + + +* PKCS and CMS algorithms: [System.Security.Cryptography.Pkcs](https://www.nuget.org/packages/System.Security.Cryptography.Pkcs/) + +## Feedback & Contributing + + + +System.Security.Cryptography.ProtectedData is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/ChainPal.Apple.cs b/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/ChainPal.Apple.cs index 7c3ab590319808..6df9aacc727856 100644 --- a/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/ChainPal.Apple.cs +++ b/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/ChainPal.Apple.cs @@ -348,16 +348,14 @@ private void BuildAndSetProperties((X509Certificate2, int)[] elementTuples) for (int i = 0; i < elementTuples.Length; i++) { - (X509Certificate2, int) tuple = elementTuples[i]; + (X509Certificate2 cert, int chainStatus) = elementTuples[i]; - elements[i] = BuildElement(tuple.Item1, tuple.Item2); - allStatus |= tuple.Item2; + elements[i] = new X509ChainElement(cert, BuildChainElementStatuses(cert, chainStatus), ""); + allStatus |= chainStatus; } ChainElements = elements; - - X509ChainElement rollupElement = BuildElement(null!, allStatus); - ChainStatus = rollupElement.ChainElementStatus; + ChainStatus = BuildChainElementStatuses(null, allStatus); } private static void FixupRevocationStatus( @@ -457,11 +455,11 @@ private static X509ChainStatusFlags FindUntrustedRootReason(X509Certificate2 cer return X509ChainStatusFlags.UntrustedRoot; } - private X509ChainElement BuildElement(X509Certificate2 cert, int dwStatus) + private X509ChainStatus[] BuildChainElementStatuses(X509Certificate2? cert, int dwStatus) { if (dwStatus == 0) { - return new X509ChainElement(cert, Array.Empty(), ""); + return Array.Empty(); } List statuses = new List(); @@ -499,7 +497,7 @@ private X509ChainElement BuildElement(X509Certificate2 cert, int dwStatus) } } - return new X509ChainElement(cert, statuses.ToArray(), ""); + return statuses.ToArray(); } private readonly struct X509ChainErrorMapping diff --git a/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/OpenSslCachedSystemStoreProvider.cs b/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/OpenSslCachedSystemStoreProvider.cs index 4c9643c01e2fcb..e66b3d1ad11022 100644 --- a/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/OpenSslCachedSystemStoreProvider.cs +++ b/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/X509Certificates/OpenSslCachedSystemStoreProvider.cs @@ -21,14 +21,14 @@ internal sealed class OpenSslCachedSystemStoreProvider : IStorePal private static readonly TimeSpan s_lastWriteRecheckInterval = TimeSpan.FromSeconds(5); private static readonly TimeSpan s_assumeInvalidInterval = TimeSpan.FromMinutes(5); private static readonly Stopwatch s_recheckStopwatch = new Stopwatch(); - private static DirectoryInfo? s_rootStoreDirectoryInfo = SafeOpenRootDirectoryInfo(); + private static string[]? s_rootStoreDirectories; private static bool s_defaultRootDir; - private static readonly FileInfo? s_rootStoreFileInfo = SafeOpenRootFileInfo(); + private static string? s_rootStoreFile; + private static DateTime[]? s_directoryLastWrite; + private static DateTime s_fileLastWrite; // Use non-Value-Tuple so that it's an atomic update. private static Tuple? s_nativeCollections; - private static DateTime s_directoryCertsLastWrite; - private static DateTime s_fileCertsLastWrite; private readonly bool _isRoot; @@ -93,18 +93,11 @@ private static Tuple GetCollections() { lock (s_recheckStopwatch) { - FileInfo? fileInfo = s_rootStoreFileInfo; - DirectoryInfo? dirInfo = s_rootStoreDirectoryInfo; - - fileInfo?.Refresh(); - dirInfo?.Refresh(); - if (ret == null || elapsed > s_assumeInvalidInterval || - (fileInfo != null && fileInfo.Exists && ContentWriteTime(fileInfo) != s_fileCertsLastWrite) || - (dirInfo != null && dirInfo.Exists && ContentWriteTime(dirInfo) != s_directoryCertsLastWrite)) + LastWriteTimesHaveChanged()) { - ret = LoadMachineStores(dirInfo, fileInfo); + ret = LoadMachineStores(); } } } @@ -113,9 +106,37 @@ private static Tuple GetCollections() return ret; } - private static Tuple LoadMachineStores( - DirectoryInfo? rootStorePath, - FileInfo? rootStoreFile) + private static bool LastWriteTimesHaveChanged() + { + Debug.Assert( + Monitor.IsEntered(s_recheckStopwatch), + "LastWriteTimesHaveChanged assumes a lock(s_recheckStopwatch)"); + + if (s_rootStoreFile != null) + { + _ = TryStatFile(s_rootStoreFile, out DateTime lastModified); + if (lastModified != s_fileLastWrite) + { + return true; + } + } + + if (s_rootStoreDirectories != null && s_directoryLastWrite != null) + { + for (int i = 0; i < s_rootStoreDirectories.Length; i++) + { + _ = TryStatDirectory(s_rootStoreDirectories[i], out DateTime lastModified); + if (lastModified != s_directoryLastWrite[i]) + { + return true; + } + } + } + + return false; + } + + private static Tuple LoadMachineStores() { Debug.Assert( Monitor.IsEntered(s_recheckStopwatch), @@ -126,61 +147,76 @@ private static Tuple LoadMachineStores SafeX509StackHandle intermedStore = Interop.Crypto.NewX509Stack(); Interop.Crypto.CheckValidOpenSslHandle(intermedStore); - DateTime newFileTime = default; - DateTime newDirTime = default; - var uniqueRootCerts = new HashSet(); var uniqueIntermediateCerts = new HashSet(); bool firstLoad = (s_nativeCollections == null); - if (rootStoreFile != null && rootStoreFile.Exists) + if (firstLoad) { - newFileTime = ContentWriteTime(rootStoreFile); - ProcessFile(rootStoreFile); + s_rootStoreDirectories = GetRootStoreDirectories(out s_defaultRootDir); + s_directoryLastWrite = new DateTime[s_rootStoreDirectories.Length]; + s_rootStoreFile = GetRootStoreFile(); + } + else + { + Debug.Assert(s_rootStoreDirectories is not null); + Debug.Assert(s_directoryLastWrite is not null); + } + + if (s_rootStoreFile != null) + { + ProcessFile(s_rootStoreFile, out s_fileLastWrite); } bool hasStoreData = false; - if (rootStorePath != null && rootStorePath.Exists) + for (int i = 0; i < s_rootStoreDirectories.Length; i++) { - newDirTime = ContentWriteTime(rootStorePath); - hasStoreData = ProcessDir(rootStorePath); + hasStoreData = ProcessDir(s_rootStoreDirectories[i], out s_directoryLastWrite[i]); } if (firstLoad && !hasStoreData && s_defaultRootDir) { - DirectoryInfo etcSslCerts = new DirectoryInfo("/etc/ssl/certs"); - - if (etcSslCerts.Exists) + const string DefaultCertDir = "/etc/ssl/certs"; + hasStoreData = ProcessDir(DefaultCertDir, out DateTime lastModified); + if (hasStoreData) { - DateTime tmpTime = ContentWriteTime(etcSslCerts); - hasStoreData = ProcessDir(etcSslCerts); - - if (hasStoreData) - { - newDirTime = tmpTime; - s_rootStoreDirectoryInfo = etcSslCerts; - } + s_rootStoreDirectories = new[] { DefaultCertDir }; + s_directoryLastWrite = new[] { lastModified }; } } - bool ProcessDir(DirectoryInfo dir) + bool ProcessDir(string dir, out DateTime lastModified) { + if (!TryStatDirectory(dir, out lastModified)) + { + return false; + } + bool hasStoreData = false; - foreach (FileInfo file in dir.EnumerateFiles()) + foreach (string file in Directory.EnumerateFiles(dir)) { - hasStoreData |= ProcessFile(file); + hasStoreData |= ProcessFile(file, out _, skipStat: true); } return hasStoreData; } - bool ProcessFile(FileInfo file) + bool ProcessFile(string file, out DateTime lastModified, bool skipStat = false) { bool readData = false; - using (SafeBioHandle fileBio = Interop.Crypto.BioNewFile(file.FullName, "rb")) + if (skipStat) + { + lastModified = default; + } + else if (!TryStatFile(file, out lastModified)) + { + return false; + } + + using (SafeBioHandle fileBio = Interop.Crypto.BioNewFile(file, "rb")) { // The handle may be invalid, for example when we don't have read permission for the file. if (fileBio.IsInvalid) @@ -274,114 +310,78 @@ bool ProcessFile(FileInfo file) // on every call. Volatile.Write(ref s_nativeCollections, newCollections); - s_directoryCertsLastWrite = newDirTime; - s_fileCertsLastWrite = newFileTime; s_recheckStopwatch.Restart(); return newCollections; } - private static FileInfo? SafeOpenRootFileInfo() + private static string? GetRootStoreFile() { string? rootFile = Interop.Crypto.GetX509RootStoreFile(); if (!string.IsNullOrEmpty(rootFile)) { - try - { - return new FileInfo(rootFile); - } - catch (ArgumentException) - { - // If SSL_CERT_FILE is set to the empty string, or anything else which gives - // "The path is not of a legal form", then the GetX509RootStoreFile value is ignored. - } + return Path.GetFullPath(rootFile); } return null; } - private static DirectoryInfo? SafeOpenRootDirectoryInfo() + private static string[] GetRootStoreDirectories(out bool isDefault) { - string? rootDirectory = Interop.Crypto.GetX509RootStorePath(out s_defaultRootDir); + string rootDirectory = Interop.Crypto.GetX509RootStorePath(out isDefault) ?? ""; - if (!string.IsNullOrEmpty(rootDirectory)) - { - try - { - return new DirectoryInfo(rootDirectory); - } - catch (ArgumentException) - { - // If SSL_CERT_DIR is set to the empty string, or anything else which gives - // "The path is not of a legal form", then the GetX509RootStoreFile value is ignored. - } - } - - return null; - } - - private static DateTime ContentWriteTime(FileInfo info) - { - string path = info.FullName; - string? target = Interop.Sys.ReadLink(path); - - if (string.IsNullOrEmpty(target)) - { - return info.LastWriteTimeUtc; - } + string[] directories = rootDirectory.Split(Path.PathSeparator, StringSplitOptions.RemoveEmptyEntries); - if (target[0] != '/') + for (int i = 0; i < directories.Length; i++) { - target = Path.Join(info.Directory?.FullName, target); + directories[i] = Path.GetFullPath(directories[i]); } - try + // Remove duplicates. + if (directories.Length > 1) { - var targetInfo = new FileInfo(target); - - if (targetInfo.Exists) + var set = new HashSet(directories, StringComparer.Ordinal); + if (set.Count != directories.Length) { - return targetInfo.LastWriteTimeUtc; + // Preserve the original order. + string[] directoriesTrimmed = new string[set.Count]; + int j = 0; + for (int i = 0; i < directories.Length; i++) + { + string directory = directories[i]; + if (set.Remove(directory)) + { + directoriesTrimmed[j++] = directory; + } + } + Debug.Assert(set.Count == 0); + directories = directoriesTrimmed; } } - catch (ArgumentException) - { - // If we can't load information about the link path, just treat it as not a link. - } - return info.LastWriteTimeUtc; + return directories; } - private static DateTime ContentWriteTime(DirectoryInfo info) - { - string path = info.FullName; - string? target = Interop.Sys.ReadLink(path); - - if (string.IsNullOrEmpty(target)) - { - return info.LastWriteTimeUtc; - } + private static bool TryStatFile(string path, out DateTime lastModified) + => TryStat(path, Interop.Sys.FileTypes.S_IFREG, out lastModified); - if (target[0] != '/') - { - target = Path.Join(info.Parent?.FullName, target); - } + private static bool TryStatDirectory(string path, out DateTime lastModified) + => TryStat(path, Interop.Sys.FileTypes.S_IFDIR, out lastModified); - try - { - var targetInfo = new DirectoryInfo(target); + private static bool TryStat(string path, int fileType, out DateTime lastModified) + { + lastModified = default; - if (targetInfo.Exists) - { - return targetInfo.LastWriteTimeUtc; - } - } - catch (ArgumentException) + Interop.Sys.FileStatus status; + // Use Stat to follow links. + if (Interop.Sys.Stat(path, out status) < 0 || + (status.Mode & Interop.Sys.FileTypes.S_IFMT) != fileType) { - // If we can't load information about the link path, just treat it as not a link. + return false; } - return info.LastWriteTimeUtc; + lastModified = DateTime.UnixEpoch + TimeSpan.FromTicks(status.MTime * TimeSpan.TicksPerSecond + status.MTimeNsec / TimeSpan.NanosecondsPerTick); + return true; } } } diff --git a/src/libraries/System.Security.Cryptography/tests/DSATests.cs b/src/libraries/System.Security.Cryptography/tests/DSATests.cs index b995a5e0920893..8eca860fe4fb9b 100644 --- a/src/libraries/System.Security.Cryptography/tests/DSATests.cs +++ b/src/libraries/System.Security.Cryptography/tests/DSATests.cs @@ -171,7 +171,7 @@ protected override void Dispose(bool disposing) public override void ImportParameters(DSAParameters parameters) => _dsa.ImportParameters(parameters); public override bool VerifySignature(byte[] rgbHash, byte[] rgbSignature) => _dsa.VerifySignature(rgbHash, rgbSignature); protected override byte[] HashData(Stream data, HashAlgorithmName hashAlgorithm) => - (byte[])_dsa.GetType().GetMethod( + (byte[])typeof(DSA).GetMethod( nameof(HashData), BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance, null, @@ -179,7 +179,7 @@ protected override byte[] HashData(Stream data, HashAlgorithmName hashAlgorithm) null) .Invoke(_dsa, new object[] { data, hashAlgorithm }); protected override byte[] HashData(byte[] data, int offset, int count, HashAlgorithmName hashAlgorithm) => - (byte[])_dsa.GetType().GetMethod( + (byte[])typeof(DSA).GetMethod( nameof(HashData), BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance, null, diff --git a/src/libraries/System.Security.Cryptography/tests/ECDsaTests.cs b/src/libraries/System.Security.Cryptography/tests/ECDsaTests.cs index 5a871f35c2ef08..c858fd0866213d 100644 --- a/src/libraries/System.Security.Cryptography/tests/ECDsaTests.cs +++ b/src/libraries/System.Security.Cryptography/tests/ECDsaTests.cs @@ -169,7 +169,7 @@ public byte[] BaseHashData(byte[] data, int offset, int count, HashAlgorithmName base.HashData(data, offset, count, hashAlgorithm); protected override byte[] HashData(Stream data, HashAlgorithmName hashAlgorithm) => - (byte[])_ecdsa.GetType().GetMethod( + (byte[])typeof(ECDsa).GetMethod( nameof(HashData), BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance, null, @@ -178,7 +178,7 @@ protected override byte[] HashData(Stream data, HashAlgorithmName hashAlgorithm) .Invoke(_ecdsa, new object[] { data, hashAlgorithm }); protected override byte[] HashData(byte[] data, int offset, int count, HashAlgorithmName hashAlgorithm) => - (byte[])_ecdsa.GetType().GetMethod( + (byte[])typeof(ECDsa).GetMethod( nameof(HashData), BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance, null, diff --git a/src/libraries/System.Security.Cryptography/tests/X509Certificates/X509StoreTests.Unix.cs b/src/libraries/System.Security.Cryptography/tests/X509Certificates/X509StoreTests.Unix.cs index 0efb6c12028fb9..f460d6b9bd6c69 100644 --- a/src/libraries/System.Security.Cryptography/tests/X509Certificates/X509StoreTests.Unix.cs +++ b/src/libraries/System.Security.Cryptography/tests/X509Certificates/X509StoreTests.Unix.cs @@ -10,7 +10,6 @@ namespace System.Security.Cryptography.X509Certificates.Tests { public partial class X509StoreTests { - [ConditionalFact(nameof(NotRunningAsRootAndRemoteExecutorSupported))] // root can read '2.pem' [PlatformSpecific(TestPlatforms.Linux)] // Windows/OSX doesn't use SSL_CERT_{DIR,FILE}. private void X509Store_MachineStoreLoadSkipsInvalidFiles() @@ -50,6 +49,47 @@ private void X509Store_MachineStoreLoadSkipsInvalidFiles() }, new RemoteInvokeOptions { StartInfo = psi }).Dispose(); } + [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + [PlatformSpecific(TestPlatforms.Linux)] // Windows/OSX doesn't use SSL_CERT_{DIR,FILE}. + private void X509Store_MachineStoreLoadsMutipleSslCertDirectories() + { + // Create 3 certificates and place them in two directories that will be passed + // using SSL_CERT_DIR. + string sslCertDir1 = GetTestFilePath(); + Directory.CreateDirectory(sslCertDir1); + File.WriteAllBytes(Path.Combine(sslCertDir1, "1.pem"), TestData.SelfSigned1PemBytes); + File.WriteAllBytes(Path.Combine(sslCertDir1, "2.pem"), TestData.SelfSigned2PemBytes); + string sslCertDir2 = GetTestFilePath(); + Directory.CreateDirectory(sslCertDir2); + File.WriteAllBytes(Path.Combine(sslCertDir2, "3.pem"), TestData.SelfSigned3PemBytes); + + // Add a non-existing directory after each valid directory to verify they are ignored. + string sslCertDir = string.Join(Path.PathSeparator, + new[] { + sslCertDir1, + sslCertDir2, + "", // empty string + sslCertDir2, // duplicate directory + "/invalid2", // path that does not exist + }); + + var psi = new ProcessStartInfo(); + psi.Environment.Add("SSL_CERT_DIR", sslCertDir); + // Set SSL_CERT_FILE to avoid loading the default bundle file. + psi.Environment.Add("SSL_CERT_FILE", "/nonexisting"); + RemoteExecutor.Invoke(() => + { + Assert.NotNull(Environment.GetEnvironmentVariable("SSL_CERT_DIR")); + using (var store = new X509Store(StoreName.Root, StoreLocation.LocalMachine)) + { + store.Open(OpenFlags.OpenExistingOnly); + + // Check nr of certificates in store. + Assert.Equal(3, store.Certificates.Count); + } + }, new RemoteInvokeOptions { StartInfo = psi }).Dispose(); + } + public static bool NotRunningAsRootAndRemoteExecutorSupported => !Environment.IsPrivilegedProcess && RemoteExecutor.IsSupported; } } diff --git a/src/libraries/System.Security.Permissions/src/CompatibilitySuppressions.xml b/src/libraries/System.Security.Permissions/src/CompatibilitySuppressions.xml index 10869bf91f9227..5ff9a82feffe7d 100644 --- a/src/libraries/System.Security.Permissions/src/CompatibilitySuppressions.xml +++ b/src/libraries/System.Security.Permissions/src/CompatibilitySuppressions.xml @@ -218,6 +218,27 @@ lib/netstandard2.0/System.Security.Permissions.dll true + + CP0014 + P:System.Security.Permissions.FileIOPermissionAttribute.All:[T:System.ObsoleteAttribute] + lib/netstandard2.0/System.Security.Permissions.dll + lib/netstandard2.0/System.Security.Permissions.dll + true + + + CP0014 + P:System.Security.Permissions.ReflectionPermissionAttribute.ReflectionEmit:[T:System.ObsoleteAttribute] + lib/netstandard2.0/System.Security.Permissions.dll + lib/netstandard2.0/System.Security.Permissions.dll + true + + + CP0014 + P:System.Security.Permissions.ReflectionPermissionAttribute.TypeInformation:[T:System.ObsoleteAttribute] + lib/netstandard2.0/System.Security.Permissions.dll + lib/netstandard2.0/System.Security.Permissions.dll + true + CP0014 P:System.Security.Permissions.RegistryPermissionAttribute.All:[T:System.ObsoleteAttribute] diff --git a/src/libraries/System.Text.Json/gen/Helpers/DiagnosticInfo.cs b/src/libraries/System.Text.Json/gen/Helpers/DiagnosticInfo.cs deleted file mode 100644 index 493f79191d4375..00000000000000 --- a/src/libraries/System.Text.Json/gen/Helpers/DiagnosticInfo.cs +++ /dev/null @@ -1,43 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Linq; -using System.Numerics.Hashing; -using Microsoft.CodeAnalysis; - -namespace System.Text.Json.SourceGeneration -{ - /// - /// Descriptor for diagnostic instances using structural equality comparison. - /// Provides a work-around for https://github.com/dotnet/roslyn/issues/68291. - /// - public readonly struct DiagnosticInfo : IEquatable - { - public required DiagnosticDescriptor Descriptor { get; init; } - public required object?[] MessageArgs { get; init; } - public required Location? Location { get; init; } - - public Diagnostic CreateDiagnostic() - => Diagnostic.Create(Descriptor, Location, MessageArgs); - - public override readonly bool Equals(object? obj) => obj is DiagnosticInfo info && Equals(info); - public readonly bool Equals(DiagnosticInfo other) - { - return Descriptor.Equals(other.Descriptor) && - MessageArgs.SequenceEqual(other.MessageArgs) && - Location == other.Location; - } - - public override readonly int GetHashCode() - { - int hashCode = Descriptor.GetHashCode(); - foreach (object? messageArg in MessageArgs) - { - hashCode = HashHelpers.Combine(hashCode, messageArg?.GetHashCode() ?? 0); - } - - hashCode = HashHelpers.Combine(hashCode, Location?.GetHashCode() ?? 0); - return hashCode; - } - } -} diff --git a/src/libraries/System.Text.Json/gen/Helpers/RoslynExtensions.cs b/src/libraries/System.Text.Json/gen/Helpers/RoslynExtensions.cs index 7d280ce7603c2d..3f3ecb506fd83d 100644 --- a/src/libraries/System.Text.Json/gen/Helpers/RoslynExtensions.cs +++ b/src/libraries/System.Text.Json/gen/Helpers/RoslynExtensions.cs @@ -25,8 +25,6 @@ internal static class RoslynExtensions return compilation.GetBestTypeByMetadataName(type.FullName); } - public static string GetFullyQualifiedName(this ITypeSymbol type) => type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - public static Location? GetLocation(this ISymbol typeSymbol) => typeSymbol.Locations.Length > 0 ? typeSymbol.Locations[0] : null; @@ -36,12 +34,6 @@ internal static class RoslynExtensions return reference?.SyntaxTree.GetLocation(reference.Span); } - /// - /// Creates a copy of the Location instance that does not capture a reference to Compilation. - /// - public static Location GetTrimmedLocation(this Location location) - => Location.Create(location.SourceTree?.FilePath ?? "", location.SourceSpan, location.GetLineSpan().Span); - /// /// Returns true if the specified location is contained in one of the syntax trees in the compilation. /// diff --git a/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Parser.cs b/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Parser.cs index 0f3b11b038bc93..594f7ad9770c3c 100644 --- a/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Parser.cs +++ b/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Parser.cs @@ -60,12 +60,7 @@ public void ReportDiagnostic(DiagnosticDescriptor descriptor, Location? location location = _contextClassLocation; } - Diagnostics.Add(new DiagnosticInfo - { - Descriptor = descriptor, - Location = location.GetTrimmedLocation(), - MessageArgs = messageArgs ?? Array.Empty(), - }); + Diagnostics.Add(DiagnosticInfo.Create(descriptor, location, messageArgs)); } public Parser(KnownTypeSymbols knownSymbols) @@ -868,7 +863,7 @@ private List ParsePropertyGenerationSpecs( { Location? typeLocation = typeToGenerate.Location; List properties = new(); - PropertyHierarchyResolutionState state = new(); + PropertyHierarchyResolutionState state = new(options); hasExtensionDataProperty = false; // Walk the type hierarchy starting from the current type up to the base type(s) @@ -975,11 +970,10 @@ bool PropertyIsOverriddenAndIgnored(IPropertySymbol property, Dictionary Properties = new(); - public Dictionary AddedProperties = new(); + public Dictionary AddedProperties = new(options?.PropertyNameCaseInsensitive == true ? StringComparer.OrdinalIgnoreCase : StringComparer.Ordinal); public Dictionary? IgnoredMembers; public bool IsPropertyOrderSpecified; public bool HasInvalidConfigurationForFastPath; diff --git a/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Roslyn3.11.cs b/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Roslyn3.11.cs index 4c58a3d968ac54..7520f9bc75a6f5 100644 --- a/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Roslyn3.11.cs +++ b/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Roslyn3.11.cs @@ -8,6 +8,7 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Text; +using SourceGenerators; namespace System.Text.Json.SourceGeneration { diff --git a/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Roslyn4.0.cs b/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Roslyn4.0.cs index e3f8b4aacf6c5b..447f54c7f07821 100644 --- a/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Roslyn4.0.cs +++ b/src/libraries/System.Text.Json/gen/JsonSourceGenerator.Roslyn4.0.cs @@ -9,6 +9,7 @@ #if !ROSLYN4_4_OR_GREATER using Microsoft.CodeAnalysis.DotnetRuntime.Extensions; #endif +using SourceGenerators; namespace System.Text.Json.SourceGeneration { diff --git a/src/libraries/System.Text.Json/gen/Model/ContextGenerationSpec.cs b/src/libraries/System.Text.Json/gen/Model/ContextGenerationSpec.cs index 1e2ee2d737e009..00c7192c3ae58c 100644 --- a/src/libraries/System.Text.Json/gen/Model/ContextGenerationSpec.cs +++ b/src/libraries/System.Text.Json/gen/Model/ContextGenerationSpec.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics; -using System.Text.Json.Serialization; +using SourceGenerators; namespace System.Text.Json.SourceGeneration { diff --git a/src/libraries/System.Text.Json/gen/Model/ParameterGenerationSpec.cs b/src/libraries/System.Text.Json/gen/Model/ParameterGenerationSpec.cs index 2945b20b730b15..68e32d01531569 100644 --- a/src/libraries/System.Text.Json/gen/Model/ParameterGenerationSpec.cs +++ b/src/libraries/System.Text.Json/gen/Model/ParameterGenerationSpec.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using SourceGenerators; + namespace System.Text.Json.SourceGeneration { /// diff --git a/src/libraries/System.Text.Json/gen/Model/PropertyGenerationSpec.cs b/src/libraries/System.Text.Json/gen/Model/PropertyGenerationSpec.cs index 56b42970f68893..214c32b4d19e21 100644 --- a/src/libraries/System.Text.Json/gen/Model/PropertyGenerationSpec.cs +++ b/src/libraries/System.Text.Json/gen/Model/PropertyGenerationSpec.cs @@ -3,6 +3,7 @@ using System.Diagnostics; using System.Text.Json.Serialization; +using SourceGenerators; namespace System.Text.Json.SourceGeneration { diff --git a/src/libraries/System.Text.Json/gen/Model/PropertyInitializerGenerationSpec.cs b/src/libraries/System.Text.Json/gen/Model/PropertyInitializerGenerationSpec.cs index 9fc68a11928470..608ce8e887d725 100644 --- a/src/libraries/System.Text.Json/gen/Model/PropertyInitializerGenerationSpec.cs +++ b/src/libraries/System.Text.Json/gen/Model/PropertyInitializerGenerationSpec.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using SourceGenerators; + namespace System.Text.Json.SourceGeneration { /// diff --git a/src/libraries/System.Text.Json/gen/Model/SourceGenerationOptionsSpec.cs b/src/libraries/System.Text.Json/gen/Model/SourceGenerationOptionsSpec.cs index 7e94f824bae8cb..83b587fb962f7e 100644 --- a/src/libraries/System.Text.Json/gen/Model/SourceGenerationOptionsSpec.cs +++ b/src/libraries/System.Text.Json/gen/Model/SourceGenerationOptionsSpec.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Text.Json.Serialization; +using SourceGenerators; namespace System.Text.Json.SourceGeneration { diff --git a/src/libraries/System.Text.Json/gen/Model/TypeGenerationSpec.cs b/src/libraries/System.Text.Json/gen/Model/TypeGenerationSpec.cs index 189295bcb971ca..9b71bf16438b89 100644 --- a/src/libraries/System.Text.Json/gen/Model/TypeGenerationSpec.cs +++ b/src/libraries/System.Text.Json/gen/Model/TypeGenerationSpec.cs @@ -4,6 +4,7 @@ using System.Diagnostics; using System.Text.Json.Serialization; using Microsoft.CodeAnalysis; +using SourceGenerators; namespace System.Text.Json.SourceGeneration { diff --git a/src/libraries/System.Text.Json/gen/System.Text.Json.SourceGeneration.targets b/src/libraries/System.Text.Json/gen/System.Text.Json.SourceGeneration.targets index 4020a05cb421db..23add6278d7c07 100644 --- a/src/libraries/System.Text.Json/gen/System.Text.Json.SourceGeneration.targets +++ b/src/libraries/System.Text.Json/gen/System.Text.Json.SourceGeneration.targets @@ -30,8 +30,11 @@ + + + @@ -54,9 +57,7 @@ - - @@ -74,6 +75,5 @@ - diff --git a/src/libraries/System.Text.Json/src/Resources/Strings.resx b/src/libraries/System.Text.Json/src/Resources/Strings.resx index f091984783b601..0ebab3e5d27d6f 100644 --- a/src/libraries/System.Text.Json/src/Resources/Strings.resx +++ b/src/libraries/System.Text.Json/src/Resources/Strings.resx @@ -696,6 +696,9 @@ JsonObjectCreationHandling.Populate is incompatible with reference handling. + + JsonObjectCreationHandling.Populate is currently not supported in types with parameterized constructors. + Either the JSON value is not in a supported format, or is out of bounds for an Int128. diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Object/ObjectWithParameterizedConstructorConverter.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Object/ObjectWithParameterizedConstructorConverter.cs index 304ca0a26e409c..7c5d6aae1c4051 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Object/ObjectWithParameterizedConstructorConverter.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Object/ObjectWithParameterizedConstructorConverter.cs @@ -25,10 +25,10 @@ internal sealed override bool OnTryRead(ref Utf8JsonReader reader, Type typeToCo { JsonTypeInfo jsonTypeInfo = state.Current.JsonTypeInfo; - if (jsonTypeInfo.CreateObject != null || state.Current.IsPopulating) + if (!jsonTypeInfo.UsesParameterizedConstructor || state.Current.IsPopulating) { // Fall back to default object converter in following cases: - // - if user has set a default constructor delegate with contract customization + // - if user configuration has invalidated the parameterized constructor // - we're continuing populating an object. return base.OnTryRead(ref reader, typeToConvert, options, ref state, out value); } diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/DefaultJsonTypeInfoResolver.Helpers.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/DefaultJsonTypeInfoResolver.Helpers.cs index 8ee9f3db0283b4..3f929d87378485 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/DefaultJsonTypeInfoResolver.Helpers.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/DefaultJsonTypeInfoResolver.Helpers.cs @@ -38,12 +38,20 @@ internal static MemberAccessor MemberAccessor private static JsonTypeInfo CreateTypeInfoCore(Type type, JsonConverter converter, JsonSerializerOptions options) { JsonTypeInfo typeInfo = JsonTypeInfo.CreateJsonTypeInfo(type, converter, options); - typeInfo.NumberHandling = GetNumberHandlingForType(typeInfo.Type); - typeInfo.PreferredPropertyObjectCreationHandling = GetObjectCreationHandlingForType(typeInfo.Type); - if (typeInfo.Kind == JsonTypeInfoKind.Object) + if (GetNumberHandlingForType(typeInfo.Type) is { } numberHandling) { - typeInfo.UnmappedMemberHandling = GetUnmappedMemberHandling(typeInfo.Type); + typeInfo.NumberHandling = numberHandling; + } + + if (GetObjectCreationHandlingForType(typeInfo.Type) is { } creationHandling) + { + typeInfo.PreferredPropertyObjectCreationHandling = creationHandling; + } + + if (GetUnmappedMemberHandling(typeInfo.Type) is { } unmappedMemberHandling) + { + typeInfo.UnmappedMemberHandling = unmappedMemberHandling; } typeInfo.PopulatePolymorphismMetadata(); @@ -80,7 +88,7 @@ private static void PopulateProperties(JsonTypeInfo typeInfo) bool constructorHasSetsRequiredMembersAttribute = typeInfo.Converter.ConstructorInfo?.HasSetsRequiredMembersAttribute() ?? false; - JsonTypeInfo.PropertyHierarchyResolutionState state = new(); + JsonTypeInfo.PropertyHierarchyResolutionState state = new(typeInfo.Options); // Walk the type hierarchy starting from the current type up to the base type(s) foreach (Type currentType in typeInfo.Type.GetSortedTypeHierarchy()) diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonMetadataServices.Helpers.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonMetadataServices.Helpers.cs index 1b7113f9dd758a..965b4cea39570a 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonMetadataServices.Helpers.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonMetadataServices.Helpers.cs @@ -137,7 +137,7 @@ internal static void PopulateProperties(JsonTypeInfo typeInfo, JsonTypeInfo.Json // Regardless of the source generator we need to re-run the naming conflict resolution algorithm // at run time since it is possible that the naming policy or other configs can be different then. - JsonTypeInfo.PropertyHierarchyResolutionState state = new(); + JsonTypeInfo.PropertyHierarchyResolutionState state = new(typeInfo.Options); foreach (JsonPropertyInfo jsonPropertyInfo in properties) { diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonPropertyInfo.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonPropertyInfo.cs index 0ab5c08d7825b3..e2234093474e0d 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonPropertyInfo.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonPropertyInfo.cs @@ -495,9 +495,17 @@ private void DetermineEffectiveObjectCreationHandlingForProperty() Debug.Assert(ParentTypeInfo != null, "We should have ensured parent is assigned in JsonTypeInfo"); Debug.Assert(!IsConfigured, "Should not be called post-configuration."); + JsonObjectCreationHandling effectiveObjectCreationHandling = JsonObjectCreationHandling.Replace; if (ObjectCreationHandling == null) { - JsonObjectCreationHandling preferredCreationHandling = ParentTypeInfo.PreferredPropertyObjectCreationHandling ?? Options.PreferredObjectCreationHandling; + // Consult type-level configuration, then global configuration. + // Ignore global configuration if we're using a parameterized constructor. + JsonObjectCreationHandling preferredCreationHandling = + ParentTypeInfo.PreferredPropertyObjectCreationHandling + ?? (ParentTypeInfo.DetermineUsesParameterizedConstructor() + ? JsonObjectCreationHandling.Replace + : Options.PreferredObjectCreationHandling); + bool canPopulate = preferredCreationHandling == JsonObjectCreationHandling.Populate && EffectiveConverter.CanPopulate && @@ -506,7 +514,7 @@ private void DetermineEffectiveObjectCreationHandlingForProperty() !ParentTypeInfo.SupportsPolymorphicDeserialization && !(Set == null && IgnoreReadOnlyMember); - EffectiveObjectCreationHandling = canPopulate ? JsonObjectCreationHandling.Populate : JsonObjectCreationHandling.Replace; + effectiveObjectCreationHandling = canPopulate ? JsonObjectCreationHandling.Populate : JsonObjectCreationHandling.Replace; } else if (ObjectCreationHandling == JsonObjectCreationHandling.Populate) { @@ -537,18 +545,24 @@ private void DetermineEffectiveObjectCreationHandlingForProperty() ThrowHelper.ThrowInvalidOperationException_ObjectCreationHandlingPropertyCannotAllowReadOnlyMember(this); } - EffectiveObjectCreationHandling = JsonObjectCreationHandling.Populate; - } - else - { - Debug.Assert(EffectiveObjectCreationHandling == JsonObjectCreationHandling.Replace); + effectiveObjectCreationHandling = JsonObjectCreationHandling.Populate; } - if (EffectiveObjectCreationHandling == JsonObjectCreationHandling.Populate && - Options.ReferenceHandlingStrategy != ReferenceHandlingStrategy.None) + if (effectiveObjectCreationHandling is JsonObjectCreationHandling.Populate) { - ThrowHelper.ThrowInvalidOperationException_ObjectCreationHandlingPropertyCannotAllowReferenceHandling(); + if (ParentTypeInfo.DetermineUsesParameterizedConstructor()) + { + ThrowHelper.ThrowNotSupportedException_ObjectCreationHandlingPropertyDoesNotSupportParameterizedConstructors(); + } + + if (Options.ReferenceHandlingStrategy != ReferenceHandlingStrategy.None) + { + ThrowHelper.ThrowInvalidOperationException_ObjectCreationHandlingPropertyCannotAllowReferenceHandling(); + } } + + // Validation complete, commit configuration. + EffectiveObjectCreationHandling = effectiveObjectCreationHandling; } private bool NumberHandingIsApplicable() diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonTypeInfo.Cache.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonTypeInfo.Cache.cs index 86a7af256a78a8..5a901fbb80eae6 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonTypeInfo.Cache.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonTypeInfo.Cache.cs @@ -35,6 +35,14 @@ public abstract partial class JsonTypeInfo // All of the serializable parameters on a POCO constructor keyed on parameter name. // Only parameters which bind to properties are cached. internal JsonPropertyDictionary? ParameterCache { get; private set; } + internal bool UsesParameterizedConstructor + { + get + { + Debug.Assert(IsConfigured); + return ParameterCache != null; + } + } // All of the serializable properties on a POCO (except the optional extension property) keyed on property name. internal JsonPropertyDictionary? PropertyCache { get; private set; } diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonTypeInfo.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonTypeInfo.cs index b9e9fe60d2b23f..668e0c7b15e1a1 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonTypeInfo.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonTypeInfo.cs @@ -552,17 +552,14 @@ public JsonObjectCreationHandling? PreferredPropertyObjectCreationHandling { VerifyMutable(); - if (value is not null) + if (Kind != JsonTypeInfoKind.Object) { - if (Kind != JsonTypeInfoKind.Object) - { - ThrowHelper.ThrowInvalidOperationException_JsonTypeInfoOperationNotPossibleForKind(Kind); - } + ThrowHelper.ThrowInvalidOperationException_JsonTypeInfoOperationNotPossibleForKind(Kind); + } - if (!JsonSerializer.IsValidCreationHandlingValue(value.Value)) - { - throw new ArgumentOutOfRangeException(nameof(value)); - } + if (value is not null && !JsonSerializer.IsValidCreationHandlingValue(value.Value)) + { + throw new ArgumentOutOfRangeException(nameof(value)); } _preferredPropertyObjectCreationHandling = value; @@ -684,7 +681,7 @@ private void Configure() { ConfigureProperties(); - if (Converter.ConstructorIsParameterized) + if (DetermineUsesParameterizedConstructor()) { ConfigureConstructorParameters(); } @@ -808,6 +805,12 @@ bool IsCurrentNodeCompatible() /// private bool IsCompatibleWithCurrentOptions { get; set; } = true; + /// + /// Determine if the current configuration is compatible with using a parameterized constructor. + /// + internal bool DetermineUsesParameterizedConstructor() + => Converter.ConstructorIsParameterized && CreateObject is null; + #if DEBUG internal string GetPropertyDebugInfo(ReadOnlySpan unescapedPropertyName) { @@ -989,10 +992,9 @@ public JsonPropertyInfo CreateJsonPropertyInfo(Type propertyType, string name) internal abstract ValueTask DeserializeAsObjectAsync(Stream utf8Json, CancellationToken cancellationToken); internal abstract object? DeserializeAsObject(Stream utf8Json); - internal ref struct PropertyHierarchyResolutionState + internal ref struct PropertyHierarchyResolutionState(JsonSerializerOptions options) { - public PropertyHierarchyResolutionState() { } - public Dictionary AddedProperties = new(); + public Dictionary AddedProperties = new(options.PropertyNameCaseInsensitive ? StringComparer.OrdinalIgnoreCase : StringComparer.Ordinal); public Dictionary? IgnoredProperties; public bool IsPropertyOrderSpecified; } @@ -1107,7 +1109,7 @@ internal void ConfigureProperties() internal void ConfigureConstructorParameters() { Debug.Assert(Kind == JsonTypeInfoKind.Object); - Debug.Assert(Converter.ConstructorIsParameterized); + Debug.Assert(DetermineUsesParameterizedConstructor()); Debug.Assert(PropertyCache is not null); Debug.Assert(ParameterCache is null); diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/ReadStack.cs b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/ReadStack.cs index 59a47bc3ac7bc6..25c067cc10930a 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/ReadStack.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/ReadStack.cs @@ -386,20 +386,20 @@ public JsonTypeInfo GetTopJsonTypeInfoWithParameterizedConstructor() for (int i = 0; i < _count - 1; i++) { - if (_stack[i].JsonTypeInfo.Converter.ConstructorIsParameterized) + if (_stack[i].JsonTypeInfo.UsesParameterizedConstructor) { return _stack[i].JsonTypeInfo; } } - Debug.Assert(Current.JsonTypeInfo.Converter.ConstructorIsParameterized); + Debug.Assert(Current.JsonTypeInfo.UsesParameterizedConstructor); return Current.JsonTypeInfo; } [MethodImpl(MethodImplOptions.AggressiveInlining)] private void SetConstructorArgumentState() { - if (Current.JsonTypeInfo.Converter.ConstructorIsParameterized) + if (Current.JsonTypeInfo.UsesParameterizedConstructor) { Current.CtorArgumentState ??= new(); } diff --git a/src/libraries/System.Text.Json/src/System/Text/Json/ThrowHelper.Serialization.cs b/src/libraries/System.Text.Json/src/System/Text/Json/ThrowHelper.Serialization.cs index 7072d9e3020085..5b05ff243a80a9 100644 --- a/src/libraries/System.Text.Json/src/System/Text/Json/ThrowHelper.Serialization.cs +++ b/src/libraries/System.Text.Json/src/System/Text/Json/ThrowHelper.Serialization.cs @@ -99,6 +99,12 @@ public static void ThrowInvalidOperationException_ObjectCreationHandlingProperty throw new InvalidOperationException(SR.ObjectCreationHandlingPropertyCannotAllowReferenceHandling); } + [DoesNotReturn] + public static void ThrowNotSupportedException_ObjectCreationHandlingPropertyDoesNotSupportParameterizedConstructors() + { + throw new NotSupportedException(SR.ObjectCreationHandlingPropertyDoesNotSupportParameterizedConstructors); + } + [DoesNotReturn] public static void ThrowJsonException_SerializationConverterRead(JsonConverter? converter) { diff --git a/src/libraries/System.Text.Json/tests/Common/JsonCreationHandlingTests.Object.cs b/src/libraries/System.Text.Json/tests/Common/JsonCreationHandlingTests.Object.cs index aae9da6b2c628b..e25adffef53fa2 100644 --- a/src/libraries/System.Text.Json/tests/Common/JsonCreationHandlingTests.Object.cs +++ b/src/libraries/System.Text.Json/tests/Common/JsonCreationHandlingTests.Object.cs @@ -1165,4 +1165,52 @@ public class ClassWithInvalidPropertyAnnotation [JsonObjectCreationHandling((JsonObjectCreationHandling)(-1))] public List Property { get; } } + + [Theory] + [InlineData(typeof(ClassWithParameterizedConstructorWithPopulateProperty))] + [InlineData(typeof(ClassWithParameterizedConstructorWithPopulateType))] + public async Task ClassWithParameterizedCtor_UsingPopulateConfiguration_ThrowsNotSupportedException(Type type) + { + object instance = Activator.CreateInstance(type, "Jim"); + string json = """{"Username":"Jim","PhoneNumbers":["123456"]}"""; + + await Assert.ThrowsAsync(() => Serializer.SerializeWrapper(instance, type)); + await Assert.ThrowsAsync(() => Serializer.DeserializeWrapper(json, type)); + Assert.Throws(() => Serializer.GetTypeInfo(type)); + } + + public class ClassWithParameterizedConstructorWithPopulateProperty(string name) + { + public string Name { get; } = name; + + [JsonObjectCreationHandling(JsonObjectCreationHandling.Populate)] + public List PhoneNumbers { get; } = new(); + } + + [JsonObjectCreationHandling(JsonObjectCreationHandling.Populate)] + public class ClassWithParameterizedConstructorWithPopulateType(string name) + { + public string Name { get; } = name; + + public List PhoneNumbers { get; } = new(); + } + + [Fact] + public async Task ClassWithParameterizedCtor_NoPopulateConfiguration_WorksWithGlobalPopulateConfiguration() + { + string json = """{"Username":"Jim","PhoneNumbers":["123456"]}"""; + + JsonSerializerOptions options = Serializer.CreateOptions(makeReadOnly: false); + options.PreferredObjectCreationHandling = JsonObjectCreationHandling.Populate; + + ClassWithParameterizedConstructorNoPopulate result = await Serializer.DeserializeWrapper(json, options); + Assert.Empty(result.PhoneNumbers); + } + + public class ClassWithParameterizedConstructorNoPopulate(string name) + { + public string Name { get; } = name; + + public List PhoneNumbers { get; } = new(); + } } diff --git a/src/libraries/System.Text.Json/tests/Common/PropertyNameTests.cs b/src/libraries/System.Text.Json/tests/Common/PropertyNameTests.cs index 4295359c6f0380..021481ae5a1362 100644 --- a/src/libraries/System.Text.Json/tests/Common/PropertyNameTests.cs +++ b/src/libraries/System.Text.Json/tests/Common/PropertyNameTests.cs @@ -494,5 +494,34 @@ public class ClassWithSpecialCharacters [JsonPropertyName("\uA000_2")] // Valid C# property name: \uA000_2 public int YiIt_2 { get; set; } } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ClassWithIgnoredCaseInsensitiveConflict_RespectsIgnoredMember(bool propertyNameCaseInsensitive) + { + // Regression test for https://github.com/dotnet/runtime/issues/93903 + // specifically for propertyNameCaseInsensitive := true + + JsonSerializerOptions options = Serializer.CreateOptions(makeReadOnly: false); + options.PropertyNameCaseInsensitive = propertyNameCaseInsensitive; + + var value = new ClassWithIgnoredCaseInsensitiveConflict { name = "lowercase", Name = "uppercase" }; + string json = await Serializer.SerializeWrapper(value, options); + + Assert.Equal("""{"name":"lowercase"}""", json); + + value = await Serializer.DeserializeWrapper(json, options); + Assert.Equal("lowercase", value.name); + Assert.Null(value.Name); + } + + public class ClassWithIgnoredCaseInsensitiveConflict + { + public string name { get; set; } + + [JsonIgnore] + public string Name { get; set; } + } } } diff --git a/src/libraries/System.Text.Json/tests/Common/TestClasses/TestClasses.cs b/src/libraries/System.Text.Json/tests/Common/TestClasses/TestClasses.cs index 79859c1e73cc5f..470d624d3646fb 100644 --- a/src/libraries/System.Text.Json/tests/Common/TestClasses/TestClasses.cs +++ b/src/libraries/System.Text.Json/tests/Common/TestClasses/TestClasses.cs @@ -1913,69 +1913,81 @@ public override string ConvertName(string name) } } + public static class ReflectionExtensions + { +#if NET6_0_OR_GREATER + [return: System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembers(System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicConstructors)] + public static Type WithConstructors( + [System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembers(System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicConstructors)] + this Type type) => type; +#else + public static Type WithConstructors(this Type type) => type; +#endif + } + public static class CollectionTestTypes { public static IEnumerable EnumerableTypes() { - yield return typeof(TElement[]); // ArrayConverter - yield return typeof(ConcurrentQueue); // ConcurrentQueueOfTConverter - yield return typeof(GenericICollectionWrapper); // ICollectionOfTConverter - yield return typeof(WrapperForIEnumerable); // IEnumerableConverter - yield return typeof(WrapperForIReadOnlyCollectionOfT); // IEnumerableOfTConverter - yield return typeof(Queue); // IEnumerableWithAddMethodConverter - yield return typeof(WrapperForIList); // IListConverter - yield return typeof(Collection); // IListOfTConverter - yield return typeof(ImmutableList); // ImmutableEnumerableOfTConverter - yield return typeof(HashSet); // ISetOfTConverter - yield return typeof(List); // ListOfTConverter - yield return typeof(Queue); // QueueOfTConverter + yield return typeof(TElement[]).WithConstructors(); // ArrayConverter + yield return typeof(ConcurrentQueue).WithConstructors(); // ConcurrentQueueOfTConverter + yield return typeof(GenericICollectionWrapper).WithConstructors(); // ICollectionOfTConverter + yield return typeof(WrapperForIEnumerable).WithConstructors(); // IEnumerableConverter + yield return typeof(WrapperForIReadOnlyCollectionOfT).WithConstructors(); // IEnumerableOfTConverter + yield return typeof(Queue).WithConstructors(); // IEnumerableWithAddMethodConverter + yield return typeof(WrapperForIList).WithConstructors(); // IListConverter + yield return typeof(Collection).WithConstructors(); // IListOfTConverter + yield return typeof(ImmutableList).WithConstructors(); // ImmutableEnumerableOfTConverter + yield return typeof(HashSet).WithConstructors(); // ISetOfTConverter + yield return typeof(List).WithConstructors(); // ListOfTConverter + yield return typeof(Queue).WithConstructors(); // QueueOfTConverter } public static IEnumerable DeserializableGenericEnumerableTypes() { - yield return typeof(TElement[]); // ArrayConverter - yield return typeof(ConcurrentQueue); // ConcurrentQueueOfTConverter - yield return typeof(GenericICollectionWrapper); // ICollectionOfTConverter - yield return typeof(IEnumerable); // IEnumerableConverter - yield return typeof(Collection); // IListOfTConverter - yield return typeof(ImmutableList); // ImmutableEnumerableOfTConverter - yield return typeof(HashSet); // ISetOfTConverter - yield return typeof(List); // ListOfTConverter - yield return typeof(Queue); // QueueOfTConverter + yield return typeof(TElement[]).WithConstructors(); // ArrayConverter + yield return typeof(ConcurrentQueue).WithConstructors(); // ConcurrentQueueOfTConverter + yield return typeof(GenericICollectionWrapper).WithConstructors(); // ICollectionOfTConverter + yield return typeof(IEnumerable).WithConstructors(); // IEnumerableConverter + yield return typeof(Collection).WithConstructors(); // IListOfTConverter + yield return typeof(ImmutableList).WithConstructors(); // ImmutableEnumerableOfTConverter + yield return typeof(HashSet).WithConstructors(); // ISetOfTConverter + yield return typeof(List).WithConstructors(); // ListOfTConverter + yield return typeof(Queue).WithConstructors(); // QueueOfTConverter } public static IEnumerable DeserializableNonGenericEnumerableTypes() { - yield return typeof(Queue); // IEnumerableWithAddMethodConverter - yield return typeof(WrapperForIList); // IListConverter + yield return typeof(Queue).WithConstructors(); // IEnumerableWithAddMethodConverter + yield return typeof(WrapperForIList).WithConstructors(); // IListConverter } public static IEnumerable DictionaryTypes() { - yield return typeof(Dictionary); // DictionaryOfStringTValueConverter - yield return typeof(Hashtable); // IDictionaryConverter - yield return typeof(ConcurrentDictionary); // IDictionaryOfStringTValueConverter - yield return typeof(GenericIDictionaryWrapper); // IDictionaryOfStringTValueConverter - yield return typeof(ImmutableDictionary); // ImmutableDictionaryOfStringTValueConverter - yield return typeof(GenericIReadOnlyDictionaryWrapper); // IReadOnlyDictionaryOfStringTValueConverter + yield return typeof(Dictionary).WithConstructors(); // DictionaryOfStringTValueConverter + yield return typeof(Hashtable).WithConstructors(); // IDictionaryConverter + yield return typeof(ConcurrentDictionary).WithConstructors(); // IDictionaryOfStringTValueConverter + yield return typeof(GenericIDictionaryWrapper).WithConstructors(); // IDictionaryOfStringTValueConverter + yield return typeof(ImmutableDictionary).WithConstructors(); // ImmutableDictionaryOfStringTValueConverter + yield return typeof(GenericIReadOnlyDictionaryWrapper).WithConstructors(); // IReadOnlyDictionaryOfStringTValueConverter } public static IEnumerable DeserializableDictionaryTypes() { - yield return typeof(Dictionary); // DictionaryOfStringTValueConverter - yield return typeof(Hashtable); // IDictionaryConverter - yield return typeof(IDictionary); // IDictionaryConverter - yield return typeof(ConcurrentDictionary); // IDictionaryOfStringTValueConverter - yield return typeof(IDictionary); // IDictionaryOfStringTValueConverter - yield return typeof(GenericIDictionaryWrapper); // IDictionaryOfStringTValueConverter - yield return typeof(ImmutableDictionary); // ImmutableDictionaryOfStringTValueConverter - yield return typeof(IReadOnlyDictionary); // IReadOnlyDictionaryOfStringTValueConverter + yield return typeof(Dictionary).WithConstructors(); // DictionaryOfStringTValueConverter + yield return typeof(Hashtable).WithConstructors(); // IDictionaryConverter + yield return typeof(IDictionary).WithConstructors(); // IDictionaryConverter + yield return typeof(ConcurrentDictionary).WithConstructors(); // IDictionaryOfStringTValueConverter + yield return typeof(IDictionary).WithConstructors(); // IDictionaryOfStringTValueConverter + yield return typeof(GenericIDictionaryWrapper).WithConstructors(); // IDictionaryOfStringTValueConverter + yield return typeof(ImmutableDictionary).WithConstructors(); // ImmutableDictionaryOfStringTValueConverter + yield return typeof(IReadOnlyDictionary).WithConstructors(); // IReadOnlyDictionaryOfStringTValueConverter } public static IEnumerable DeserializableNonGenericDictionaryTypes() { - yield return typeof(Hashtable); // IDictionaryConverter - yield return typeof(SortedList); // IDictionaryConverter + yield return typeof(Hashtable).WithConstructors(); // IDictionaryConverter + yield return typeof(SortedList).WithConstructors(); // IDictionaryConverter } } diff --git a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/JsonCreationHandlingTests.cs b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/JsonCreationHandlingTests.cs index eab6f939b93674..5862387200f6a4 100644 --- a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/JsonCreationHandlingTests.cs +++ b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/JsonCreationHandlingTests.cs @@ -278,6 +278,9 @@ public sealed class JsonCreationHandlingTests_AsyncStreamWithSmallBuffer() [JsonSerializable(typeof(SimpleClassWitNonPopulatableProperty))] [JsonSerializable(typeof(ClassWithInvalidTypeAnnotation))] [JsonSerializable(typeof(ClassWithInvalidPropertyAnnotation))] + [JsonSerializable(typeof(ClassWithParameterizedConstructorWithPopulateProperty))] + [JsonSerializable(typeof(ClassWithParameterizedConstructorWithPopulateType))] + [JsonSerializable(typeof(ClassWithParameterizedConstructorNoPopulate))] internal partial class CreationHandlingTestContext : JsonSerializerContext { } diff --git a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/PropertyNameTests.cs b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/PropertyNameTests.cs index 82566bf7123ce7..e512451eed72bc 100644 --- a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/PropertyNameTests.cs +++ b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/PropertyNameTests.cs @@ -28,6 +28,7 @@ public PropertyNameTests_Metadata() [JsonSerializable(typeof(ObjectPropertyNamesDifferentByCaseOnly_TestClass))] [JsonSerializable(typeof(OverridePropertyNameDesignTime_TestClass))] [JsonSerializable(typeof(SimpleTestClass))] + [JsonSerializable(typeof(ClassWithIgnoredCaseInsensitiveConflict))] internal sealed partial class PropertyNameTestsContext_Metadata : JsonSerializerContext { } @@ -53,6 +54,7 @@ public PropertyNameTests_Default() [JsonSerializable(typeof(ObjectPropertyNamesDifferentByCaseOnly_TestClass))] [JsonSerializable(typeof(OverridePropertyNameDesignTime_TestClass))] [JsonSerializable(typeof(SimpleTestClass))] + [JsonSerializable(typeof(ClassWithIgnoredCaseInsensitiveConflict))] internal sealed partial class PropertyNameTestsContext_Default : JsonSerializerContext { } diff --git a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/JsonSourceGeneratorIncrementalTests.cs b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/JsonSourceGeneratorIncrementalTests.cs index 7a38a7e5fb5128..daa6498cbc9b2d 100644 --- a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/JsonSourceGeneratorIncrementalTests.cs +++ b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/JsonSourceGeneratorIncrementalTests.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Reflection; using Microsoft.CodeAnalysis; +using SourceGenerators.Tests; using Xunit; namespace System.Text.Json.SourceGeneration.UnitTests @@ -29,7 +30,7 @@ public static void CompilingTheSameSourceResultsInEqualModels(Func ContextGenerationSpec ctx2 = result2.ContextGenerationSpecs[i]; Assert.NotSame(ctx1, ctx2); - AssertStructurallyEqual(ctx1, ctx2); + GeneratorTestHelpers.AssertStructurallyEqual(ctx1, ctx2); Assert.Equal(ctx1, ctx2); Assert.Equal(ctx1.GetHashCode(), ctx2.GetHashCode()); @@ -86,7 +87,7 @@ public partial class JsonContext : JsonSerializerContext { } ContextGenerationSpec ctx2 = result2.ContextGenerationSpecs[0]; Assert.NotSame(ctx1, ctx2); - AssertStructurallyEqual(ctx1, ctx2); + GeneratorTestHelpers.AssertStructurallyEqual(ctx1, ctx2); Assert.Equal(ctx1, ctx2); Assert.Equal(ctx1.GetHashCode(), ctx2.GetHashCode()); @@ -377,74 +378,5 @@ public static IEnumerable GetCompilationHelperFactories() .Where(m => m.ReturnType == typeof(Compilation) && m.GetParameters().Length == 0) .Select(m => new object[] { Delegate.CreateDelegate(typeof(Func), m) }); } - - /// - /// Asserts for structural equality, returning a path to the mismatching data when not equal. - /// - private static void AssertStructurallyEqual(T expected, T actual) - { - CheckAreEqualCore(expected, actual, new()); - static void CheckAreEqualCore(object expected, object actual, Stack path) - { - if (expected is null || actual is null) - { - if (expected is not null || actual is not null) - { - FailNotEqual(); - } - - return; - } - - Type type = expected.GetType(); - if (type != actual.GetType()) - { - FailNotEqual(); - return; - } - - if (expected is IEnumerable leftCollection) - { - if (actual is not IEnumerable rightCollection) - { - FailNotEqual(); - return; - } - - object?[] expectedValues = leftCollection.Cast().ToArray(); - object?[] actualValues = rightCollection.Cast().ToArray(); - - for (int i = 0; i < Math.Max(expectedValues.Length, actualValues.Length); i++) - { - object? expectedElement = i < expectedValues.Length ? expectedValues[i] : ""; - object? actualElement = i < actualValues.Length ? actualValues[i] : ""; - - path.Push($"[{i}]"); - CheckAreEqualCore(expectedElement, actualElement, path); - path.Pop(); - } - } - - if (type.GetProperty("EqualityContract", BindingFlags.Instance | BindingFlags.NonPublic, null, returnType: typeof(Type), types: Array.Empty(), null) != null) - { - // Type is a C# record, run pointwise equality comparison. - foreach (PropertyInfo property in type.GetProperties(BindingFlags.Public | BindingFlags.Instance)) - { - path.Push("." + property.Name); - CheckAreEqualCore(property.GetValue(expected), property.GetValue(actual), path); - path.Pop(); - } - - return; - } - - if (!expected.Equals(actual)) - { - FailNotEqual(); - } - - void FailNotEqual() => Assert.Fail($"Value not equal in ${string.Join("", path.Reverse())}: expected {expected}, but was {actual}."); - } - } } } diff --git a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/System.Text.Json.SourceGeneration.Unit.Tests.targets b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/System.Text.Json.SourceGeneration.Unit.Tests.targets index a700b2a9f3a385..56bf105dc1fddf 100644 --- a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/System.Text.Json.SourceGeneration.Unit.Tests.targets +++ b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/System.Text.Json.SourceGeneration.Unit.Tests.targets @@ -12,6 +12,7 @@ + diff --git a/src/libraries/System.Text.Json/tests/System.Text.Json.Tests/Serialization/MetadataTests/DefaultJsonTypeInfoResolverTests.JsonTypeInfo.cs b/src/libraries/System.Text.Json/tests/System.Text.Json.Tests/Serialization/MetadataTests/DefaultJsonTypeInfoResolverTests.JsonTypeInfo.cs index 06fd59bae037e9..bc6e3a28ff78f7 100644 --- a/src/libraries/System.Text.Json/tests/System.Text.Json.Tests/Serialization/MetadataTests/DefaultJsonTypeInfoResolverTests.JsonTypeInfo.cs +++ b/src/libraries/System.Text.Json/tests/System.Text.Json.Tests/Serialization/MetadataTests/DefaultJsonTypeInfoResolverTests.JsonTypeInfo.cs @@ -1440,11 +1440,10 @@ public static void PreferredPropertyObjectCreationHandling_NonObjectKind_ThrowsI { JsonTypeInfo jsonTypeInfo = JsonTypeInfo.CreateJsonTypeInfo(type, new()); - // Invalid kinds default to null and can be set to null. - Assert.Null(jsonTypeInfo.PreferredPropertyObjectCreationHandling); - jsonTypeInfo.PreferredPropertyObjectCreationHandling = null; + // Invalid kinds default to null. Assert.Null(jsonTypeInfo.PreferredPropertyObjectCreationHandling); + Assert.Throws(() => jsonTypeInfo.PreferredPropertyObjectCreationHandling = null); Assert.Throws(() => jsonTypeInfo.PreferredPropertyObjectCreationHandling = JsonObjectCreationHandling.Populate); Assert.Throws(() => jsonTypeInfo.PreferredPropertyObjectCreationHandling = JsonObjectCreationHandling.Replace); Assert.Null(jsonTypeInfo.PreferredPropertyObjectCreationHandling); diff --git a/src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.ForEachAsync.cs b/src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.ForEachAsync.cs index b3c6d50cafe2e7..48d76dff065eaa 100644 --- a/src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.ForEachAsync.cs +++ b/src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.ForEachAsync.cs @@ -186,6 +186,7 @@ static bool CompareExchange(ref T location, T value, T comparand) => // If we're the last worker to complete, complete the operation. if (state.SignalWorkerCompletedIterating()) { + state.Dispose(); state.Complete(); } } @@ -745,7 +746,7 @@ public ValueTask DisposeAsync() /// Stores the state associated with an IAsyncEnumerable ForEachAsync operation, shared between all its workers. /// Specifies the type of data being enumerated. - private sealed class ForEachState : ForEachAsyncState + private sealed class ForEachState : ForEachAsyncState, IDisposable { public T NextAvailable; public readonly T ToExclusive; @@ -759,6 +760,8 @@ public ForEachState( NextAvailable = fromExclusive; ToExclusive = toExclusive; } + + public void Dispose() => _registration.Dispose(); } } } diff --git a/src/libraries/apicompat/ApiCompatBaseline.netstandard2.0.xml b/src/libraries/apicompat/ApiCompatBaseline.netstandard2.0.xml index f19da8b94090d0..12fce752d1b3d7 100644 --- a/src/libraries/apicompat/ApiCompatBaseline.netstandard2.0.xml +++ b/src/libraries/apicompat/ApiCompatBaseline.netstandard2.0.xml @@ -2773,10 +2773,22 @@ netstandard2.0/netstandard.dll net8.0/netstandard.dll + + CP0015 + T:System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverageAttribute:[T:System.AttributeUsageAttribute] + netstandard2.0/netstandard.dll + net8.0/netstandard.dll + CP0015 P:System.Timers.Timer.Interval:[T:System.ComponentModel.DefaultValueAttribute] netstandard2.0/System.dll net8.0/System.dll - \ No newline at end of file + + CP0015 + T:System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverageAttribute:[T:System.AttributeUsageAttribute] + netstandard2.0/System.dll + net8.0/System.dll + + diff --git a/src/libraries/apicompat/ApiCompatBaseline.netstandard2.1.xml b/src/libraries/apicompat/ApiCompatBaseline.netstandard2.1.xml index de4ff50c7dc0ed..a6009206387156 100644 --- a/src/libraries/apicompat/ApiCompatBaseline.netstandard2.1.xml +++ b/src/libraries/apicompat/ApiCompatBaseline.netstandard2.1.xml @@ -853,6 +853,12 @@ netstandard2.1/netstandard.dll net8.0/netstandard.dll + + CP0015 + T:System.Runtime.CompilerServices.AsyncMethodBuilderAttribute:[T:System.AttributeUsageAttribute] + netstandard2.1/netstandard.dll + net8.0/netstandard.dll + CP0015 T:System.Runtime.InteropServices.ManagedToNativeComInteropStubAttribute:[T:System.AttributeUsageAttribute] @@ -871,4 +877,4 @@ netstandard2.1/netstandard.dll net8.0/netstandard.dll - \ No newline at end of file + diff --git a/src/mono/mono.proj b/src/mono/mono.proj index 6be683f8e2eb2c..6f15a31d94dca4 100644 --- a/src/mono/mono.proj +++ b/src/mono/mono.proj @@ -254,9 +254,10 @@ - - <_MonoLLVMTargetArchitecture Condition="'$(TargetArchitecture)' == 'wasm' and '$(MonoUseLLVMPackage)' == 'true'">$(BuildArchitecture) - <_MonoLLVMTargetArchitecture Condition="'$(TargetArchitecture)' != 'wasm' and '$(MonoUseLLVMPackage)' == 'true'">$(TargetArchitecture) + + <_MonoLLVMTargetArchitecture Condition="'$(TargetArchitecture)' == 'wasm'">$(BuildArchitecture) + <_MonoLLVMTargetArchitecture Condition="'$(TargetArchitecture)' != 'wasm'">$(TargetArchitecture) + <_MonoLLVMHostArchitecture Condition="'$(AotHostArchitecture)' != ''">$(AotHostArchitecture) <_MonoCMakeArgs Condition="'$(_MonoUseNinja)' == 'true'" Include="-G Ninja"/> @@ -698,14 +699,14 @@ $(MonoCrossDir)/usr/lib/gcc/aarch64-linux-gnu/5 - - <_MonoLLVMTargetArchitecture Condition="'$(MonoUseLLVMPackage)' == 'true'">$(BuildArchitecture) - <_MonoLLVMTargetArchitecture Condition="'$(MonoUseLLVMPackage)' == 'true'">$(AotHostArchitecture) + + <_MonoLLVMTargetArchitecture>$(TargetArchitecture) + <_MonoLLVMHostArchitecture>$(AotHostArchitecture) - <_MonoAOTCXXFLAGS Include="-I$(MonoLLVMDir)\$(_MonoLLVMTargetArchitecture)\include\c++\v1" /> - <_MonoAOTCXXFLAGS Include="-L$(MonoLLVMDir)\$(_MonoLLVMTargetArchitecture)\lib" /> + <_MonoAOTCXXFLAGS Include="-I$(MonoLLVMDir)\$(_MonoLLVMHostArchitecture)\include\c++\v1" /> + <_MonoAOTCXXFLAGS Include="-L$(MonoLLVMDir)\$(_MonoLLVMHostArchitecture)\lib" /> <_MonoAOTCXXFLAGS Include="-stdlib=libc++" /> @@ -844,7 +845,7 @@ - + @@ -936,8 +937,7 @@ <_MonoAotCrossPdbFilePath>$(MonoObjCrossDir)out\bin\$(MonoAotCrossPdbFileName) - <_MonoLLVMTargetArchitecture>$(BuildArchitecture) - <_MonoLLVMTargetArchitecture Condition="'$(TargetArchitecture)' != 'wasm'">$(AotHostArchitecture) + <_MonoLLVMHostArchitecture>$(AotHostArchitecture) @@ -977,25 +977,25 @@ <_MonoRuntimeArtifacts Condition="'$(HostOS)' == 'Linux' and ('$(MonoBundleLLVMOptimizer)' == 'true' or '$(MonoEnableLLVM)' == 'true') and '$(TargetArchitecture)' != 'wasm' and '$(MonoUseLibCxx)' == 'true'" Include="$(MonoLLVMDir)\$(_MonoLLVMTargetArchitecture)\lib\libc++abi.so.1"> $(RuntimeBinDir)libc++abi.so.1 - <_MonoRuntimeArtifacts Condition="'$(HostOS)' == 'Linux' and ((('$(MonoAOTBundleLLVMOptimizer)' == 'true' or '$(MonoAOTEnableLLVM)' == 'true') and '$(MonoUseLibCxx)' == 'true') or '$(TargetArchitecture)' == 'wasm')" Include="$(MonoLLVMDir)\$(_MonoLLVMTargetArchitecture)\lib\libc++.so.1"> + <_MonoRuntimeArtifacts Condition="'$(HostOS)' == 'Linux' and ((('$(MonoAOTBundleLLVMOptimizer)' == 'true' or '$(MonoAOTEnableLLVM)' == 'true') and '$(MonoUseLibCxx)' == 'true') or '$(TargetArchitecture)' == 'wasm')" Include="$(MonoLLVMDir)\$(_MonoLLVMHostArchitecture)\lib\libc++.so.1"> $(RuntimeBinDir)cross\$(OutputRID)\libc++.so.1 - <_MonoRuntimeArtifacts Condition="'$(HostOS)' == 'Linux' and ((('$(MonoAOTBundleLLVMOptimizer)' == 'true' or '$(MonoAOTEnableLLVM)' == 'true') and '$(MonoUseLibCxx)' == 'true') or '$(TargetArchitecture)' == 'wasm')" Include="$(MonoLLVMDir)\$(_MonoLLVMTargetArchitecture)\lib\libc++abi.so.1"> + <_MonoRuntimeArtifacts Condition="'$(HostOS)' == 'Linux' and ((('$(MonoAOTBundleLLVMOptimizer)' == 'true' or '$(MonoAOTEnableLLVM)' == 'true') and '$(MonoUseLibCxx)' == 'true') or '$(TargetArchitecture)' == 'wasm')" Include="$(MonoLLVMDir)\$(_MonoLLVMHostArchitecture)\lib\libc++abi.so.1"> $(RuntimeBinDir)cross\$(OutputRID)\libc++abi.so.1 <_MonoRuntimeArtifacts Include="$(_MonoAotCrossPdbFilePath)" Condition="Exists('$(_MonoAotCrossPdbFilePath)')"> $(RuntimeBinDir)cross\$(OutputRID)\$(MonoAotCrossPdbFileName) - <_MonoRuntimeArtifacts Condition="'$(MonoBundleLLVMOptimizer)' == 'true'" Include="$(MonoLLVMDir)\$(_MonoLLVMTargetArchitecture)\bin\llc$(ExeSuffix)"> + <_MonoRuntimeArtifacts Condition="'$(MonoBundleLLVMOptimizer)' == 'true'" Include="$(MonoLLVMDir)\$(_MonoLLVMHostArchitecture)\bin\llc$(ExeSuffix)"> $(RuntimeBinDir)\llc$(ExeSuffix) - <_MonoRuntimeArtifacts Condition="'$(MonoBundleLLVMOptimizer)' == 'true'" Include="$(MonoLLVMDir)\$(_MonoLLVMTargetArchitecture)\bin\opt$(ExeSuffix)"> + <_MonoRuntimeArtifacts Condition="'$(MonoBundleLLVMOptimizer)' == 'true'" Include="$(MonoLLVMDir)\$(_MonoLLVMHostArchitecture)\bin\opt$(ExeSuffix)"> $(RuntimeBinDir)\opt$(ExeSuffix) - <_MonoRuntimeArtifacts Condition="'$(MonoAOTBundleLLVMOptimizer)' == 'true'" Include="$(MonoLLVMDir)\$(_MonoLLVMTargetArchitecture)\bin\llc$(ExeSuffix)"> + <_MonoRuntimeArtifacts Condition="'$(MonoAOTBundleLLVMOptimizer)' == 'true'" Include="$(MonoLLVMDir)\$(_MonoLLVMHostArchitecture)\bin\llc$(ExeSuffix)"> $(RuntimeBinDir)cross\$(OutputRID)\llc$(ExeSuffix) - <_MonoRuntimeArtifacts Condition="'$(MonoAOTBundleLLVMOptimizer)' == 'true'" Include="$(MonoLLVMDir)\$(_MonoLLVMTargetArchitecture)\bin\opt$(ExeSuffix)"> + <_MonoRuntimeArtifacts Condition="'$(MonoAOTBundleLLVMOptimizer)' == 'true'" Include="$(MonoLLVMDir)\$(_MonoLLVMHostArchitecture)\bin\opt$(ExeSuffix)"> $(RuntimeBinDir)cross\$(OutputRID)\opt$(ExeSuffix) <_MonoIncludeArtifacts Include="$(MonoObjDir)out\include\**" /> diff --git a/src/mono/mono/component/debugger-agent.c b/src/mono/mono/component/debugger-agent.c index a210242c1eb691..9a6906f6e272ef 100644 --- a/src/mono/mono/component/debugger-agent.c +++ b/src/mono/mono/component/debugger-agent.c @@ -5258,6 +5258,13 @@ buffer_add_value_full (Buffer *buf, MonoType *t, void *addr, MonoDomain *domain, nfields = 0; iter = NULL; while ((f = mono_class_get_fields_internal (klass, &iter))) { + if (G_UNLIKELY (!f->type)) { + ERROR_DECL(error); + mono_field_resolve_type (f, error); + mono_error_cleanup (error); + if (!f->type) + continue; + } if (f->type->attrs & FIELD_ATTRIBUTE_STATIC) continue; if (mono_field_is_deleted (f)) @@ -5275,6 +5282,13 @@ buffer_add_value_full (Buffer *buf, MonoType *t, void *addr, MonoDomain *domain, iter = NULL; while ((f = mono_class_get_fields_internal (klass, &iter))) { + if (G_UNLIKELY (!f->type)) { + ERROR_DECL(error); + mono_field_resolve_type (f, error); + mono_error_cleanup (error); + if (!f->type) + continue; + } if (f->type->attrs & FIELD_ATTRIBUTE_STATIC) continue; if (mono_field_is_deleted (f)) @@ -5375,6 +5389,13 @@ decode_vtype (MonoType *t, MonoDomain *domain, gpointer void_addr, gpointer void nfields = decode_int (buf, &buf, limit); while ((f = mono_class_get_fields_internal (klass, &iter))) { + if (G_UNLIKELY (!f->type)) { + ERROR_DECL(error); + mono_field_resolve_type (f, error); + mono_error_cleanup (error); + if (!f->type) + continue; + } if (f->type->attrs & FIELD_ATTRIBUTE_STATIC) continue; if (mono_field_is_deleted (f)) @@ -5476,6 +5497,13 @@ decode_vtype_compute_size (MonoType *t, MonoDomain *domain, gpointer void_buf, g nfields = decode_int (buf, &buf, limit); while ((f = mono_class_get_fields_internal (klass, &iter))) { + if (G_UNLIKELY (!f->type)) { + ERROR_DECL(error); + mono_field_resolve_type (f, error); + mono_error_cleanup (error); + if (!f->type) + continue; + } if (f->type->attrs & FIELD_ATTRIBUTE_STATIC) continue; if (mono_field_is_deleted (f)) @@ -8481,6 +8509,13 @@ type_commands_internal (int command, MonoClass *klass, MonoDomain *domain, guint buffer_add_int (buf, nfields); while ((f = mono_class_get_fields_internal (klass, &iter))) { + if (G_UNLIKELY (!f->type)) { + ERROR_DECL(field_error); + mono_field_resolve_type (f, field_error); + mono_error_cleanup (field_error); + if (!f->type) + continue; + } buffer_add_fieldid (buf, domain, f); buffer_add_string (buf, f->name); buffer_add_typeid (buf, domain, mono_class_from_mono_type_internal (f->type)); @@ -8861,6 +8896,13 @@ type_commands_internal (int command, MonoClass *klass, MonoDomain *domain, guint int nfields = 0; gpointer iter = NULL; while ((f = mono_class_get_fields_internal (klass, &iter))) { + if (G_UNLIKELY (!f->type)) { + ERROR_DECL(field_error); + mono_field_resolve_type (f, field_error); + mono_error_cleanup (field_error); + if (!f->type) + continue; + } if (f->type->attrs & FIELD_ATTRIBUTE_STATIC) continue; if (mono_field_is_deleted (f)) @@ -8871,6 +8913,13 @@ type_commands_internal (int command, MonoClass *klass, MonoDomain *domain, guint iter = NULL; while ((f = mono_class_get_fields_internal (klass, &iter))) { + if (G_UNLIKELY (!f->type)) { + ERROR_DECL(field_error); + mono_field_resolve_type (f, field_error); + mono_error_cleanup (field_error); + if (!f->type) + continue; + } if (f->type->attrs & FIELD_ATTRIBUTE_STATIC) continue; if (mono_field_is_deleted (f)) diff --git a/src/mono/mono/metadata/class.c b/src/mono/mono/metadata/class.c index 284e3153a1a253..c5fcd2a8d7a18d 100644 --- a/src/mono/mono/metadata/class.c +++ b/src/mono/mono/metadata/class.c @@ -4100,7 +4100,7 @@ mono_class_is_assignable_from_general (MonoClass *klass, MonoClass *oklass, gboo return; } - if (m_class_is_array_special_interface (klass) && m_class_get_rank (oklass) == 1) { + if (m_class_is_array_special_interface (klass) && m_class_get_rank (oklass) == 1 && m_class_get_byval_arg (oklass)->type == MONO_TYPE_SZARRAY) { if (mono_class_is_gtd (klass)) { /* klass is an array special gtd like * IList`1<>, and oklass is X[] for some X. diff --git a/src/mono/mono/mini/aot-compiler.c b/src/mono/mono/mini/aot-compiler.c index e067d2714d8a37..c342ef5d007571 100644 --- a/src/mono/mono/mini/aot-compiler.c +++ b/src/mono/mono/mini/aot-compiler.c @@ -10708,6 +10708,18 @@ execute_system (const char * command) #ifdef ENABLE_LLVM +#ifdef HOST_WIN32 +#define OPT_NAME "opt.exe" +#else +#define OPT_NAME "opt" +#endif + +#ifdef HOST_WIN32 +#define LLC_NAME "llc.exe" +#else +#define LLC_NAME "llc" +#endif + /* * emit_llvm_file: * @@ -10776,11 +10788,11 @@ emit_llvm_file (MonoAotCompile *acfg) } else { #if LLVM_API_VERSION >= 1600 /* The safepoints pass requires new pass manager syntax*/ - opts = g_strdup ("-disable-tail-calls -passes='"); + opts = g_strdup ("-disable-tail-calls -passes=\""); if (!acfg->aot_opts.llvm_only) { opts = g_strdup_printf ("%sdefault,", opts); } - opts = g_strdup_printf ("%splace-safepoints' -spp-all-backedges", opts); + opts = g_strdup_printf ("%splace-safepoints\" -spp-all-backedges", opts); #elif LLVM_API_VERSION >= 1300 /* The safepoints pass requires the old pass manager */ opts = g_strdup ("-disable-tail-calls -place-safepoints -spp-all-backedges -enable-new-pm=0"); @@ -10810,7 +10822,7 @@ emit_llvm_file (MonoAotCompile *acfg) opts = g_strdup_printf ("%s -fp-contract=fast -enable-no-infs-fp-math -enable-no-nans-fp-math -enable-no-signed-zeros-fp-math -enable-no-trapping-fp-math -enable-unsafe-fp-math", opts); } - command = g_strdup_printf ("\"%sopt\" -f %s -o \"%s\" \"%s\"", acfg->aot_opts.llvm_path, opts, optbc, tempbc); + command = g_strdup_printf ("\"%s" OPT_NAME "\" -f %s -o \"%s\" \"%s\"", acfg->aot_opts.llvm_path, opts, optbc, tempbc); aot_printf (acfg, "Executing opt: %s\n", command); if (execute_system (command) != 0) return FALSE; @@ -10885,7 +10897,7 @@ emit_llvm_file (MonoAotCompile *acfg) g_string_append_printf (acfg->llc_args, " -mattr=%s", acfg->aot_opts.llvm_cpu_attr); } - command = g_strdup_printf ("\"%sllc\" %s -o \"%s\" \"%s.opt.bc\"", acfg->aot_opts.llvm_path, acfg->llc_args->str, output_fname, acfg->tmpbasename); + command = g_strdup_printf ("\"%s" LLC_NAME "\" %s -o \"%s\" \"%s.opt.bc\"", acfg->aot_opts.llvm_path, acfg->llc_args->str, output_fname, acfg->tmpbasename); g_free (output_fname); aot_printf (acfg, "Executing llc: %s\n", command); diff --git a/src/mono/mono/mini/llvm-intrinsics.h b/src/mono/mono/mini/llvm-intrinsics.h index 1bb09bf0388ae7..be73ec309dfa85 100644 --- a/src/mono/mono/mini/llvm-intrinsics.h +++ b/src/mono/mono/mini/llvm-intrinsics.h @@ -291,6 +291,8 @@ INTRINS_OVR_2_ARG(WASM_NARROW_UNSIGNED_V16, wasm_narrow_unsigned, Wasm, sse_i1_t INTRINS_OVR_2_ARG(WASM_NARROW_UNSIGNED_V8, wasm_narrow_unsigned, Wasm, sse_i2_t, sse_i4_t) INTRINS_OVR_2_ARG(WASM_CONV_R8_TO_I4, fptosi_sat, Generic, v64_i4_t, v128_r8_t) INTRINS_OVR_2_ARG(WASM_CONV_R8_TO_U4, fptoui_sat, Generic, v64_i4_t, v128_r8_t) +INTRINS_OVR_TAG(WASM_FMAX, maximum, Generic, V128 | R4 | R8) +INTRINS_OVR_TAG(WASM_FMIN, minimum, Generic, V128 | R4 | R8) INTRINS_OVR_TAG(WASM_PMAX, wasm_pmax, Wasm, V128 | R4 | R8) INTRINS_OVR_TAG(WASM_PMIN, wasm_pmin, Wasm, V128 | R4 | R8) INTRINS_OVR(WASM_PMAX_V4, fabs, Generic, sse_r4_t) diff --git a/src/mono/mono/mini/mini-generic-sharing.c b/src/mono/mono/mini/mini-generic-sharing.c index 6ad8dcb0075cfc..c131d51a6bd070 100644 --- a/src/mono/mono/mini/mini-generic-sharing.c +++ b/src/mono/mono/mini/mini-generic-sharing.c @@ -2886,7 +2886,8 @@ info_equal (gpointer data1, gpointer data2, MonoRgctxInfoType info_type) return data1 == data2; case MONO_RGCTX_INFO_VIRT_METHOD: case MONO_RGCTX_INFO_VIRT_METHOD_CODE: - case MONO_RGCTX_INFO_VIRT_METHOD_BOX_TYPE: { + case MONO_RGCTX_INFO_VIRT_METHOD_BOX_TYPE: + case MONO_RGCTX_INFO_GSHAREDVT_CONSTRAINED_CALL_INFO: { MonoJumpInfoVirtMethod *info1 = (MonoJumpInfoVirtMethod *)data1; MonoJumpInfoVirtMethod *info2 = (MonoJumpInfoVirtMethod *)data2; diff --git a/src/mono/mono/mini/mini-llvm.c b/src/mono/mono/mini/mini-llvm.c index a6c23aa59b67a2..657bb23d5cf63d 100644 --- a/src/mono/mono/mini/mini-llvm.c +++ b/src/mono/mono/mini/mini-llvm.c @@ -8183,9 +8183,13 @@ MONO_RESTORE_WARNING result = fcmp_and_select (builder, ins, l, r); } -#elif defined(TARGET_ARM64) +#elif defined(TARGET_ARM64) || defined(TARGET_WASM) LLVMValueRef min_max_args [] = { l, r }; +#ifdef TARGET_WASM + IntrinsicId iid = ins->inst_c0 == OP_FMAX ? INTRINS_WASM_FMAX : INTRINS_WASM_FMIN; +#else IntrinsicId iid = ins->inst_c0 == OP_FMAX ? INTRINS_AARCH64_ADV_SIMD_FMAX : INTRINS_AARCH64_ADV_SIMD_FMIN; +#endif llvm_ovr_tag_t ovr_tag = ovr_tag_from_mono_vector_class (ins->klass); result = call_overloaded_intrins (ctx, iid, ovr_tag, min_max_args, ""); #else diff --git a/src/mono/nuget/Microsoft.NET.Workload.Mono.Toolchain.Current.Manifest/WorkloadManifest.json.in b/src/mono/nuget/Microsoft.NET.Workload.Mono.Toolchain.Current.Manifest/WorkloadManifest.json.in index 5af8dcbd94e17e..076e642d2b6209 100644 --- a/src/mono/nuget/Microsoft.NET.Workload.Mono.Toolchain.Current.Manifest/WorkloadManifest.json.in +++ b/src/mono/nuget/Microsoft.NET.Workload.Mono.Toolchain.Current.Manifest/WorkloadManifest.json.in @@ -5,7 +5,7 @@ }, "workloads": { "wasm-tools": { - "description": ".NET WebAssembly build tools", + "description": ".NET WebAssembly build tools for net8.0", "packs": [ "Microsoft.NET.Runtime.WebAssembly.Sdk", "Microsoft.NETCore.App.Runtime.Mono.browser-wasm", @@ -15,7 +15,7 @@ "platforms": [ "win-x64", "win-arm64", "linux-x64", "linux-arm64", "osx-x64", "osx-arm64"] }, "wasm-experimental": { - "description": ".NET WebAssembly experimental tooling", + "description": ".NET WebAssembly experimental tooling for net8.0", "packs": [ "Microsoft.NET.Runtime.WebAssembly.Templates", "Microsoft.NETCore.App.Runtime.Mono.multithread.browser-wasm", @@ -24,7 +24,7 @@ "platforms": [ "win-x64", "win-arm64", "linux-x64", "linux-arm64", "osx-x64", "osx-arm64" ] }, "wasi-experimental": { - "description": ".NET WASI experimental", + "description": ".NET WASI experimental for net8.0", "packs": [ "Microsoft.NET.Runtime.WebAssembly.Wasi.Sdk", "Microsoft.NETCore.App.Runtime.Mono.wasi-wasm", diff --git a/src/mono/wasm/README.md b/src/mono/wasm/README.md index 34c18bc8711468..cdace086f5603e 100644 --- a/src/mono/wasm/README.md +++ b/src/mono/wasm/README.md @@ -350,3 +350,12 @@ npm update --lockfile-version=1 | Multi-thread | linux: build only | none | * `high resource aot` runs a few specific library tests with AOT, that require more memory to AOT. + + +# Perf pipeline + +TBD + +## Updates needed + +- when the base OS is upgraded, check if the version of node installed in the `eng/pipelines/coreclr/templates/run-performance-job.yml` needs an upgrade too. diff --git a/src/mono/wasm/Wasm.Build.Tests/BuildTestBase.cs b/src/mono/wasm/Wasm.Build.Tests/BuildTestBase.cs index fa15c7bef6d817..9633d22701cf0d 100644 --- a/src/mono/wasm/Wasm.Build.Tests/BuildTestBase.cs +++ b/src/mono/wasm/Wasm.Build.Tests/BuildTestBase.cs @@ -588,8 +588,8 @@ internal BuildPaths GetBuildPaths(BuildArgs buildArgs, bool forPublish = true) } protected static string GetSkiaSharpReferenceItems() - => @" - + => @" + "; protected static string s_mainReturns42 = @" diff --git a/src/mono/wasm/Wasm.Build.Tests/WasmNativeDefaultsTests.cs b/src/mono/wasm/Wasm.Build.Tests/WasmNativeDefaultsTests.cs index 5efd84085b8863..4bbea5dff8a926 100644 --- a/src/mono/wasm/Wasm.Build.Tests/WasmNativeDefaultsTests.cs +++ b/src/mono/wasm/Wasm.Build.Tests/WasmNativeDefaultsTests.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.IO; +using System.Text.RegularExpressions; using Xunit; using Xunit.Abstractions; @@ -12,6 +13,7 @@ namespace Wasm.Build.Tests { public class WasmNativeDefaultsTests : TestMainJsTestBase { + private static Regex s_regex = new("\\*\\* WasmBuildNative:.*"); public WasmNativeDefaultsTests(ITestOutputHelper output, SharedBuildPerTestClassFixture buildContext) : base(output, buildContext) { @@ -39,19 +41,20 @@ public static TheoryData SettingDifferentFromV // Config=Release always causes relinking when publishing bool publishValue = forPublish && config == "Release" ? true : false; // Setting the default value from the runtime pack shouldn't trigger relinking - data.Add(config, $"<{defaultPair.propertyName}>{defaultPair.defaultValueInRuntimePack}", + data.Add(config, $"<{defaultPair.propertyName}>{defaultPair.defaultValueInRuntimePack.ToString().ToLower()}", /*aot*/ false, /*build*/ false, /*publish*/ publishValue); // Leaving the property unset, so checking the default data.Add(config, "", /*aot*/ false, /*build*/ false, /*publish*/ publishValue); // Setting the !default value should trigger relinking - data.Add(config, $"<{defaultPair.propertyName}>{!defaultPair.defaultValueInRuntimePack}", + data.Add(config, $"<{defaultPair.propertyName}>{(!defaultPair.defaultValueInRuntimePack).ToString().ToLower()}", /*aot*/ false, /*build*/ true, /*publish*/ true); } } return data; } + public static TheoryData DefaultsTestData(bool forPublish) { TheoryData data = new() @@ -93,45 +96,34 @@ public static TheoryData DefaultsTestData(bool return data; } +#pragma warning disable xUnit1026 // For unused *buildValue*, and *publishValue* parameters [Theory] [MemberData(nameof(DefaultsTestData), parameters: false)] [MemberData(nameof(SettingDifferentFromValuesInRuntimePack), parameters: false)] public void DefaultsWithBuild(string config, string extraProperties, bool aot, bool expectWasmBuildNativeForBuild, bool expectWasmBuildNativeForPublish) { - string output = CheckWasmNativeDefaultValue("native_defaults_build", config, extraProperties, aot, dotnetWasmFromRuntimePack: !expectWasmBuildNativeForPublish, publish: false); - - bool expectedWasmNativeStripValue = true; - if (/*isBuild && */ expectWasmBuildNativeForBuild && config == "Debug") - expectedWasmNativeStripValue = false; + (string output, string? line) = CheckWasmNativeDefaultValue("native_defaults_build", config, extraProperties, aot, dotnetWasmFromRuntimePack: !expectWasmBuildNativeForBuild, publish: false); - // bool expectedWasmNativeStripValue = !(wasmBuildNativeForBuild && config == "Debug"); - // for build - Assert.Contains($"** WasmBuildNative: '{expectWasmBuildNativeForBuild.ToString().ToLower()}', WasmNativeStrip: '{expectedWasmNativeStripValue.ToString().ToLower()}', WasmBuildingForNestedPublish: ''", output); - Assert.Contains("Stopping the build", output); + InferAndCheckPropertyValues(line, isPublish: false, wasmBuildNative: expectWasmBuildNativeForBuild, config: config); } -#pragma warning disable xUnit1026 // For unused *buildValue* parameter [Theory] [MemberData(nameof(DefaultsTestData), parameters: true)] [MemberData(nameof(SettingDifferentFromValuesInRuntimePack), parameters: true)] public void DefaultsWithPublish(string config, string extraProperties, bool aot, bool expectWasmBuildNativeForBuild, bool expectWasmBuildNativeForPublish) { - string output = CheckWasmNativeDefaultValue("native_defaults_publish", config, extraProperties, aot, dotnetWasmFromRuntimePack: !expectWasmBuildNativeForPublish, publish: true); + (string output, string? line) = CheckWasmNativeDefaultValue("native_defaults_publish", config, extraProperties, aot, dotnetWasmFromRuntimePack: !expectWasmBuildNativeForPublish, publish: true); - // for build - // Assert.DoesNotContain($"** WasmBuildNative: '{buildValue.ToString().ToLower()}', WasmNativeStrip: 'true', WasmBuildingForNestedPublish: ''", output); - // for publish - Assert.Contains($"** WasmBuildNative: '{expectWasmBuildNativeForPublish.ToString().ToLower()}', WasmNativeStrip: 'true', WasmBuildingForNestedPublish: 'true'", output); - Assert.Contains("Stopping the build", output); + InferAndCheckPropertyValues(line, isPublish: true, wasmBuildNative: expectWasmBuildNativeForPublish, config: config); } #pragma warning restore xunit1026 public static TheoryData SetWasmNativeStripExplicitlyTestData(bool publish) => new() { - {"Debug", "true", false, true }, - {"Release", "true", publish, true }, - {"Debug", "false", true, false }, - {"Release", "false", true, false } + {"Debug", "true", /*wasmBuildNative*/ false, /*wasmNativeStrip*/ true }, + {"Release", "true", /*wasmBuildNative*/ publish, /*wasmNativeStrip*/ true }, + {"Debug", "false", /*wasmBuildNative*/ true, /*wasmNativeStrip*/ false }, + {"Release", "false", /*wasmBuildNative*/ true, /*wasmNativeStrip*/ false } }; public static TheoryData SetWasmNativeStripExplicitlyWithWasmBuildNativeTestData() => new() @@ -147,10 +139,13 @@ public void DefaultsWithPublish(string config, string extraProperties, bool aot, [MemberData(nameof(SetWasmNativeStripExplicitlyWithWasmBuildNativeTestData))] public void WasmNativeStripDefaultWithBuild(string config, string extraProperties, bool expectedWasmBuildNativeValue, bool expectedWasmNativeStripValue) { - string output = CheckWasmNativeDefaultValue("native_strip_defaults", config, extraProperties, aot: false, dotnetWasmFromRuntimePack: !expectedWasmBuildNativeValue, publish: false); + (string output, string? line) = CheckWasmNativeDefaultValue("native_strip_defaults", config, extraProperties, aot: false, dotnetWasmFromRuntimePack: !expectedWasmBuildNativeValue, publish: false); - Assert.Contains($"** WasmBuildNative: '{expectedWasmBuildNativeValue.ToString().ToLower()}', WasmNativeStrip: '{expectedWasmNativeStripValue.ToString().ToLower()}', WasmBuildingForNestedPublish: ''", output); - Assert.Contains("Stopping the build", output); + CheckPropertyValues(line, + wasmBuildNative: expectedWasmBuildNativeValue, + wasmNativeStrip: expectedWasmNativeStripValue, + wasmNativeDebugSymbols: true, + wasmBuildingForNestedPublish: null); } [Theory] @@ -158,37 +153,38 @@ public void WasmNativeStripDefaultWithBuild(string config, string extraPropertie [MemberData(nameof(SetWasmNativeStripExplicitlyWithWasmBuildNativeTestData))] public void WasmNativeStripDefaultWithPublish(string config, string extraProperties, bool expectedWasmBuildNativeValue, bool expectedWasmNativeStripValue) { - string output = CheckWasmNativeDefaultValue("native_strip_defaults", config, extraProperties, aot: false, dotnetWasmFromRuntimePack: !expectedWasmBuildNativeValue, publish: true); + (string output, string? line) = CheckWasmNativeDefaultValue("native_strip_defaults", config, extraProperties, aot: false, dotnetWasmFromRuntimePack: !expectedWasmBuildNativeValue, publish: true); - Assert.Contains($"** WasmBuildNative: '{expectedWasmBuildNativeValue.ToString().ToLower()}', WasmNativeStrip: '{expectedWasmNativeStripValue.ToString().ToLower()}', WasmBuildingForNestedPublish: 'true'", output); - Assert.Contains("Stopping the build", output); + CheckPropertyValues(line, + wasmBuildNative: expectedWasmBuildNativeValue, + wasmNativeStrip: expectedWasmNativeStripValue, + wasmNativeDebugSymbols: true, + wasmBuildingForNestedPublish: true); } [Theory] /* always relink */ - [InlineData("Debug", "", /*build*/ true, /*publish*/ true)] - [InlineData("Release", "", /*build*/ true, /*publish*/ true)] - [InlineData("Release", "false", /*build*/ true, /*publish*/ true)] - public void WithNativeReference(string config, string extraProperties, bool buildValue, bool publishValue) + [InlineData("Debug", "", /*publish*/ false)] + [InlineData("Debug", "", /*publish*/ true)] + [InlineData("Release", "", /*publish*/ false)] + [InlineData("Release", "", /*publish*/ true)] + [InlineData("Release", "false", /*publish*/ true)] + public void WithNativeReference(string config, string extraProperties, bool publish) { string nativeLibPath = Path.Combine(BuildEnvironment.TestAssetsPath, "native-libs", "native-lib.o"); string nativeRefItem = @$""; - string output = CheckWasmNativeDefaultValue("native_defaults_publish", + (string output, string? line) = CheckWasmNativeDefaultValue("native_defaults_publish", config, extraProperties, aot: false, - dotnetWasmFromRuntimePack: !publishValue, - publish: true, + dotnetWasmFromRuntimePack: !publish, + publish: publish, extraItems: nativeRefItem); - // for build - FIXME: - Assert.DoesNotContain($"** WasmBuildNative: '{buildValue.ToString().ToLower()}', WasmBuildingForNestedPublish: ''", output); - // for publish - Assert.Contains($"** WasmBuildNative: '{publishValue.ToString().ToLower()}', WasmNativeStrip: 'true', WasmBuildingForNestedPublish: 'true'", output); - Assert.Contains("Stopping the build", output); + InferAndCheckPropertyValues(line, isPublish: publish, wasmBuildNative: true, config: config); } - private string CheckWasmNativeDefaultValue(string projectName, + private (string, string?) CheckWasmNativeDefaultValue(string projectName, string config, string extraProperties, bool aot, @@ -201,7 +197,7 @@ private string CheckWasmNativeDefaultValue(string projectName, string printValueTarget = @" - + " + (publish ? @"" : @"") @@ -223,7 +219,32 @@ private string CheckWasmNativeDefaultValue(string projectName, BuildOnlyAfterPublish: false, Publish: publish)); - return output; + Assert.Contains("Stopping the build", output); + + Match m = s_regex.Match(output); + Assert.Equal(1, m.Groups.Count); + return (output, m.Success ? m.Groups[0]?.ToString() : null); + } + + private void InferAndCheckPropertyValues(string? line, bool isPublish, bool wasmBuildNative, string config) + { + bool expectedWasmNativeStripValue; + if (!isPublish && wasmBuildNative && config == "Debug") + expectedWasmNativeStripValue = false; + else + expectedWasmNativeStripValue = true; + + CheckPropertyValues(line, wasmBuildNative, expectedWasmNativeStripValue, /*wasmNativeDebugSymbols*/true, isPublish); + } + + private void CheckPropertyValues(string? line, bool wasmBuildNative, bool wasmNativeStrip, bool wasmNativeDebugSymbols, bool? wasmBuildingForNestedPublish) + { + Assert.NotNull(line); + Assert.Contains($"** WasmBuildNative: '{wasmBuildNative.ToString().ToLower()}', " + + $"WasmNativeStrip: '{wasmNativeStrip.ToString().ToLower()}', " + + $"WasmNativeDebugSymbols: '{wasmNativeDebugSymbols.ToString().ToLower()}', " + + $"WasmBuildingForNestedPublish: '{(wasmBuildingForNestedPublish.HasValue && wasmBuildingForNestedPublish == true ? "true" : "")}'", + line); } } } diff --git a/src/mono/wasm/build/WasmApp.Native.targets b/src/mono/wasm/build/WasmApp.Native.targets index 295474566a14a0..73e5720abb01ed 100644 --- a/src/mono/wasm/build/WasmApp.Native.targets +++ b/src/mono/wasm/build/WasmApp.Native.targets @@ -22,7 +22,6 @@ $(_BeforeWasmBuildAppDependsOn); _SetupEmscripten; _SetWasmBuildNativeDefaults; - _SetWasmNativeStripDefault; _ReadEmccProps @@ -119,6 +118,7 @@ <_BoolPropertiesThatTriggerRelinking Include="InvariantTimezone" DefaultValueInRuntimePack="false" /> <_BoolPropertiesThatTriggerRelinking Include="InvariantGlobalization" DefaultValueInRuntimePack="false" /> <_BoolPropertiesThatTriggerRelinking Include="WasmNativeStrip" DefaultValueInRuntimePack="true" /> + @@ -133,7 +133,6 @@ true true - false @@ -147,10 +146,23 @@ true + + false + + + true + false + + + + true + true + + @@ -160,14 +172,6 @@ - - - - false - true - - - @@ -178,7 +182,6 @@ <_MonoAotCrossCompilerPath>@(MonoAotCrossCompiler->WithMetadataValue('RuntimeIdentifier','browser-wasm')) <_EmccDefaultFlagsRsp>$([MSBuild]::NormalizePath($(_WasmRuntimePackSrcDir), 'emcc-default.rsp')) <_EmccDefaultLinkFlagsRsp>$([MSBuild]::NormalizePath($(_WasmRuntimePackSrcDir), 'emcc-link.rsp')) - true $(WasmBuildNative) <_WasmICallTablePath>$(_WasmIntermediateOutputPath)icall-table.h @@ -221,7 +224,7 @@ <_EmccCommonFlags Include="$(_DefaultEmccFlags)" /> <_EmccCommonFlags Include="$(EmccFlags)" /> - <_EmccCommonFlags Include="-g" Condition="'$(WasmNativeDebugSymbols)' == 'true'" /> + <_EmccCommonFlags Include="-g" Condition="'$(WasmNativeStrip)' == 'false'" /> <_EmccCommonFlags Include="-v" Condition="'$(EmccVerbose)' != 'false'" /> <_EmccCommonFlags Include="-s DISABLE_EXCEPTION_CATCHING=0" Condition="'$(WasmEnableExceptionHandling)' == 'false'" /> <_EmccCommonFlags Include="-fwasm-exceptions" Condition="'$(WasmEnableExceptionHandling)' == 'true'" /> @@ -249,6 +252,7 @@ <_EmccCFlags Include="-emit-llvm" /> <_EmccCFlags Include=""-I%(_EmccIncludePaths.Identity)"" /> + <_EmccCFlags Include="-g" Condition="'$(WasmNativeDebugSymbols)' == 'true'" /> <_EmccLDFlags Include="$(EmccLinkOptimizationFlag)" /> diff --git a/src/mono/wasm/debugger/BrowserDebugProxy/EvaluateExpression.cs b/src/mono/wasm/debugger/BrowserDebugProxy/EvaluateExpression.cs index c384f74e9c7223..45d700e122ea7f 100644 --- a/src/mono/wasm/debugger/BrowserDebugProxy/EvaluateExpression.cs +++ b/src/mono/wasm/debugger/BrowserDebugProxy/EvaluateExpression.cs @@ -387,12 +387,14 @@ private static async Task> ResolveElementAccess(ExpressionSyntaxR { var values = new List(); JObject index = null; + List nestedIndexers = new(); IEnumerable elementAccesses = replacer.elementAccess; foreach (ElementAccessExpressionSyntax elementAccess in elementAccesses.Reverse()) { - index = await resolver.Resolve(elementAccess, replacer.memberAccessValues, index, replacer.variableDefinitions, token); + index = await resolver.Resolve(elementAccess, replacer.memberAccessValues, nestedIndexers, replacer.variableDefinitions, token); if (index == null) throw new ReturnAsErrorException($"Failed to resolve element access for {elementAccess}", "ReferenceError"); + nestedIndexers.Add(index); } values.Add(index); return values; diff --git a/src/mono/wasm/debugger/BrowserDebugProxy/MemberReferenceResolver.cs b/src/mono/wasm/debugger/BrowserDebugProxy/MemberReferenceResolver.cs index 650583a9dc7bf2..e1b9583ddbe3e1 100644 --- a/src/mono/wasm/debugger/BrowserDebugProxy/MemberReferenceResolver.cs +++ b/src/mono/wasm/debugger/BrowserDebugProxy/MemberReferenceResolver.cs @@ -366,7 +366,12 @@ async Task ResolveAsInstanceMember(ArraySegment parts, JObject } } - public async Task Resolve(ElementAccessExpressionSyntax elementAccess, Dictionary memberAccessValues, JObject indexObject, List variableDefinitions, CancellationToken token) + public async Task Resolve( + ElementAccessExpressionSyntax elementAccess, + Dictionary memberAccessValues, + List nestedIndexObject, + List variableDefinitions, + CancellationToken token) { try { @@ -376,12 +381,13 @@ public async Task Resolve(ElementAccessExpressionSyntax elementAccess, if (rootObject == null) { - // it might be a jagged array where indexObject should be treated as a new rootObject - rootObject = indexObject; - indexObject = null; + // it might be a jagged array where the previously added nestedIndexObject should be treated as a new rootObject + rootObject = nestedIndexObject.LastOrDefault(); + if (rootObject != null) + nestedIndexObject.RemoveAt(nestedIndexObject.Count - 1); } - ElementIndexInfo elementIdxInfo = await GetElementIndexInfo(); + ElementIndexInfo elementIdxInfo = await GetElementIndexInfo(nestedIndexObject); if (elementIdxInfo is null) return null; @@ -394,6 +400,7 @@ public async Task Resolve(ElementAccessExpressionSyntax elementAccess, if (!DotnetObjectId.TryParse(rootObject?["objectId"]?.Value(), out DotnetObjectId objectId)) throw new InvalidOperationException($"Cannot apply indexing with [] to a primitive object of type '{type}'"); + bool isMultidimensional = elementIdxInfo.DimensionsCount != 1; switch (objectId.Scheme) { case "valuetype": //can be an inlined array @@ -407,7 +414,7 @@ public async Task Resolve(ElementAccessExpressionSyntax elementAccess, } case "array": rootObject["value"] = await context.SdbAgent.GetArrayValues(objectId.Value, token); - if (!elementIdxInfo.IsMultidimensional) + if (!isMultidimensional) { int.TryParse(elementIdxInfo.ElementIdxStr, out elementIdx); return (JObject)rootObject["value"][elementIdx]["value"]; @@ -417,10 +424,8 @@ public async Task Resolve(ElementAccessExpressionSyntax elementAccess, return (JObject)(((JArray)rootObject["value"]).FirstOrDefault(x => x["name"].Value() == elementIdxInfo.ElementIdxStr)["value"]); } case "object": - if (elementIdxInfo.IsMultidimensional) - throw new InvalidOperationException($"Cannot apply indexing with [,] to an object of type '{type}'"); // ToDo: try to use the get_Item for string as well - if (type == "string") + if (!isMultidimensional && type == "string") { var eaExpressionFormatted = elementAccessStrExpression.Replace('.', '_'); // instance_str variableDefinitions.Add(new (eaExpressionFormatted, rootObject, ExpressionEvaluator.ConvertJSToCSharpLocalVariableAssignment(eaExpressionFormatted, rootObject))); @@ -428,7 +433,7 @@ public async Task Resolve(ElementAccessExpressionSyntax elementAccess, var variableDef = await ExpressionEvaluator.GetVariableDefinitions(this, variableDefinitions, invokeToStringInObject: false, token); return await ExpressionEvaluator.EvaluateSimpleExpression(this, eaFormatted, elementAccessStr, variableDef, logger, token); } - if (indexObject is null && elementIdxInfo.IndexingExpression is null) + if (elementIdxInfo.Indexers is null || elementIdxInfo.Indexers.Count == 0) throw new InternalErrorException($"Unable to write index parameter to invoke the method in the runtime."); var typeIds = await context.SdbAgent.GetTypeIdsForObject(objectId.Value, true, token); @@ -441,15 +446,13 @@ public async Task Resolve(ElementAccessExpressionSyntax elementAccess, { MethodInfoWithDebugInformation methodInfo = await context.SdbAgent.GetMethodInfo(methodIds[i], token); ParameterInfo[] paramInfo = methodInfo.GetParametersInfo(); - if (paramInfo.Length == 1) + if (paramInfo.Length == elementIdxInfo.DimensionsCount) { try { - if (indexObject != null && !CheckParametersCompatibility(paramInfo[0].TypeCode, indexObject)) + if (!CheckParametersCompatibility(paramInfo, elementIdxInfo.Indexers)) continue; - ArraySegment buffer = indexObject is null ? - await WriteLiteralExpressionAsIndex(objectId, elementIdxInfo.IndexingExpression, elementIdxInfo.ElementIdxStr) : - await WriteJObjectAsIndex(objectId, indexObject, elementIdxInfo.ElementIdxStr, paramInfo[0].TypeCode); + ArraySegment buffer = await WriteIndexObjectAsIndices(objectId, elementIdxInfo.Indexers, paramInfo); JObject getItemRetObj = await context.SdbAgent.InvokeMethod(buffer, methodIds[i], token); return (JObject)getItemRetObj["value"]; } @@ -470,31 +473,32 @@ await WriteLiteralExpressionAsIndex(objectId, elementIdxInfo.IndexingExpression, throw new ReturnAsErrorException($"Unable to evaluate element access '{elementAccess}': {ex.Message}", ex.GetType().Name); } - async Task GetElementIndexInfo() + async Task GetElementIndexInfo(List nestedIndexers) { - // e.g. x[a[0]], x[a[b[1]]] etc. - if (indexObject is not null) - return new ElementIndexInfo(ElementIdxStr: indexObject["value"].ToString() ); - if (elementAccess.ArgumentList is null) return null; - StringBuilder elementIdxStr = new StringBuilder(); - var multiDimensionalArray = false; + int dimCnt = elementAccess.ArgumentList.Arguments.Count; LiteralExpressionSyntax indexingExpression = null; - for (int i = 0; i < elementAccess.ArgumentList.Arguments.Count; i++) + StringBuilder elementIdxStr = new StringBuilder(); + List indexers = new(); + // nesting should be resolved in reverse order + int nestedIndexersCnt = nestedIndexers.Count - 1; + for (int i = 0; i < dimCnt; i++) { + JObject indexObject; var arg = elementAccess.ArgumentList.Arguments[i]; if (i != 0) { elementIdxStr.Append(", "); - multiDimensionalArray = true; } // e.g. x[1] if (arg.Expression is LiteralExpressionSyntax) { indexingExpression = arg.Expression as LiteralExpressionSyntax; - elementIdxStr.Append(indexingExpression.ToString()); + string expression = indexingExpression.ToString(); + elementIdxStr.Append(expression); + indexers.Add(indexingExpression); } // e.g. x[a] or x[a.b] @@ -508,6 +512,18 @@ async Task GetElementIndexInfo() // x[a] indexObject ??= await Resolve(argParm.Identifier.Text, token); elementIdxStr.Append(indexObject["value"].ToString()); + indexers.Add(indexObject); + } + // nested indexing, e.g. x[a[0]], x[a[b[1]]], x[a[0], b[1]] + else if (arg.Expression is ElementAccessExpressionSyntax) + { + if (nestedIndexers == null || nestedIndexersCnt < 0) + throw new InvalidOperationException($"Cannot resolve nested indexing"); + JObject nestedIndexObject = nestedIndexers[nestedIndexersCnt]; + nestedIndexers.RemoveAt(nestedIndexersCnt); + elementIdxStr.Append(nestedIndexObject["value"].ToString()); + indexers.Add(nestedIndexObject); + nestedIndexersCnt--; } // indexing with expressions, e.g. x[a + 1] else @@ -519,36 +535,57 @@ async Task GetElementIndexInfo() if (idxType != "number") throw new InvalidOperationException($"Cannot index with an object of type '{idxType}'"); elementIdxStr.Append(indexObject["value"].ToString()); + indexers.Add(indexObject); } } return new ElementIndexInfo( + DimensionsCount: dimCnt, ElementIdxStr: elementIdxStr.ToString(), - IsMultidimensional: multiDimensionalArray, - IndexingExpression: indexingExpression); + Indexers: indexers); } - async Task> WriteJObjectAsIndex(DotnetObjectId rootObjId, JObject indexObject, string elementIdxStr, ElementType? expectedType) + async Task> WriteIndexObjectAsIndices(DotnetObjectId rootObjId, List indexObjects, ParameterInfo[] paramInfo) { using var writer = new MonoBinaryWriter(); writer.WriteObj(rootObjId, context.SdbAgent); - writer.Write(1); // number of method args - if (!await writer.WriteJsonValue(indexObject, context.SdbAgent, expectedType, token)) - throw new InternalErrorException($"Parsing index of type {indexObject["type"].Value()} to write it into the buffer failed."); + writer.Write(indexObjects.Count); // number of method args + foreach ((ParameterInfo pi, object indexObject) in paramInfo.Zip(indexObjects)) + { + if (indexObject is JObject indexJObject) + { + // indexed by an identifier name syntax + if (!await writer.WriteJsonValue(indexJObject, context.SdbAgent, pi.TypeCode, token)) + throw new InternalErrorException($"Parsing index of type {indexJObject["type"].Value()} to write it into the buffer failed."); + } + else if (indexObject is LiteralExpressionSyntax expression) + { + // indexed by a literal expression syntax + if (!await writer.WriteConst(expression, context.SdbAgent, token)) + throw new InternalErrorException($"Parsing literal expression index = {expression} to write it into the buffer failed."); + } + else + { + throw new InternalErrorException($"Unexpected index type."); + } + } return writer.GetParameterBuffer(); } + } - async Task> WriteLiteralExpressionAsIndex(DotnetObjectId rootObjId, LiteralExpressionSyntax indexingExpression, string elementIdxStr) + private static bool CheckParametersCompatibility(ParameterInfo[] paramInfos, List indexObjects) + { + if (paramInfos.Length != indexObjects.Count) + return false; + foreach ((ParameterInfo paramInfo, object indexObj) in paramInfos.Zip(indexObjects)) { - using var writer = new MonoBinaryWriter(); - writer.WriteObj(rootObjId, context.SdbAgent); - writer.Write(1); // number of method args - if (!await writer.WriteConst(indexingExpression, context.SdbAgent, token)) - throw new InternalErrorException($"Parsing index of type {indexObject["type"].Value()} to write it into the buffer failed."); - return writer.GetParameterBuffer(); + // shouldn't we check LiteralExpressionSyntax for compatibility as well? + if (indexObj is JObject indexJObj && !CheckParameterCompatibility(paramInfo.TypeCode, indexJObj)) + return false; } + return true; } - private static bool CheckParametersCompatibility(ElementType? paramTypeCode, JObject value) + private static bool CheckParameterCompatibility(ElementType? paramTypeCode, JObject value) { if (!paramTypeCode.HasValue) return true; @@ -871,7 +908,8 @@ public JObject TryGetEvaluationResult(string id) private sealed record ElementIndexInfo( string ElementIdxStr, - bool IsMultidimensional = false, - LiteralExpressionSyntax IndexingExpression = null); + // keeps JObjects and LiteralExpressionSyntaxes: + List Indexers, + int DimensionsCount = 1); } } diff --git a/src/mono/wasm/debugger/DebuggerTestSuite/EvaluateOnCallFrame2Tests.cs b/src/mono/wasm/debugger/DebuggerTestSuite/EvaluateOnCallFrame2Tests.cs index 051da33469ce2d..b1a79b28ceeefe 100644 --- a/src/mono/wasm/debugger/DebuggerTestSuite/EvaluateOnCallFrame2Tests.cs +++ b/src/mono/wasm/debugger/DebuggerTestSuite/EvaluateOnCallFrame2Tests.cs @@ -731,5 +731,23 @@ await CheckEvaluateFail(id, ("dt+1", "Cannot evaluate '(dt+1\n)': (2,9): error CS0019: Operator '+' cannot be applied to operands of type 'object' and 'int'") ); }); + + [Fact] + public async Task EvaluateObjectIndexingMultidimensional() => await CheckInspectLocalsAtBreakpointSite( + "DebuggerTests.EvaluateLocalsWithIndexingTests", "EvaluateLocals", 12, "DebuggerTests.EvaluateLocalsWithIndexingTests.EvaluateLocals", + "window.setTimeout(function() { invoke_static_method ('[debugger-test] DebuggerTests.EvaluateLocalsWithIndexingTests:EvaluateLocals'); })", + wait_for_event_fn: async (pause_location) => + { + var id = pause_location["callFrames"][0]["callFrameId"].Value(); + await EvaluateOnCallFrameAndCheck(id, + ("f[j, aDouble]", TNumber("3.34")), //only IdentifierNameSyntaxes + ("f[1, aDouble]", TNumber("3.34")), //IdentifierNameSyntax with LiteralExpressionSyntax + ("f[aChar, \"&\", longString]", TString("9-&-longString")), + ("f[f.numArray[j], aDouble]", TNumber("4.34")), //ElementAccessExpressionSyntax + ("f[f.numArray[j], f.numArray[0]]", TNumber("3")), //multiple ElementAccessExpressionSyntaxes + ("f[f.numArray[f.numList[0]], f.numArray[i]]", TNumber("3")), + ("f[f.numArray[f.numList[0]], f.numArray[f.numArray[i]]]", TNumber("4")) + ); + }); } } diff --git a/src/mono/wasm/debugger/DebuggerTestSuite/EvaluateOnCallFrameTests.cs b/src/mono/wasm/debugger/DebuggerTestSuite/EvaluateOnCallFrameTests.cs index 2ff9bd26a28272..2d0fb87822758a 100644 --- a/src/mono/wasm/debugger/DebuggerTestSuite/EvaluateOnCallFrameTests.cs +++ b/src/mono/wasm/debugger/DebuggerTestSuite/EvaluateOnCallFrameTests.cs @@ -585,7 +585,7 @@ public async Task EvaluateIndexingNegative() => await CheckInspectLocalsAtBreakp Assert.Equal("Unable to evaluate element access 'f.idx0[2]': Cannot apply indexing with [] to a primitive object of type 'number'", res.Error["result"]?["description"]?.Value()); var exceptionDetailsStack = res.Error["exceptionDetails"]?["stackTrace"]?["callFrames"]?[0]; Assert.Equal("DebuggerTests.EvaluateLocalsWithIndexingTests.EvaluateLocals", exceptionDetailsStack?["functionName"]?.Value()); - Assert.Equal(556, exceptionDetailsStack?["lineNumber"]?.Value()); + Assert.Equal(558, exceptionDetailsStack?["lineNumber"]?.Value()); Assert.Equal(12, exceptionDetailsStack?["columnNumber"]?.Value()); (_, res) = await EvaluateOnCallFrame(id, "f[1]", expect_ok: false ); Assert.Equal( "Unable to evaluate element access 'f[1]': Cannot apply indexing with [] to an object of type 'DebuggerTests.EvaluateLocalsWithIndexingTests.TestEvaluate'", res.Error["result"]?["description"]?.Value()); @@ -722,7 +722,7 @@ public async Task EvaluateIndexingByExpressionNegative() => await CheckInspectLo Assert.Equal("Unable to evaluate element access 'f.numList[\"a\" + 1]': Cannot index with an object of type 'string'", res.Error["result"]?["description"]?.Value()); var exceptionDetailsStack = res.Error["exceptionDetails"]?["stackTrace"]?["callFrames"]?[0]; Assert.Equal("DebuggerTests.EvaluateLocalsWithIndexingTests.EvaluateLocals", exceptionDetailsStack?["functionName"]?.Value()); - Assert.Equal(556, exceptionDetailsStack?["lineNumber"]?.Value()); + Assert.Equal(558, exceptionDetailsStack?["lineNumber"]?.Value()); Assert.Equal(12, exceptionDetailsStack?["columnNumber"]?.Value()); }); diff --git a/src/mono/wasm/debugger/DebuggerTestSuite/HotReloadTests.cs b/src/mono/wasm/debugger/DebuggerTestSuite/HotReloadTests.cs index bfeb9ec9b4d5ea..ade5c6adc8bbd5 100644 --- a/src/mono/wasm/debugger/DebuggerTestSuite/HotReloadTests.cs +++ b/src/mono/wasm/debugger/DebuggerTestSuite/HotReloadTests.cs @@ -596,7 +596,8 @@ await SendCommandAndCheck (JObject.FromObject(new { }), "Debugger.resume", scrip await CheckProps (c, new { Field1 = TNumber(123), Field2 = TString("spqr"), - }, "c", num_fields: 2); + Field3 = TString(null), + }, "c", num_fields: 3); }); } diff --git a/src/mono/wasm/debugger/tests/ApplyUpdateReferencedAssembly3/MethodBody2_v2.cs b/src/mono/wasm/debugger/tests/ApplyUpdateReferencedAssembly3/MethodBody2_v2.cs index 379adbe266908e..cbde71233f4b45 100644 --- a/src/mono/wasm/debugger/tests/ApplyUpdateReferencedAssembly3/MethodBody2_v2.cs +++ b/src/mono/wasm/debugger/tests/ApplyUpdateReferencedAssembly3/MethodBody2_v2.cs @@ -22,6 +22,7 @@ public C() } public double Field1; public string Field2; + public string Field3; } } diff --git a/src/mono/wasm/debugger/tests/debugger-test/debugger-evaluate-test.cs b/src/mono/wasm/debugger/tests/debugger-test/debugger-evaluate-test.cs index 480dc30115c430..e46177ecc925b5 100644 --- a/src/mono/wasm/debugger/tests/debugger-test/debugger-evaluate-test.cs +++ b/src/mono/wasm/debugger/tests/debugger-test/debugger-evaluate-test.cs @@ -522,7 +522,6 @@ public class TestEvaluate public int idx0; public int idx1; - // ToDo: add 2d indexing - https://github.com/dotnet/runtime/issues/76062 public string this[char key] => "res_" + key; public string this[bool key] => key.ToString(); public bool this[string key] => key.Length > 3; @@ -530,11 +529,14 @@ public class TestEvaluate public int this[float key] => (int)key; public int this[decimal key] => (int)key; + public double this[int key1, double key2] => key1 + key2; + public string this[char key1, string key2, string key3] => $"{key1}-{key2}-{key3}"; + public void run() { numList = new List { 1, 2 }; textList = new List { "1", "2" }; - numArray = new int[] { 1, 2 }; + numArray = new int[] { 1, 2, 0 }; textArray = new string[] { "1", "2" }; numArrayOfArrays = new int[][] { numArray, numArray }; numListOfLists = new List> { numList, numList }; diff --git a/src/mono/wasm/host/BrowserHost.cs b/src/mono/wasm/host/BrowserHost.cs index adb24f81f88dac..cbe160eeb0214a 100644 --- a/src/mono/wasm/host/BrowserHost.cs +++ b/src/mono/wasm/host/BrowserHost.cs @@ -74,7 +74,7 @@ private async Task RunAsync(ILoggerFactory loggerFactory, CancellationToken toke debugging: _args.CommonConfig.Debugging); runArgsJson.Save(Path.Combine(_args.CommonConfig.AppPath, "runArgs.json")); - string[] urls = envVars.TryGetValue("ASPNETCORE_URLS", out string? aspnetUrls) + string[] urls = (envVars.TryGetValue("ASPNETCORE_URLS", out string? aspnetUrls) && aspnetUrls.Length > 0) ? aspnetUrls.Split(';', StringSplitOptions.RemoveEmptyEntries) : new string[] { $"http://127.0.0.1:{_args.CommonConfig.HostProperties.WebServerPort}", "https://127.0.0.1:0" }; diff --git a/src/mono/wasm/host/DevServer/DevServer.cs b/src/mono/wasm/host/DevServer/DevServer.cs index 1acbe6954eeb3b..9a5a079cee695e 100644 --- a/src/mono/wasm/host/DevServer/DevServer.cs +++ b/src/mono/wasm/host/DevServer/DevServer.cs @@ -46,7 +46,8 @@ internal static class DevServer services.AddSingleton(Options.Create(options)); services.AddSingleton(realUrlsAvailableTcs); services.AddRouting(); - }); + }) + .UseUrls(options.Urls); IWebHost? host = builder.Build(); diff --git a/src/mono/wasm/host/RuntimeConfigJson.cs b/src/mono/wasm/host/RuntimeConfigJson.cs index ed698ed8fb3725..3ad30dd88015ae 100644 --- a/src/mono/wasm/host/RuntimeConfigJson.cs +++ b/src/mono/wasm/host/RuntimeConfigJson.cs @@ -24,7 +24,7 @@ internal sealed record WasmHostProperties( int? FirefoxDebuggingPort, int? ChromeProxyPort, int? ChromeDebuggingPort, - int WebServerPort = 9000) + int WebServerPort = 0) { // using an explicit property because the deserializer doesn't like // extension data in the record constructor diff --git a/src/mono/wasm/runtime/es6/dotnet.es6.pre.js b/src/mono/wasm/runtime/es6/dotnet.es6.pre.js index 9eb9b1c6b99e7d..490935d5ca0284 100644 --- a/src/mono/wasm/runtime/es6/dotnet.es6.pre.js +++ b/src/mono/wasm/runtime/es6/dotnet.es6.pre.js @@ -1,3 +1,5 @@ if (_nativeModuleLoaded) throw new Error("Native module already loaded"); _nativeModuleLoaded = true; -createDotnetRuntime = Module = createDotnetRuntime(Module); \ No newline at end of file +createDotnetRuntime = Module = createDotnetRuntime(Module); +Module["getWasmIndirectFunctionTable"] = function () { return wasmTable; } +Module["getMemory"] = function () { return wasmMemory; } diff --git a/src/mono/wasm/runtime/globals.ts b/src/mono/wasm/runtime/globals.ts index 66948642bbc55c..88be75543ab4e5 100644 --- a/src/mono/wasm/runtime/globals.ts +++ b/src/mono/wasm/runtime/globals.ts @@ -66,6 +66,12 @@ export function setRuntimeGlobals(globalObjects: GlobalObjects) { beforeOnRuntimeInitialized: createPromiseController(), afterOnRuntimeInitialized: createPromiseController(), afterPostRun: createPromiseController(), + mono_wasm_exit: () => { + throw new Error("Mono shutdown"); + }, + abort: (reason: any) => { + throw reason; + } }); Object.assign(globalObjects.module.config!, {}) as any; diff --git a/src/mono/wasm/runtime/jiterpreter-jit-call.ts b/src/mono/wasm/runtime/jiterpreter-jit-call.ts index eba38843de57ef..af918bedb0cf9c 100644 --- a/src/mono/wasm/runtime/jiterpreter-jit-call.ts +++ b/src/mono/wasm/runtime/jiterpreter-jit-call.ts @@ -281,7 +281,7 @@ export function mono_jiterp_do_jit_call_indirect( jit_call_cb: jitCallCb, }, m: { - h: (Module).asm.memory + h: (Module).getMemory() }, }); const impl = instance.exports.do_jit_call_indirect; diff --git a/src/mono/wasm/runtime/jiterpreter-support.ts b/src/mono/wasm/runtime/jiterpreter-support.ts index 5ce5759b2a6a4b..0589c28d5bd06e 100644 --- a/src/mono/wasm/runtime/jiterpreter-support.ts +++ b/src/mono/wasm/runtime/jiterpreter-support.ts @@ -239,9 +239,12 @@ export class WasmBuilder { } getWasmImports(): WebAssembly.Imports { + const memory = (Module).getMemory(); + mono_assert(memory instanceof WebAssembly.Memory, () => `expected heap import to be WebAssembly.Memory but was ${memory}`); + const result: any = { c: this.getConstants(), - m: { h: (Module).asm.memory }, + m: { h: memory }, // f: { f: getWasmFunctionTable() }, }; @@ -1589,7 +1592,7 @@ export function copyIntoScratchBuffer(src: NativePointer, size: number): NativeP export function getWasmFunctionTable() { if (!wasmTable) - wasmTable = (Module)["asm"]["__indirect_function_table"]; + wasmTable = Module.getWasmIndirectFunctionTable(); if (!wasmTable) throw new Error("Module did not export the indirect function table"); return wasmTable; diff --git a/src/mono/wasm/runtime/startup.ts b/src/mono/wasm/runtime/startup.ts index e6f8d80ef82f48..c385844dfd30bb 100644 --- a/src/mono/wasm/runtime/startup.ts +++ b/src/mono/wasm/runtime/startup.ts @@ -5,7 +5,7 @@ import MonoWasmThreads from "consts:monoWasmThreads"; import WasmEnableLegacyJsInterop from "consts:wasmEnableLegacyJsInterop"; import { DotnetModuleInternal, CharPtrNull } from "./types/internal"; -import { linkerDisableLegacyJsInterop, ENVIRONMENT_IS_PTHREAD, exportedRuntimeAPI, INTERNAL, loaderHelpers, Module, runtimeHelpers, createPromiseController, mono_assert, linkerWasmEnableSIMD, linkerWasmEnableEH } from "./globals"; +import { linkerDisableLegacyJsInterop, ENVIRONMENT_IS_PTHREAD, exportedRuntimeAPI, INTERNAL, loaderHelpers, Module, runtimeHelpers, createPromiseController, mono_assert, linkerWasmEnableSIMD, linkerWasmEnableEH, ENVIRONMENT_IS_NODE, ENVIRONMENT_IS_WORKER } from "./globals"; import cwraps, { init_c_exports } from "./cwraps"; import { mono_wasm_raise_debug_event, mono_wasm_runtime_ready } from "./debug"; import { toBase64StringImpl } from "./base64"; @@ -40,15 +40,7 @@ import { assertNoProxies } from "./gc-handles"; const MONO_PTHREAD_POOL_SIZE = 4; export async function configureRuntimeStartup(): Promise { - if (linkerWasmEnableSIMD) { - mono_assert(await loaderHelpers.simd(), "This browser/engine doesn't support WASM SIMD. Please use a modern version. See also https://aka.ms/dotnet-wasm-features"); - } - if (linkerWasmEnableEH) { - mono_assert(await loaderHelpers.exceptions(), "This browser/engine doesn't support WASM exception handling. Please use a modern version. See also https://aka.ms/dotnet-wasm-features"); - } - await init_polyfills_async(); - await checkMemorySnapshotSize(); } @@ -240,6 +232,15 @@ async function onRuntimeInitializedAsync(userOnRuntimeInitialized: () => void) { // wait for previous stage await runtimeHelpers.afterPreRun.promise; mono_log_debug("onRuntimeInitialized"); + + runtimeHelpers.mono_wasm_exit = cwraps.mono_wasm_exit; + runtimeHelpers.abort = (reason: any) => { + if (!loaderHelpers.is_exited()) { + cwraps.mono_wasm_abort(); + } + throw reason; + }; + const mark = startMeasure(); // signal this stage, this will allow pending assets to allocate memory runtimeHelpers.beforeOnRuntimeInitialized.promise_control.resolve(); @@ -272,6 +273,10 @@ async function onRuntimeInitializedAsync(userOnRuntimeInitialized: () => void) { bindings_init(); runtimeHelpers.runtimeReady = true; + if (ENVIRONMENT_IS_NODE && !ENVIRONMENT_IS_WORKER) { + Module.runtimeKeepalivePush(); + } + if (MonoWasmThreads) { runtimeHelpers.javaScriptExports.install_synchronization_context(); runtimeHelpers.jsSynchronizationContextInstalled = true; @@ -361,13 +366,6 @@ function mono_wasm_pre_init_essential(isWorker: boolean): void { } init_c_exports(); - runtimeHelpers.mono_wasm_exit = cwraps.mono_wasm_exit; - runtimeHelpers.abort = (reason: any) => { - if (!loaderHelpers.is_exited()) { - cwraps.mono_wasm_abort(); - } - throw reason; - }; cwraps_internal(INTERNAL); if (WasmEnableLegacyJsInterop && !linkerDisableLegacyJsInterop) { cwraps_mono_api(MONO); @@ -386,7 +384,6 @@ async function mono_wasm_pre_init_essential_async(): Promise { mono_log_debug("mono_wasm_pre_init_essential_async"); Module.addRunDependency("mono_wasm_pre_init_essential_async"); - if (MonoWasmThreads) { preAllocatePThreadWorkerPool(MONO_PTHREAD_POOL_SIZE, runtimeHelpers.config); } @@ -466,8 +463,12 @@ async function instantiate_wasm_module( await runtimeHelpers.beforePreInit.promise; Module.addRunDependency("instantiate_wasm_module"); + const wasmFeaturePromise = ensureUsedWasmFeatures(); + replace_linker_placeholders(imports); const assetToLoad = await loaderHelpers.wasmDownloadPromise.promise; + + await wasmFeaturePromise; await instantiate_wasm_asset(assetToLoad, imports, successCallback); assetToLoad.pendingDownloadInternal = null as any; // GC assetToLoad.pendingDownload = null as any; // GC @@ -499,6 +500,15 @@ async function instantiate_wasm_module( Module.removeRunDependency("instantiate_wasm_module"); } +async function ensureUsedWasmFeatures() { + if (linkerWasmEnableSIMD) { + mono_assert(await loaderHelpers.simd(), "This browser/engine doesn't support WASM SIMD. Please use a modern version. See also https://aka.ms/dotnet-wasm-features"); + } + if (linkerWasmEnableEH) { + mono_assert(await loaderHelpers.exceptions(), "This browser/engine doesn't support WASM exception handling. Please use a modern version. See also https://aka.ms/dotnet-wasm-features"); + } +} + async function mono_wasm_before_memory_snapshot() { const mark = startMeasure(); if (runtimeHelpers.loadedMemorySnapshotSize) { diff --git a/src/mono/wasm/runtime/types/internal.ts b/src/mono/wasm/runtime/types/internal.ts index aedf803aceff1f..a91b21973c8d5e 100644 --- a/src/mono/wasm/runtime/types/internal.ts +++ b/src/mono/wasm/runtime/types/internal.ts @@ -454,6 +454,8 @@ export declare interface EmscriptenModuleInternal { ready: Promise; asm: { memory?: WebAssembly.Memory }; wasmMemory?: WebAssembly.Memory; + getWasmIndirectFunctionTable: any; + getMemory: WebAssembly.Memory; getWasmTableEntry(index: number): any; removeRunDependency(id: string): void; addRunDependency(id: string): void; diff --git a/src/mono/wasm/testassets/WasmBasicTestApp/WasmBasicTestApp.csproj b/src/mono/wasm/testassets/WasmBasicTestApp/WasmBasicTestApp.csproj index 33858b9a6a755f..761ac6354ce861 100644 --- a/src/mono/wasm/testassets/WasmBasicTestApp/WasmBasicTestApp.csproj +++ b/src/mono/wasm/testassets/WasmBasicTestApp/WasmBasicTestApp.csproj @@ -4,6 +4,7 @@ browser-wasm Exe true + true diff --git a/src/mono/wasm/wasm.proj b/src/mono/wasm/wasm.proj index 66bc509910571a..3b75fa9feeb1f0 100644 --- a/src/mono/wasm/wasm.proj +++ b/src/mono/wasm/wasm.proj @@ -33,7 +33,6 @@ <_EmccDefaultsRspPath>$(NativeBinDir)src\emcc-default.rsp <_EmccCompileRspPath>$(NativeBinDir)src\emcc-compile.rsp <_EmccLinkRspPath>$(NativeBinDir)src\emcc-link.rsp - false $(EMSDK_PATH)\upstream\bin\llvm-ar $(EmSdkLLVMAr).exe diff --git a/src/native/libs/System.Globalization.Native/entrypoints.c b/src/native/libs/System.Globalization.Native/entrypoints.c index cffad72a023721..84d2177d558841 100644 --- a/src/native/libs/System.Globalization.Native/entrypoints.c +++ b/src/native/libs/System.Globalization.Native/entrypoints.c @@ -69,6 +69,7 @@ static const Entry s_globalizationNative[] = DllImportEntry(GlobalizationNative_GetLocaleInfoSecondaryGroupingSizeNative) DllImportEntry(GlobalizationNative_GetLocaleInfoStringNative) DllImportEntry(GlobalizationNative_GetLocaleNameNative) + DllImportEntry(GlobalizationNative_GetLocalesNative) DllImportEntry(GlobalizationNative_GetLocaleTimeFormatNative) DllImportEntry(GlobalizationNative_IndexOfNative) DllImportEntry(GlobalizationNative_StartsWithNative) diff --git a/src/native/libs/System.Globalization.Native/pal_locale.h b/src/native/libs/System.Globalization.Native/pal_locale.h index 7fe89f667f2132..4a1fe0768e4fda 100644 --- a/src/native/libs/System.Globalization.Native/pal_locale.h +++ b/src/native/libs/System.Globalization.Native/pal_locale.h @@ -21,4 +21,6 @@ PALEXPORT int32_t GlobalizationNative_GetLocaleTimeFormat(const UChar* localeNam PALEXPORT const char* GlobalizationNative_GetLocaleNameNative(const char* localeName); PALEXPORT const char* GlobalizationNative_GetLocaleTimeFormatNative(const char* localeName, int shortFormat); + +PALEXPORT int32_t GlobalizationNative_GetLocalesNative(UChar* locales, int32_t length); #endif diff --git a/src/native/libs/System.Globalization.Native/pal_locale.m b/src/native/libs/System.Globalization.Native/pal_locale.m index d8ab7da1fbee0c..4789ac89691da2 100644 --- a/src/native/libs/System.Globalization.Native/pal_locale.m +++ b/src/native/libs/System.Globalization.Native/pal_locale.m @@ -97,7 +97,7 @@ static void GetParent(const char* localeID, char* parent, int32_t parentCapacity { @autoreleasepool { - const char* value; + NSString *value; NSString *locName = [NSString stringWithFormat:@"%s", localeName]; NSLocale *currentLocale = [[NSLocale alloc] initWithLocaleIdentifier:locName]; NSNumberFormatter *numberFormatter = [[NSNumberFormatter alloc] init]; @@ -112,35 +112,35 @@ static void GetParent(const char* localeID, char* parent, int32_t parentCapacity case LocaleString_LocalizedDisplayName: /// Display name (language + country usually) in English, eg "German (Germany)" (corresponds to LOCALE_SENGLISHDISPLAYNAME) case LocaleString_EnglishDisplayName: - value = [[gbLocale displayNameForKey:NSLocaleIdentifier value:currentLocale.localeIdentifier] UTF8String]; - break; + value = [gbLocale displayNameForKey:NSLocaleIdentifier value:currentLocale.localeIdentifier]; + break; /// Display name in native locale language, eg "Deutsch (Deutschland) (corresponds to LOCALE_SNATIVEDISPLAYNAME) case LocaleString_NativeDisplayName: - value = [[currentLocale displayNameForKey:NSLocaleIdentifier value:currentLocale.localeIdentifier] UTF8String]; + value = [currentLocale displayNameForKey:NSLocaleIdentifier value:currentLocale.localeIdentifier]; break; /// Language Display Name for a language, eg "German" in UI language (corresponds to LOCALE_SLOCALIZEDLANGUAGENAME) case LocaleString_LocalizedLanguageName: /// English name of language, eg "German" (corresponds to LOCALE_SENGLISHLANGUAGENAME) case LocaleString_EnglishLanguageName: - value = [[gbLocale localizedStringForLanguageCode:currentLocale.languageCode] UTF8String]; + value = [gbLocale localizedStringForLanguageCode:currentLocale.languageCode]; break; /// native name of language, eg "Deutsch" (corresponds to LOCALE_SNATIVELANGUAGENAME) case LocaleString_NativeLanguageName: - value = [[currentLocale localizedStringForLanguageCode:currentLocale.languageCode] UTF8String]; + value = [currentLocale localizedStringForLanguageCode:currentLocale.languageCode]; break; /// English name of country, eg "Germany" (corresponds to LOCALE_SENGLISHCOUNTRYNAME) case LocaleString_EnglishCountryName: - value = [[gbLocale localizedStringForCountryCode:currentLocale.countryCode] UTF8String]; + value = [gbLocale localizedStringForCountryCode:currentLocale.countryCode]; break; /// native name of country, eg "Deutschland" (corresponds to LOCALE_SNATIVECOUNTRYNAME) case LocaleString_NativeCountryName: - value = [[currentLocale localizedStringForCountryCode:currentLocale.countryCode] UTF8String]; + value = [currentLocale localizedStringForCountryCode:currentLocale.countryCode]; break; case LocaleString_ThousandSeparator: - value = [currentLocale.groupingSeparator UTF8String]; + value = currentLocale.groupingSeparator; break; case LocaleString_DecimalSeparator: - value = [currentLocale.decimalSeparator UTF8String]; + value = currentLocale.decimalSeparator; // or value = [[currentLocale objectForKey:NSLocaleDecimalSeparator] UTF8String]; break; case LocaleString_Digits: @@ -150,87 +150,84 @@ static void GetParent(const char* localeID, char* parent, int32_t parentCapacity [nf1 setLocale:currentLocale]; NSNumber *newNum = [nf1 numberFromString:digitsString]; - value = [[newNum stringValue] UTF8String]; + value = [newNum stringValue]; break; } case LocaleString_MonetarySymbol: - value = [currentLocale.currencySymbol UTF8String]; + value = currentLocale.currencySymbol; break; case LocaleString_Iso4217MonetarySymbol: // check if this is correct, check currencyISOCode - value = [currentLocale.currencySymbol UTF8String]; + value = currentLocale.currencyCode; break; case LocaleString_CurrencyEnglishName: - value = [[gbLocale localizedStringForCurrencyCode:currentLocale.currencyCode] UTF8String]; + value = [gbLocale localizedStringForCurrencyCode:currentLocale.currencyCode]; break; case LocaleString_CurrencyNativeName: - value = [[currentLocale localizedStringForCurrencyCode:currentLocale.currencyCode] UTF8String]; + value = [currentLocale localizedStringForCurrencyCode:currentLocale.currencyCode]; break; case LocaleString_MonetaryDecimalSeparator: - value = [numberFormatter.currencyDecimalSeparator UTF8String]; + value = numberFormatter.currencyDecimalSeparator; break; case LocaleString_MonetaryThousandSeparator: - value = [numberFormatter.currencyGroupingSeparator UTF8String]; + value = numberFormatter.currencyGroupingSeparator; break; case LocaleString_AMDesignator: - value = [dateFormatter.AMSymbol UTF8String]; + value = dateFormatter.AMSymbol; break; case LocaleString_PMDesignator: - value = [dateFormatter.PMSymbol UTF8String]; + value = dateFormatter.PMSymbol; break; case LocaleString_PositiveSign: - value = [numberFormatter.plusSign UTF8String]; + value = numberFormatter.plusSign; break; case LocaleString_NegativeSign: - value = [numberFormatter.minusSign UTF8String]; + value = numberFormatter.minusSign; break; case LocaleString_Iso639LanguageTwoLetterName: - value = [[currentLocale objectForKey:NSLocaleLanguageCode] UTF8String]; + value = [currentLocale objectForKey:NSLocaleLanguageCode]; break; case LocaleString_Iso639LanguageThreeLetterName: { NSString *iso639_2 = [currentLocale objectForKey:NSLocaleLanguageCode]; - value = uloc_getISO3LanguageByLangCode([iso639_2 UTF8String]); - break; + return iso639_2 == nil ? strdup("") : strdup(uloc_getISO3LanguageByLangCode([iso639_2 UTF8String])); } case LocaleString_Iso3166CountryName: - value = [[currentLocale objectForKey:NSLocaleCountryCode] UTF8String]; + value = [currentLocale objectForKey:NSLocaleCountryCode]; break; case LocaleString_Iso3166CountryName2: { - const char *countryCode = strdup([[currentLocale objectForKey:NSLocaleCountryCode] UTF8String]); - value = uloc_getISO3CountryByCountryCode(countryCode); - break; + NSString* countryCode = [currentLocale objectForKey:NSLocaleCountryCode]; + return countryCode == nil ? strdup("") : strdup(uloc_getISO3CountryByCountryCode([countryCode UTF8String])); } case LocaleString_NaNSymbol: - value = [numberFormatter.notANumberSymbol UTF8String]; + value = numberFormatter.notANumberSymbol; break; case LocaleString_PositiveInfinitySymbol: - value = [numberFormatter.positiveInfinitySymbol UTF8String]; + value = numberFormatter.positiveInfinitySymbol; break; case LocaleString_NegativeInfinitySymbol: - value = [numberFormatter.negativeInfinitySymbol UTF8String]; + value = numberFormatter.negativeInfinitySymbol; break; case LocaleString_PercentSymbol: - value = [numberFormatter.percentSymbol UTF8String]; + value = numberFormatter.percentSymbol; break; case LocaleString_PerMilleSymbol: - value = [numberFormatter.perMillSymbol UTF8String]; + value = numberFormatter.perMillSymbol; break; case LocaleString_ParentName: { char localeNameTemp[FULLNAME_CAPACITY]; const char* lName = [currentLocale.localeIdentifier UTF8String]; GetParent(lName, localeNameTemp, FULLNAME_CAPACITY); - value = strdup(localeNameTemp); - break; + return strdup(localeNameTemp); } default: - value = ""; + value = nil; break; } - return value ? strdup(value) : ""; + return value == nil ? strdup("") : strdup([value UTF8String]); } } @@ -667,6 +664,54 @@ Returns time format information (in native format, it needs to be converted to . } } +// GlobalizationNative_GetLocalesNative gets all locale names and store it in the value buffer +// in case of success, it returns the count of the characters stored in value buffer +// in case of failure, it returns negative number. +// if the input value buffer is null, it returns the length needed to store the +// locale names list. +// if the value is not null, it fills the value with locale names separated by the length +// of each name. +int32_t GlobalizationNative_GetLocalesNative(UChar* value, int32_t length) +{ + @autoreleasepool + { + NSArray* availableLocaleIdentifiers = [NSLocale availableLocaleIdentifiers]; + int32_t index = 0; + int32_t totalLength = 0; + int32_t availableLength = (int32_t)[availableLocaleIdentifiers count]; + + if (availableLength <= 0) + return -1; // failed + + for (NSInteger i = 0; i < availableLength; i++) + { + NSString *localeIdentifier = availableLocaleIdentifiers[i]; + int32_t localeNameLength = localeIdentifier.length; + totalLength += localeNameLength + 1; // add 1 for the name length + if (value != NULL) + { + if (totalLength > length) + return -3; + + value[index++] = (UChar) localeNameLength; + + for (int j = 0; j < localeNameLength; j++) + { + if ((UChar)[localeIdentifier characterAtIndex:j] == '_') + { + value[index++] = (UChar) '-'; + } + else + { + value[index++] = (UChar) [localeIdentifier characterAtIndex:j]; + } + } + } + } + return totalLength; + } +} + #endif #if defined(TARGET_MACCATALYST) || defined(TARGET_IOS) || defined(TARGET_TVOS) diff --git a/src/tasks/Microsoft.NET.WebAssembly.Webcil/Microsoft.NET.WebAssembly.Webcil.csproj b/src/tasks/Microsoft.NET.WebAssembly.Webcil/Microsoft.NET.WebAssembly.Webcil.csproj index c35eb57e80686b..d09ae4a569a598 100644 --- a/src/tasks/Microsoft.NET.WebAssembly.Webcil/Microsoft.NET.WebAssembly.Webcil.csproj +++ b/src/tasks/Microsoft.NET.WebAssembly.Webcil/Microsoft.NET.WebAssembly.Webcil.csproj @@ -16,6 +16,7 @@ + diff --git a/src/tasks/Microsoft.NET.WebAssembly.Webcil/WebcilConverter.cs b/src/tasks/Microsoft.NET.WebAssembly.Webcil/WebcilConverter.cs index a38af7270a2dad..13c34bde4b8ea1 100644 --- a/src/tasks/Microsoft.NET.WebAssembly.Webcil/WebcilConverter.cs +++ b/src/tasks/Microsoft.NET.WebAssembly.Webcil/WebcilConverter.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Buffers.Binary; using System.IO; using System.Collections.Immutable; using System.Reflection.PortableExecutable; @@ -174,16 +175,23 @@ public unsafe void GatherInfo(PEReader peReader, out WCFileInfo wcInfo, out PEFi SectionStart: firstWCSection); } - private static void WriteHeader(Stream s, WebcilHeader header) + private static void WriteHeader(Stream s, WebcilHeader webcilHeader) { - WriteStructure(s, header); + if (!BitConverter.IsLittleEndian) + { + webcilHeader.version_major = BinaryPrimitives.ReverseEndianness(webcilHeader.version_major); + webcilHeader.version_minor = BinaryPrimitives.ReverseEndianness(webcilHeader.version_minor); + webcilHeader.coff_sections = BinaryPrimitives.ReverseEndianness(webcilHeader.coff_sections); + webcilHeader.pe_cli_header_rva = BinaryPrimitives.ReverseEndianness(webcilHeader.pe_cli_header_rva); + webcilHeader.pe_cli_header_size = BinaryPrimitives.ReverseEndianness(webcilHeader.pe_cli_header_size); + webcilHeader.pe_debug_rva = BinaryPrimitives.ReverseEndianness(webcilHeader.pe_debug_rva); + webcilHeader.pe_debug_size = BinaryPrimitives.ReverseEndianness(webcilHeader.pe_debug_size); + } + WriteStructure(s, webcilHeader); } private static void WriteSectionHeaders(Stream s, ImmutableArray sectionsHeaders) { - // FIXME: fixup endianness - if (!BitConverter.IsLittleEndian) - throw new NotImplementedException(); foreach (var sectionHeader in sectionsHeaders) { WriteSectionHeader(s, sectionHeader); @@ -192,6 +200,16 @@ private static void WriteSectionHeaders(Stream s, ImmutableArray(Stream s, T structure) where T : unmanaged { - // FIXME: fixup endianness - if (!BitConverter.IsLittleEndian) - throw new NotImplementedException(); unsafe { byte* p = (byte*)&structure; @@ -212,9 +227,6 @@ private static void WriteStructure(Stream s, T structure) private static void WriteStructure(Stream s, T structure) where T : unmanaged { - // FIXME: fixup endianness - if (!BitConverter.IsLittleEndian) - throw new NotImplementedException(); int size = Marshal.SizeOf(); byte[] buffer = new byte[size]; IntPtr ptr = IntPtr.Zero; diff --git a/src/tasks/Microsoft.NET.WebAssembly.Webcil/WebcilReader.cs b/src/tasks/Microsoft.NET.WebAssembly.Webcil/WebcilReader.cs index 4f42f827986643..ac4f9d86095a90 100644 --- a/src/tasks/Microsoft.NET.WebAssembly.Webcil/WebcilReader.cs +++ b/src/tasks/Microsoft.NET.WebAssembly.Webcil/WebcilReader.cs @@ -6,7 +6,7 @@ using System.IO; using System.Reflection; using System.Runtime.InteropServices; - +using System.Buffers.Binary; using System.Reflection.Metadata; using System.Reflection.PortableExecutable; @@ -63,14 +63,20 @@ private unsafe bool ReadHeader() { return false; } - if (!BitConverter.IsLittleEndian) - { - throw new NotImplementedException("TODO: implement big endian support"); - } fixed (byte* p = buffer) { header = *(WebcilHeader*)p; } + if (!BitConverter.IsLittleEndian) + { + header.version_major = BinaryPrimitives.ReverseEndianness(header.version_major); + header.version_minor = BinaryPrimitives.ReverseEndianness(header.version_minor); + header.coff_sections = BinaryPrimitives.ReverseEndianness(header.coff_sections); + header.pe_cli_header_rva = BinaryPrimitives.ReverseEndianness(header.pe_cli_header_rva); + header.pe_cli_header_size = BinaryPrimitives.ReverseEndianness(header.pe_cli_header_size); + header.pe_debug_rva = BinaryPrimitives.ReverseEndianness(header.pe_debug_rva); + header.pe_debug_rva = BinaryPrimitives.ReverseEndianness(header.pe_debug_size); + } if (header.id[0] != 'W' || header.id[1] != 'b' || header.id[2] != 'I' || header.id[3] != 'L' || header.version_major != Internal.Constants.WC_VERSION_MAJOR @@ -346,6 +352,7 @@ private long TranslateRVA(uint rva) private unsafe ImmutableArray ReadSections() { + WebcilSectionHeader secheader; var sections = ImmutableArray.CreateBuilder(_header.coff_sections); var buffer = new byte[Marshal.SizeOf()]; _stream.Seek(SectionDirectoryOffset + _webcilInWasmOffset, SeekOrigin.Begin); @@ -357,8 +364,24 @@ private unsafe ImmutableArray ReadSections() } fixed (byte* p = buffer) { - // FIXME endianness - sections.Add(*(WebcilSectionHeader*)p); + secheader = (*(WebcilSectionHeader*)p); + } + if (!BitConverter.IsLittleEndian) + { + sections.Add + ( + new WebcilSectionHeader + ( + virtualSize: BinaryPrimitives.ReverseEndianness(secheader.VirtualSize), + virtualAddress: BinaryPrimitives.ReverseEndianness(secheader.VirtualAddress), + sizeOfRawData: BinaryPrimitives.ReverseEndianness(secheader.SizeOfRawData), + pointerToRawData: BinaryPrimitives.ReverseEndianness(secheader.PointerToRawData) + ) + ); + } + else + { + sections.Add(secheader); } } return sections.MoveToImmutable(); diff --git a/src/tasks/WorkloadBuildTasks/InstallWorkloadFromArtifacts.cs b/src/tasks/WorkloadBuildTasks/InstallWorkloadFromArtifacts.cs index e520057d5b3bdf..21170ea2152843 100644 --- a/src/tasks/WorkloadBuildTasks/InstallWorkloadFromArtifacts.cs +++ b/src/tasks/WorkloadBuildTasks/InstallWorkloadFromArtifacts.cs @@ -49,7 +49,7 @@ public partial class InstallWorkloadFromArtifacts : Task private string _tempDir = string.Empty; private string _nugetCachePath = string.Empty; - [GeneratedRegex(@"^\d+\.\d+\.\d+(-[A-z]*\.*\d*)?")] + [GeneratedRegex(@"^\d+\.\d+\.\d+(-rtm|-[A-z]*\.*\d*)?")] private static partial Regex bandVersionRegex(); public override bool Execute() @@ -215,7 +215,7 @@ private bool InstallPacks(InstallWorkloadRequest req, string nugetConfigContents (int exitCode, string output) = Utils.TryRunProcess( Log, Path.Combine(req.TargetPath, "dotnet"), - $"workload install --skip-manifest-update --configfile \"{nugetConfigPath}\" --temp-dir \"{_tempDir}/workload-install-temp\" {req.WorkloadId}", + $"workload install --skip-manifest-update --skip-sign-check --configfile \"{nugetConfigPath}\" --temp-dir \"{_tempDir}/workload-install-temp\" {req.WorkloadId}", workingDir: _tempDir, envVars: new Dictionary () { ["NUGET_PACKAGES"] = _nugetCachePath @@ -301,8 +301,8 @@ private bool InstallWorkloadManifest(ITaskItem workloadId, string name, string v string packagePreleaseVersion = bandVersionRegex().Match(version).Groups[1].Value; string bandPreleaseVersion = bandVersionRegex().Match(bandVersion).Groups[1].Value; - if (packagePreleaseVersion != bandPreleaseVersion && packagePreleaseVersion != "-dev" && packagePreleaseVersion != "-ci") - bandVersion = bandVersion.Replace (bandPreleaseVersion, packagePreleaseVersion); + if (packagePreleaseVersion != bandPreleaseVersion && packagePreleaseVersion != "-dev" && packagePreleaseVersion != "-ci" && bandPreleaseVersion != "") + bandVersion = bandVersion.Replace(bandPreleaseVersion, packagePreleaseVersion); PackageReference pkgRef = new(Name: $"{name}.Manifest-{bandVersion}", Version: version, diff --git a/src/tests/Common/external/external.csproj b/src/tests/Common/external/external.csproj index d2541b5ae4835b..71aa7a42b5661c 100644 --- a/src/tests/Common/external/external.csproj +++ b/src/tests/Common/external/external.csproj @@ -13,7 +13,7 @@ --> $(TargetingPackPath) $(NetCoreAppToolCurrent) - win7-x86;win7-x64 + win-x86;win-x64 SharedLibrary false false diff --git a/src/tests/Interop/COM/NETClients/IDispatch/Program.cs b/src/tests/Interop/COM/NETClients/IDispatch/Program.cs index 71839d9afffb8f..fc9144575f0d91 100644 --- a/src/tests/Interop/COM/NETClients/IDispatch/Program.cs +++ b/src/tests/Interop/COM/NETClients/IDispatch/Program.cs @@ -139,6 +139,21 @@ static void Validate_Exception() Assert.Equal(GetErrorCodeFromHResult(e.HResult), errorCode); // Failing HRESULT exceptions contain CLR generated messages } + + // Calling methods through IDispatch::Invoke() (i.e., late-bound) doesn't + // propagate the HRESULT when marked with PreserveSig. It is always 0. + { + Console.WriteLine($"Calling {nameof(DispatchTesting.TriggerException)} (PreserveSig) with {nameof(IDispatchTesting_Exception.Int)} {errorCode}..."); + var dispatchTesting2 = (IDispatchTestingPreserveSig1)dispatchTesting; + Assert.Equal(0, dispatchTesting2.TriggerException(IDispatchTesting_Exception.Int, errorCode)); + } + + { + // Validate the HRESULT as a value type construct works for IDispatch. + Console.WriteLine($"Calling {nameof(DispatchTesting.TriggerException)} (PreserveSig, ValueType) with {nameof(IDispatchTesting_Exception.Int)} {errorCode}..."); + var dispatchTesting3 = (IDispatchTestingPreserveSig2)dispatchTesting; + Assert.Equal(0, dispatchTesting3.TriggerException(IDispatchTesting_Exception.Int, errorCode).Value); + } } static void Validate_StructNotSupported() diff --git a/src/tests/Interop/COM/NETServer/DispatchTesting.cs b/src/tests/Interop/COM/NETServer/DispatchTesting.cs index 477e5751f69e73..66461b8c7e47f2 100644 --- a/src/tests/Interop/COM/NETServer/DispatchTesting.cs +++ b/src/tests/Interop/COM/NETServer/DispatchTesting.cs @@ -57,6 +57,7 @@ public void TriggerException(IDispatchTesting_Exception excep, int errorCode) case IDispatchTesting_Exception.Disp: throw new Exception(); case IDispatchTesting_Exception.HResult: + case IDispatchTesting_Exception.Int: throw new System.ComponentModel.Win32Exception(errorCode); } } diff --git a/src/tests/Interop/COM/NativeServer/DispatchTesting.h b/src/tests/Interop/COM/NativeServer/DispatchTesting.h index 927439fe03dc48..fbe7db6c1bad7f 100644 --- a/src/tests/Interop/COM/NativeServer/DispatchTesting.h +++ b/src/tests/Interop/COM/NativeServer/DispatchTesting.h @@ -243,6 +243,8 @@ class DispatchTesting : public UnknownImpl, public IDispatchTesting return DISP_E_EXCEPTION; case IDispatchTesting_Exception_HResult: return HRESULT_FROM_WIN32(errorCode); + case IDispatchTesting_Exception_Int: + return errorCode; default: return S_FALSE; // Return a success case to indicate failure to trigger a failure. } diff --git a/src/tests/Interop/COM/ServerContracts/Server.Contracts.cs b/src/tests/Interop/COM/ServerContracts/Server.Contracts.cs index 758c200acaabae..0bac21e66ee17e 100644 --- a/src/tests/Interop/COM/ServerContracts/Server.Contracts.cs +++ b/src/tests/Interop/COM/ServerContracts/Server.Contracts.cs @@ -209,6 +209,7 @@ public enum IDispatchTesting_Exception { Disp, HResult, + Int, } [StructLayout(LayoutKind.Sequential)] @@ -220,6 +221,12 @@ public struct HFA_4 public float w; } + [StructLayout(LayoutKind.Sequential)] + public struct HRESULT + { + public int Value; + } + [ComVisible(true)] [Guid("a5e04c1c-474e-46d2-bbc0-769d04e12b54")] [InterfaceType(ComInterfaceType.InterfaceIsIDispatch)] @@ -257,6 +264,32 @@ void DoubleNumeric_ReturnByRef ( System.Collections.IEnumerator GetEnumerator(); } + [ComVisible(true)] + [Guid("a5e04c1c-474e-46d2-bbc0-769d04e12b54")] + [InterfaceType(ComInterfaceType.InterfaceIsIDispatch)] + public interface IDispatchTestingPreserveSig1 + { + void Reserved1(); + void Reserved2(); + void Reserved3(); + + [PreserveSig] + int TriggerException(IDispatchTesting_Exception excep, int errorCode); + } + + [ComVisible(true)] + [Guid("a5e04c1c-474e-46d2-bbc0-769d04e12b54")] + [InterfaceType(ComInterfaceType.InterfaceIsIDispatch)] + public interface IDispatchTestingPreserveSig2 + { + void Reserved1(); + void Reserved2(); + void Reserved3(); + + [PreserveSig] + HRESULT TriggerException(IDispatchTesting_Exception excep, int errorCode); + } + [ComVisible(true)] [Guid("83AFF8E4-C46A-45DB-9D91-2ADB5164545E")] [InterfaceType(ComInterfaceType.InterfaceIsIDispatch)] diff --git a/src/tests/Interop/COM/ServerContracts/Server.Contracts.h b/src/tests/Interop/COM/ServerContracts/Server.Contracts.h index 3c9a1fcb06cbe1..1eb0528aae4b78 100644 --- a/src/tests/Interop/COM/ServerContracts/Server.Contracts.h +++ b/src/tests/Interop/COM/ServerContracts/Server.Contracts.h @@ -385,6 +385,7 @@ enum IDispatchTesting_Exception { IDispatchTesting_Exception_Disp, IDispatchTesting_Exception_HResult, + IDispatchTesting_Exception_Int, }; struct __declspec(uuid("a5e04c1c-474e-46d2-bbc0-769d04e12b54")) diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_83387/Runtime_83387.cs b/src/tests/JIT/Regression/JitBlue/Runtime_83387/Runtime_83387.cs new file mode 100644 index 00000000000000..cb70acd1677573 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_83387/Runtime_83387.cs @@ -0,0 +1,19 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; +using Xunit; + +public class Runtime_83387 +{ + [MethodImpl(MethodImplOptions.NoOptimization)] + [Fact] + public static int TestEntryPoint() + { + (ushort A, ushort R) c = (1, 65535); + Vector128 v1 = Vector128.Create((uint)100); + v1 = v1 * c.A; + return (int)v1.ToScalar(); + } +} diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_83387/Runtime_83387.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_83387/Runtime_83387.csproj new file mode 100644 index 00000000000000..15edd99711a1a4 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_83387/Runtime_83387.csproj @@ -0,0 +1,8 @@ + + + True + + + + + \ No newline at end of file diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_91252/Runtime_91252.cs b/src/tests/JIT/Regression/JitBlue/Runtime_91252/Runtime_91252.cs new file mode 100644 index 00000000000000..d4035f3de978fc --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_91252/Runtime_91252.cs @@ -0,0 +1,37 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// +// This test verifies if we correctly value number the operation of +// x ^ x to zero. +// +// Found by Antigen + +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; +using Xunit; + +public class Issue_91252 +{ + static Vector64 s_v64_int_22 = Vector64.Create(-5); + Vector64 v64_int_72 = Vector64.Create(-1); + + [MethodImpl(MethodImplOptions.NoInlining)] + public int Repro() + { + s_v64_int_22 = v64_int_72; + return Check(v64_int_72 ^ v64_int_72); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public int Check(Vector64 a) + { + return (a == Vector64.Zero) ? 100 : 101; + } + + [Fact] + public static int EntryPoint() + { + var obj = new Issue_91252(); + return obj.Repro(); + } +} diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_91252/Runtime_91252.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_91252/Runtime_91252.csproj new file mode 100644 index 00000000000000..de6d5e08882e86 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_91252/Runtime_91252.csproj @@ -0,0 +1,8 @@ + + + True + + + + + diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_92218/Runtime_92218.cs b/src/tests/JIT/Regression/JitBlue/Runtime_92218/Runtime_92218.cs new file mode 100644 index 00000000000000..9b4696e31fc16c --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_92218/Runtime_92218.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.CompilerServices; +using System.Threading; +using Xunit; + +public struct MutableStruct +{ + private long _internalValue; + + public long InternalValue + { + get => Volatile.Read(ref _internalValue); + private set => Volatile.Write(ref _internalValue, value); + } + + public void Add(long value) => AddInternal(value); + private void AddInternal(long value) => InternalValue += value; + public MutableStruct(long value) => InternalValue = value; +} + +public static class Runtime_92218 +{ + [Fact] + [MethodImpl(MethodImplOptions.AggressiveOptimization)] + public static void Problem() + { + var test = new MutableStruct(420); + var from = new MutableStruct(42); + + var wrapper = -new TimeSpan(3); + + while (test.InternalValue >= from.InternalValue) + { + test.Add(wrapper.Ticks); + } + } +} \ No newline at end of file diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_92218/Runtime_92218.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_92218/Runtime_92218.csproj new file mode 100644 index 00000000000000..15edd99711a1a4 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_92218/Runtime_92218.csproj @@ -0,0 +1,8 @@ + + + True + + + + + \ No newline at end of file diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_92349/Runtime_92349.cs b/src/tests/JIT/Regression/JitBlue/Runtime_92349/Runtime_92349.cs new file mode 100644 index 00000000000000..5de0a28895b268 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_92349/Runtime_92349.cs @@ -0,0 +1,29 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.Intrinsics.X86; +using System.Runtime.Intrinsics; +using System.Runtime.CompilerServices; +using System.Threading; +using Xunit; + +public static class Runtime_92349 +{ + [MethodImpl(MethodImplOptions.AggressiveOptimization)] + unsafe static void Test(byte* pValue) + { + *pValue = (byte)Sse2.ConvertToInt32(Vector128.Create(-10, 0, 0, 0)); + } + + [Fact] + public unsafe static void EntryPoint() + { + if (Sse2.IsSupported) + { + ulong value = 0; + Test((byte*)Unsafe.AsPointer(ref value)); + Assert.True(value == 246); + } + } +} diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_92349/Runtime_92349.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_92349/Runtime_92349.csproj new file mode 100644 index 00000000000000..6bb210527e0797 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_92349/Runtime_92349.csproj @@ -0,0 +1,9 @@ + + + True + True + + + + + \ No newline at end of file diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_92357/Runtime_92357.cs b/src/tests/JIT/Regression/JitBlue/Runtime_92357/Runtime_92357.cs new file mode 100644 index 00000000000000..4704441bacce6c --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_92357/Runtime_92357.cs @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.X86; +using Xunit; + +public static class Runtime_92357 +{ + [Fact] + [MethodImpl(MethodImplOptions.AggressiveOptimization)] + public static void Problem() + { + if (!Avx2.IsSupported) + { + return; + } + + int y1 = 5; + + Vector256 actual1 = Test1(Vector256.Create((short)1), ref y1); + Vector256 expected1 = Vector256.Create(10, 0, 10, 0, 10, 0, 10, 0, 10, 0, 10, 0, 10, 0, 10, 0); + + Assert.Equal(expected1, actual1); + + long y2 = 5; + + Vector256 actual2 = Test2(Vector256.Create((int)1), ref y2); + Vector256 expected2 = Vector256.Create(10, 0, 10, 0, 10, 0, 10, 0); + + Assert.Equal(expected2, actual2); + } + + [MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.AggressiveOptimization)] + public static Vector256 Test1(Vector256 x, ref int y) + { + return Avx2.MultiplyLow(x + x, Vector256.Create(y).AsInt16()); + } + + [MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.AggressiveOptimization)] + public static Vector256 Test2(Vector256 x, ref long y) + { + return Avx2.MultiplyLow(x + x, Vector256.Create(y).AsInt32()); + } +} diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_92357/Runtime_92357.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_92357/Runtime_92357.csproj new file mode 100644 index 00000000000000..15edd99711a1a4 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_92357/Runtime_92357.csproj @@ -0,0 +1,8 @@ + + + True + + + + + \ No newline at end of file diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_92590/Runtime_92590.cs b/src/tests/JIT/Regression/JitBlue/Runtime_92590/Runtime_92590.cs new file mode 100644 index 00000000000000..99a5ef2ee5d18d --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_92590/Runtime_92590.cs @@ -0,0 +1,63 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; +using Xunit; + +public class Runtime_92590 +{ + [Fact] + public static void TestEntryPoint() + { + Span bytes = stackalloc byte[4]; + bytes.Fill(0xff); + TestByteByte(ref bytes[0], 0, Vector256.Create((byte)1)); + + Assert.True(bytes.SequenceEqual(stackalloc byte[] { 0x2, 0xff, 0xff, 0xff })); + + bytes.Fill(0xff); + TestByteInt(ref bytes[0], 0, Vector256.Create(1)); + + Assert.True(bytes.SequenceEqual(stackalloc byte[] { 0x2, 0xff, 0xff, 0xff })); + + int i = int.MaxValue; + TestIntByte(ref i, 0, Vector256.Create((byte)1)); + + Assert.Equal(2, i); + + i = int.MaxValue; + TestIntInt(ref i, 0, Vector256.Create(1)); + + Assert.Equal(2, i); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static void TestByteByte(ref byte b, int x, Vector256 vin) + { + Vector256 v = vin + vin; + Unsafe.Add(ref b, x) = v[3]; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static void TestByteInt(ref byte b, int x, Vector256 vin) + { + Vector256 v = vin + vin; + Unsafe.Add(ref b, x) = (byte)v[3]; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static void TestIntByte(ref int b, int x, Vector256 vin) + { + Vector256 v = vin + vin; + Unsafe.Add(ref b, x) = v[3]; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static void TestIntInt(ref int b, int x, Vector256 vin) + { + Vector256 v = vin + vin; + Unsafe.Add(ref b, x) = v[3]; + } +} diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_92590/Runtime_92590.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_92590/Runtime_92590.csproj new file mode 100644 index 00000000000000..15edd99711a1a4 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_92590/Runtime_92590.csproj @@ -0,0 +1,8 @@ + + + True + + + + + \ No newline at end of file diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_93342/Runtime_93342.cs b/src/tests/JIT/Regression/JitBlue/Runtime_93342/Runtime_93342.cs new file mode 100644 index 00000000000000..6a68d7f5650797 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_93342/Runtime_93342.cs @@ -0,0 +1,104 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; +using Xunit; + +public class Runtime_93342 +{ + private int foo; + private int bar; + private int baz; + + [Fact] + public static void TestEntryPoint() + { + new Runtime_93342().Run(); + } + + [MethodImpl(MethodImplOptions.AggressiveOptimization)] + private void Run() + { + if (foo == 1) + { + bar += 11; + baz += 11; + } + if (foo == 2) + bar += 12; + if (foo == 3) + bar += 13; + if (foo == 4) + bar += 14; + if (foo == 5) + bar += 15; + if (foo == 6) + bar += 16; + if (foo == 7) + bar += 17; + if (foo == 8) + bar += 18; + if (foo == 9) + bar += 19; + if (foo == 10) + bar += 20; + if (foo == 11) + bar += 21; + if (foo == 12) + bar += 22; + if (foo == 13) + bar += 23; + if (foo == 14) + bar += 24; + if (foo == 15) + bar += 25; + if (foo == 16) + bar += 26; + if (foo == 17) + bar += 27; + if (foo == 18) + bar += 28; + if (foo == 19) + bar += 29; + if (foo == 20) + bar += 30; + if (foo == 21) + bar += 31; + if (foo == 22) + bar += 32; + if (foo == 23) + bar += 33; + if (foo == 24) + bar += 34; + if (foo == 25) + bar += 35; + if (foo == 26) + bar += 36; + if (foo == 27) + bar += 37; + if (foo == 28) + bar += 38; + if (foo == 29) + bar += 39; + if (foo == 30) + bar += 40; + if (foo == 31) + bar += 41; + if (foo == 32) + bar += 42; + if (foo == 33) + bar += 43; + if (foo == 34) + bar += 44; + if (foo == 35) + bar += 45; + if (foo == 36) + bar += 46; + if (foo == 37) + bar += 47; + + bar = baz; + } +} diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_93342/Runtime_93342.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_93342/Runtime_93342.csproj new file mode 100644 index 00000000000000..15edd99711a1a4 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_93342/Runtime_93342.csproj @@ -0,0 +1,8 @@ + + + True + + + + + \ No newline at end of file diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_93650/Runtime_93650.cs b/src/tests/JIT/Regression/JitBlue/Runtime_93650/Runtime_93650.cs new file mode 100644 index 00000000000000..b424dd6c2f3c98 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_93650/Runtime_93650.cs @@ -0,0 +1,42 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.CompilerServices; +using System.Text; +using Xunit; + +public struct Holder +{ + internal StringBuilder.AppendInterpolatedStringHandler _h; + public Holder() => _h = new(0, 0, new()); + + internal StringBuilder GetBuilder() => Unsafe.As(ref _h); +} + +public static class Runtime_93650 +{ + static int N = 1; + + [Fact] + public static int Problem() + { + var sb = new Holder(); + for (int i = 0; i < N; i++) + { + var s = Bind(ref sb); + if (s.Length != 0) + { + Console.WriteLine("FAILED: StringBuilder.ToString() returned: " + s); + return -1; + } + } + + return 100; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static string Bind(ref Holder parameters) => GetString(parameters.GetBuilder()); + + public static string GetString(StringBuilder sb) => sb.ToString(); +} diff --git a/src/tests/JIT/Regression/JitBlue/Runtime_93650/Runtime_93650.csproj b/src/tests/JIT/Regression/JitBlue/Runtime_93650/Runtime_93650.csproj new file mode 100644 index 00000000000000..15edd99711a1a4 --- /dev/null +++ b/src/tests/JIT/Regression/JitBlue/Runtime_93650/Runtime_93650.csproj @@ -0,0 +1,8 @@ + + + True + + + + + \ No newline at end of file diff --git a/src/tests/Loader/classloader/StaticVirtualMethods/NegativeTestCases/MethodBodyOnUnrelatedType.ilproj b/src/tests/Loader/classloader/StaticVirtualMethods/NegativeTestCases/MethodBodyOnUnrelatedType.ilproj index 397521a4d40047..294b3c3a66827e 100644 --- a/src/tests/Loader/classloader/StaticVirtualMethods/NegativeTestCases/MethodBodyOnUnrelatedType.ilproj +++ b/src/tests/Loader/classloader/StaticVirtualMethods/NegativeTestCases/MethodBodyOnUnrelatedType.ilproj @@ -4,6 +4,9 @@ false + + + true Full diff --git a/src/tests/Loader/classloader/regressions/GitHub_93597/GitHub_93597.cs b/src/tests/Loader/classloader/regressions/GitHub_93597/GitHub_93597.cs new file mode 100644 index 00000000000000..122ec91663b37b --- /dev/null +++ b/src/tests/Loader/classloader/regressions/GitHub_93597/GitHub_93597.cs @@ -0,0 +1,43 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; + +public class ReproGH93597 { + public static int Main() { + var expected = new int[] {5,4,3,2,1}; + + const int LowerBound = 5; + + var expectedNzlba = NonZeroLowerBoundArray(expected, LowerBound); + + return Helper(expectedNzlba); + return 100; + } + [MethodImpl(MethodImplOptions.NoInlining)] + private static int Helper(Array a) { + IEnumerable ie = null; + try { + ie = (IEnumerable)a; + } catch (InvalidCastException) { + Console.WriteLine ("caught ICE, good"); + return 100; + } + ie.GetEnumerator(); // mono crashes here + return 101; + } + + + private static Array NonZeroLowerBoundArray(Array szArrayContents, int lowerBound) + { + Array array = Array.CreateInstance(szArrayContents.GetType().GetElementType(), new int[] { szArrayContents.Length }, new int[] { lowerBound }); + for (int i = 0; i < szArrayContents.Length; i++) + { + array.SetValue(szArrayContents.GetValue(i), i + lowerBound); + } + return array; + } + +} + diff --git a/src/tests/Loader/classloader/regressions/GitHub_93597/GitHub_93597.csproj b/src/tests/Loader/classloader/regressions/GitHub_93597/GitHub_93597.csproj new file mode 100644 index 00000000000000..a6b761d37bc58b --- /dev/null +++ b/src/tests/Loader/classloader/regressions/GitHub_93597/GitHub_93597.csproj @@ -0,0 +1,5 @@ + + + + + diff --git a/src/tests/issues.targets b/src/tests/issues.targets index d950862b7fdbf2..5c9868e334b088 100644 --- a/src/tests/issues.targets +++ b/src/tests/issues.targets @@ -1162,6 +1162,7 @@ + diff --git a/src/tests/nativeaot/SmokeTests/UnitTests/Interfaces.cs b/src/tests/nativeaot/SmokeTests/UnitTests/Interfaces.cs index 7ad4f974bbb09c..e4b7f437382360 100644 --- a/src/tests/nativeaot/SmokeTests/UnitTests/Interfaces.cs +++ b/src/tests/nativeaot/SmokeTests/UnitTests/Interfaces.cs @@ -40,6 +40,7 @@ public static int Run() TestDefaultInterfaceVariance.Run(); TestVariantInterfaceOptimizations.Run(); TestSharedInterfaceMethods.Run(); + TestGenericAnalysis.Run(); TestCovariantReturns.Run(); TestDynamicInterfaceCastable.Run(); TestStaticInterfaceMethodsAnalysis.Run(); @@ -653,6 +654,54 @@ public static void Run() } } + class TestGenericAnalysis + { + interface IInterface + { + string Method(object p); + } + + interface IInterface + { + string Method(T p); + } + + class C1 : IInterface, IInterface + { + public string Method(object p) => "Method(object)"; + public string Method(T p) => "Method(T)"; + } + + class C2 : IInterface, IInterface + { + public string Method(object p) => "Method(object)"; + public string Method(T p) => "Method(T)"; + } + + class C3 : IInterface, IInterface + { + public string Method(object p) => "Method(object)"; + public string Method(T p) => "Method(T)"; + } + + static IInterface s_c1 = new C1(); + static IInterface s_c2 = new C2(); + static IInterface s_c3a = new C3(); + static IInterface s_c3b = new C3(); + + public static void Run() + { + if (s_c1.Method(null) != "Method(object)") + throw new Exception(); + if (s_c2.Method(null) != "Method(T)") + throw new Exception(); + if (s_c3a.Method(null) != "Method(T)") + throw new Exception(); + if (s_c3b.Method(null) != "Method(object)") + throw new Exception(); + } + } + class TestCovariantReturns { interface IFoo diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases.Expectations/Assertions/BaseInAssemblyAttribute.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases.Expectations/Assertions/BaseInAssemblyAttribute.cs index 1d8ed24b3645c1..5ca7d2eeef5e4f 100644 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases.Expectations/Assertions/BaseInAssemblyAttribute.cs +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases.Expectations/Assertions/BaseInAssemblyAttribute.cs @@ -5,5 +5,11 @@ namespace Mono.Linker.Tests.Cases.Expectations.Assertions { public abstract class BaseInAssemblyAttribute : BaseExpectedLinkedBehaviorAttribute { + /// + /// By default the behavior should be preserved by all platforms + /// This property can override that by setting only the platforms + /// which are expected to preserve the desired behavior. + /// + public Tool Tool { get; set; } = Tool.TrimmerAnalyzerAndNativeAot; } } diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/CanLinkPublicApisOfLibrary.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/CanLinkPublicApisOfLibrary.cs index 33353a8f16bd02..c1148c3c9675a9 100644 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/CanLinkPublicApisOfLibrary.cs +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/CanLinkPublicApisOfLibrary.cs @@ -3,6 +3,8 @@ namespace Mono.Linker.Tests.Cases.Libraries { + [IgnoreTestCase("NativeAOT doesn't implement library trimming the same way", IgnoredBy = Tool.NativeAot)] + [KeptAttributeAttribute (typeof (IgnoreTestCaseAttribute), By = Tool.Trimmer)] [SetupLinkerLinkPublicAndFamily] [SetupCompileAsLibrary] [Kept] @@ -37,4 +39,4 @@ private void UnusedPrivateMethod () { } } -} \ No newline at end of file +} diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/CopyUsedAssemblyWithMainEntryRoot.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/CopyUsedAssemblyWithMainEntryRoot.cs index 1aa6bdace5fde1..bd410c15382034 100644 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/CopyUsedAssemblyWithMainEntryRoot.cs +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/CopyUsedAssemblyWithMainEntryRoot.cs @@ -4,6 +4,8 @@ namespace Mono.Linker.Tests.Cases.Libraries { + [IgnoreTestCase ("NativeAOT doesn't implement copy used behavior", IgnoredBy = Tool.NativeAot)] + [KeptAttributeAttribute (typeof (IgnoreTestCaseAttribute), By = Tool.Trimmer)] [Kept] [KeptMember (".ctor()")] [SetupLinkerAction ("copyused", "test")] @@ -30,4 +32,4 @@ private void UnusedPrivateMethod () CopyUsedAssemblyWithMainEntryRoot_Lib.Unused (); } } -} \ No newline at end of file +} diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/CopyUsedAssemblyWithPublicRoots.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/CopyUsedAssemblyWithPublicRoots.cs index b3e7bca6b401a7..e4558ef5e18ad2 100644 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/CopyUsedAssemblyWithPublicRoots.cs +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/CopyUsedAssemblyWithPublicRoots.cs @@ -3,6 +3,8 @@ namespace Mono.Linker.Tests.Cases.Libraries { + [IgnoreTestCase ("NativeAOT doesn't implement copy used behavior", IgnoredBy = Tool.NativeAot)] + [KeptAttributeAttribute (typeof (IgnoreTestCaseAttribute), By = Tool.Trimmer)] [Kept] [KeptMember (".ctor()")] [SetupLinkerAction ("copyused", "test")] @@ -24,4 +26,4 @@ private void UnusedPrivateMethod () { } } -} \ No newline at end of file +} diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/DefaultLibraryLinkBehavior.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/DefaultLibraryLinkBehavior.cs index fd0c557816221a..93ab6940440ae6 100644 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/DefaultLibraryLinkBehavior.cs +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/DefaultLibraryLinkBehavior.cs @@ -3,6 +3,8 @@ namespace Mono.Linker.Tests.Cases.Libraries { + [IgnoreTestCase ("NativeAOT doesn't implement library trimming the same way", IgnoredBy = Tool.NativeAot)] + [KeptAttributeAttribute (typeof (IgnoreTestCaseAttribute), By = Tool.Trimmer)] [SetupCompileAsLibrary] [SetupLinkerArgument ("-a", "test.dll")] [Kept] @@ -26,4 +28,4 @@ private void UnusedPrivateMethod () { } } -} \ No newline at end of file +} diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/Dependencies/UserAssemblyActionWorks_ChildLib.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/Dependencies/UserAssemblyActionWorks_ChildLib.cs new file mode 100644 index 00000000000000..f05284ac282d2b --- /dev/null +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/Dependencies/UserAssemblyActionWorks_ChildLib.cs @@ -0,0 +1,28 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Mono.Linker.Tests.Cases.Libraries.Dependencies +{ + public abstract class UserAssemblyActionWorks_ChildLib + { + public abstract void MustOverride (); + + public static void ChildUnusedMethod (InputType input) { } + + private static void ChildUnusedPrivateMethod () { } + + public void ChildUnusedInstanceMethod () { } + + public int UnusedProperty { get; set; } + + public static int UnusedField; + } + + public class InputType { } +} diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/Dependencies/UserAssemblyActionWorks_Lib.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/Dependencies/UserAssemblyActionWorks_Lib.cs index d46e34942d2fa7..16a86c9f112fb3 100644 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/Dependencies/UserAssemblyActionWorks_Lib.cs +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/Dependencies/UserAssemblyActionWorks_Lib.cs @@ -1,7 +1,9 @@ namespace Mono.Linker.Tests.Cases.Libraries.Dependencies { - public class UserAssemblyActionWorks_Lib + public class UserAssemblyActionWorks_Lib : UserAssemblyActionWorks_ChildLib { + public override void MustOverride () { } + public static void Used () { } @@ -10,4 +12,4 @@ public static void Unused () { } } -} \ No newline at end of file +} diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/LibraryWithUnresolvedInterfaces.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/LibraryWithUnresolvedInterfaces.cs index d7f015495350f1..d97b0ee0c515d4 100644 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/LibraryWithUnresolvedInterfaces.cs +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/LibraryWithUnresolvedInterfaces.cs @@ -7,6 +7,8 @@ namespace Mono.Linker.Tests.Cases.Libraries { + [IgnoreTestCase ("NativeAOT doesn't implement library trimming the same way", IgnoredBy = Tool.NativeAot)] + [KeptAttributeAttribute (typeof (IgnoreTestCaseAttribute), By = Tool.Trimmer)] [SetupCompileBefore ("copylibrary.dll", new[] { "Dependencies/CopyLibrary.cs" }, removeFromLinkerInput: true)] [SetupLinkerArgument ("--skip-unresolved", "true")] [SetupLinkerArgument ("-a", "test.exe", "library")] diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/RootLibrary.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/RootLibrary.cs index 1317ee910d1b18..34c38504d088b0 100644 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/RootLibrary.cs +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/RootLibrary.cs @@ -11,6 +11,8 @@ namespace Mono.Linker.Tests.Cases.Libraries { + [IgnoreTestCase ("NativeAOT doesn't implement library trimming the same way", IgnoredBy = Tool.NativeAot)] + [KeptAttributeAttribute (typeof (IgnoreTestCaseAttribute), By = Tool.Trimmer)] [SetupCompileBefore ("copylibrary.dll", new[] { "Dependencies/CopyLibrary.cs" })] [SetupLinkerAction ("copy", "copylibrary")] [SetupLinkerArgument ("-a", "test.exe", "library")] diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/RootLibraryInternalsWithIVT.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/RootLibraryInternalsWithIVT.cs index 4e0fa2deb509b6..6f3e2aaac459d5 100644 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/RootLibraryInternalsWithIVT.cs +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/RootLibraryInternalsWithIVT.cs @@ -10,6 +10,8 @@ namespace Mono.Linker.Tests.Cases.Libraries { + [IgnoreTestCase ("NativeAOT doesn't implement library trimming the same way", IgnoredBy = Tool.NativeAot)] + [KeptAttributeAttribute (typeof (IgnoreTestCaseAttribute), By = Tool.Trimmer)] [Kept] [KeptMember (".ctor()")] [SetupLinkerLinkPublicAndFamily] diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/RootLibraryVisibleAndDescriptor.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/RootLibraryVisibleAndDescriptor.cs index 0f3309358c9549..15c63b08f95aca 100644 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/RootLibraryVisibleAndDescriptor.cs +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/RootLibraryVisibleAndDescriptor.cs @@ -4,6 +4,8 @@ namespace Mono.Linker.Tests.Cases.Libraries { + [IgnoreTestCase ("NativeAOT doesn't implement library trimming the same way", IgnoredBy = Tool.NativeAot)] + [KeptAttributeAttribute (typeof (IgnoreTestCaseAttribute), By = Tool.Trimmer)] [Kept] [KeptMember (".ctor()")] [SetupLinkerLinkPublicAndFamily] diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/RootLibraryVisibleForwarders.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/RootLibraryVisibleForwarders.cs index 09e3fed90f879e..3e02e47b508d2a 100644 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/RootLibraryVisibleForwarders.cs +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/RootLibraryVisibleForwarders.cs @@ -8,6 +8,9 @@ namespace Mono.Linker.Tests.Cases.Libraries { + [IgnoreTestCase ("NativeAOT doesn't implement library trimming the same way", IgnoredBy = Tool.NativeAot)] + [KeptAttributeAttribute (typeof (IgnoreTestCaseAttribute), By = Tool.Trimmer)] + [SetupCompileBefore ("library.dll", new[] { "Dependencies/RootLibraryVisibleForwarders_Lib.cs" })] [SetupLinkerLinkPublicAndFamily] [Define ("RootLibraryVisibleForwarders")] diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/RootLibraryVisibleForwardersWithoutReference.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/RootLibraryVisibleForwardersWithoutReference.cs index bf566ead64abd7..137bd15fa357cd 100644 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/RootLibraryVisibleForwardersWithoutReference.cs +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/RootLibraryVisibleForwardersWithoutReference.cs @@ -8,6 +8,9 @@ namespace Mono.Linker.Tests.Cases.Libraries { + [IgnoreTestCase ("NativeAOT doesn't implement library trimming the same way", IgnoredBy = Tool.NativeAot)] + [KeptAttributeAttribute (typeof (IgnoreTestCaseAttribute), By = Tool.Trimmer)] + [SetupCompileBefore ("library.dll", new[] { "Dependencies/RootLibraryVisibleForwarders_Lib.cs" }, outputSubFolder: "isolated")] [SetupLinkerLinkPublicAndFamily] [SetupLinkerArgument ("-a", "isolated/library.dll", "visible")] // Checks for no-eager exported type resolving diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/UserAssemblyActionWorks.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/UserAssemblyActionWorks.cs index ee367654a71951..21cabeece3775e 100644 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/UserAssemblyActionWorks.cs +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/Libraries/UserAssemblyActionWorks.cs @@ -7,10 +7,23 @@ namespace Mono.Linker.Tests.Cases.Libraries /// /// We have to check another assembly because the test exe is included with -a and that will cause it to be linked /// - [SetupLinkerDefaultAction ("copy")] - [SetupCompileBefore ("lib.dll", new[] { "Dependencies/UserAssemblyActionWorks_Lib.cs" })] - [KeptAllTypesAndMembersInAssembly ("lib.dll")] + [SetupCompileBefore ("childlib.dll", new[] { "Dependencies/UserAssemblyActionWorks_ChildLib.cs" })] + [SetupCompileBefore ("lib.dll", new[] { "Dependencies/UserAssemblyActionWorks_Lib.cs" }, new[] { "childlib.dll" })] + [SetupLinkerAction ("link", "childlib")] + [SetupLinkerAction ("copy", "lib")] [SetupLinkerAction ("link", "test")] + + [KeptAllTypesAndMembersInAssembly ("lib.dll")] + [KeptTypeInAssembly("childlib", "Mono.Linker.Tests.Cases.Libraries.Dependencies.UserAssemblyActionWorks_ChildLib")] + + [KeptMemberInAssembly ("childlib", "Mono.Linker.Tests.Cases.Libraries.Dependencies.UserAssemblyActionWorks_ChildLib", "MustOverride()")] + + [RemovedMemberInAssembly ("childlib", "Mono.Linker.Tests.Cases.Libraries.Dependencies.UserAssemblyActionWorks_ChildLib", "ChildUnusedMethod(Mono.Linker.Tests.Cases.Libraries.Dependencies.InputType)")] + [RemovedMemberInAssembly ("childlib", "Mono.Linker.Tests.Cases.Libraries.Dependencies.UserAssemblyActionWorks_ChildLib", "ChildUnusedPrivateMethod()")] + [RemovedMemberInAssembly ("childlib", "Mono.Linker.Tests.Cases.Libraries.Dependencies.UserAssemblyActionWorks_ChildLib", "ChildUnusedInstanceMethod()")] + [RemovedMemberInAssembly ("childlib", "Mono.Linker.Tests.Cases.Libraries.Dependencies.UserAssemblyActionWorks_ChildLib", "UnusedProperty")] + [RemovedMemberInAssembly ("childlib", "Mono.Linker.Tests.Cases.Libraries.Dependencies.UserAssemblyActionWorks_ChildLib", "UnusedField")] + [RemovedTypeInAssembly ("childlib", "Mono.Linker.Tests.Cases.Libraries.Dependencies.InputType")] public class UserAssemblyActionWorks { public static void Main () @@ -18,4 +31,4 @@ public static void Main () UserAssemblyActionWorks_Lib.Used (); } } -} \ No newline at end of file +} diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/LinkXml/CanPreserveAnExportedType.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/LinkXml/CanPreserveAnExportedType.cs index f3d82cc39d0e53..b9868f3f7c9ddd 100644 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases/LinkXml/CanPreserveAnExportedType.cs +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/LinkXml/CanPreserveAnExportedType.cs @@ -8,9 +8,10 @@ namespace Mono.Linker.Tests.Cases.LinkXml // Add another assembly in that uses the forwarder just to make things a little more complex [SetupCompileBefore ("Forwarder.dll", new[] { "Dependencies/CanPreserveAnExportedType_Forwarder.cs" }, references: new[] { "Library.dll" })] - [KeptMemberInAssembly ("Library.dll", typeof (CanPreserveAnExportedType_Library), "Field1", "Method()", ".ctor()")] + // NativeAOT doesn't have a concept of type forwarders in the compiled app, everything is fully resolved + [KeptMemberInAssembly ("Library.dll", typeof (CanPreserveAnExportedType_Library), "Field1", "Method()", ".ctor()", Tool = Tool.Trimmer)] + [KeptTypeInAssembly ("Forwarder.dll", typeof (CanPreserveAnExportedType_Library), Tool = Tool.Trimmer)] [SetupLinkerDescriptorFile ("CanPreserveAnExportedType.xml")] - [KeptTypeInAssembly ("Forwarder.dll", typeof (CanPreserveAnExportedType_Library))] class CanPreserveAnExportedType { public static void Main () diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/LinkXml/CanPreserveExportedTypesUsingRegex.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/LinkXml/CanPreserveExportedTypesUsingRegex.cs index 5ebb38bde4a6f3..7772e77be9d36d 100644 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases/LinkXml/CanPreserveExportedTypesUsingRegex.cs +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/LinkXml/CanPreserveExportedTypesUsingRegex.cs @@ -8,9 +8,10 @@ namespace Mono.Linker.Tests.Cases.LinkXml // Add another assembly in that uses the forwarder just to make things a little more complex [SetupCompileBefore ("Forwarder.dll", new[] { "Dependencies/CanPreserveAnExportedType_Forwarder.cs" }, references: new[] { "Library.dll" })] - [KeptMemberInAssembly ("Library.dll", typeof (CanPreserveAnExportedType_Library), "Field1", "Method()", ".ctor()")] + // NativeAOT doesn't have a concept of type forwarders in the compiled app, everything is fully resolved + [KeptMemberInAssembly ("Library.dll", typeof (CanPreserveAnExportedType_Library), "Field1", "Method()", ".ctor()", Tool = Tool.Trimmer)] + [KeptTypeInAssembly ("Forwarder.dll", typeof (CanPreserveAnExportedType_Library), Tool = Tool.Trimmer)] [SetupLinkerDescriptorFile ("CanPreserveExportedTypesUsingRegex.xml")] - [KeptTypeInAssembly ("Forwarder.dll", typeof (CanPreserveAnExportedType_Library))] class CanPreserveExportedTypesUsingRegex { public static void Main () diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/LinkXml/EmbeddedLinkXmlFromCopyAssemblyIsProcessed.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/LinkXml/EmbeddedLinkXmlFromCopyAssemblyIsProcessed.cs index 96d7e941f6a45a..1f5f9456981462 100644 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases/LinkXml/EmbeddedLinkXmlFromCopyAssemblyIsProcessed.cs +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/LinkXml/EmbeddedLinkXmlFromCopyAssemblyIsProcessed.cs @@ -12,8 +12,10 @@ namespace Mono.Linker.Tests.Cases.LinkXml [IgnoreDescriptors (false)] [SetupLinkerAction ("copy", "CopyLibrary")] - [KeptTypeInAssembly ("CopyLibrary.dll", typeof (CopyLibrary))] - [KeptTypeInAssembly ("Library.dll", typeof (OtherLibrary))] + // NativeAOT doesn't support reading embedded descriptors from a resource called "AssemblyName" + // It only supports "ILLink.Descriptor.xml" name + [KeptTypeInAssembly ("CopyLibrary.dll", typeof (CopyLibrary), Tool = Tool.Trimmer)] + [KeptTypeInAssembly ("Library.dll", typeof (OtherLibrary), Tool = Tool.Trimmer)] public class EmbeddedLinkXmlFromCopyAssemblyIsProcessed { public static void Main () @@ -22,4 +24,4 @@ public static void Main () tmp.Method (); } } -} \ No newline at end of file +} diff --git a/src/tools/illink/test/Mono.Linker.Tests.Cases/LinkXml/UsedNonRequiredExportedTypeIsKeptWhenRooted.cs b/src/tools/illink/test/Mono.Linker.Tests.Cases/LinkXml/UsedNonRequiredExportedTypeIsKeptWhenRooted.cs index 96630532a60b37..c8668aaa0c9a88 100644 --- a/src/tools/illink/test/Mono.Linker.Tests.Cases/LinkXml/UsedNonRequiredExportedTypeIsKeptWhenRooted.cs +++ b/src/tools/illink/test/Mono.Linker.Tests.Cases/LinkXml/UsedNonRequiredExportedTypeIsKeptWhenRooted.cs @@ -3,6 +3,9 @@ namespace Mono.Linker.Tests.Cases.LinkXml { + [IgnoreTestCase ("NativeAOT doesn't implement 'visible' rooting behavior", IgnoredBy = Tool.NativeAot)] + [KeptAttributeAttribute (typeof (IgnoreTestCaseAttribute), By = Tool.Trimmer)] + [SetupLinkerDescriptorFile ("UsedNonRequiredExportedTypeIsKeptWhenRooted.xml")] [SetupLinkerArgument ("-a", "libfwd.dll", "visible")] @@ -21,4 +24,4 @@ public static void Main () var tmp = typeof (UsedNonRequiredExportedTypeIsKeptWhenRooted_Used).ToString (); } } -} \ No newline at end of file +} diff --git a/src/tools/illink/test/Mono.Linker.Tests/TestCasesRunner/AssemblyChecker.cs b/src/tools/illink/test/Mono.Linker.Tests/TestCasesRunner/AssemblyChecker.cs index 95888eaa6a233b..17c94b3998cc5f 100644 --- a/src/tools/illink/test/Mono.Linker.Tests/TestCasesRunner/AssemblyChecker.cs +++ b/src/tools/illink/test/Mono.Linker.Tests/TestCasesRunner/AssemblyChecker.cs @@ -957,7 +957,6 @@ protected static IEnumerable GetExpectedAttributes (ICustomAttributeProv foreach (var additionalExpectedAttributesFromFixedField in GetCustomAttributeCtorValues (fixedField, nameof (KeptAttributeOnFixedBufferTypeAttribute))) yield return additionalExpectedAttributesFromFixedField.ToString (); - } } diff --git a/src/tools/illink/test/Mono.Linker.Tests/TestCasesRunner/ResultChecker.cs b/src/tools/illink/test/Mono.Linker.Tests/TestCasesRunner/ResultChecker.cs index 641036229a215a..692985443ae6b3 100644 --- a/src/tools/illink/test/Mono.Linker.Tests/TestCasesRunner/ResultChecker.cs +++ b/src/tools/illink/test/Mono.Linker.Tests/TestCasesRunner/ResultChecker.cs @@ -746,7 +746,7 @@ void VerifyKeptAllTypesAndMembersInAssembly (AssemblyDefinition linked) var missingInLinked = originalTypes.Keys.Except (linkedTypes.Keys); - Assert.That (missingInLinked, Is.Empty, $"Expected all types to exist in the linked assembly, but one or more were missing"); + Assert.That (missingInLinked, Is.Empty, $"Expected all types to exist in the linked assembly {linked.Name}, but one or more were missing"); foreach (var originalKvp in originalTypes) { var linkedType = linkedTypes[originalKvp.Key]; @@ -1162,6 +1162,11 @@ Dictionary> BuildOtherAssemblyCheckTable (Assembly foreach (var typeWithRemoveInAssembly in original.AllDefinedTypes ()) { foreach (var attr in typeWithRemoveInAssembly.CustomAttributes.Where (IsTypeInOtherAssemblyAssertion)) { var assemblyName = (string) attr.ConstructorArguments[0].Value; + + Tool? toolTarget = (Tool?) (int?) attr.GetPropertyValue ("Tool"); + if (toolTarget is not null && !toolTarget.Value.HasFlag (Tool.Trimmer)) + continue; + if (!checks.TryGetValue (assemblyName, out List checksForAssembly)) checks[assemblyName] = checksForAssembly = new List ();