1
0
Fork 0
mirror of https://github.com/VSadov/Satori.git synced 2025-06-10 18:11:04 +09:00

[release/9.0-rc2] NRBF Fuzzer and bug fixes (#107788)

* [NRBF] Don't use Unsafe.As when decoding DateTime(s) (#105749)

* Add NrbfDecoder Fuzzer (#107385)

* [NRBF] Fix bugs discovered by the fuzzer (#107368)

* bug #1: don't allow for values out of the SerializationRecordType enum range

* bug #2: throw SerializationException rather than KeyNotFoundException when the referenced record is missing or it points to a record of different type

* bug #3: throw SerializationException rather than FormatException when it's being thrown by BinaryReader (or sth else that we use)

* bug #4: document the fact that IOException can be thrown

* bug #5: throw SerializationException rather than OverflowException when parsing the decimal fails

* bug #6: 0 and 17 are illegal values for PrimitiveType enum

* bug #7: throw SerializationException when a surrogate character is read (so far an ArgumentException was thrown)
# Conflicts:
#	src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NrbfDecoder.cs

* [NRBF] throw SerializationException when a surrogate character is read (#107532)

 (so far an ArgumentException was thrown)

* [NRBF] Fuzzing non-seekable stream input (#107605)

* [NRBF] More bug fixes (#107682)

- Don't use `Debug.Fail` not followed by an exception (it may cause problems for apps deployed in Debug)
- avoid Int32 overflow
- throw for unexpected enum values just in case parsing has not rejected them
- validate the number of chars read by BinaryReader.ReadChars
- pass serialization record id to ex message
- return false rather than throw EndOfStreamException when provided Stream has not enough data
- don't restore the position in finally 
- limit max SZ and MD array length to Array.MaxLength, stop using LinkedList<T> as List<T> will be able to hold all elements now
- remove internal enum values that were always illegal, but needed to be handled everywhere
- Fix DebuggerDisplay

* [NRBF] Comments and bug fixes from internal code review (#107735)

* copy comments and asserts from Levis internal code review

* apply Levis suggestion: don't store Array.MaxLength as a const, as it may change in the future

* add missing and fix some of the existing comments

* first bug fix: SerializationRecord.TypeNameMatches should throw ArgumentNullException for null Type argument

* second bug fix: SerializationRecord.TypeNameMatches should know the difference between SZArray and single-dimension, non-zero offset arrays (example: int[] and int[*])

* third bug fix: don't cast bytes to booleans

* fourth bug fix: don't cast bytes to DateTimes

* add one test case that I've forgot in previous PR
# Conflicts:
#	src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializationRecord.cs

* [NRBF] Address issues discovered by Threat Model  (#106629)

* introduce ArrayRecord.FlattenedLength

* do not include invalid Type or Assembly names in the exception messages, as it's most likely corrupted/tampered/malicious data and could be used as a vector of attack.

* It is possible to have binary array records have an element type of array without being marked as jagged

---------

Co-authored-by: Buyaa Namnan <bunamnan@microsoft.com>
This commit is contained in:
Adam Sitnik 2024-09-18 01:06:23 +02:00 committed by GitHub
parent fc781c3e2e
commit fde8a3b8ea
Signed by: github
GPG key ID: B5690EEEBB952194
54 changed files with 1179 additions and 239 deletions

View file

@ -97,6 +97,14 @@ extends:
SYSTEM_ACCESSTOKEN: $(System.AccessToken)
displayName: Send JsonDocumentFuzzer to OneFuzz
- task: onefuzz-task@0
inputs:
onefuzzOSes: 'Windows'
env:
onefuzzDropDirectory: $(fuzzerProject)/deployment/NrbfDecoderFuzzer
SYSTEM_ACCESSTOKEN: $(System.AccessToken)
displayName: Send NrbfDecoderFuzzer to OneFuzz
- task: onefuzz-task@0
inputs:
onefuzzOSes: 'Windows'

View file

@ -18,6 +18,17 @@ internal static class Assert
throw new Exception($"Expected={expected} Actual={actual}");
}
public static void NotNull<T>(T value)
{
if (value == null)
{
ThrowNull();
}
static void ThrowNull() =>
throw new Exception("Value is null");
}
public static void SequenceEqual<T>(ReadOnlySpan<T> expected, ReadOnlySpan<T> actual)
{
if (!expected.SequenceEqual(actual))

File diff suppressed because one or more lines are too long

View file

@ -30,4 +30,8 @@
</None>
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\System.Formats.Nrbf\src\System.Formats.Nrbf.csproj" />
</ItemGroup>
</Project>

View file

@ -24,15 +24,15 @@ namespace DotnetFuzzing.Fuzzers
using PooledBoundedMemory<char> inputPoisonedBefore = PooledBoundedMemory<char>.Rent(chars, PoisonPagePlacement.Before);
using PooledBoundedMemory<char> inputPoisonedAfter = PooledBoundedMemory<char>.Rent(chars, PoisonPagePlacement.After);
Test(inputPoisonedBefore);
Test(inputPoisonedAfter);
Test(inputPoisonedBefore.Span);
Test(inputPoisonedAfter.Span);
}
private static void Test(PooledBoundedMemory<char> inputPoisoned)
private static void Test(Span<char> span)
{
if (AssemblyNameInfo.TryParse(inputPoisoned.Span, out AssemblyNameInfo? fromTryParse))
if (AssemblyNameInfo.TryParse(span, out AssemblyNameInfo? fromTryParse))
{
AssemblyNameInfo fromParse = AssemblyNameInfo.Parse(inputPoisoned.Span);
AssemblyNameInfo fromParse = AssemblyNameInfo.Parse(span);
Assert.Equal(fromTryParse.Name, fromParse.Name);
Assert.Equal(fromTryParse.FullName, fromParse.FullName);
@ -66,7 +66,7 @@ namespace DotnetFuzzing.Fuzzers
{
try
{
_ = AssemblyNameInfo.Parse(inputPoisoned.Span);
_ = AssemblyNameInfo.Parse(span);
}
catch (ArgumentException)
{

View file

@ -0,0 +1,126 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Buffers;
using System.Formats.Nrbf;
using System.Runtime.Serialization;
using System.Text;
namespace DotnetFuzzing.Fuzzers
{
internal sealed class NrbfDecoderFuzzer : IFuzzer
{
public string[] TargetAssemblies { get; } = ["System.Formats.Nrbf"];
public string[] TargetCoreLibPrefixes => [];
public string Dictionary => "nrbfdecoder.dict";
public void FuzzTarget(ReadOnlySpan<byte> bytes)
{
Test(bytes, PoisonPagePlacement.Before);
Test(bytes, PoisonPagePlacement.After);
}
private static void Test(ReadOnlySpan<byte> bytes, PoisonPagePlacement poisonPagePlacement)
{
using PooledBoundedMemory<byte> inputPoisoned = PooledBoundedMemory<byte>.Rent(bytes, poisonPagePlacement);
using MemoryStream seekableStream = new(inputPoisoned.Memory.ToArray());
Test(inputPoisoned.Span, seekableStream);
// NrbfDecoder has few code paths dedicated to non-seekable streams, let's test them as well.
using NonSeekableStream nonSeekableStream = new(inputPoisoned.Memory.ToArray());
Test(inputPoisoned.Span, nonSeekableStream);
}
private static void Test(Span<byte> testSpan, Stream stream)
{
if (NrbfDecoder.StartsWithPayloadHeader(testSpan))
{
try
{
SerializationRecord record = NrbfDecoder.Decode(stream, out IReadOnlyDictionary<SerializationRecordId, SerializationRecord> recordMap);
switch (record.RecordType)
{
case SerializationRecordType.ArraySingleObject:
SZArrayRecord<object?> arrayObj = (SZArrayRecord<object?>)record;
object?[] objArray = arrayObj.GetArray();
Assert.Equal(arrayObj.Length, objArray.Length);
Assert.Equal(1, arrayObj.Rank);
break;
case SerializationRecordType.ArraySingleString:
SZArrayRecord<string?> arrayString = (SZArrayRecord<string?>)record;
string?[] array = arrayString.GetArray();
Assert.Equal(arrayString.Length, array.Length);
Assert.Equal(1, arrayString.Rank);
Assert.Equal(true, arrayString.TypeNameMatches(typeof(string[])));
break;
case SerializationRecordType.ArraySinglePrimitive:
case SerializationRecordType.BinaryArray:
ArrayRecord arrayBinary = (ArrayRecord)record;
Assert.NotNull(arrayBinary.TypeName);
break;
case SerializationRecordType.BinaryObjectString:
_ = ((PrimitiveTypeRecord<string>)record).Value;
break;
case SerializationRecordType.ClassWithId:
case SerializationRecordType.ClassWithMembersAndTypes:
case SerializationRecordType.SystemClassWithMembersAndTypes:
ClassRecord classRecord = (ClassRecord)record;
Assert.NotNull(classRecord.TypeName);
foreach (string name in classRecord.MemberNames)
{
Assert.Equal(true, classRecord.HasMember(name));
}
break;
case SerializationRecordType.MemberPrimitiveTyped:
PrimitiveTypeRecord primitiveType = (PrimitiveTypeRecord)record;
Assert.NotNull(primitiveType.Value);
break;
case SerializationRecordType.MemberReference:
Assert.NotNull(record.TypeName);
break;
case SerializationRecordType.BinaryLibrary:
Assert.Equal(false, record.Id.Equals(default));
break;
case SerializationRecordType.ObjectNull:
case SerializationRecordType.ObjectNullMultiple:
case SerializationRecordType.ObjectNullMultiple256:
Assert.Equal(default, record.Id);
break;
case SerializationRecordType.MessageEnd:
case SerializationRecordType.SerializedStreamHeader:
// case SerializationRecordType.ClassWithMembers: will cause NotSupportedException
// case SerializationRecordType.SystemClassWithMembers: will cause NotSupportedException
default:
throw new Exception("Unexpected RecordType");
}
}
catch (SerializationException) { /* Reading from the stream encountered invalid NRBF data.*/ }
catch (NotSupportedException) { /* Reading from the stream encountered unsupported records */ }
catch (DecoderFallbackException) { /* Reading from the stream encountered an invalid UTF8 sequence. */ }
catch (EndOfStreamException) { /* The end of the stream was reached before reading SerializationRecordType.MessageEnd record. */ }
catch (IOException) { /* An I/O error occurred. */ }
}
else
{
try
{
NrbfDecoder.Decode(stream);
throw new Exception("Decoding supposed to fail!");
}
catch (SerializationException) { /* Everything has to start with a header */ }
catch (NotSupportedException) { /* Reading from the stream encountered unsupported records */ }
catch (EndOfStreamException) { /* The end of the stream was reached before reading SerializationRecordType.MessageEnd record. */ }
}
}
private class NonSeekableStream : MemoryStream
{
public NonSeekableStream(byte[] buffer) : base(buffer) { }
public override bool CanSeek => false;
}
}
}

View file

@ -3,8 +3,6 @@
using System.Buffers;
using System.Reflection.Metadata;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
using System.Text;
namespace DotnetFuzzing.Fuzzers
@ -55,7 +53,7 @@ namespace DotnetFuzzing.Fuzzers
try
{
TypeName.Parse(testSpan);
Assert.Equal(true, false); // should never succeed
throw new Exception("Parsing was supposed to fail!");
}
catch (ArgumentException) { }
catch (InvalidOperationException) { }

View file

@ -11,6 +11,7 @@ namespace System.Formats.Nrbf
internal ArrayRecord() { }
public override System.Formats.Nrbf.SerializationRecordId Id { get { throw null; } }
public abstract System.ReadOnlySpan<int> Lengths { get; }
public virtual long FlattenedLength { get; }
public int Rank { get { throw null; } }
[System.Diagnostics.CodeAnalysis.RequiresDynamicCode("The code for an array of the specified type might not be available.")]
public System.Array GetArray(System.Type expectedArrayType, bool allowNulls = true) { throw null; }

View file

@ -126,26 +126,23 @@
<data name="Serialization_UnexpectedNullRecordCount" xml:space="preserve">
<value>Unexpected Null Record count.</value>
</data>
<data name="Serialization_MaxArrayLength" xml:space="preserve">
<value>The serialized array length ({0}) was larger than the configured limit {1}.</value>
</data>
<data name="NotSupported_RecordType" xml:space="preserve">
<value>{0} Record Type is not supported by design.</value>
</data>
<data name="Serialization_InvalidReference" xml:space="preserve">
<value>Member reference was pointing to a record of unexpected type.</value>
<value>Invalid member reference.</value>
</data>
<data name="Serialization_InvalidTypeName" xml:space="preserve">
<value>Invalid type name: `{0}`.</value>
<value>Invalid type name.</value>
</data>
<data name="Serialization_TypeMismatch" xml:space="preserve">
<value>Expected the array to be of type {0}, but its element type was {1}.</value>
</data>
<data name="Serialization_InvalidTypeOrAssemblyName" xml:space="preserve">
<value>Invalid type or assembly name: `{0},{1}`.</value>
<value>Invalid type or assembly name.</value>
</data>
<data name="Serialization_DuplicateMemberName" xml:space="preserve">
<value>Duplicate member name: `{0}`.</value>
<value>Duplicate member name.</value>
</data>
<data name="Argument_NonSeekableStream" xml:space="preserve">
<value>Stream does not support seeking.</value>
@ -160,6 +157,12 @@
<value>Only arrays with zero offsets are supported.</value>
</data>
<data name="Serialization_InvalidAssemblyName" xml:space="preserve">
<value>Invalid assembly name: `{0}`.</value>
<value>Invalid assembly name.</value>
</data>
<data name="Serialization_InvalidFormat" xml:space="preserve">
<value>Invalid format.</value>
</data>
<data name="Serialization_SurrogateCharacter" xml:space="preserve">
<value>A surrogate character was read.</value>
</data>
</root>

View file

@ -3,6 +3,9 @@
namespace System.Formats.Nrbf;
// See [MS-NRBF] Sec. 2.7 for more information.
// https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/ca3ad2bc-777b-413a-a72a-9ba6ced76bc3
[Flags]
internal enum AllowedRecordTypes : uint
{

View file

@ -13,22 +13,26 @@ namespace System.Formats.Nrbf;
/// <remarks>
/// ArrayInfo structures are described in <see href="https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/8fac763f-e46d-43a1-b360-80eb83d2c5fb">[MS-NRBF] 2.4.2.1</see>.
/// </remarks>
[DebuggerDisplay("Length={Length}, {ArrayType}, rank={Rank}")]
[DebuggerDisplay("{ArrayType}, rank={Rank}")]
internal readonly struct ArrayInfo
{
internal const int MaxArrayLength = 2147483591; // Array.MaxLength
#if NET8_0_OR_GREATER
internal static int MaxArrayLength => Array.MaxLength; // dynamic lookup in case the value changes in a future runtime
#else
internal const int MaxArrayLength = 2147483591; // hardcode legacy Array.MaxLength for downlevel runtimes
#endif
internal ArrayInfo(SerializationRecordId id, long totalElementsCount, BinaryArrayType arrayType = BinaryArrayType.Single, int rank = 1)
{
Id = id;
TotalElementsCount = totalElementsCount;
FlattenedLength = totalElementsCount;
ArrayType = arrayType;
Rank = rank;
}
internal SerializationRecordId Id { get; }
internal long TotalElementsCount { get; }
internal long FlattenedLength { get; }
internal BinaryArrayType ArrayType { get; }
@ -36,8 +40,8 @@ internal readonly struct ArrayInfo
internal int GetSZArrayLength()
{
Debug.Assert(TotalElementsCount <= MaxArrayLength);
return (int)TotalElementsCount;
Debug.Assert(FlattenedLength <= MaxArrayLength);
return (int)FlattenedLength;
}
internal static ArrayInfo Decode(BinaryReader reader)
@ -47,7 +51,7 @@ internal readonly struct ArrayInfo
{
int length = reader.ReadInt32();
if (length is < 0 or > MaxArrayLength)
if (length < 0 || length > MaxArrayLength)
{
ThrowHelper.ThrowInvalidValue(length);
}

View file

@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.Reflection.Metadata;
using System.Formats.Nrbf.Utils;
using System.Diagnostics;
namespace System.Formats.Nrbf;
@ -54,6 +55,7 @@ internal sealed class ArrayOfClassesRecord : SZArrayRecord<ClassRecord>
}
int nullCount = ((NullsRecord)actual).NullCount;
Debug.Assert(nullCount > 0, "All implementations of NullsRecord are expected to return a positive value for NullCount.");
do
{
result[resultIndex++] = null;
@ -63,6 +65,8 @@ internal sealed class ArrayOfClassesRecord : SZArrayRecord<ClassRecord>
}
}
Debug.Assert(resultIndex == result.Length, "We should have traversed the entirety of the newly created array.");
return result;
}

View file

@ -18,7 +18,7 @@ public abstract class ArrayRecord : SerializationRecord
private protected ArrayRecord(ArrayInfo arrayInfo)
{
ArrayInfo = arrayInfo;
ValuesToRead = arrayInfo.TotalElementsCount;
ValuesToRead = arrayInfo.FlattenedLength;
}
/// <summary>
@ -27,6 +27,12 @@ public abstract class ArrayRecord : SerializationRecord
/// <value>A buffer of integers that represent the number of elements in every dimension.</value>
public abstract ReadOnlySpan<int> Lengths { get; }
/// <summary>
/// When overridden in a derived class, gets the total number of all elements in every dimension.
/// </summary>
/// <value>A number that represent the total number of all elements in every dimension.</value>
public virtual long FlattenedLength => ArrayInfo.FlattenedLength;
/// <summary>
/// Gets the rank of the array.
/// </summary>
@ -44,7 +50,12 @@ public abstract class ArrayRecord : SerializationRecord
internal long ValuesToRead { get; private protected set; }
private protected ArrayInfo ArrayInfo { get; }
internal ArrayInfo ArrayInfo { get; }
internal bool IsJagged
=> ArrayInfo.ArrayType == BinaryArrayType.Jagged
// It is possible to have binary array records have an element type of array without being marked as jagged.
|| TypeName.GetElementType().IsArray;
/// <summary>
/// Allocates an array and fills it with the data provided in the serialized records (in case of primitive types like <see cref="string"/> or <see cref="int"/>) or the serialized records themselves.

View file

@ -5,6 +5,7 @@ using System.Collections.Generic;
using System.IO;
using System.Reflection.Metadata;
using System.Formats.Nrbf.Utils;
using System.Diagnostics;
namespace System.Formats.Nrbf;
@ -33,13 +34,15 @@ internal sealed class ArraySingleObjectRecord : SZArrayRecord<object?>
{
object?[] values = new object?[Length];
for (int recordIndex = 0, valueIndex = 0; recordIndex < Records.Count; recordIndex++)
int valueIndex = 0;
for (int recordIndex = 0; recordIndex < Records.Count; recordIndex++)
{
SerializationRecord record = Records[recordIndex];
int nullCount = record is NullsRecord nullsRecord ? nullsRecord.NullCount : 0;
if (nullCount == 0)
{
// "new object[] { <SELF> }" is special cased because it allows for storing reference to itself.
values[valueIndex++] = record is MemberReferenceRecord referenceRecord && referenceRecord.Reference.Equals(Id)
? values // a reference to self, and a way to get StackOverflow exception ;)
: record.GetValue();
@ -59,6 +62,8 @@ internal sealed class ArraySingleObjectRecord : SZArrayRecord<object?>
while (nullCount > 0);
}
Debug.Assert(valueIndex == values.Length, "We should have traversed the entirety of the newly created array.");
return values;
}

View file

@ -41,17 +41,9 @@ internal sealed class ArraySinglePrimitiveRecord<T> : SZArrayRecord<T>
public override T[] GetArray(bool allowNulls = true)
=> (T[])(_arrayNullsNotAllowed ??= (Values is T[] array ? array : Values.ToArray()));
internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType()
{
Debug.Fail("GetAllowedRecordType should never be called on ArraySinglePrimitiveRecord");
throw new InvalidOperationException();
}
internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType() => throw new InvalidOperationException();
private protected override void AddValue(object value)
{
Debug.Fail("AddValue should never be called on ArraySinglePrimitiveRecord");
throw new InvalidOperationException();
}
private protected override void AddValue(object value) => throw new InvalidOperationException();
internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int count)
{
@ -61,8 +53,32 @@ internal sealed class ArraySinglePrimitiveRecord<T> : SZArrayRecord<T>
return (List<T>)(object)DecodeDecimals(reader, count);
}
// char[] has a unique representation in NRBF streams. Typical strings are transcoded
// to UTF-8 and prefixed with the number of bytes in the UTF-8 representation. char[]
// is also serialized as UTF-8, but it is instead prefixed with the number of chars
// in the UTF-16 representation, not the number of bytes in the UTF-8 representation.
// This number doesn't directly precede the UTF-8 contents in the NRBF stream; it's
// instead contained within the ArrayInfo structure (passed to this method as the
// 'count' argument).
//
// The practical consequence of this is that we don't actually know how many UTF-8
// bytes we need to consume in order to ensure we've read 'count' chars. We know that
// an n-length UTF-16 string turns into somewhere between [n .. 3n] UTF-8 bytes.
// The best we can do is that when reading an n-element char[], we'll ensure that
// there are at least n bytes remaining in the input stream. We'll still need to
// account for that even with this check, we might hit EOF before fully populating
// the char[]. But from a safety perspective, it does appropriately limit our
// allocations to be proportional to the amount of data present in the input stream,
// which is a sufficient defense against DoS.
long requiredBytes = count;
if (typeof(T) != typeof(char)) // the input is UTF8
if (typeof(T) == typeof(DateTime) || typeof(T) == typeof(TimeSpan))
{
// We can't assume DateTime as represented by the runtime is 8 bytes.
// The only assumption we can make is that it's 8 bytes on the wire.
requiredBytes *= 8;
}
else if (typeof(T) != typeof(char))
{
requiredBytes *= Unsafe.SizeOf<T>();
}
@ -85,7 +101,11 @@ internal sealed class ArraySinglePrimitiveRecord<T> : SZArrayRecord<T>
}
else if (typeof(T) == typeof(char))
{
return (T[])(object)reader.ReadChars(count);
return (T[])(object)reader.ParseChars(count);
}
else if (typeof(T) == typeof(TimeSpan) || typeof(T) == typeof(DateTime))
{
return DecodeTime(reader, count);
}
// It's safe to pre-allocate, as we have ensured there is enough bytes in the stream.
@ -94,7 +114,7 @@ internal sealed class ArraySinglePrimitiveRecord<T> : SZArrayRecord<T>
#if NET
reader.BaseStream.ReadExactly(resultAsBytes);
#else
byte[] bytes = ArrayPool<byte>.Shared.Rent(Math.Min(count * Unsafe.SizeOf<T>(), 256_000));
byte[] bytes = ArrayPool<byte>.Shared.Rent((int)Math.Min(requiredBytes, 256_000));
while (!resultAsBytes.IsEmpty)
{
@ -138,8 +158,7 @@ internal sealed class ArraySinglePrimitiveRecord<T> : SZArrayRecord<T>
}
#endif
}
else if (typeof(T) == typeof(long) || typeof(T) == typeof(ulong) || typeof(T) == typeof(double)
|| typeof(T) == typeof(DateTime) || typeof(T) == typeof(TimeSpan))
else if (typeof(T) == typeof(long) || typeof(T) == typeof(ulong) || typeof(T) == typeof(double))
{
Span<long> span = MemoryMarshal.Cast<T, long>(result);
#if NET
@ -153,37 +172,62 @@ internal sealed class ArraySinglePrimitiveRecord<T> : SZArrayRecord<T>
}
}
if (typeof(T) == typeof(bool))
{
// See DontCastBytesToBooleans test to see what could go wrong.
bool[] booleans = (bool[])(object)result;
resultAsBytes = MemoryMarshal.AsBytes<T>(result);
for (int i = 0; i < booleans.Length; i++)
{
// We don't use the bool array to get the value, as an optimizing compiler or JIT could elide this.
if (resultAsBytes[i] != 0) // it can be any byte different than 0
{
booleans[i] = true; // set it to 1 in explicit way
}
}
}
return result;
}
private static List<decimal> DecodeDecimals(BinaryReader reader, int count)
{
List<decimal> values = new();
#if NET
Span<byte> buffer = stackalloc byte[256];
for (int i = 0; i < count; i++)
{
int stringLength = reader.Read7BitEncodedInt();
if (!(stringLength > 0 && stringLength <= buffer.Length))
values.Add(reader.ParseDecimal());
}
return values;
}
private static T[] DecodeTime(BinaryReader reader, int count)
{
T[] values = new T[count];
for (int i = 0; i < values.Length; i++)
{
if (typeof(T) == typeof(DateTime))
{
ThrowHelper.ThrowInvalidValue(stringLength);
values[i] = (T)(object)Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadUInt64());
}
else if (typeof(T) == typeof(TimeSpan))
{
values[i] = (T)(object)new TimeSpan(reader.ReadInt64());
}
else
{
throw new InvalidOperationException();
}
reader.BaseStream.ReadExactly(buffer.Slice(0, stringLength));
values.Add(decimal.Parse(buffer.Slice(0, stringLength), CultureInfo.InvariantCulture));
}
#else
for (int i = 0; i < count; i++)
{
values.Add(decimal.Parse(reader.ReadString(), CultureInfo.InvariantCulture));
}
#endif
return values;
}
private static List<T> DecodeFromNonSeekableStream(BinaryReader reader, int count)
{
// The count arg could originate from untrusted input, so we shouldn't
// pass it as-is to the ctor's capacity arg. We'll instead rely on
// List<T>.Add's O(1) amortization to keep the entire loop O(count).
List<T> values = new List<T>(Math.Min(count, 4));
for (int i = 0; i < count; i++)
{
@ -201,7 +245,7 @@ internal sealed class ArraySinglePrimitiveRecord<T> : SZArrayRecord<T>
}
else if (typeof(T) == typeof(char))
{
values.Add((T)(object)reader.ReadChar());
values.Add((T)(object)reader.ParseChar());
}
else if (typeof(T) == typeof(short))
{
@ -237,13 +281,15 @@ internal sealed class ArraySinglePrimitiveRecord<T> : SZArrayRecord<T>
}
else if (typeof(T) == typeof(DateTime))
{
values.Add((T)(object)Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadInt64()));
values.Add((T)(object)Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadUInt64()));
}
else if (typeof(T) == typeof(TimeSpan))
{
values.Add((T)(object)new TimeSpan(reader.ReadInt64()));
}
else
{
Debug.Assert(typeof(T) == typeof(TimeSpan));
values.Add((T)(object)new TimeSpan(reader.ReadInt64()));
throw new InvalidOperationException();
}
}

View file

@ -5,6 +5,7 @@ using System.Collections.Generic;
using System.IO;
using System.Reflection.Metadata;
using System.Formats.Nrbf.Utils;
using System.Diagnostics;
namespace System.Formats.Nrbf;
@ -21,7 +22,7 @@ internal sealed class ArraySingleStringRecord : SZArrayRecord<string?>
public override SerializationRecordType RecordType => SerializationRecordType.ArraySingleString;
/// <inheritdoc />
public override TypeName TypeName => TypeNameHelpers.GetPrimitiveSZArrayTypeName(PrimitiveType.String);
public override TypeName TypeName => TypeNameHelpers.GetPrimitiveSZArrayTypeName(TypeNameHelpers.StringPrimitiveType);
private List<SerializationRecord> Records { get; }
@ -47,7 +48,8 @@ internal sealed class ArraySingleStringRecord : SZArrayRecord<string?>
{
string?[] values = new string?[Length];
for (int recordIndex = 0, valueIndex = 0; recordIndex < Records.Count; recordIndex++)
int valueIndex = 0;
for (int recordIndex = 0; recordIndex < Records.Count; recordIndex++)
{
SerializationRecord record = Records[recordIndex];
@ -73,6 +75,7 @@ internal sealed class ArraySingleStringRecord : SZArrayRecord<string?>
}
int nullCount = ((NullsRecord)record).NullCount;
Debug.Assert(nullCount > 0, "All implementations of NullsRecord are expected to return a positive value for NullCount.");
do
{
values[valueIndex++] = null;
@ -81,6 +84,8 @@ internal sealed class ArraySingleStringRecord : SZArrayRecord<string?>
while (nullCount > 0);
}
Debug.Assert(valueIndex == values.Length, "We should have traversed the entirety of the newly created array.");
return values;
}
}

View file

@ -6,6 +6,7 @@ using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Reflection.Metadata;
using System.Formats.Nrbf.Utils;
using System.Diagnostics;
namespace System.Formats.Nrbf;
@ -27,12 +28,15 @@ internal sealed class BinaryArrayRecord : ArrayRecord
];
private TypeName? _typeName;
private long _totalElementsCount;
private BinaryArrayRecord(ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo)
: base(arrayInfo)
{
MemberTypeInfo = memberTypeInfo;
Values = [];
// We need to parse all elements of the jagged array to obtain total elements count.
_totalElementsCount = -1;
}
public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray;
@ -40,6 +44,22 @@ internal sealed class BinaryArrayRecord : ArrayRecord
/// <inheritdoc/>
public override ReadOnlySpan<int> Lengths => new int[1] { Length };
/// <inheritdoc/>
public override long FlattenedLength
{
get
{
if (_totalElementsCount < 0)
{
_totalElementsCount = IsJagged
? GetJaggedArrayFlattenedLength(this)
: ArrayInfo.FlattenedLength;
}
return _totalElementsCount;
}
}
public override TypeName TypeName
=> _typeName ??= MemberTypeInfo.GetArrayTypeName(ArrayInfo);
@ -84,6 +104,10 @@ internal sealed class BinaryArrayRecord : ArrayRecord
case SerializationRecordType.ArraySinglePrimitive:
case SerializationRecordType.ArraySingleObject:
case SerializationRecordType.ArraySingleString:
// Recursion depth is bounded by the depth of arrayType, which is
// a trustworthy Type instance. Don't need to worry about stack overflow.
ArrayRecord nestedArrayRecord = (ArrayRecord)record;
Array nestedArray = nestedArrayRecord.GetArray(actualElementType, allowNulls);
array.SetValue(nestedArray, resultIndex++);
@ -97,6 +121,7 @@ internal sealed class BinaryArrayRecord : ArrayRecord
}
int nullCount = ((NullsRecord)item).NullCount;
Debug.Assert(nullCount > 0, "All implementations of NullsRecord are expected to return a positive value for NullCount.");
do
{
array.SetValue(null, resultIndex++);
@ -110,6 +135,8 @@ internal sealed class BinaryArrayRecord : ArrayRecord
}
}
Debug.Assert(resultIndex == array.Length, "We should have traversed the entirety of the newly created array.");
return array;
}
@ -122,6 +149,7 @@ internal sealed class BinaryArrayRecord : ArrayRecord
bool isRectangular = arrayType is BinaryArrayType.Rectangular;
// It is an arbitrary limit in the current CoreCLR type loader.
// Don't change this value without reviewing the loop a few lines below.
const int MaxSupportedArrayRank = 32;
if (rank < 1 || rank > MaxSupportedArrayRank
@ -132,18 +160,26 @@ internal sealed class BinaryArrayRecord : ArrayRecord
}
int[] lengths = new int[rank]; // adversary-controlled, but acceptable since upper limit of 32
long totalElementCount = 1;
long totalElementCount = 1; // to avoid integer overflow during the multiplication below
for (int i = 0; i < lengths.Length; i++)
{
lengths[i] = ArrayInfo.ParseValidArrayLength(reader);
totalElementCount *= lengths[i];
if (totalElementCount > uint.MaxValue)
// n.b. This forbids "new T[Array.MaxLength, Array.MaxLength, Array.MaxLength, ..., 0]"
// but allows "new T[0, Array.MaxLength, Array.MaxLength, Array.MaxLength, ...]". But
// that's the same behavior that newarr and Array.CreateInstance exhibit, so at least
// we're consistent.
if (totalElementCount > ArrayInfo.MaxArrayLength)
{
ThrowHelper.ThrowInvalidValue(lengths[i]); // max array size exceeded
}
}
// Per BinaryReaderExtensions.ReadArrayType, we do not support nonzero offsets, so
// we don't need to read the NRBF stream 'LowerBounds' field here.
MemberTypeInfo memberTypeInfo = MemberTypeInfo.Decode(reader, 1, options, recordMap);
ArrayInfo arrayInfo = new(objectId, totalElementCount, arrayType, rank);
@ -157,6 +193,65 @@ internal sealed class BinaryArrayRecord : ArrayRecord
: new BinaryArrayRecord(arrayInfo, memberTypeInfo);
}
private static long GetJaggedArrayFlattenedLength(BinaryArrayRecord jaggedArrayRecord)
{
long result = 0;
Queue<BinaryArrayRecord>? jaggedArrayRecords = null;
do
{
if (jaggedArrayRecords is not null)
{
jaggedArrayRecord = jaggedArrayRecords.Dequeue();
}
Debug.Assert(jaggedArrayRecord.IsJagged);
// In theory somebody could create a payload that would represent
// a very nested array with total elements count > long.MaxValue.
// That is why this method is using checked arithmetic.
result = checked(result + jaggedArrayRecord.Length); // count the arrays themselves
foreach (object value in jaggedArrayRecord.Values)
{
if (value is not SerializationRecord record)
{
continue;
}
if (record.RecordType == SerializationRecordType.MemberReference)
{
record = ((MemberReferenceRecord)record).GetReferencedRecord();
}
switch (record.RecordType)
{
case SerializationRecordType.ArraySinglePrimitive:
case SerializationRecordType.ArraySingleObject:
case SerializationRecordType.ArraySingleString:
case SerializationRecordType.BinaryArray:
ArrayRecord nestedArrayRecord = (ArrayRecord)record;
if (nestedArrayRecord.IsJagged)
{
(jaggedArrayRecords ??= new()).Enqueue((BinaryArrayRecord)nestedArrayRecord);
}
else
{
// Don't call nestedArrayRecord.FlattenedLength to avoid any potential recursion,
// just call nestedArrayRecord.ArrayInfo.FlattenedLength that returns pre-computed value.
result = checked(result + nestedArrayRecord.ArrayInfo.FlattenedLength);
}
break;
default:
break;
}
}
}
while (jaggedArrayRecords is not null && jaggedArrayRecords.Count > 0);
return result;
}
private protected override void AddValue(object value) => Values.Add(value);
internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType()
@ -186,6 +281,9 @@ internal sealed class BinaryArrayRecord : ArrayRecord
Type elementType = arrayType;
int arrayNestingDepth = 0;
// Loop iteration counts are bound by the nesting depth of arrayType,
// which is a trustworthy input. No DoS concerns.
while (elementType.IsArray)
{
elementType = elementType.GetElementType()!;

View file

@ -30,14 +30,7 @@ internal sealed class BinaryLibraryRecord : SerializationRecord
public override SerializationRecordType RecordType => SerializationRecordType.BinaryLibrary;
public override TypeName TypeName
{
get
{
Debug.Fail("TypeName should never be called on BinaryLibraryRecord");
return TypeName.Parse(nameof(BinaryLibraryRecord).AsSpan());
}
}
public override TypeName TypeName => TypeName.Parse(nameof(BinaryLibraryRecord).AsSpan());
internal string? RawLibraryName { get; }
@ -57,7 +50,7 @@ internal sealed class BinaryLibraryRecord : SerializationRecord
}
else if (!options.UndoTruncatedTypeNames)
{
ThrowHelper.ThrowInvalidAssemblyName(rawName);
ThrowHelper.ThrowInvalidAssemblyName();
}
return new BinaryLibraryRecord(id, rawName);

View file

@ -50,7 +50,8 @@ internal sealed class ClassInfo
// Use Dictionary instead of List so that searching for member IDs by name
// is O(n) instead of O(m * n), where m = memberCount and n = memberNameLength,
// in degenerate cases.
// in degenerate cases. Since memberCount may be hostile, don't allow it to be
// used as the initial capacity in the collection instance.
Dictionary<string, int> memberNames = new(StringComparer.Ordinal);
for (int i = 0; i < memberCount; i++)
{
@ -70,7 +71,7 @@ internal sealed class ClassInfo
continue;
}
#endif
throw new SerializationException(SR.Format(SR.Serialization_DuplicateMemberName, memberName));
ThrowHelper.ThrowDuplicateMemberName();
}
return new ClassInfo(id, typeName, memberNames);

View file

@ -9,7 +9,7 @@ using System.Formats.Nrbf.Utils;
namespace System.Formats.Nrbf;
/// <summary>
/// Identifies a class by it's name and library id.
/// Identifies a class by its name and library id.
/// </summary>
/// <remarks>
/// ClassTypeInfo structures are described in <see href="https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/844b24dd-9f82-426e-9b98-05334307a239">[MS-NRBF] 2.1.1.8</see>.
@ -26,7 +26,7 @@ internal sealed class ClassTypeInfo
string rawName = reader.ReadString();
SerializationRecordId libraryId = SerializationRecordId.Decode(reader);
BinaryLibraryRecord library = (BinaryLibraryRecord)recordMap[libraryId];
BinaryLibraryRecord library = recordMap.GetRecord<BinaryLibraryRecord>(libraryId);
return new ClassTypeInfo(rawName.ParseNonSystemClassRecordTypeName(library, options));
}

View file

@ -34,10 +34,7 @@ internal sealed class ClassWithIdRecord : ClassRecord
SerializationRecordId id = SerializationRecordId.Decode(reader);
SerializationRecordId metadataId = SerializationRecordId.Decode(reader);
if (recordMap[metadataId] is not ClassRecord referencedRecord)
{
throw new SerializationException(SR.Serialization_InvalidReference);
}
ClassRecord referencedRecord = recordMap.GetRecord<ClassRecord>(metadataId);
return new ClassWithIdRecord(id, referencedRecord);
}

View file

@ -27,7 +27,7 @@ internal sealed class ClassWithMembersAndTypesRecord : ClassRecord
MemberTypeInfo memberTypeInfo = MemberTypeInfo.Decode(reader, classInfo.MemberNames.Count, options, recordMap);
SerializationRecordId libraryId = SerializationRecordId.Decode(reader);
BinaryLibraryRecord library = (BinaryLibraryRecord)recordMap[libraryId];
BinaryLibraryRecord library = recordMap.GetRecord<BinaryLibraryRecord>(libraryId);
classInfo.LoadTypeName(library, options);
return new ClassWithMembersAndTypesRecord(classInfo, memberTypeInfo);

View file

@ -38,5 +38,5 @@ internal sealed class MemberReferenceRecord : SerializationRecord
internal static MemberReferenceRecord Decode(BinaryReader reader, RecordMap recordMap)
=> new(SerializationRecordId.Decode(reader), recordMap);
internal SerializationRecord GetReferencedRecord() => RecordMap[Reference];
internal SerializationRecord GetReferencedRecord() => RecordMap.GetRecord(Reference);
}

View file

@ -53,10 +53,14 @@ internal readonly struct MemberTypeInfo
case BinaryType.Class:
info[i] = (type, ClassTypeInfo.Decode(reader, options, recordMap));
break;
default:
// Other types have no additional data.
Debug.Assert(type is BinaryType.String or BinaryType.ObjectArray or BinaryType.StringArray or BinaryType.Object);
case BinaryType.String:
case BinaryType.StringArray:
case BinaryType.Object:
case BinaryType.ObjectArray:
// These types have no additional data.
break;
default:
throw new InvalidOperationException();
}
}
@ -97,7 +101,8 @@ internal readonly struct MemberTypeInfo
BinaryType.PrimitiveArray => (PrimitiveArray, default),
BinaryType.Class => (NonSystemClass, default),
BinaryType.SystemClass => (SystemClass, default),
_ => (ObjectArray, default)
BinaryType.ObjectArray => (ObjectArray, default),
_ => throw new InvalidOperationException()
};
}
@ -105,7 +110,7 @@ internal readonly struct MemberTypeInfo
{
// This library tries to minimize the number of concepts the users need to learn to use it.
// Since SZArrays are most common, it provides an SZArrayRecord<T> abstraction.
// Every other array (jagged, multi-dimensional etc) is represented using SZArrayRecord.
// Every other array (jagged, multi-dimensional etc) is represented using ArrayRecord.
// The goal of this method is to determine whether given array can be represented as SZArrayRecord<ClassRecord>.
(BinaryType binaryType, object? additionalInfo) = Infos[0];
@ -144,15 +149,15 @@ internal readonly struct MemberTypeInfo
TypeName elementTypeName = binaryType switch
{
BinaryType.String => TypeNameHelpers.GetPrimitiveTypeName(PrimitiveType.String),
BinaryType.StringArray => TypeNameHelpers.GetPrimitiveSZArrayTypeName(PrimitiveType.String),
BinaryType.String => TypeNameHelpers.GetPrimitiveTypeName(TypeNameHelpers.StringPrimitiveType),
BinaryType.StringArray => TypeNameHelpers.GetPrimitiveSZArrayTypeName(TypeNameHelpers.StringPrimitiveType),
BinaryType.Primitive => TypeNameHelpers.GetPrimitiveTypeName((PrimitiveType)additionalInfo!),
BinaryType.PrimitiveArray => TypeNameHelpers.GetPrimitiveSZArrayTypeName((PrimitiveType)additionalInfo!),
BinaryType.Object => TypeNameHelpers.GetPrimitiveTypeName(TypeNameHelpers.ObjectPrimitiveType),
BinaryType.ObjectArray => TypeNameHelpers.GetPrimitiveSZArrayTypeName(TypeNameHelpers.ObjectPrimitiveType),
BinaryType.SystemClass => (TypeName)additionalInfo!,
BinaryType.Class => ((ClassTypeInfo)additionalInfo!).TypeName,
_ => throw new ArgumentOutOfRangeException(paramName: nameof(binaryType), actualValue: binaryType, message: null)
_ => throw new InvalidOperationException()
};
// In general, arrayRank == 1 may have two different meanings:

View file

@ -24,12 +24,5 @@ internal sealed class MessageEndRecord : SerializationRecord
public override SerializationRecordId Id => SerializationRecordId.NoId;
public override TypeName TypeName
{
get
{
Debug.Fail("TypeName should never be called on MessageEndRecord");
return TypeName.Parse(nameof(MessageEndRecord).AsSpan());
}
}
public override TypeName TypeName => TypeName.Parse(nameof(MessageEndRecord).AsSpan());
}

View file

@ -27,7 +27,5 @@ internal readonly struct NextInfo
internal PrimitiveType PrimitiveType { get; }
internal NextInfo With(AllowedRecordTypes allowed, PrimitiveType primitiveType)
=> allowed == Allowed && primitiveType == PrimitiveType
? this // previous record was of the same type
: new(allowed, Parent, Stack, primitiveType);
=> new(allowed, Parent, Stack, primitiveType);
}

View file

@ -22,7 +22,7 @@ public static class NrbfDecoder
// The header consists of:
// - a byte that describes the record type (SerializationRecordType.SerializedStreamHeader)
// - four 32 bit integers:
// - root Id (every value is valid)
// - root Id (every value except of 0 is valid)
// - header Id (value is ignored)
// - major version, it has to be equal 1.
// - minor version, it has to be equal 0.
@ -46,6 +46,7 @@ public static class NrbfDecoder
/// <exception cref="ArgumentNullException"><paramref name="stream" /> is <see langword="null" />.</exception>
/// <exception cref="NotSupportedException">The stream does not support reading or seeking.</exception>
/// <exception cref="ObjectDisposedException">The stream was closed.</exception>
/// <exception cref="IOException">An I/O error occurred.</exception>
/// <remarks><para>When this method returns, <paramref name="stream" /> will be restored to its original position.</para></remarks>
public static bool StartsWithPayloadHeader(Stream stream)
{
@ -68,28 +69,22 @@ public static class NrbfDecoder
return false;
}
try
byte[] buffer = new byte[SerializedStreamHeaderRecord.Size];
int offset = 0;
while (offset < buffer.Length)
{
#if NET
Span<byte> buffer = stackalloc byte[SerializedStreamHeaderRecord.Size];
stream.ReadExactly(buffer);
#else
byte[] buffer = new byte[SerializedStreamHeaderRecord.Size];
int offset = 0;
while (offset < buffer.Length)
int read = stream.Read(buffer, offset, buffer.Length - offset);
if (read == 0)
{
int read = stream.Read(buffer, offset, buffer.Length - offset);
if (read == 0)
throw new EndOfStreamException();
offset += read;
stream.Position = beginning;
return false;
}
#endif
return StartsWithPayloadHeader(buffer);
}
finally
{
stream.Position = beginning;
offset += read;
}
bool result = StartsWithPayloadHeader(buffer);
stream.Position = beginning;
return result;
}
/// <summary>
@ -107,6 +102,7 @@ public static class NrbfDecoder
/// <exception cref="ArgumentNullException"><paramref name="payload"/> is <see langword="null" />.</exception>
/// <exception cref="ArgumentException"><paramref name="payload"/> does not support reading or is already closed.</exception>
/// <exception cref="SerializationException">Reading from <paramref name="payload"/> encounters invalid NRBF data.</exception>
/// <exception cref="IOException">An I/O error occurred.</exception>
/// <exception cref="NotSupportedException">
/// Reading from <paramref name="payload"/> encounters not supported records.
/// For example, arrays with non-zero offset or not supported record types
@ -142,7 +138,14 @@ public static class NrbfDecoder
#endif
using BinaryReader reader = new(payload, ThrowOnInvalidUtf8Encoding, leaveOpen: leaveOpen);
return Decode(reader, options ?? new(), out recordMap);
try
{
return Decode(reader, options ?? new(), out recordMap);
}
catch (FormatException) // can be thrown by various BinaryReader methods
{
throw new SerializationException(SR.Serialization_InvalidFormat);
}
}
/// <summary>
@ -213,12 +216,7 @@ public static class NrbfDecoder
private static SerializationRecord DecodeNext(BinaryReader reader, RecordMap recordMap,
AllowedRecordTypes allowed, PayloadOptions options, out SerializationRecordType recordType)
{
byte nextByte = reader.ReadByte();
if (((uint)allowed & (1u << nextByte)) == 0)
{
ThrowHelper.ThrowForUnexpectedRecordType(nextByte);
}
recordType = (SerializationRecordType)nextByte;
recordType = reader.ReadSerializationRecordType(allowed);
SerializationRecord record = recordType switch
{
@ -237,7 +235,8 @@ public static class NrbfDecoder
SerializationRecordType.ObjectNullMultiple => ObjectNullMultipleRecord.Decode(reader),
SerializationRecordType.ObjectNullMultiple256 => ObjectNullMultiple256Record.Decode(reader),
SerializationRecordType.SerializedStreamHeader => SerializedStreamHeaderRecord.Decode(reader),
_ => SystemClassWithMembersAndTypesRecord.Decode(reader, recordMap, options),
SerializationRecordType.SystemClassWithMembersAndTypes => SystemClassWithMembersAndTypesRecord.Decode(reader, recordMap, options),
_ => throw new InvalidOperationException()
};
recordMap.Add(record);
@ -254,7 +253,7 @@ public static class NrbfDecoder
PrimitiveType.Boolean => new MemberPrimitiveTypedRecord<bool>(reader.ReadBoolean()),
PrimitiveType.Byte => new MemberPrimitiveTypedRecord<byte>(reader.ReadByte()),
PrimitiveType.SByte => new MemberPrimitiveTypedRecord<sbyte>(reader.ReadSByte()),
PrimitiveType.Char => new MemberPrimitiveTypedRecord<char>(reader.ReadChar()),
PrimitiveType.Char => new MemberPrimitiveTypedRecord<char>(reader.ParseChar()),
PrimitiveType.Int16 => new MemberPrimitiveTypedRecord<short>(reader.ReadInt16()),
PrimitiveType.UInt16 => new MemberPrimitiveTypedRecord<ushort>(reader.ReadUInt16()),
PrimitiveType.Int32 => new MemberPrimitiveTypedRecord<int>(reader.ReadInt32()),
@ -263,10 +262,10 @@ public static class NrbfDecoder
PrimitiveType.UInt64 => new MemberPrimitiveTypedRecord<ulong>(reader.ReadUInt64()),
PrimitiveType.Single => new MemberPrimitiveTypedRecord<float>(reader.ReadSingle()),
PrimitiveType.Double => new MemberPrimitiveTypedRecord<double>(reader.ReadDouble()),
PrimitiveType.Decimal => new MemberPrimitiveTypedRecord<decimal>(decimal.Parse(reader.ReadString(), CultureInfo.InvariantCulture)),
PrimitiveType.DateTime => new MemberPrimitiveTypedRecord<DateTime>(Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadInt64())),
// String is handled with a record, never on it's own
_ => new MemberPrimitiveTypedRecord<TimeSpan>(new TimeSpan(reader.ReadInt64())),
PrimitiveType.Decimal => new MemberPrimitiveTypedRecord<decimal>(reader.ParseDecimal()),
PrimitiveType.DateTime => new MemberPrimitiveTypedRecord<DateTime>(Utils.BinaryReaderExtensions.CreateDateTimeFromData(reader.ReadUInt64())),
PrimitiveType.TimeSpan => new MemberPrimitiveTypedRecord<TimeSpan>(new TimeSpan(reader.ReadInt64())),
_ => throw new InvalidOperationException()
};
}
@ -291,7 +290,8 @@ public static class NrbfDecoder
PrimitiveType.Double => Decode<double>(info, reader),
PrimitiveType.Decimal => Decode<decimal>(info, reader),
PrimitiveType.DateTime => Decode<DateTime>(info, reader),
_ => Decode<TimeSpan>(info, reader),
PrimitiveType.TimeSpan => Decode<TimeSpan>(info, reader),
_ => throw new InvalidOperationException()
};
static SerializationRecord Decode<T>(ArrayInfo info, BinaryReader reader) where T : unmanaged

View file

@ -12,12 +12,5 @@ internal abstract class NullsRecord : SerializationRecord
public override SerializationRecordId Id => SerializationRecordId.NoId;
public override TypeName TypeName
{
get
{
Debug.Fail($"TypeName should never be called on {GetType().Name}");
return TypeName.Parse(GetType().Name.AsSpan());
}
}
public override TypeName TypeName => TypeName.Parse(GetType().Name.AsSpan());
}

View file

@ -25,10 +25,17 @@ public sealed class PayloadOptions
/// </summary>
/// <value><see langword="true" /> if truncated type names should be reassembled; otherwise, <see langword="false" />.</value>
/// <remarks>
/// <para>
/// Example:
/// TypeName: "Namespace.TypeName`1[[Namespace.GenericArgName"
/// LibraryName: "AssemblyName]]"
/// Is combined into "Namespace.TypeName`1[[Namespace.GenericArgName, AssemblyName]]"
/// </para>
/// <para>
/// Setting this to <see langword="true" /> can render <see cref="NrbfDecoder"/> susceptible to Denial of Service
/// attacks when parsing or handling malicious input.
/// </para>
/// <para>The default value is <see langword="false" />.</para>
/// </remarks>
public bool UndoTruncatedTypeNames { get; set; }
}

View file

@ -11,10 +11,6 @@ namespace System.Formats.Nrbf;
/// </remarks>
internal enum PrimitiveType : byte
{
/// <summary>
/// Used internally to express no value
/// </summary>
None = 0,
Boolean = 1,
Byte = 2,
Char = 3,
@ -30,7 +26,19 @@ internal enum PrimitiveType : byte
DateTime = 13,
UInt16 = 14,
UInt32 = 15,
UInt64 = 16,
Null = 17,
String = 18
UInt64 = 16
// This internal enum no longer contains Null and String as they were always illegal:
// - In case of BinaryArray (NRBF 2.4.3.1):
// "If the BinaryTypeEnum value is Primitive, the PrimitiveTypeEnumeration
// value in AdditionalTypeInfo MUST NOT be Null (17) or String (18)."
// - In case of MemberPrimitiveTyped (NRBF 2.5.1):
// "PrimitiveTypeEnum (1 byte): A PrimitiveTypeEnumeration
// value that specifies the Primitive Type of data that is being transmitted.
// This field MUST NOT contain a value of 17 (Null) or 18 (String)."
// - In case of ArraySinglePrimitive (NRBF 2.4.3.3):
// "A PrimitiveTypeEnumeration value that identifies the Primitive Type
// of the items of the Array. The value MUST NOT be 17 (Null) or 18 (String)."
// - In case of MemberTypeInfo (NRBF 2.3.1.2):
// "When the BinaryTypeEnum value is Primitive, the PrimitiveTypeEnumeration
// value in AdditionalInfo MUST NOT be Null (17) or String (18)."
}

View file

@ -56,14 +56,15 @@ internal sealed class RecordMap : IReadOnlyDictionary<SerializationRecordId, Ser
return;
}
#endif
throw new SerializationException(SR.Format(SR.Serialization_DuplicateSerializationRecordId, record.Id));
throw new SerializationException(SR.Format(SR.Serialization_DuplicateSerializationRecordId, record.Id._id));
}
}
}
internal SerializationRecord GetRootRecord(SerializedStreamHeaderRecord header)
{
SerializationRecord rootRecord = _map[header.RootId];
SerializationRecord rootRecord = GetRecord(header.RootId);
if (rootRecord is SystemClassWithMembersAndTypesRecord systemClass)
{
// update the record map, so it's visible also to those who access it via Id
@ -72,4 +73,14 @@ internal sealed class RecordMap : IReadOnlyDictionary<SerializationRecordId, Ser
return rootRecord;
}
internal SerializationRecord GetRecord(SerializationRecordId recordId)
=> _map.TryGetValue(recordId, out SerializationRecord? record)
? record
: throw new SerializationException(SR.Serialization_InvalidReference);
internal T GetRecord<T>(SerializationRecordId recordId) where T : SerializationRecord
=> _map.TryGetValue(recordId, out SerializationRecord? record) && record is T casted
? casted
: throw new SerializationException(SR.Serialization_InvalidReference);
}

View file

@ -8,13 +8,14 @@ using System.Reflection.Metadata;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Formats.Nrbf.Utils;
using System.Diagnostics;
namespace System.Formats.Nrbf;
internal sealed class RectangularArrayRecord : ArrayRecord
{
private readonly int[] _lengths;
private readonly ICollection<object> _values;
private readonly List<object> _values;
private TypeName? _typeName;
private RectangularArrayRecord(Type elementType, ArrayInfo arrayInfo,
@ -24,18 +25,8 @@ internal sealed class RectangularArrayRecord : ArrayRecord
MemberTypeInfo = memberTypeInfo;
_lengths = lengths;
// A List<T> can hold as many objects as an array, so for multi-dimensional arrays
// with more elements than Array.MaxLength we use LinkedList.
// Testing that many elements takes a LOT of time, so to ensure that both code paths are tested,
// we always use LinkedList code path for Debug builds.
#if DEBUG
_values = new LinkedList<object>();
#else
_values = arrayInfo.TotalElementsCount <= ArrayInfo.MaxArrayLength
? new List<object>(canPreAllocate ? arrayInfo.GetSZArrayLength() : Math.Min(4, arrayInfo.GetSZArrayLength()))
: new LinkedList<object>();
#endif
// ArrayInfo.GetSZArrayLength ensures to return a value <= Array.MaxLength
_values = new List<object>(canPreAllocate ? arrayInfo.GetSZArrayLength() : Math.Min(4, arrayInfo.GetSZArrayLength()));
}
public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray;
@ -65,10 +56,12 @@ internal sealed class RectangularArrayRecord : ArrayRecord
#if !NET8_0_OR_GREATER
int[] indices = new int[_lengths.Length];
nuint numElementsWritten = 0; // only for debugging; not used in release builds
foreach (object value in _values)
{
result.SetValue(GetActualValue(value), indices);
numElementsWritten++;
int dimension = indices.Length - 1;
while (dimension >= 0)
@ -88,6 +81,9 @@ internal sealed class RectangularArrayRecord : ArrayRecord
}
}
Debug.Assert(numElementsWritten == (uint)_values.Count, "We should have traversed the entirety of the source values collection.");
Debug.Assert(numElementsWritten == (ulong)result.LongLength, "We should have traversed the entirety of the destination array.");
return result;
#else
// Idea from Array.CoreCLR that maps an array of int indices into
@ -108,6 +104,7 @@ internal sealed class RectangularArrayRecord : ArrayRecord
else if (ElementType == typeof(TimeSpan)) CopyTo<TimeSpan>(_values, result);
else if (ElementType == typeof(DateTime)) CopyTo<DateTime>(_values, result);
else if (ElementType == typeof(decimal)) CopyTo<decimal>(_values, result);
else throw new InvalidOperationException();
}
else
{
@ -116,7 +113,7 @@ internal sealed class RectangularArrayRecord : ArrayRecord
return result;
static void CopyTo<T>(ICollection<object> list, Array array)
static void CopyTo<T>(List<object> list, Array array)
{
ref byte arrayDataRef = ref MemoryMarshal.GetArrayDataReference(array);
ref T firstElementRef = ref Unsafe.As<byte, T>(ref arrayDataRef);
@ -127,6 +124,8 @@ internal sealed class RectangularArrayRecord : ArrayRecord
targetElement = (T)GetActualValue(value)!;
flattenedIndex++;
}
Debug.Assert(flattenedIndex == (ulong)array.LongLength, "We should have traversed the entirety of the array.");
}
#endif
}
@ -167,7 +166,7 @@ internal sealed class RectangularArrayRecord : ArrayRecord
PrimitiveType.Boolean => sizeof(bool),
PrimitiveType.Byte => sizeof(byte),
PrimitiveType.SByte => sizeof(sbyte),
PrimitiveType.Char => sizeof(byte), // it's UTF8
PrimitiveType.Char => sizeof(byte), // it's UTF8 (see comment below)
PrimitiveType.Int16 => sizeof(short),
PrimitiveType.UInt16 => sizeof(ushort),
PrimitiveType.Int32 => sizeof(int),
@ -176,12 +175,29 @@ internal sealed class RectangularArrayRecord : ArrayRecord
PrimitiveType.Int64 => sizeof(long),
PrimitiveType.UInt64 => sizeof(ulong),
PrimitiveType.Double => sizeof(double),
_ => -1
PrimitiveType.TimeSpan => sizeof(ulong),
PrimitiveType.DateTime => sizeof(ulong),
PrimitiveType.Decimal => -1, // represented as variable-length string
_ => throw new InvalidOperationException()
};
if (sizeOfSingleValue > 0)
{
long size = arrayInfo.TotalElementsCount * sizeOfSingleValue;
// NRBF encodes rectangular char[,,,...] by converting each standalone UTF-16 code point into
// its UTF-8 encoding. This means that surrogate code points (including adjacent surrogate
// pairs) occurring within a char[,,,...] cannot be encoded by NRBF. BinaryReader will detect
// that they're ill-formed and reject them on read.
//
// Per the comment in ArraySinglePrimitiveRecord.DecodePrimitiveTypes, we'll assume best-case
// encoding where 1 UTF-16 char encodes as a single UTF-8 byte, even though this might lead
// to encountering an EOF if we realize later that we actually need to read more bytes in
// order to fully populate the char[,,,...] array. Any such allocation is still linearly
// proportional to the length of the incoming payload, so it's not a DoS vector.
// The multiplication below is guaranteed not to overflow because FlattenedLength is bounded
// to <= Array.MaxLength (see BinaryArrayRecord.Decode) and sizeOfSingleValue is at most 8.
Debug.Assert(arrayInfo.FlattenedLength >= 0 && arrayInfo.FlattenedLength <= long.MaxValue / sizeOfSingleValue);
long size = arrayInfo.FlattenedLength * sizeOfSingleValue;
bool? isDataAvailable = reader.IsDataAvailable(size);
if (isDataAvailable.HasValue)
{
@ -215,7 +231,8 @@ internal sealed class RectangularArrayRecord : ArrayRecord
PrimitiveType.DateTime => typeof(DateTime),
PrimitiveType.UInt16 => typeof(ushort),
PrimitiveType.UInt32 => typeof(uint),
_ => typeof(ulong)
PrimitiveType.UInt64 => typeof(ulong),
_ => throw new InvalidOperationException()
};
private static Type MapPrimitiveArray(PrimitiveType primitiveType)
@ -235,7 +252,8 @@ internal sealed class RectangularArrayRecord : ArrayRecord
PrimitiveType.DateTime => typeof(DateTime[]),
PrimitiveType.UInt16 => typeof(ushort[]),
PrimitiveType.UInt32 => typeof(uint[]),
_ => typeof(ulong[]),
PrimitiveType.UInt64 => typeof(ulong[]),
_ => throw new InvalidOperationException()
};
private static object? GetActualValue(object value)

View file

@ -13,7 +13,7 @@ namespace System.Formats.Nrbf;
/// <remarks>
/// <para>
/// Every instance returned to the end user can be either <see cref="PrimitiveTypeRecord{T}"/>,
/// a <see cref="ClassRecord"/> or an <see cref="ArrayRecord"/>.
/// a <see cref="ClassRecord"/>, or an <see cref="ArrayRecord"/>.
/// </para>
/// </remarks>
[DebuggerDisplay("{RecordType}, {Id}")]
@ -50,7 +50,20 @@ public abstract class SerializationRecord
/// </remarks>
/// <param name="type">The type to compare against.</param>
/// <returns><see langword="true" /> if the serialized type name match provided type; otherwise, <see langword="false" />.</returns>
public bool TypeNameMatches(Type type) => Matches(type, TypeName);
/// <exception cref="ArgumentNullException"><paramref name="type" /> is <see langword="null" />.</exception>
public bool TypeNameMatches(Type type)
{
#if NET
ArgumentNullException.ThrowIfNull(type);
#else
if (type is null)
{
throw new ArgumentNullException(nameof(type));
}
#endif
return Matches(type, TypeName);
}
private static bool Matches(Type type, TypeName typeName)
{
@ -61,10 +74,38 @@ public abstract class SerializationRecord
return false;
}
// The TypeName.FullName property getter is recursive and backed by potentially hostile
// input. See comments in that property getter for more information, including what defenses
// are in place to prevent attacks.
//
// Note that the equality comparison below is worst-case O(n) since the adversary could ensure
// that only the last char differs. Even if the strings have equal contents, we should still
// expect the comparison to take O(n) time since RuntimeType.FullName and TypeName.FullName
// will never reference the same string instance with current runtime implementations.
//
// Since a call to Matches could take place within a loop, and since TypeName.FullName could
// be arbitrarily long (it's attacker-controlled and the NRBF protocol allows backtracking via
// the ClassWithId record, providing a form of compression), this presents opportunity
// for an algorithmic complexity attack, where a (2 * l)-length payload has an l-length type
// name and an array with l elements, resulting in O(l^2) total work factor. Protection against
// such attack is provided by the fact that the System.Type object is fully under the app's
// control and is assumed to be trusted and a reasonable length. This brings the cumulative loop
// work factor back down to O(l * RuntimeType.FullName), which is acceptable.
//
// The above statement assumes that "(string)m == (string)n" has worst-case complexity
// O(min(m.Length, n.Length)). This is not stated in string's public docs, but it is
// a guaranteed behavior for all built-in Ordinal string comparisons.
// At first, check the non-allocating properties for mismatch.
if (type.IsArray != typeName.IsArray || type.IsConstructedGenericType != typeName.IsConstructedGenericType
|| type.IsNested != typeName.IsNested
|| (type.IsArray && type.GetArrayRank() != typeName.GetArrayRank()))
|| (type.IsArray && type.GetArrayRank() != typeName.GetArrayRank())
#if NET
|| type.IsSZArray != typeName.IsSZArray // int[] vs int[*]
#else
|| (type.IsArray && type.Name != typeName.Name)
#endif
)
{
return false;
}
@ -111,11 +152,16 @@ public abstract class SerializationRecord
/// For reference records, it returns the referenced record.
/// For other records, it returns the records themselves.
/// </summary>
/// <remarks>
/// Overrides of this method should take care not to allow
/// the introduction of cycles, even in the face of adversarial
/// edges in the object graph.
/// </remarks>
internal virtual object? GetValue() => this;
internal virtual void HandleNextRecord(SerializationRecord nextRecord, NextInfo info)
=> Debug.Fail($"HandleNextRecord should not have been called for '{GetType().Name}'");
=> throw new InvalidOperationException();
internal virtual void HandleNextValue(object value, NextInfo info)
=> Debug.Fail($"HandleNextValue should not have been called for '{GetType().Name}'");
=> throw new InvalidOperationException();
}

View file

@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Formats.Nrbf.Utils;
using System.IO;
using System.Linq;
@ -15,6 +16,7 @@ namespace System.Formats.Nrbf;
/// <summary>
/// The ID of <see cref="SerializationRecord" />.
/// </summary>
[DebuggerDisplay("{_id}")]
public readonly struct SerializationRecordId : IEquatable<SerializationRecordId>
{
#pragma warning disable CS0649 // the default value is used on purpose
@ -29,6 +31,15 @@ public readonly struct SerializationRecordId : IEquatable<SerializationRecordId>
{
int id = reader.ReadInt32();
// Many object ids are required to be positive. See:
// - https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/8fac763f-e46d-43a1-b360-80eb83d2c5fb
// - https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/eb503ca5-e1f6-4271-a7ee-c4ca38d07996
// - https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/7fcf30e1-4ad4-4410-8f1a-901a4a1ea832 (for library id)
//
// Exception: https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/0a192be0-58a1-41d0-8a54-9c91db0ab7bf may be negative
// The problem is that input generated with FormatterTypeStyle.XsdString ends up generating negative Ids anyway.
// That information is not reflected in payload in anyway, so we just always allow for negative Ids.
if (id == 0)
{
ThrowHelper.ThrowInvalidValue(id);

View file

@ -6,6 +6,9 @@ namespace System.Formats.Nrbf;
/// <summary>
/// Record type.
/// </summary>
/// <remarks>
/// SerializationRecordType enumeration is described in <see href="https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-nrbf/954a0657-b901-4813-9398-4ec732fe8b32">[MS-NRBF] 2.1.2.1</see>.
/// </remarks>
public enum SerializationRecordType
{
/// <summary>

View file

@ -24,14 +24,8 @@ internal sealed class SerializedStreamHeaderRecord : SerializationRecord
public override SerializationRecordType RecordType => SerializationRecordType.SerializedStreamHeader;
public override TypeName TypeName
{
get
{
Debug.Fail("TypeName should never be called on SerializedStreamHeaderRecord");
return TypeName.Parse(nameof(SerializedStreamHeaderRecord).AsSpan());
}
}
public override TypeName TypeName => TypeName.Parse(nameof(SerializedStreamHeaderRecord).AsSpan());
public override SerializationRecordId Id => SerializationRecordId.NoId;
internal SerializationRecordId RootId { get; }

View file

@ -75,29 +75,30 @@ internal sealed class SystemClassWithMembersAndTypesRecord : ClassRecord
_ => this
};
}
else if (HasMember("_ticks") && MemberValues[0] is long ticks && TypeNameMatches(typeof(TimeSpan)))
else if (HasMember("_ticks") && GetRawValue("_ticks") is long ticks && TypeNameMatches(typeof(TimeSpan)))
{
return Create(new TimeSpan(ticks));
}
}
else if (MemberValues.Count == 2
&& HasMember("ticks") && HasMember("dateData")
&& MemberValues[0] is long value && MemberValues[1] is ulong
&& GetRawValue("ticks") is long && GetRawValue("dateData") is ulong dateData
&& TypeNameMatches(typeof(DateTime)))
{
return Create(Utils.BinaryReaderExtensions.CreateDateTimeFromData(value));
return Create(Utils.BinaryReaderExtensions.CreateDateTimeFromData(dateData));
}
else if(MemberValues.Count == 4
else if (MemberValues.Count == 4
&& HasMember("lo") && HasMember("mid") && HasMember("hi") && HasMember("flags")
&& MemberValues[0] is int && MemberValues[1] is int && MemberValues[2] is int && MemberValues[3] is int
&& GetRawValue("lo") is int lo && GetRawValue("mid") is int mid
&& GetRawValue("hi") is int hi && GetRawValue("flags") is int flags
&& TypeNameMatches(typeof(decimal)))
{
int[] bits =
[
GetInt32("lo"),
GetInt32("mid"),
GetInt32("hi"),
GetInt32("flags")
lo,
mid,
hi,
flags
];
return Create(new decimal(bits));

View file

@ -1,18 +1,41 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Diagnostics;
using System.Globalization;
using System.IO;
using System.Reflection;
using System.Reflection.Metadata;
using System.Runtime.CompilerServices;
using System.Runtime.Serialization;
using System.Threading;
namespace System.Formats.Nrbf.Utils;
internal static class BinaryReaderExtensions
{
private static object? s_baseAmbiguousDstDateTime;
internal static SerializationRecordType ReadSerializationRecordType(this BinaryReader reader, AllowedRecordTypes allowed)
{
byte nextByte = reader.ReadByte();
if (nextByte > (byte)SerializationRecordType.MethodReturn // MethodReturn is the last defined value.
|| (nextByte > (byte)SerializationRecordType.ArraySingleString && nextByte < (byte)SerializationRecordType.MethodCall) // not part of the spec
|| ((uint)allowed & (1u << nextByte)) == 0) // valid, but not allowed
{
ThrowHelper.ThrowForUnexpectedRecordType(nextByte);
}
return (SerializationRecordType)nextByte;
}
internal static BinaryArrayType ReadArrayType(this BinaryReader reader)
{
// To simplify the behavior and security review of the BinaryArrayRecord type, we
// do not support reading non-zero-offset arrays. If this should change in the
// future, the BinaryArrayRecord.Decode method and supporting infrastructure
// will need re-review.
byte arrayType = reader.ReadByte();
// Rectangular is the last defined value.
if (arrayType > (byte)BinaryArrayType.Rectangular)
@ -43,8 +66,8 @@ internal static class BinaryReaderExtensions
internal static PrimitiveType ReadPrimitiveType(this BinaryReader reader)
{
byte primitiveType = reader.ReadByte();
// String is the last defined value, 4 is not used at all.
if (primitiveType is 4 or > (byte)PrimitiveType.String)
// Boolean is the first valid value (1), UInt64 (16) is the last one. 4 is not used at all.
if (primitiveType is 4 or < (byte)PrimitiveType.Boolean or > (byte)PrimitiveType.UInt64)
{
ThrowHelper.ThrowInvalidValue(primitiveType);
}
@ -60,7 +83,7 @@ internal static class BinaryReaderExtensions
PrimitiveType.Boolean => reader.ReadBoolean(),
PrimitiveType.Byte => reader.ReadByte(),
PrimitiveType.SByte => reader.ReadSByte(),
PrimitiveType.Char => reader.ReadChar(),
PrimitiveType.Char => reader.ParseChar(),
PrimitiveType.Int16 => reader.ReadInt16(),
PrimitiveType.UInt16 => reader.ReadUInt16(),
PrimitiveType.Int32 => reader.ReadInt32(),
@ -69,41 +92,130 @@ internal static class BinaryReaderExtensions
PrimitiveType.UInt64 => reader.ReadUInt64(),
PrimitiveType.Single => reader.ReadSingle(),
PrimitiveType.Double => reader.ReadDouble(),
PrimitiveType.Decimal => decimal.Parse(reader.ReadString(), CultureInfo.InvariantCulture),
PrimitiveType.DateTime => CreateDateTimeFromData(reader.ReadInt64()),
_ => new TimeSpan(reader.ReadInt64()),
PrimitiveType.Decimal => reader.ParseDecimal(),
PrimitiveType.DateTime => CreateDateTimeFromData(reader.ReadUInt64()),
PrimitiveType.TimeSpan => new TimeSpan(reader.ReadInt64()),
_ => throw new InvalidOperationException(),
};
// TODO: fix https://github.com/dotnet/runtime/issues/102826
// BinaryFormatter serializes decimals as strings and we can't BinaryReader.ReadDecimal.
internal static decimal ParseDecimal(this BinaryReader reader)
{
// The spec (MS NRBF 2.1.1.6, https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-nrbf/10b218f5-9b2b-4947-b4b7-07725a2c8127)
// says that the length of LengthPrefixedString must be of optimal size (using as few bytes as possible).
// BinaryReader.ReadString does not enforce that and we are OK with that,
// as it takes care of handling multiple edge cases and we don't want to re-implement it.
string text = reader.ReadString();
if (!decimal.TryParse(text, NumberStyles.Number, CultureInfo.InvariantCulture, out decimal result))
{
ThrowHelper.ThrowInvalidFormat();
}
return result;
}
internal static char ParseChar(this BinaryReader reader)
{
try
{
return reader.ReadChar();
}
catch (ArgumentException) // A surrogate character was read.
{
throw new SerializationException(SR.Serialization_SurrogateCharacter);
}
}
internal static char[] ParseChars(this BinaryReader reader, int count)
{
char[]? result;
try
{
result = reader.ReadChars(count);
}
catch (ArgumentException) // A surrogate character was read.
{
throw new SerializationException(SR.Serialization_SurrogateCharacter);
}
if (result.Length != count)
{
// We might hit EOF before fully reading the requested
// number of chars. This means that ReadChars(count) could return a char[] with
// *fewer* than 'count' elements.
ThrowHelper.ThrowEndOfStreamException();
}
return result;
}
/// <summary>
/// Creates a <see cref="DateTime"/> object from raw data with validation.
/// </summary>
/// <exception cref="SerializationException"><paramref name="data"/> was invalid.</exception>
internal static DateTime CreateDateTimeFromData(long data)
/// <exception cref="SerializationException"><paramref name="dateData"/> was invalid.</exception>
internal static DateTime CreateDateTimeFromData(ulong dateData)
{
// Copied from System.Runtime.Serialization.Formatters.Binary.BinaryParser
// Use DateTime's public constructor to validate the input, but we
// can't return that result as it strips off the kind. To address
// that, store the value directly into a DateTime via an unsafe cast.
// See BinaryFormatterWriter.WriteDateTime for details.
ulong ticks = dateData & 0x3FFFFFFF_FFFFFFFFUL;
DateTimeKind kind = (DateTimeKind)(dateData >> 62);
try
{
const long TicksMask = 0x3FFFFFFFFFFFFFFF;
_ = new DateTime(data & TicksMask);
return ((uint)kind <= (uint)DateTimeKind.Local) ? new DateTime((long)ticks, kind) : CreateFromAmbiguousDst(ticks);
}
catch (ArgumentException ex)
{
// Bad data
throw new SerializationException(ex.Message, ex);
}
return Unsafe.As<long, DateTime>(ref data);
[MethodImpl(MethodImplOptions.NoInlining)]
static DateTime CreateFromAmbiguousDst(ulong ticks)
{
// There's no public API to create a DateTime from an ambiguous DST, and we
// can't use private reflection to access undocumented .NET Framework APIs.
// However, the ISerializable pattern *is* a documented protocol, so we can
// use DateTime's serialization ctor to create a zero-tick "ambiguous" instance,
// then keep reusing it as the base to which we can add our tick offsets.
if (s_baseAmbiguousDstDateTime is not DateTime baseDateTime)
{
#pragma warning disable SYSLIB0050 // Type or member is obsolete
SerializationInfo si = new(typeof(DateTime), new FormatterConverter());
// We don't know the value of "ticks", so we don't specify it.
// If the code somehow runs on a very old runtime that does not know the concept of "dateData"
// (it should not be possible as the library targets .NET Standard 2.0)
// the ctor is going to throw rather than silently return an invalid value.
si.AddValue("dateData", 0xC0000000_00000000UL); // new value (serialized as ulong)
#if NET
baseDateTime = CallPrivateSerializationConstructor(si, new StreamingContext(StreamingContextStates.All));
#else
ConstructorInfo ci = typeof(DateTime).GetConstructor(
BindingFlags.Instance | BindingFlags.NonPublic,
binder: null,
new Type[] { typeof(SerializationInfo), typeof(StreamingContext) },
modifiers: null);
baseDateTime = (DateTime)ci.Invoke(new object[] { si, new StreamingContext(StreamingContextStates.All) });
#endif
#pragma warning restore SYSLIB0050 // Type or member is obsolete
Volatile.Write(ref s_baseAmbiguousDstDateTime, baseDateTime); // it's ok if two threads race here
}
return baseDateTime.AddTicks((long)ticks);
}
#if NET
[UnsafeAccessor(UnsafeAccessorKind.Constructor)]
extern static DateTime CallPrivateSerializationConstructor(SerializationInfo si, StreamingContext ct);
#endif
}
internal static bool? IsDataAvailable(this BinaryReader reader, long requiredBytes)
{
Debug.Assert(requiredBytes >= 0);
if (!reader.BaseStream.CanSeek)
{
return null;

View file

@ -6,28 +6,33 @@ using System.Runtime.Serialization;
namespace System.Formats.Nrbf.Utils;
// The exception messages do not contain member/type/assembly names on purpose,
// as it's most likely corrupted/tampered/malicious data.
internal static class ThrowHelper
{
internal static void ThrowInvalidValue(object value)
internal static void ThrowDuplicateMemberName()
=> throw new SerializationException(SR.Serialization_DuplicateMemberName);
internal static void ThrowInvalidValue(int value)
=> throw new SerializationException(SR.Format(SR.Serialization_InvalidValue, value));
internal static void ThrowInvalidReference()
=> throw new SerializationException(SR.Serialization_InvalidReference);
internal static void ThrowInvalidTypeName(string name)
=> throw new SerializationException(SR.Format(SR.Serialization_InvalidTypeName, name));
internal static void ThrowInvalidTypeName()
=> throw new SerializationException(SR.Serialization_InvalidTypeName);
internal static void ThrowUnexpectedNullRecordCount()
=> throw new SerializationException(SR.Serialization_UnexpectedNullRecordCount);
internal static void ThrowMaxArrayLength(long limit, long actual)
=> throw new SerializationException(SR.Format(SR.Serialization_MaxArrayLength, actual, limit));
internal static void ThrowArrayContainedNulls()
=> throw new SerializationException(SR.Serialization_ArrayContainedNulls);
internal static void ThrowInvalidAssemblyName(string rawName)
=> throw new SerializationException(SR.Format(SR.Serialization_InvalidAssemblyName, rawName));
internal static void ThrowInvalidAssemblyName()
=> throw new SerializationException(SR.Serialization_InvalidAssemblyName);
internal static void ThrowInvalidFormat()
=> throw new SerializationException(SR.Serialization_InvalidFormat);
internal static void ThrowEndOfStreamException()
=> throw new EndOfStreamException();

View file

@ -12,7 +12,8 @@ namespace System.Formats.Nrbf.Utils;
internal static class TypeNameHelpers
{
// PrimitiveType does not define Object, IntPtr or UIntPtr
// PrimitiveType does not define Object, IntPtr or UIntPtr.
internal const PrimitiveType StringPrimitiveType = (PrimitiveType)18;
internal const PrimitiveType ObjectPrimitiveType = (PrimitiveType)19;
internal const PrimitiveType IntPtrPrimitiveType = (PrimitiveType)20;
internal const PrimitiveType UIntPtrPrimitiveType = (PrimitiveType)21;
@ -22,8 +23,6 @@ internal static class TypeNameHelpers
internal static TypeName GetPrimitiveTypeName(PrimitiveType primitiveType)
{
Debug.Assert(primitiveType is not (PrimitiveType.None or PrimitiveType.Null));
TypeName? typeName = s_primitiveTypeNames[(int)primitiveType];
if (typeName is null)
{
@ -44,11 +43,11 @@ internal static class TypeNameHelpers
PrimitiveType.Decimal => "System.Decimal",
PrimitiveType.TimeSpan => "System.TimeSpan",
PrimitiveType.DateTime => "System.DateTime",
PrimitiveType.String => "System.String",
StringPrimitiveType => "System.String",
ObjectPrimitiveType => "System.Object",
IntPtrPrimitiveType => "System.IntPtr",
UIntPtrPrimitiveType => "System.UIntPtr",
_ => throw new ArgumentOutOfRangeException(paramName: nameof(primitiveType), actualValue: primitiveType, message: null)
_ => throw new InvalidOperationException()
};
s_primitiveTypeNames[(int)primitiveType] = typeName = TypeName.Parse(fullName.AsSpan()).WithCoreLibAssemblyName();
@ -99,7 +98,7 @@ internal static class TypeNameHelpers
else if (typeof(T) == typeof(TimeSpan))
return PrimitiveType.TimeSpan;
else if (typeof(T) == typeof(string))
return PrimitiveType.String;
return StringPrimitiveType;
else if (typeof(T) == typeof(IntPtr))
return IntPtrPrimitiveType;
else if (typeof(T) == typeof(UIntPtr))
@ -118,6 +117,17 @@ internal static class TypeNameHelpers
Debug.Assert(payloadOptions.UndoTruncatedTypeNames);
Debug.Assert(libraryRecord.RawLibraryName is not null);
// This is potentially a DoS vector, as somebody could submit:
// [1] BinaryLibraryRecord = <really long string>
// [2] ClassRecord (lib = [1])
// [3] ClassRecord (lib = [1])
// ...
// [n] ClassRecord (lib = [1])
//
// Which means somebody submits a payload of length O(long + n) and tricks us into
// performing O(long * n) work. For this reason, we have marked the UndoTruncatedTypeNames
// property as "keep this disabled unless you trust the input."
// Combining type and library allows us for handling truncated generic type names that may be present in resources.
ArraySegment<char> assemblyQualifiedName = RentAssemblyQualifiedName(rawName, libraryRecord.RawLibraryName);
TypeName.TryParse(assemblyQualifiedName.AsSpan(), out TypeName? typeName, payloadOptions.TypeNameParseOptions);
@ -125,7 +135,7 @@ internal static class TypeNameHelpers
if (typeName is null)
{
throw new SerializationException(SR.Format(SR.Serialization_InvalidTypeOrAssemblyName, rawName, libraryRecord.RawLibraryName));
throw new SerializationException(SR.Serialization_InvalidTypeOrAssemblyName);
}
if (typeName.AssemblyName is null)
@ -149,6 +159,10 @@ internal static class TypeNameHelpers
private static TypeName With(this TypeName typeName, AssemblyNameInfo assemblyName)
{
// This is a recursive method over potentially hostile TypeName arguments.
// We assume the complexity of the TypeName arg was appropriately bounded.
// See comment in TypeName.FullName property getter for more info.
if (!typeName.IsSimple)
{
if (typeName.IsArray)
@ -169,7 +183,7 @@ internal static class TypeNameHelpers
else
{
// BinaryFormatter can not serialize pointers or references.
ThrowHelper.ThrowInvalidTypeName(typeName.FullName);
ThrowHelper.ThrowInvalidTypeName();
}
}
@ -187,6 +201,7 @@ internal static class TypeNameHelpers
return typeName;
}
// Complexity is O(typeName.Length + libraryName.Length)
private static ArraySegment<char> RentAssemblyQualifiedName(string typeName, string libraryName)
{
int length = typeName.Length + 1 + libraryName.Length;

View file

@ -3,6 +3,8 @@
using System.Collections.Generic;
using System.IO;
using System.Runtime.Serialization;
using System.Text;
using Xunit;
namespace System.Formats.Nrbf.Tests;
@ -24,6 +26,51 @@ public class ArraySinglePrimitiveRecordTests : ReadTests
}
}
[Fact]
public void DontCastBytesToBooleans()
{
using MemoryStream stream = new();
BinaryWriter writer = new(stream, Encoding.UTF8);
WriteSerializedStreamHeader(writer);
writer.Write((byte)SerializationRecordType.ArraySinglePrimitive);
writer.Write(1); // object ID
writer.Write(2); // length
writer.Write((byte)PrimitiveType.Boolean); // element type
writer.Write((byte)0x01);
writer.Write((byte)0x02);
writer.Write((byte)SerializationRecordType.MessageEnd);
stream.Position = 0;
SZArrayRecord<bool> serializationRecord = (SZArrayRecord<bool>)NrbfDecoder.Decode(stream);
bool[] bools = serializationRecord.GetArray();
bool a = bools[0];
Assert.True(a);
bool b = bools[1];
Assert.True(b);
bool c = a && b;
Assert.True(c);
}
[Fact]
public void DontCastBytesToDateTimes()
{
using MemoryStream stream = new();
BinaryWriter writer = new(stream, Encoding.UTF8);
WriteSerializedStreamHeader(writer);
writer.Write((byte)SerializationRecordType.ArraySinglePrimitive);
writer.Write(1); // object ID
writer.Write(1); // length
writer.Write((byte)PrimitiveType.DateTime); // element type
writer.Write(ulong.MaxValue); // un-representable DateTime
writer.Write((byte)SerializationRecordType.MessageEnd);
stream.Position = 0;
Assert.Throws<SerializationException>(() => NrbfDecoder.Decode(stream));
}
[Theory]
[MemberData(nameof(GetCanReadArrayOfAnySizeArgs))]
public void CanReadArrayOfAnySize_Bool(int size, bool canSeek) => Test<bool>(size, canSeek);
@ -94,6 +141,7 @@ public class ArraySinglePrimitiveRecordTests : ReadTests
SZArrayRecord<T> arrayRecord = (SZArrayRecord<T>)NrbfDecoder.Decode(stream);
Assert.Equal(size, arrayRecord.Length);
Assert.Equal(size, arrayRecord.FlattenedLength);
T?[] output = arrayRecord.GetArray();
Assert.Equal(input, output);
Assert.Same(output, arrayRecord.GetArray());

View file

@ -154,7 +154,7 @@ public class AttackTests : ReadTests
writer.Write((byte)SerializationRecordType.ArraySinglePrimitive);
writer.Write(1); // object ID
writer.Write(Array.MaxLength); // length
writer.Write((byte)2); // PrimitiveType.Byte
writer.Write((byte)PrimitiveType.Byte);
writer.Write((byte)SerializationRecordType.MessageEnd);
stream.Position = 0;

View file

@ -1,4 +1,5 @@
using System.IO;
using System.Collections.Generic;
using System.IO;
using System.Runtime.Serialization.Formatters;
using System.Runtime.Serialization.Formatters.Binary;
using Microsoft.DotNet.XUnitExtensions;
@ -103,4 +104,44 @@ public class EdgeCaseTests : ReadTests
Assert.Throws<NotSupportedException>(() => NrbfDecoder.Decode(ms));
}
public static IEnumerable<object[]> CanReadAllKindsOfDateTimes_Arguments
{
get
{
yield return new object[] { new DateTime(1990, 11, 24, 0, 0, 0, DateTimeKind.Local) };
yield return new object[] { new DateTime(1990, 11, 25, 0, 0, 0, DateTimeKind.Utc) };
yield return new object[] { new DateTime(1990, 11, 26, 0, 0, 0, DateTimeKind.Unspecified) };
}
}
[Theory]
[MemberData(nameof(CanReadAllKindsOfDateTimes_Arguments))]
public void CanReadAllKindsOfDateTimes_DateTimeIsTheRootRecord(DateTime input)
{
using MemoryStream stream = Serialize(input);
PrimitiveTypeRecord<DateTime> dateTimeRecord = (PrimitiveTypeRecord<DateTime>)NrbfDecoder.Decode(stream);
Assert.Equal(input.Ticks, dateTimeRecord.Value.Ticks);
Assert.Equal(input.Kind, dateTimeRecord.Value.Kind);
}
[Serializable]
public class ClassWithDateTime
{
public DateTime Value;
}
[Theory]
[MemberData(nameof(CanReadAllKindsOfDateTimes_Arguments))]
public void CanReadAllKindsOfDateTimes_DateTimeIsMemberOfTheRootRecord(DateTime input)
{
using MemoryStream stream = Serialize(new ClassWithDateTime() { Value = input });
ClassRecord classRecord = NrbfDecoder.DecodeClassRecord(stream);
Assert.Equal(input.Ticks, classRecord.GetDateTime(nameof(ClassWithDateTime.Value)).Ticks);
Assert.Equal(input.Kind, classRecord.GetDateTime(nameof(ClassWithDateTime.Value)).Kind);
}
}

View file

@ -426,7 +426,10 @@ public class InvalidInputTests : ReadTests
{
foreach (byte binaryType in new byte[] { (byte)0 /* BinaryType.Primitive */, (byte)7 /* BinaryType.PrimitiveArray */ })
{
yield return new object[] { recordType, binaryType, (byte)0 }; // value not used by the spec
yield return new object[] { recordType, binaryType, (byte)4 }; // value not used by the spec
yield return new object[] { recordType, binaryType, (byte)17 }; // used by the spec, but illegal in given context
yield return new object[] { recordType, binaryType, (byte)18 }; // used by the spec, but illegal in given context
yield return new object[] { recordType, binaryType, (byte)19 };
}
}
@ -478,4 +481,125 @@ public class InvalidInputTests : ReadTests
stream.Position = 0;
Assert.Throws<SerializationException>(() => NrbfDecoder.Decode(stream));
}
[Theory]
[InlineData(18, typeof(NotSupportedException))] // not part of the spec, but still less than max allowed value (22)
[InlineData(19, typeof(NotSupportedException))] // same as above
[InlineData(20, typeof(NotSupportedException))] // same as above
[InlineData(23, typeof(SerializationException))] // not part of the spec and more than max allowed value (22)
[InlineData(64, typeof(SerializationException))] // same as above but also matches AllowedRecordTypes.SerializedStreamHeader
public void InvalidSerializationRecordType(byte recordType, Type expectedException)
{
using MemoryStream stream = new();
BinaryWriter writer = new(stream, Encoding.UTF8);
WriteSerializedStreamHeader(writer);
writer.Write(recordType); // SerializationRecordType
writer.Write((byte)SerializationRecordType.MessageEnd);
stream.Position = 0;
Assert.Throws(expectedException, () => NrbfDecoder.Decode(stream));
}
[Fact]
public void MissingRootRecord()
{
const int RootRecordId = 1;
using MemoryStream stream = new();
BinaryWriter writer = new(stream, Encoding.UTF8);
WriteSerializedStreamHeader(writer, rootId: RootRecordId);
writer.Write((byte)SerializationRecordType.BinaryObjectString);
writer.Write(RootRecordId + 1); // a different ID
writer.Write("theString");
writer.Write((byte)SerializationRecordType.MessageEnd);
stream.Position = 0;
Assert.Throws<SerializationException>(() => NrbfDecoder.Decode(stream));
}
[Fact]
public void Invalid7BitEncodedStringLength()
{
// The highest bit of the last byte is set (so it's invalid).
byte[] invalidLength = [byte.MaxValue, byte.MaxValue, byte.MaxValue, byte.MaxValue, byte.MaxValue];
using MemoryStream stream = new();
BinaryWriter writer = new(stream, Encoding.UTF8);
WriteSerializedStreamHeader(writer);
writer.Write((byte)SerializationRecordType.BinaryObjectString);
writer.Write(1); // root record Id
writer.Write(invalidLength); // the length prefix
writer.Write(Encoding.UTF8.GetBytes("theString"));
writer.Write((byte)SerializationRecordType.MessageEnd);
stream.Position = 0;
Assert.Throws<SerializationException>(() => NrbfDecoder.Decode(stream));
}
[Theory]
[InlineData("79228162514264337593543950336")] // invalid format (decimal.MaxValue + 1)
[InlineData("1111111111111111111111111111111111111111111111111")] // overflow
public void InvalidDecimal(string textRepresentation)
{
using MemoryStream stream = new();
BinaryWriter writer = new(stream, Encoding.UTF8);
WriteSerializedStreamHeader(writer);
writer.Write((byte)SerializationRecordType.SystemClassWithMembersAndTypes);
writer.Write(1); // root record Id
writer.Write("ClassWithDecimalField"); // type name
writer.Write(1); // member count
writer.Write("memberName");
writer.Write((byte)BinaryType.Primitive);
writer.Write((byte)PrimitiveType.Decimal);
writer.Write(textRepresentation);
writer.Write((byte)SerializationRecordType.MessageEnd);
stream.Position = 0;
Assert.Throws<SerializationException>(() => NrbfDecoder.Decode(stream));
}
[Theory]
[InlineData(true)]
[InlineData(false)]
public void SurrogateCharacters(bool array)
{
using MemoryStream stream = new();
BinaryWriter writer = new(stream, Encoding.UTF8);
WriteSerializedStreamHeader(writer);
writer.Write((byte)SerializationRecordType.SystemClassWithMembersAndTypes);
writer.Write(1); // root record Id
writer.Write("ClassWithCharField"); // type name
writer.Write(1); // member count
writer.Write("memberName");
if (array)
{
writer.Write((byte)BinaryType.PrimitiveArray);
writer.Write((byte)PrimitiveType.Char);
writer.Write((byte)SerializationRecordType.ArraySinglePrimitive);
writer.Write(2); // array record Id
writer.Write(1); // array length
writer.Write((byte)PrimitiveType.Char);
}
else
{
writer.Write((byte)BinaryType.Primitive);
writer.Write((byte)PrimitiveType.Char);
}
writer.Write((byte)0xC0); // a surrogate character
writer.Write((byte)SerializationRecordType.MessageEnd);
stream.Position = 0;
Assert.Throws<SerializationException>(() => NrbfDecoder.Decode(stream));
}
}

View file

@ -1,4 +1,5 @@
using System.Formats.Nrbf.Utils;
using System.IO;
using System.Linq;
using Xunit;
@ -6,29 +7,91 @@ namespace System.Formats.Nrbf.Tests;
public class JaggedArraysTests : ReadTests
{
[Fact]
public void CanReadJaggedArraysOfPrimitiveTypes_2D()
[Theory]
[InlineData(true)]
[InlineData(false)]
public void CanReadJaggedArraysOfPrimitiveTypes_2D(bool useReferences)
{
int[][] input = new int[7][];
int[] same = [1, 2, 3];
for (int i = 0; i < input.Length; i++)
{
input[i] = [i, i, i];
input[i] = useReferences
? same // reuse the same object (represented as a single record that is referenced multiple times)
: [i, i, i]; // create new array
}
var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input));
Verify(input, arrayRecord);
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength);
}
[Theory]
[InlineData(1)] // SerializationRecordType.ObjectNull
[InlineData(200)] // SerializationRecordType.ObjectNullMultiple256
[InlineData(10_000)] // SerializationRecordType.ObjectNullMultiple
public void FlattenedLengthIncludesNullArrays(int nullCount)
{
int[][] input = new int[nullCount][];
var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input));
Verify(input, arrayRecord);
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
Assert.Equal(nullCount, arrayRecord.FlattenedLength);
}
[Fact]
public void ItIsPossibleToHaveBinaryArrayRecordsHaveAnElementTypeOfArrayWithoutBeingMarkedAsJagged()
{
int[][][] input = new int[3][][];
long totalElementsCount = 0;
for (int i = 0; i < input.Length; i++)
{
input[i] = new int[4][];
totalElementsCount++; // count the arrays themselves
for (int j = 0; j < input[i].Length; j++)
{
input[i][j] = [i, j, 0, 1, 2];
totalElementsCount += input[i][j].Length;
totalElementsCount++; // count the arrays themselves
}
}
byte[] serialized = Serialize(input).ToArray();
const int ArrayTypeByteIndex =
sizeof(byte) + sizeof(int) * 4 + // stream header
sizeof(byte) + // SerializationRecordType.BinaryArray
sizeof(int); // SerializationRecordId
Assert.Equal((byte)BinaryArrayType.Jagged, serialized[ArrayTypeByteIndex]);
// change the reported array type
serialized[ArrayTypeByteIndex] = (byte)BinaryArrayType.Single;
var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(new MemoryStream(serialized));
Verify(input, arrayRecord);
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
Assert.Equal(3 + 3 * 4 + 3 * 4 * 5, totalElementsCount);
Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength);
}
[Fact]
public void CanReadJaggedArraysOfPrimitiveTypes_3D()
{
int[][][] input = new int[7][][];
long totalElementsCount = 0;
for (int i = 0; i < input.Length; i++)
{
totalElementsCount++; // count the arrays themselves
input[i] = new int[1][];
totalElementsCount++; // count the arrays themselves
input[i][0] = [i, i, i];
totalElementsCount += input[i][0].Length;
}
var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input));
@ -36,6 +99,8 @@ public class JaggedArraysTests : ReadTests
Verify(input, arrayRecord);
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
Assert.Equal(1, arrayRecord.Rank);
Assert.Equal(7 + 7 * 1 + 7 * 1 * 3, totalElementsCount);
Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength);
}
[Fact]
@ -60,6 +125,7 @@ public class JaggedArraysTests : ReadTests
Verify(input, arrayRecord);
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
Assert.Equal(1, arrayRecord.Rank);
Assert.Equal(input.Length + input.Length * 3 * 3, arrayRecord.FlattenedLength);
}
[Fact]
@ -75,6 +141,7 @@ public class JaggedArraysTests : ReadTests
Verify(input, arrayRecord);
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength);
}
[Fact]
@ -90,6 +157,7 @@ public class JaggedArraysTests : ReadTests
Verify(input, arrayRecord);
Assert.Equal(input, arrayRecord.GetArray(input.GetType()));
Assert.Equal(input.Length + input.Length * 3, arrayRecord.FlattenedLength);
}
[Serializable]
@ -102,14 +170,18 @@ public class JaggedArraysTests : ReadTests
public void CanReadJaggedArraysOfComplexTypes()
{
ComplexType[][] input = new ComplexType[3][];
long totalElementsCount = 0;
for (int i = 0; i < input.Length; i++)
{
input[i] = Enumerable.Range(0, i + 1).Select(j => new ComplexType { SomeField = j }).ToArray();
totalElementsCount += input[i].Length;
totalElementsCount++; // count the arrays themselves
}
var arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input));
Verify(input, arrayRecord);
Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength);
var output = (ClassRecord?[][])arrayRecord.GetArray(input.GetType());
for (int i = 0; i < input.Length; i++)
{

View file

@ -45,10 +45,10 @@ public abstract class ReadTests
};
#pragma warning restore SYSLIB0011 // Type or member is obsolete
protected static void WriteSerializedStreamHeader(BinaryWriter writer, int major = 1, int minor = 0)
protected static void WriteSerializedStreamHeader(BinaryWriter writer, int major = 1, int minor = 0, int rootId = 1)
{
writer.Write((byte)SerializationRecordType.SerializedStreamHeader);
writer.Write(1); // root ID
writer.Write(rootId); // root ID
writer.Write(1); // header ID
writer.Write(major); // major version
writer.Write(minor); // minor version

View file

@ -223,10 +223,13 @@ public class RectangularArraysTests : ReadTests
internal static void Verify(Array input, ArrayRecord arrayRecord)
{
Assert.Equal(input.Rank, arrayRecord.Lengths.Length);
long totalElementsCount = 1;
for (int i = 0; i < input.Rank; i++)
{
Assert.Equal(input.GetLength(i), arrayRecord.Lengths[i]);
totalElementsCount *= input.GetLength(i);
}
Assert.Equal(totalElementsCount, arrayRecord.FlattenedLength);
Assert.Equal(input.GetType().FullName, arrayRecord.TypeName.FullName);
Assert.Equal(input.GetType().GetAssemblyNameIncludingTypeForwards(), arrayRecord.TypeName.AssemblyName!.FullName);
}

View file

@ -7,6 +7,8 @@
<ItemGroup>
<Compile Include="..\src\System\Formats\Nrbf\BinaryArrayType.cs" Link="BinaryArrayType.cs" />
<Compile Include="..\src\System\Formats\Nrbf\BinaryType.cs" Link="BinaryType.cs" />
<Compile Include="..\src\System\Formats\Nrbf\PrimitiveType.cs" Link="PrimitiveType.cs" />
</ItemGroup>
<ItemGroup>

View file

@ -73,6 +73,34 @@ public class TypeMatchTests : ReadTests
Verify(new Dictionary<string, List<ValueTuple<int, short>>>());
}
[Fact]
public void ThrowsForNullType()
{
List<int> input = new List<int>();
SerializationRecord record = NrbfDecoder.Decode(Serialize(input));
Assert.Throws<ArgumentNullException>(() => record.TypeNameMatches(type: null));
}
[Fact]
public void TakesCustomOffsetsIntoAccount()
{
int[] input = [1, 2, 3];
SerializationRecord record = NrbfDecoder.Decode(Serialize(input));
Assert.True(record.TypeNameMatches(typeof(int[])));
Type nonSzArray = typeof(int).Assembly.GetType("System.Int32[*]");
#if NET
Assert.False(nonSzArray.IsSZArray);
Assert.True(nonSzArray.IsVariableBoundArray);
#endif
Assert.Equal(1, nonSzArray.GetArrayRank());
Assert.False(record.TypeNameMatches(nonSzArray));
}
[Fact]
public void TakesGenericTypeDefinitionIntoAccount()
{

View file

@ -81,6 +81,10 @@ namespace System.Reflection.Metadata
/// <summary>
/// Gets the name of the culture associated with the assembly.
/// </summary>
/// <remarks>
/// Do not create a <see cref="System.Globalization.CultureInfo"/> instance from this string unless
/// you know the string has originated from a trustworthy source.
/// </remarks>
public string? CultureName { get; }
/// <summary>
@ -131,6 +135,10 @@ namespace System.Reflection.Metadata
/// <summary>
/// Initializes a new instance of the <seealso cref="AssemblyName"/> class based on the stored information.
/// </summary>
/// <remarks>
/// Do not create an <see cref="AssemblyName"/> instance with <see cref="CultureName"/> string unless
/// you know the string has originated from a trustworthy source.
/// </remarks>
public AssemblyName ToAssemblyName()
{
AssemblyName assemblyName = new();

View file

@ -95,7 +95,7 @@ namespace System.Reflection.Metadata
/// If <see cref="AssemblyName"/> returns null, simply returns <see cref="FullName"/>.
/// </remarks>
public string AssemblyQualifiedName
=> _assemblyQualifiedName ??= AssemblyName is null ? FullName : $"{FullName}, {AssemblyName.FullName}";
=> _assemblyQualifiedName ??= AssemblyName is null ? FullName : $"{FullName}, {AssemblyName.FullName}"; // see recursion comments in FullName
/// <summary>
/// Returns assembly name which contains this type, or null if this <see cref="TypeName"/> was not
@ -142,6 +142,17 @@ namespace System.Reflection.Metadata
{
get
{
// This is a recursive method over potentially hostile input. Protection against DoS is offered
// via the [Try]Parse method and TypeNameParserOptions.MaxNodes property at construction time.
// This FullName property getter and related methods assume that this TypeName instance has an
// acceptable node count.
//
// The node count controls the total amount of work performed by this method, including:
// - The max possible stack depth due to the recursive methods calls; and
// - The total number of bytes allocated by this function. For a deeply-nested TypeName
// object, the total allocation across the full object graph will be
// O(FullName.Length * GetNodeCount()).
if (_fullName is null)
{
if (IsConstructedGenericType)
@ -245,6 +256,8 @@ namespace System.Reflection.Metadata
{
get
{
// Lookups to Name and FullName might be recursive. See comments in FullName property getter.
if (_name is null)
{
if (IsConstructedGenericType)
@ -425,6 +438,17 @@ namespace System.Reflection.Metadata
/// <exception cref="InvalidOperationException">The current type name is not simple.</exception>
public TypeName WithAssemblyName(AssemblyNameInfo? assemblyName)
{
// Recursive method. See comments in FullName property getter for more information
// on how this is protected against attack.
//
// n.b. AssemblyNameInfo could also be hostile. The typical exploit is that a single
// long AssemblyNameInfo is associated with one or more simple TypeName objects,
// leading to an alg. complexity attack (DoS). It's important that TypeName doesn't
// actually *do* anything with the provided AssemblyNameInfo rather than store it.
// For example, don't use it inside a string concat operation unless the caller
// explicitly requested that to happen. If the input is hostile, the caller should
// never perform such concats in a loop.
if (!IsSimple)
{
TypeNameParserHelpers.ThrowInvalidOperation_NotSimpleName(FullName);

View file

@ -80,6 +80,8 @@ namespace System.Reflection.Metadata
return null;
}
// At this point, we have performed O(fullTypeNameLength) total work.
ReadOnlySpan<char> fullTypeName = _inputString.Slice(0, fullTypeNameLength);
_inputString = _inputString.Slice(fullTypeNameLength);
@ -142,6 +144,12 @@ namespace System.Reflection.Metadata
}
}
// At this point, we may have performed O(fullTypeNameLength + _inputString.Length) total work.
// This will be the case if there was whitespace after the full type name in the original input
// string. We could end up looking at these same whitespace chars again later in this method,
// such as when parsing decorators. We rely on the TryDive routine to limit the total number
// of times we might inspect the same character.
// If there was an error stripping the generic args, back up to
// before we started processing them, and let the decorator
// parser try handling it.
@ -202,6 +210,9 @@ namespace System.Reflection.Metadata
result = new(fullName: null, assemblyName, elementOrGenericType: result, declaringType, genericArgs);
}
// The loop below is protected by the dive check during the first decorator pass prior
// to assembly name parsing above.
if (previousDecorator != default) // some decorators were recognized
{
while (TryParseNextDecorator(ref capturedBeforeProcessing, out int parsedModifier))
@ -245,6 +256,8 @@ namespace System.Reflection.Metadata
return null;
}
// The loop below is protected by the dive check in GetFullTypeNameLength.
TypeName? declaringType = null;
int nameOffset = 0;
foreach (int nestedNameLength in nestedNameLengths)

View file

@ -16,6 +16,7 @@ namespace System.Reflection.Metadata
internal const int ByRef = -3;
private const char EscapeCharacter = '\\';
#if NET8_0_OR_GREATER
// Keep this in sync with GetFullTypeNameLength/NeedsEscaping
private static readonly SearchValues<char> s_endOfFullTypeNameDelimitersSearchValues = SearchValues.Create("[]&*,+\\");
#endif
@ -30,7 +31,7 @@ namespace System.Reflection.Metadata
foreach (TypeName genericArg in genericArgs)
{
result.Append('[');
result.Append(genericArg.AssemblyQualifiedName);
result.Append(genericArg.AssemblyQualifiedName); // see recursion comments in TypeName.FullName
result.Append(']');
result.Append(',');
}
@ -97,11 +98,16 @@ namespace System.Reflection.Metadata
return offset;
}
// Keep this in sync with s_endOfFullTypeNameDelimitersSearchValues
static bool NeedsEscaping(char c) => c is '[' or ']' or '&' or '*' or ',' or '+' or EscapeCharacter;
}
internal static ReadOnlySpan<char> GetName(ReadOnlySpan<char> fullName)
{
// The two-value form of MemoryExtensions.LastIndexOfAny does not suffer
// from the behavior mentioned in the comment at the top of GetFullTypeNameLength.
// It always takes O(m * i) worst-case time and is safe to use here.
int offset = fullName.LastIndexOfAny('.', '+');
if (offset > 0 && fullName[offset - 1] == EscapeCharacter) // this should be very rare (IL Emit & pure IL)
@ -182,6 +188,13 @@ namespace System.Reflection.Metadata
{
Debug.Assert(rankOrModifier >= 2);
// O(rank) work, so we have to assume the rank is trusted. We don't put a hard cap on this,
// but within the TypeName parser, we do require the input string to contain the correct number
// of commas. This forces the input string to have at least O(rank) length, so there's no
// alg. complexity attack possible here. Callers can of course pass any arbitrary value to
// TypeName.MakeArrayTypeName, but per first sentence in this comment, we have to assume any
// such arbitrary value which is programmatically fed in originates from a trustworthy source.
builder.Append('[');
builder.Append(',', rankOrModifier - 1);
builder.Append(']');
@ -310,6 +323,9 @@ namespace System.Reflection.Metadata
else if (TryStripFirstCharAndTrailingSpaces(ref input, ','))
{
// [,,, ...]
// The runtime restricts arrays to rank 32, but we don't enforce that here.
// Instead, the max rank is controlled by the total number of commas present
// in the array decorator.
checked { rank++; }
goto ReadNextArrayToken;
}

View file

@ -10,6 +10,13 @@ namespace System.Reflection.Metadata
/// <summary>
/// Limits the maximum value of <seealso cref="TypeName.GetNodeCount">node count</seealso> that parser can handle.
/// </summary>
/// <remarks>
/// <para>
/// Setting this to a large value can render <see cref="TypeName"/> susceptible to Denial of Service
/// attacks when parsing or handling malicious input.
/// </para>
/// <para>The default value is 20.</para>
/// </remarks>
public int MaxNodes
{
get => _maxNodes;