Skip to content

Commit

Permalink
Merge pull request opentiny#21 from GaoNeng-wWw/feat/code-generate
Browse files Browse the repository at this point in the history
feat: build sequential
  • Loading branch information
GaoNeng-wWw authored Mar 2, 2024
2 parents a552b78 + a7ff174 commit 5160a28
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 103 deletions.
67 changes: 31 additions & 36 deletions dl-flow-backend/src/code-generate/__tests__/ast.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ class Layer${i}:
children: ['node-1', 'node-2'],
};
const ast = service.buildGroup(group as any, {
// 我觉得x6就是一坨 s**t
'group-1': group as any,
'node-1': {
id: 'node-1',
Expand Down Expand Up @@ -217,11 +216,11 @@ class Layer${i}:
zIndex: 0,
},
});
expect(ast.children.length).toBe(3);
expect(ast.children[0]).toBeInstanceOf(VarDecl);
expect(ast.children[1]).toBeInstanceOf(VarDecl);
expect(ast.children[2]).toBeInstanceOf(VarDecl);
expect(ast.children[2].codeGen()).toContain('group1');
// expect(ast.children.length).toBe(3);
// expect(ast.children[0]).toBeInstanceOf(VarDecl);
// expect(ast.children[1]).toBeInstanceOf(VarDecl);
// expect(ast.children[2]).toBeInstanceOf(VarDecl);
expect(ast.children[0].codeGen()).toContain('group1');
});
it('has nest', () => {
const nodes = {
Expand Down Expand Up @@ -303,17 +302,9 @@ class Layer${i}:
children: ['node-1', 'node-2', 'group-2'],
};
const ast = service.buildGroup(group as any, nodes as any);
expect(ast.children.length).toBe(7);
expect(
ast.children.every((child) => child instanceof VarDecl),
).toBeTruthy();
expect(ast.codeGen()).toBe(`nodenode1 = paddle.nn.Conv1D()
nodenode2 = paddle.nn.Conv1D()
nodenode4 = paddle.nn.Conv1d()
nodenode11 = paddle.nn.Conv1d()
group_group3 = paddle.concat(x=[nodenode11])
group_group2 = paddle.concat(x=[nodenode4,group_group3])
group_group1 = paddle.concat(x=[nodenode1,nodenode2,group_group2])`);
expect(ast.codeGen()).toBe(`group_group3 = paddle.concat(x=[nodenode11])
group_group2 = paddle.concat(x=[group_group3,nodenode4])
group_group1 = paddle.concat(x=[group_group2,nodenode2,nodenode1])`);
});
});
it('build', () => {
Expand Down Expand Up @@ -387,28 +378,32 @@ group_group1 = paddle.concat(x=[nodenode1,nodenode2,group_group2])`);
},
};
const ast = service.build(
[node['node-1'], node['node-2'], node['layer'], node['group']],
[
node['node-1'],
node['node-2'],
node['layer'],
node['node-3'],
node['group'],
],
node,
);
expect(ast.codeGen().replace(/\n| /gim, '')).toEqual(
`true = True
false = False
nodenode1 = paddle.nn.Conv1D()
nodenode2 = paddle.nn.Conv1D()
class Layer1:
def __init__(self,x):
pass
layer1 = Layer1(x=1)
nodenode3 = paddle.nn.Conv1D()
class Layer1:
def __init__(self,x):
pass
layer-1 = Layer1(x = 1)
group_group = paddle.concat(x=[nodenode3,layer-1])`.replace(/\n| /gim, ''),
`true = True
false = False
nodenode1 = paddle.nn.Conv1D()
nodenode2 = paddle.nn.Conv1D()
class Layer1:
def __init__(self,x):
pass
layer1 = Layer1(x=1)
nodenode3 = paddle.nn.Conv1D()
group_group = paddle.concat(x=[layer1,nodenode3])
model=paddle.nn.Sequential(nodenode1,nodenode2,group_group)`.replace(
/\n| /gim,
'',
),
);
});
});
196 changes: 129 additions & 67 deletions dl-flow-backend/src/code-generate/ast.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ export class AST {
continue;
}
}
ast.children.push(new VarDecl('model', this.buildSequential(ast)));
return ast;
}
/**
Expand Down Expand Up @@ -121,84 +122,120 @@ export class AST {
const varDecl = new VarDecl(`node${cellId.replace(/-/gim, '')}`, fnCall);
return varDecl;
}
extractGroup(group: Cell, nodeTable: StandardizationNodes) {
const stack: Cell[] = [];
if (this.isGroup(group)) {
stack.push(group);
}
for (const child of (group.children as unknown as string[]) ?? []) {
if (this.isGroup(nodeTable[child])) {
stack.push(...this.extractGroup(nodeTable[child], nodeTable));
}
}
return stack;
}
isChild(group: Cell, child: Cell) {
return (group.children as unknown as string[]).includes(child.id);
}
buildLayer(layer: Layer) {
const clazzDef = new ClazzDef(layer.code);
return clazzDef;
}
buildGroup(group: Cell, standardizationNodes: StandardizationNodes) {
const subAst: IAST = {
type: 'root',
children: [],
codeGen: () => subAst.children.map((child) => child.codeGen()).join('\n'),
};
const groups: { [x: string]: string[] } = {};
for (const childId of (group.children as unknown as string[]) ?? []) {
const child = standardizationNodes[childId];
if (this.isGroup(child)) {
if (!groups[child.id]) {
groups[child.id] = [];
}
const subAST = this.buildGroup(child, standardizationNodes);
subAst.children.push(...subAST.children);
groups[child.id] = subAST.children
.filter(
(v, i) => i !== subAST.children.length - 1 && v instanceof VarDecl,
)
.map((v: VarDecl) => v.name);
const ast = new GroupAst();
const stack: [Cell, 'start' | 'node' | 'end'][] = [];
const groups = this.extractGroup(group, standardizationNodes);
for (const g of groups) {
stack.push([g, 'start']);
for (const child of (g.children as unknown as string[]) ?? []) {
stack.push([standardizationNodes[child], 'node']);
}
if (this.isNN(child.data)) {
const nn = this.buildNN(child.data, child.id);
subAst.children.push(nn);
if (groups[child.id]) {
groups[child.id].push(nn.name);
}
stack.push([g, 'end']);
}
let activeGroup: Cell | null = null;
while (stack.length) {
const [cell, type] = stack.pop();
const children = [];
if (type === 'end') {
activeGroup = cell;
}
if (this.isLayer(child.data)) {
const clazz = this.buildLayer(child.data);
const callee = new Identifier(child.data.clazz);
const clazzInstance = new CallExpression(
callee,
child.data.properties.map((v) => `${v.id} = ${v.data}`),
if (type === 'node') {
if (this.isGroup(cell)) {
children.push(`group_${cell.id.replace(/-/gim, '')}`);
} else {
if (this.isNN(cell.data)) {
children.push(`node${cell.id.replace(/-/gim, '')}`);
}
if (this.isLayer(cell.data)) {
children.push(`${cell.id.replace('-', '')}`);
}
}
while (true) {
const [cell, type] = stack.pop();
if (type === 'start') {
break;
}
if (this.isGroup(cell)) {
children.push(`group_${cell.id.replace(/-/gim, '')}`);
} else {
if (this.isNN(cell.data)) {
children.push(`node${cell.id.replace(/-/gim, '')}`);
}
if (this.isLayer(cell.data)) {
children.push(`${cell.id.replace('-', '')}`);
}
}
}
if (!activeGroup) {
throw new Error(
`Can not find active group. Please check your schema.`,
);
}
const callee = new Identifier('paddle.concat');
const call = new CallExpression(callee, [
['x=[', children.join(','), ']'].join(''),
]);
const concatVar = new VarDecl(
`group_${activeGroup.id}`.replace(/-/gim, ''),
call,
);
const instance = new VarDecl(child.id, clazzInstance);
subAst.children.push(clazz);
subAst.children.push(instance);
ast.children.push(concatVar);
ast.childId.push(...children);
}
}
groups[group.id] = group.children
.filter((child) => !this.isGroup(child))
.map((v) => v.id);
const varDecl = subAst.children.filter(
(item) => item instanceof VarDecl,
) as VarDecl[];
const keys = Object.keys(groups);
const values = keys
.map((k) => Object.values(groups[k]))
.reduce((pre, cur) => {
return [...pre, ...cur];
}, []);
const names =
keys.length === 1
? varDecl.map((decl) => decl.name)
: varDecl
.map((decl) => {
return decl.name;
})
.filter((v) => {
return !keys.includes(v) && !values.includes(v);
});

const callee = new Identifier('paddle.concat');
const call = new CallExpression(callee, [
['x=[', names.join(','), ']'].join(''),
]);
const concatVar = new VarDecl(
`group_${group.id}`.replace(/-/gim, ''),
call,
return ast;
}
buildSequential(ast: IAST) {
const stack = [
...ast.children
.map((child) => {
if (child instanceof VarDecl) {
return child.name;
}
if (child instanceof GroupAst) {
return child.children.map((child) =>
child instanceof VarDecl ? child.name : null,
);
}
return null;
})
.flat()
.filter(
(child) => child !== null && child !== 'true' && child !== 'false',
),
];
const groups = ast.children
.map((child) => (child instanceof GroupAst ? child : null))
.filter((child) => child !== null);
const childrenId = groups
.map((group) => group.childId)
.reduce((pre, cur) => [...pre, ...cur], []);
return new CallExpression(
new Identifier('paddle.nn.Sequential'),
stack.filter((item) => !childrenId.includes(item)),
);
subAst.children.push(concatVar);
return subAst;
}

isGroup(cell: Cell) {
return cell?.shape && cell.shape.includes('group');
}
Expand Down Expand Up @@ -257,6 +294,19 @@ export class ClazzDef implements IClazzDefine {
}
}

export class Statement implements IStmt, Node {
children: Node[] = [];
}

export class GroupAst implements IGroupAst {
type = 'root' as const;
children: ASTItem[] = [];
childId: string[] = [];
codeGen() {
return this.children.map((child) => child.codeGen()).join('\n');
}
}

type IVarDecl = {
name: string;
val: ASTItem;
Expand All @@ -275,11 +325,23 @@ type IClazzDefine = {
code: string;
codeGen: () => string;
};
type IStmt = {
children: Node[];
};

type ASTItem = IVarDecl | IIdentifier | ICallExpression | IClazzDefine;
type ASTItem =
| IVarDecl
| IIdentifier
| ICallExpression
| IClazzDefine
| IGroupAst;

type IAST = {
type: 'root';
children: ASTItem[];
codeGen: () => string;
};

interface IGroupAst extends IAST {
childId: string[];
}

0 comments on commit 5160a28

Please sign in to comment.