Skip to content

Commit

Permalink
Adjust condition and add optimization logic
Browse files Browse the repository at this point in the history
  • Loading branch information
yqzhishen committed Apr 5, 2024
1 parent 087b929 commit 821b3dd
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
5 changes: 3 additions & 2 deletions deployment/modules/rectified_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ def forward(self, condition, x_end=None, depth=None, steps: int = 10):
else:
x = t_start * x_end + (1 - t_start) * noise

dt = (1 - t_start) / max(1, steps)
if dt > 0.:
t_width = 1. - t_start
if t_width >= 0.:
dt = t_width / max(1, steps)
for t in torch.arange(steps, dtype=torch.long, device=device)[:, None].float() * dt + t_start:
x = self.sample_euler(x, t, dt, condition)

Expand Down
25 changes: 17 additions & 8 deletions utils/onnx_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,16 +277,25 @@ def _extract_conv_nodes_recursive(subgraph: GraphProto):
to_be_removed.append(sub_node)
[subgraph.node.remove(_n) for _n in to_be_removed]

toplevel_if_idx = toplevel_if_node = None
toplevel_entry_node_idx = toplevel_entry_node = None
# Find the **last** If node in toplevel graph
for i, n in enumerate(graph.node):
if n.op_type == 'If':
toplevel_if_idx = i
toplevel_if_node = n
if toplevel_if_node is not None:
for a in toplevel_if_node.attribute:
b = onnx.helper.get_attribute_value(a)
_extract_conv_nodes_recursive(b)
toplevel_entry_node_idx = i
toplevel_entry_node = n
# If not found, find the **last** Loop node in toplevel graph
if toplevel_entry_node is None:
for i, n in enumerate(graph.node):
if n.op_type == 'Loop':
toplevel_entry_node_idx = i
toplevel_entry_node = n
if toplevel_entry_node is not None:
for a in toplevel_entry_node.attribute:
# Apply to all sub-graphs
v = onnx.helper.get_attribute_value(a)
if isinstance(v, GraphProto):
_extract_conv_nodes_recursive(v)

# Insert the extracted nodes before the first 'If' node which carries the main denoising loop.
for key in reversed(node_dict):
alias, node = node_dict[key]
Expand All @@ -295,7 +304,7 @@ def _extract_conv_nodes_recursive(subgraph: GraphProto):
node.output.remove(node.output[0])
node.output.insert(0, alias)
# Insert node into the main graph.
graph.node.insert(toplevel_if_idx, node)
graph.node.insert(toplevel_entry_node_idx, node)
# Rename value info of the output.
for v in graph.value_info:
if v.name == out_name:
Expand Down

0 comments on commit 821b3dd

Please sign in to comment.