Skip to content

Commit

Permalink
fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
IgorBaratta committed Jan 4, 2024
1 parent d06388c commit c5b2000
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
16 changes: 9 additions & 7 deletions ffcx/codegeneration/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
import logging

import ufl
from typing import List, Optional
import ffcx.codegeneration.lnodes as L
from ffcx.ir.elementtables import UniqueTableReferenceT
from ffcx.ir.representationutils import QuadratureRule
from ffcx.ir.analysis.modified_terminals import ModifiedTerminal

logger = logging.getLogger("ffcx")

Expand Down Expand Up @@ -90,7 +94,7 @@ def get(self, mt, tabledata, quadrature_rule, access):
# Call the handler
return handler(mt, tabledata, quadrature_rule, access)

def coefficient(self, mt, tabledata, quadrature_rule, access):
def coefficient(self, mt: ModifiedTerminal, tabledata: UniqueTableReferenceT, quadrature_rule: QuadratureRule, access: str):
"""Return definition code for coefficients."""
# For applying tensor product to coefficients, we need to know if the coefficient
# has a tensor factorisation and if the quadrature rule has a tensor factorisation.
Expand Down Expand Up @@ -123,7 +127,7 @@ def coefficient(self, mt, tabledata, quadrature_rule, access):
FE = self.access.table_access(tabledata, self.entitytype, mt.restriction, iq, ic)

code = []
pre_code = []
pre_code: List[L.LNode] = []

if bs > 1 and not tabledata.is_piecewise:
# For bs > 1, the coefficient access has a stride of bs. e.g.: XYZXYZXYZ
Expand All @@ -133,7 +137,8 @@ def coefficient(self, mt, tabledata, quadrature_rule, access):
# have a sequential access pattern.
dof_access, dof_access_map = self.symbols.coefficient_dof_access_blocked(mt.terminal, ic, bs, begin)

# If a map is necessary from stride 1 to bs, the code must be added before the quadrature loop.
# If a map is necessary from stride 1 to bs, the code must be added
# before the quadrature loop.
if dof_access_map:
pre_code += [L.ArrayDecl(dof_access.array, sizes=num_dofs)]
pre_body = [L.Assign(dof_access, dof_access_map)]
Expand All @@ -146,10 +151,7 @@ def coefficient(self, mt, tabledata, quadrature_rule, access):
code += [L.VariableDecl(access, 0.0)]
code += [L.create_nested_for_loops([ic], body)]

code = L.Section("Coefficient definition", code)
pre_code = L.Section("Coefficient pre definition", pre_code)

return pre_code, code
return L.Section("Coefficient pre definition", pre_code), L.Section("Coefficient definition", code)

def _define_coordinate_dofs_lincomb(self, mt, tabledata, quadrature_rule, access):
"""Define x or J as a linear combination of coordinate dofs with given table data."""
Expand Down
4 changes: 2 additions & 2 deletions ffcx/codegeneration/lnodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#
# SPDX-License-Identifier: LGPL-3.0-or-later

from typing import List
from typing import List, Optional, Tuple
import numbers
import ufl
import numpy as np
Expand Down Expand Up @@ -728,7 +728,7 @@ def as_statement(node):
class Section(LNode):
"""A section of code with a name and a list of statements."""

def __init__(self, name: str, statements: LExpr, annotations: List[str] = None):
def __init__(self, name: str, statements: List, annotations: Optional[List] = None):
self.name = name
self.statements = [as_statement(st) for st in statements]
self.annotations = annotations or []
Expand Down

0 comments on commit c5b2000

Please sign in to comment.