Skip to content

Commit

Permalink
First working version of refacttor
Browse files Browse the repository at this point in the history
  • Loading branch information
saulshanabrook committed Nov 11, 2018
1 parent f2bfdff commit c619514
Show file tree
Hide file tree
Showing 7 changed files with 259 additions and 120 deletions.
217 changes: 166 additions & 51 deletions uarray/NumPy Compat.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,9 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "issubclass() arg 1 must be a class",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-2-b2b69a6182b9>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0muarray\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mnumba\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnjit\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/p/uarray/uarray/uarray/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mast_higher\u001b[0m \u001b[0;31m# NOQA\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0moptimize\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;31m# type: ignore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/p/uarray/uarray/uarray/ast_higher.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mast\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mast\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mmoa\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mOuterProduct\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/p/uarray/uarray/uarray/ast.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mcore\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mmoa\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mAdd\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMultiply\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/p/uarray/uarray/uarray/moa.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m register(\n\u001b[0;32m---> 56\u001b[0;31m \u001b[0mReduceVector\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"fn\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mInt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"value\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSequence\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mInt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"length\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"getitem\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 57\u001b[0m \u001b[0m_reduce_vector\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m )\n",
"\u001b[0;32m/usr/local/miniconda3/envs/uarray/lib/python3.6/site-packages/matchpy/expressions/expressions.py\u001b[0m in \u001b[0;36msymbol\u001b[0;34m(name, symbol_type)\u001b[0m\n\u001b[1;32m 786\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0missubclass\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSymbol\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0msymbol_type\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mSymbol\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 787\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mSymbolWildcard\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 788\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mSymbolWildcard\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msymbol_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvariable_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 789\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mstaticmethod\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/miniconda3/envs/uarray/lib/python3.6/site-packages/matchpy/expressions/expressions.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, symbol_type, variable_name)\u001b[0m\n\u001b[1;32m 894\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvariable_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 895\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 896\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0missubclass\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msymbol_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSymbol\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 897\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"The type constraint must be a subclass of Symbol\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 898\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mTypeError\u001b[0m: issubclass() arg 1 must be a class"
]
}
],
"outputs": [],
"source": [
"from uarray import *\n",
"import numpy as np\n",
Expand All @@ -52,7 +34,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -69,7 +51,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -78,9 +60,28 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 274 µs, sys: 127 µs, total: 401 µs\n",
"Wall time: 381 µs\n"
]
},
{
"data": {
"text/plain": [
"array([ 0, 5, 10, 15, 20, 25, 30, 35, 40, 45])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time some_fn(*args)"
]
Expand All @@ -101,7 +102,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -119,9 +120,28 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 25 µs, sys: 0 ns, total: 25 µs\n",
"Wall time: 31 µs\n"
]
},
{
"data": {
"text/plain": [
"array([ 0., 5., 10., 15., 20., 25., 30., 35., 40., 45.])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time optimized_some_fn(*args)"
]
Expand All @@ -135,9 +155,25 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"Index(Sequence(Int(1), VectorCallable(Scalar(Int(5)))),\n",
" OuterProduct(BinaryUfunc(<ufunc 'multiply'>),\n",
" ToSequenceWithDim(NPArray(Expression(Name(id='a', ctx=Load()))),\n",
" Int(1)),\n",
" ToSequenceWithDim(NPArray(Expression(Name(id='b', ctx=Load()))),\n",
" Int(1))))"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_some_fn.__optimize_steps__['resulting_expr']"
]
Expand All @@ -151,9 +187,36 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"def fn(a, b):\n",
" i_5 = ()\n",
" i_6 = b.shape[0]\n",
" i_1 = ((i_6,) + i_5)\n",
" i_0 = np.empty(i_1)\n",
" i_2 = b.shape[0]\n",
" for i_3 in range(i_2):\n",
" i_4 = i_0[i_3]\n",
" i_9 = 5\n",
" i_10 = a\n",
" i_7 = i_10[i_9]\n",
" i_11 = i_3\n",
" i_12 = b\n",
" i_8 = i_12[i_11]\n",
" i_4 = (i_7 * i_8)\n",
" i_0[i_3] = i_4\n",
" return i_0\n",
"\n"
]
}
],
"source": [
"print(optimized_some_fn.__optimize_steps__['ast_as_source'])"
]
Expand All @@ -174,7 +237,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -183,9 +246,28 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 7 µs, sys: 0 ns, total: 7 µs\n",
"Wall time: 11.2 µs\n"
]
},
{
"data": {
"text/plain": [
"array([ 0., 5., 10., 15., 20., 25., 30., 35., 40., 45.])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# run once first to compile\n",
"numba_optimized(*args)\n",
Expand All @@ -197,16 +279,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Great, another 2x speedup!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# ast.dump(ast.parse(\"(1,) + ()\"))"
"Great, another speedup!"
]
},
{
Expand All @@ -225,16 +298,27 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 14,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"<function __main__.some_fn(a, b)>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"some_fn"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -243,18 +327,49 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 16,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"Index(Sequence(Int(1), VectorCallable(Scalar(Int(5)))),\n",
" OuterProduct(BinaryUfunc(<ufunc 'multiply'>),\n",
" NPArray(Expression(Name(id='a', ctx=Load()))),\n",
" NPArray(Expression(Name(id='b', ctx=Load())))))"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dims_not_known.__optimize_steps__['resulting_expr']"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 17,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"def fn(a, b):\n",
" i_14 = 5\n",
" i_16 = a\n",
" i_17 = b\n",
" i_15 = np.multiply.outer(i_16, i_17)\n",
" i_13 = i_15[i_14]\n",
" return i_13\n",
"\n"
]
}
],
"source": [
"print(dims_not_known.__optimize_steps__['ast_as_source'])"
]
Expand Down
Loading

0 comments on commit c619514

Please sign in to comment.