Skip to content

Commit

Permalink
fix Variable[slice].assign() #653
Browse files Browse the repository at this point in the history
  • Loading branch information
Oceania2018 committed Dec 5, 2020
1 parent 6f8beab commit d75366c
Show file tree
Hide file tree
Showing 13 changed files with 271 additions and 161 deletions.
3 changes: 0 additions & 3 deletions src/TensorFlowNET.Core/APIs/tf.ops.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ public void add_to_collection<T>(string name, T value)
public void add_to_collections<T>(List<string> names, T value)
=> get_default_graph().add_to_collections(names, value);

public Tensor assign(Tensor @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null)
=> state_ops.assign(@ref, value, validate_shape, use_locking, name);

public Tensor assign(IVariableV1 @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null)
=> state_ops.assign(@ref, value, validate_shape, use_locking, name);

Expand Down
4 changes: 2 additions & 2 deletions src/TensorFlowNET.Core/Contexts/Context.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public void restore_mode()
context_switches.Pop();
}

[DebuggerStepThrough]
// [DebuggerStepThrough]
public T RunInAutoMode<T>(Func<T> graphAction, Func<T> eagerAction, params Tensor[] tensors)
{
var shouldRunInEager = executing_eagerly()
Expand All @@ -115,7 +115,7 @@ public T RunInAutoMode<T>(Func<T> graphAction, Func<T> eagerAction, params Tenso
}
}

[DebuggerStepThrough]
// [DebuggerStepThrough]
public Tensors RunInAutoMode2(Func<Tensors> graphAction,
Func<Tensors> eagerAction,
Action<Operation> recordGradient,
Expand Down
24 changes: 24 additions & 0 deletions src/TensorFlowNET.Core/Operations/gen_array_ops.cs
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,30 @@ public static Tensor strided_slice(Tensor input, Tensor begin, Tensor end, Tenso
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(),
input, begin, end, strides);

public static Operation resource_strided_slice_assign(Tensor input, Tensor begin, Tensor end, Tensor strides, Tensor value,
int begin_mask = 0,
int end_mask = 0,
int ellipsis_mask = 0,
int new_axis_mask = 0,
int shrink_axis_mask = 0,
string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("ResourceStridedSliceAssign", name, new
{
input, begin, end, strides, value,
begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask
}).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ResourceStridedSliceAssign", name,
null,
input, begin, end, strides, value,
"begin_mask", begin_mask,
"end_mask", end_mask,
"ellipsis_mask", ellipsis_mask,
"new_axis_mask", new_axis_mask,
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(),
input, begin, end, strides, value);

public static Tensor strided_slice<T>(Tensor input, T[] begin, T[] end, T[] strides,
int begin_mask = 0,
int end_mask = 0,
Expand Down
37 changes: 0 additions & 37 deletions src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,43 +34,6 @@ public static ITensorOrOperation shape_safe_assign_variable_handle(Tensor handle
name: name);
}

/// <summary>
///
/// </summary>
/// <param name="self"></param>
/// <param name="value"></param>
/// <param name="use_locking"></param>
/// <param name="read_value"></param>
/// <returns>
/// If `read_value` is `True`, this method will return the new value of the
/// variable after the assignment has completed.Otherwise, when in graph mode
/// it will return the `Operation` that does the assignment, and when in eager
/// mode it will return `None`.
/// </returns>
public static Operation assign(this Tensor self, Tensor value, bool use_locking = false, string name = null, bool read_value = true)
{
var value_tensor = ops.convert_to_tensor(value, dtype: self.dtype);
self.assert_is_compatible_with(value_tensor);
var assign_op = gen_resource_variable_ops.assign_variable_op(self, value_tensor, name: name);
if (read_value)
{
return self._lazy_read(assign_op);
}

return assign_op;
}

public static Operation _lazy_read(this Tensor self, Operation op)
{
variable_accessed(self);
throw new NotImplementedException();
}

public static void variable_accessed(this Tensor variable)
{
throw new NotImplementedException();
}

public static bool is_resource_variable(IVariableV1 var)
{
return var is ResourceVariable;
Expand Down
21 changes: 21 additions & 0 deletions src/TensorFlowNET.Core/Tensors/ParsedSliceArgs.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class ParsedSliceArgs
{
public int[] Begin { get; set; }
public Tensor PackedBegin { get; set; }
public int[] End { get; set; }
public Tensor PackedEnd { get; set; }
public int[] Strides { get; set; }
public Tensor PackedStrides { get; set; }
public int BeginMask { get; set; }
public int EndMask { get; set; }
public int ShrinkAxisMask { get; set; }
public int NewAxisMask { get; set; }
public int EllipsisMask { get; set; }
}
}
24 changes: 24 additions & 0 deletions src/TensorFlowNET.Core/Tensors/Tensor.Assign.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using NumSharp;

namespace Tensorflow
{
public partial class Tensor
{
/// <summary>
/// Used to keep the original variable when slicing
/// </summary>
public ResourceVariable OriginalVar { get; set; }
public ParsedSliceArgs OriginalVarSlice { get; set; }

public ResourceVariable assign(Tensor tensor)
{
if (OriginalVar != null)
{
OriginalVar.StridedSliceAssign(tensor, OriginalVarSlice);
return OriginalVar;
}
else
throw new RuntimeError("Operation doesn't support.");
}
}
}
75 changes: 11 additions & 64 deletions src/TensorFlowNET.Core/Tensors/Tensor.Index.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,81 +30,28 @@ public Tensor this[params Slice[] slices]
{
get
{
var begin = new List<int>();
var end = new List<int>();
var strides = new List<int>();
var args = tensor_util.ParseSlices(slices);

var index = 0;
var (new_axis_mask, shrink_axis_mask) = (0, 0);
var (begin_mask, end_mask) = (0, 0);
var ellipsis_mask = 0;

foreach (var s in slices)
{
if (s.IsNewAxis)
{
begin.Add(0);
end.Add(0);
strides.Add(1);
new_axis_mask |= (1 << index);
}
else if (s.IsEllipsis)
{
begin.Add(0);
end.Add(0);
strides.Add(1);
ellipsis_mask |= (1 << index);
}
else
{
if (s.Start.HasValue)
{
begin.Add(s.Start.Value);
}
else
{
begin.Add(0);
begin_mask |= (1 << index);
}

if (s.Stop.HasValue)
{
end.Add(s.Stop.Value);
}
else
{
end.Add(0);
end_mask |= (1 << index);
}

strides.Add(s.Step);
if (s.IsIndex)
shrink_axis_mask |= (1 << index);
}

index += 1;
}

return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope =>
return tf_with(ops.name_scope(null, "strided_slice", args), scope =>
{
string name = scope;
if (begin != null)
if (args.Begin != null)
{
var (packed_begin, packed_end, packed_strides) =
(array_ops.stack(begin.ToArray()),
array_ops.stack(end.ToArray()),
array_ops.stack(strides.ToArray()));
(array_ops.stack(args.Begin),
array_ops.stack(args.End),
array_ops.stack(args.Strides));

return gen_array_ops.strided_slice(
this,
packed_begin,
packed_end,
packed_strides,
begin_mask: begin_mask,
end_mask: end_mask,
shrink_axis_mask: shrink_axis_mask,
new_axis_mask: new_axis_mask,
ellipsis_mask: ellipsis_mask,
begin_mask: args.BeginMask,
end_mask: args.EndMask,
shrink_axis_mask: args.ShrinkAxisMask,
new_axis_mask: args.NewAxisMask,
ellipsis_mask: args.EllipsisMask,
name: name);
}

Expand Down
71 changes: 71 additions & 0 deletions src/TensorFlowNET.Core/Tensors/tensor_util.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.

using NumSharp;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Eager;
Expand Down Expand Up @@ -584,5 +585,75 @@ public static string to_numpy_string(Tensor tensor)
return nd.ToString();
}
}

public static ParsedSliceArgs ParseSlices(Slice[] slices)
{
var begin = new List<int>();
var end = new List<int>();
var strides = new List<int>();

var index = 0;
var (new_axis_mask, shrink_axis_mask) = (0, 0);
var (begin_mask, end_mask) = (0, 0);
var ellipsis_mask = 0;

foreach (var s in slices)
{
if (s.IsNewAxis)
{
begin.Add(0);
end.Add(0);
strides.Add(1);
new_axis_mask |= (1 << index);
}
else if (s.IsEllipsis)
{
begin.Add(0);
end.Add(0);
strides.Add(1);
ellipsis_mask |= (1 << index);
}
else
{
if (s.Start.HasValue)
{
begin.Add(s.Start.Value);
}
else
{
begin.Add(0);
begin_mask |= (1 << index);
}

if (s.Stop.HasValue)
{
end.Add(s.Stop.Value);
}
else
{
end.Add(0);
end_mask |= (1 << index);
}

strides.Add(s.Step);
if (s.IsIndex)
shrink_axis_mask |= (1 << index);
}

index += 1;
}

return new ParsedSliceArgs
{
Begin = begin.ToArray(),
End = end.ToArray(),
Strides = strides.ToArray(),
BeginMask = begin_mask,
EndMask = end_mask,
EllipsisMask = ellipsis_mask,
ShrinkAxisMask = shrink_axis_mask,
NewAxisMask = new_axis_mask
};
}
}
}
16 changes: 16 additions & 0 deletions src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,22 @@ public Tensor assign<T>(T value, bool use_locking = false, string name = null, b
return assign_op;
}

public void StridedSliceAssign(Tensor value, ParsedSliceArgs slice)
{
_strided_slice_assign(slice.PackedBegin, slice.PackedEnd, slice.PackedStrides, value);
}

void _strided_slice_assign(Tensor begin, Tensor end, Tensor strides, Tensor value, string name = null,
int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, int shrink_axis_mask = 0)
{
var op = gen_array_ops.resource_strided_slice_assign(handle, begin, end, strides, value,
begin_mask: begin_mask,
end_mask: end_mask,
ellipsis_mask: ellipsis_mask,
new_axis_mask: new_axis_mask,
shrink_axis_mask: shrink_axis_mask);
}

public IVariableV1 assign_lazy_load(Tensor value, string name = null)
{
var value_tensor = ops.convert_to_tensor(value, dtype: dtype);
Expand Down
Loading

0 comments on commit d75366c

Please sign in to comment.