Skip to content

Commit

Permalink
Fix type cache cross pollination (#306)
Browse files Browse the repository at this point in the history
* Add cross pollination specs and improve type filtering performance by caching

* Add type rejection benchmark

* Remove type checking during serialization, impossible to do

* Remove bad type caching

* Clean up code

* Add type checking into serialization with a separate type caching
  • Loading branch information
Arkatufus authored Mar 30, 2022
1 parent f310df6 commit 3886408
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 27 deletions.
14 changes: 11 additions & 3 deletions src/Hyperion.Akka.Integration.Tests/IntegrationSpec.cs
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,19 @@ public async Task CantDeserializeANaughtyTypeByDefault()

try
{
var serializer = system.Serialization.FindSerializerForType(typeof(DirectoryInfo));
var deserializer = system.Serialization.FindSerializerForType(typeof(DirectoryInfo));
var di = new DirectoryInfo(@"c:\");

byte[] serialized;
using (var stream = new MemoryStream())
{
var serializer = new Serializer(SerializerOptions.Default.WithDisallowUnsafeType(false));
serializer.Serialize(di, stream);
stream.Position = 0;
serialized = stream.ToArray();
}

var serialized = serializer.ToBinary(di);
var ex = Assert.Throws<SerializationException>(() => serializer.FromBinary<DirectoryInfo>(serialized));
var ex = Assert.Throws<SerializationException>(() => deserializer.FromBinary<DirectoryInfo>(serialized));
ex.InnerException.Should().BeOfType<EvilDeserializationException>();
}
finally
Expand Down
46 changes: 46 additions & 0 deletions src/Hyperion.Benchmarks/TypeRejectionBenchmark.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
using System;
using System.IO;
using BenchmarkDotNet.Attributes;
using Hyperion.Internal;

namespace Hyperion.Benchmarks
{
[Config(typeof(HyperionConfig))]
public class TypeRejectionBenchmark
{
private Serializer _serializer;
private Stream _dangerousStream;

[GlobalSetup]
public void Setup()
{
var di = new DirectoryInfo("C:\\Windows\\Windows32");
var serializer = new Serializer(SerializerOptions.Default.WithDisallowUnsafeType(false));
_dangerousStream = new MemoryStream();
serializer.Serialize(di, _dangerousStream);

_serializer = new Serializer();
}

[GlobalCleanup]
public void Cleanup()
{
_dangerousStream.Dispose();
}

[Benchmark]
public void DeserializeDanger()
{
_dangerousStream.Position = 0;
try
{
_serializer.Deserialize<DirectoryInfo>(_dangerousStream);
}
catch(EvilDeserializationException)
{
// no-op
}
}

}
}
55 changes: 45 additions & 10 deletions src/Hyperion.Tests/UnsafeDeserializationExclusionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,29 @@ public UnsafeDeserializationExclusionTests(ITestOutputHelper output)
[Fact]
public void CantDeserializeANaughtyType()
{
var serializer = new Hyperion.Serializer();
var di =new System.IO.DirectoryInfo(@"c:\");
var serializer = new Serializer(SerializerOptions.Default.WithDisallowUnsafeType(false));
var deserializer = new Serializer();
var di = new DirectoryInfo(@"c:\");

using (var stream = new MemoryStream())
{
serializer.Serialize(di, stream);
stream.Position = 0;
Assert.Throws<EvilDeserializationException>(() =>
serializer.Deserialize<DirectoryInfo>(stream));
deserializer.Deserialize<DirectoryInfo>(stream));
}
}

[Fact]
public void CantSerializeANaughtyType()
{
var serializer = new Serializer();
var di = new FileInfo(@"c:\windows\windows32\dangerous.exe");

using (var stream = new MemoryStream())
{
Assert.Throws<EvilDeserializationException>(() =>
serializer.Serialize(di, stream));
}
}

Expand Down Expand Up @@ -121,44 +135,65 @@ public void TypeFilterShouldThrowOnNaughtyType()
var options = SerializerOptions.Default
.WithTypeFilter(typeFilter);

var serializer = new Serializer(options);
var serializer = new Serializer(SerializerOptions.Default.WithDisallowUnsafeType(false));
var deserializer = new Serializer(options);

using (var stream = new MemoryStream())
{
serializer.Serialize(new ClassA(), stream);
stream.Position = 0;
Action act = () => serializer.Deserialize<ClassA>(stream);
Action act = () => deserializer.Deserialize<ClassA>(stream);
act.Should().NotThrow();

stream.Position = 0;
Action actObj = () => serializer.Deserialize<object>(stream);
Action actObj = () => deserializer.Deserialize<object>(stream);
actObj.Should().NotThrow();
}

using (var stream = new MemoryStream())
{
serializer.Serialize(new ClassB(), stream);
stream.Position = 0;
Action act = () => serializer.Deserialize<ClassB>(stream);
Action act = () => deserializer.Deserialize<ClassB>(stream);
act.Should().NotThrow();

stream.Position = 0;
Action actObj = () => serializer.Deserialize<object>(stream);
Action actObj = () => deserializer.Deserialize<object>(stream);
actObj.Should().NotThrow();
}

using (var stream = new MemoryStream())
{
serializer.Serialize(new ClassC(), stream);
stream.Position = 0;
Action act = () => serializer.Deserialize<ClassC>(stream);
Action act = () => deserializer.Deserialize<ClassC>(stream);
act.Should().Throw<UserEvilDeserializationException>();

stream.Position = 0;
Action actObj = () => serializer.Deserialize<object>(stream);
Action actObj = () => deserializer.Deserialize<object>(stream);
actObj.Should().Throw<UserEvilDeserializationException>();
}
}

[Fact]
public void TypeCacheShouldNotBleedBetweenInstances()
{
var serializer = new Serializer();
using (var stream = new MemoryStream())
{
serializer.Serialize(new ClassA(), stream);
stream.Position = 0;
serializer.Deserialize<ClassA>(stream);
}

// Type should be cached when a serializer deserialize a message
serializer.TypeNameLookup.Values.Should().Contain(typeof(ClassA));

// Type cache should not be carried to other serializer instances
var newSerializer = new Serializer();
newSerializer.TypeNameLookup.Values.Should().NotContain(typeof(ClassA));
}

}
}

Expand Down
20 changes: 20 additions & 0 deletions src/Hyperion/Extensions/StreamEx.cs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,26 @@ public static object ReadObject(this Stream stream, DeserializerSession session)
return value;
}

internal static ByteArrayKey? ReadByteArrayKey(this Stream stream, DeserializerSession session)
{
var length = stream.ReadByte();
switch (length)
{
case 0:
return null;
case 255:
length = stream.ReadInt32(session);
break;
default:
length--;
break;
}

var buffer = new byte[length];
stream.ReadFull(buffer, 0, length);
return new ByteArrayKey(buffer);
}

public static string ReadString(this Stream stream, DeserializerSession session)
{
var length = stream.ReadByte();
Expand Down
5 changes: 3 additions & 2 deletions src/Hyperion/Extensions/TypeEx.cs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ private static Type GetTypeFromManifestName(Stream stream, DeserializerSession s
{
var bytes = stream.ReadLengthEncodedByteArray(session);
var byteArr = ByteArrayKey.Create(bytes);

return session.Serializer.TypeNameLookup.GetOrAdd(byteArr, b =>
{
var shortName = StringEx.FromUtf8Bytes(b.Bytes, 0, b.Bytes.Length);
Expand All @@ -169,7 +170,7 @@ private static Type GetTypeFromManifestName(Stream stream, DeserializerSession s
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool UnsafeInheritanceCheck(Type type)
internal static bool UnsafeInheritanceCheck(Type type)
{
#if NETSTANDARD1_6
if (type.IsValueType())
Expand Down Expand Up @@ -201,7 +202,7 @@ public static bool IsDisallowedType<TType>()

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool IsDisallowedType(Type type)
=> IsDisallowedType(type.FullName);
=> IsDisallowedType(type.AssemblyQualifiedName);

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool IsDisallowedType(string name)
Expand Down
56 changes: 46 additions & 10 deletions src/Hyperion/Serializer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ public class Serializer
internal readonly ConcurrentDictionary<ByteArrayKey, Type> TypeNameLookup =
new ConcurrentDictionary<ByteArrayKey, Type>(ByteArrayKeyComparer.Instance);

internal readonly ConcurrentDictionary<Type, TypeAccepted> AcceptedTypes =
new ConcurrentDictionary<Type, TypeAccepted>();

public Serializer() : this(new SerializerOptions())
{
}
Expand Down Expand Up @@ -146,24 +149,45 @@ public void Serialize(object obj, [NotNull] Stream stream, SerializerSession ses
{
if (obj == null)
throw new ArgumentNullException(nameof(obj));
if(session == null)
throw new ArgumentNullException(nameof(session));

var type = obj.GetType();
if (Options.DisallowUnsafeTypes)
{
if (AcceptedTypes.TryGetValue(type, out var acceptance))
{
if(!acceptance.Accepted)
{
if (acceptance.UserRejected)
throw new UserEvilDeserializationException("Unsafe Type Deserialization Detected!", type.FullName);
throw new EvilDeserializationException("Unsafe Type Deserialization Detected!", type.FullName);
}
}
else
{
if(TypeEx.IsDisallowedType(type))
{
AcceptedTypes[type] = new TypeAccepted(false, false);
throw new EvilDeserializationException("Unsafe Type Deserialization Detected!", type.FullName);
}
if(!Options.TypeFilter.IsAllowed(type.AssemblyQualifiedName))
{
AcceptedTypes[type] = new TypeAccepted(false, true);
throw new UserEvilDeserializationException("Unsafe Type Deserialization Detected!",
type.FullName);
}
AcceptedTypes[type] = new TypeAccepted(true, false);
}
}

var s = GetSerializerByType(type);
s.WriteManifest(stream, session);
s.WriteValue(stream, obj, session);
}

public void Serialize(object obj, [NotNull] Stream stream)
{
if (obj == null)
throw new ArgumentNullException(nameof(obj));
SerializerSession session = GetSerializerSession();

var type = obj.GetType();
var s = GetSerializerByType(type);
s.WriteManifest(stream, session);
s.WriteValue(stream, obj, session);
}
=> Serialize(obj, stream, GetSerializerSession());

public SerializerSession GetSerializerSession()
{
Expand Down Expand Up @@ -300,5 +324,17 @@ public ValueSerializer GetDeserializerByManifest([NotNull] Stream stream, [NotNu
throw new NotSupportedException("Unknown manifest value");
}
}

internal readonly struct TypeAccepted
{
public TypeAccepted(bool accepted, bool userRejected)
{
Accepted = accepted;
UserRejected = userRejected;
}

public bool Accepted { get; }
public bool UserRejected { get; }
}
}
}
7 changes: 5 additions & 2 deletions src/Hyperion/ValueSerializers/TypeSerializer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,14 @@ public override void WriteValue(Stream stream, object value, SerializerSession s
new ConcurrentDictionary<string, Type>();
public override object ReadValue(Stream stream, DeserializerSession session)
{
var shortname = stream.ReadString(session);
if (shortname == null)
var bytes = stream.ReadByteArrayKey(session);
if (bytes == null)
return null;
var byteArr = bytes.Value;

var shortname = StringEx.FromUtf8Bytes(byteArr.Bytes, 0, byteArr.Bytes.Length);
var options = session.Serializer.Options;

var type = TypeNameLookup.GetOrAdd(shortname,
name => TypeEx.LoadTypeByName(shortname, options.DisallowUnsafeTypes, options.TypeFilter));

Expand Down

0 comments on commit 3886408

Please sign in to comment.