Skip to content

Commit

Permalink
Merge pull request #200 from tonybaloney/dict_round_trip
Browse files Browse the repository at this point in the history
Make PyDictionary and PyList convertable back to PyObject without marshalling
  • Loading branch information
tonybaloney authored Sep 17, 2024
2 parents baa7342 + 3534e06 commit 895d89d
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 16 deletions.
3 changes: 0 additions & 3 deletions src/CSnakes.Runtime/PyObjectTypeConverter.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
using CSnakes.Runtime.CPython;
using CSnakes.Runtime.Python;
using System.Collections;
using System.Runtime.CompilerServices;
using System.Numerics;
using System.Collections.Concurrent;
using System.Reflection;

Expand Down
5 changes: 5 additions & 0 deletions src/CSnakes.Runtime/Python/ICloneable.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
namespace CSnakes.Runtime.Python;
internal interface ICloneable
{
internal PyObject Clone();
}
3 changes: 2 additions & 1 deletion src/CSnakes.Runtime/Python/PyDictionary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace CSnakes.Runtime.Python;

internal class PyDictionary<TKey, TValue>(PyObject dictionary) : IReadOnlyDictionary<TKey, TValue>, IDisposable
internal class PyDictionary<TKey, TValue>(PyObject dictionary) : IReadOnlyDictionary<TKey, TValue>, IDisposable, ICloneable
where TKey : notnull
{
private readonly Dictionary<TKey, TValue> _dictionary = [];
Expand Down Expand Up @@ -100,5 +100,6 @@ public bool TryGetValue(TKey key, [MaybeNullWhen(false)] out TValue value)
return false;
}

PyObject ICloneable.Clone() => _dictionaryObject.Clone();
IEnumerator IEnumerable.GetEnumerator() => this.GetEnumerator();
}
17 changes: 7 additions & 10 deletions src/CSnakes.Runtime/Python/PyList.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,10 @@

namespace CSnakes.Runtime.Python;

internal class PyList<TItem> : IReadOnlyList<TItem>, IDisposable
internal class PyList<TItem>(PyObject listObject) : IReadOnlyList<TItem>, IDisposable, ICloneable
{
private readonly PyObject _listObject;

// If someone fetches the same index multiple times, we cache the result to avoid multiple round trips to Python
private readonly Dictionary<long, TItem> _convertedItems = new();

public PyList(PyObject listObject) => _listObject = listObject;
private readonly Dictionary<long, TItem> _convertedItems = [];

public TItem this[int index]
{
Expand All @@ -23,7 +19,7 @@ public TItem this[int index]

using (GIL.Acquire())
{
using PyObject value = PyObject.Create(CPythonAPI.PySequence_GetItem(_listObject, index));
using PyObject value = PyObject.Create(CPythonAPI.PySequence_GetItem(listObject, index));
TItem result = value.As<TItem>();
_convertedItems[index] = result;
return result;
Expand All @@ -37,21 +33,22 @@ public int Count
{
using (GIL.Acquire())
{
return (int)CPythonAPI.PySequence_Size(_listObject);
return (int)CPythonAPI.PySequence_Size(listObject);
}
}
}

public void Dispose() => _listObject.Dispose();
public void Dispose() => listObject.Dispose();

public IEnumerator<TItem> GetEnumerator()
{
// TODO: If someone fetches the same index multiple times, we cache the result to avoid multiple round trips to Python
using (GIL.Acquire())
{
return new PyEnumerable<TItem>(_listObject);
return new PyEnumerable<TItem>(listObject);
}
}

PyObject ICloneable.Clone() => listObject.Clone();
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
}
6 changes: 4 additions & 2 deletions src/CSnakes.Runtime/Python/PyObject.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
namespace CSnakes.Runtime.Python;

[DebuggerDisplay("PyObject: repr={GetRepr()}, type={GetPythonType().ToString()}")]
public class PyObject : SafeHandle
public class PyObject : SafeHandle, ICloneable
{
protected PyObject(IntPtr pyObject, bool ownsHandle = true) : base(pyObject, ownsHandle)
{
Expand Down Expand Up @@ -494,7 +494,7 @@ public static PyObject From<T>(T value)

return value switch
{
PyObject pyObject => pyObject.Clone(),
ICloneable pyObject => pyObject.Clone(),
bool b => b ? True : False,
int i => Create(CPythonAPI.PyLong_FromLong(i)),
long l => Create(CPythonAPI.PyLong_FromLongLong(l)),
Expand Down Expand Up @@ -541,4 +541,6 @@ private static void MergeKeywordArguments(string[] kwnames, PyObject[] kwvalues,
combinedKwnames = [.. newKwnames];
combinedKwvalues = [.. newKwvalues];
}

PyObject ICloneable.Clone() => Clone();
}
13 changes: 13 additions & 0 deletions src/Integration.Tests/DictsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,17 @@ public void TestDicts_TestMapping()
Assert.Equal(1, result["dictkey1"]);
Assert.Equal(2, result["dictkey2"]);
}

[Fact]
public void TestDictionaryRoundTrip()
{
var testDicts = Env.TestDicts();

IReadOnlyDictionary<string, long> testDict = new Dictionary<string, long> { { "dictkey1", 1 }, { "dictkey2", 2 } };
var result = testDicts.TestDictStrInt(testDict);
Assert.Equal(1, result["dictkey1"]);

var roundTrip = testDicts.TestDictStrInt(result);
Assert.Equal(1, roundTrip["dictkey1"]);
}
}

0 comments on commit 895d89d

Please sign in to comment.