summaryrefslogtreecommitdiffstats
path: root/mlir/utils
diff options
context:
space:
mode:
authorMahesh Ravishankar <ravishankarm@google.com>2019-08-31 09:52:18 -0700
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-08-31 09:52:44 -0700
commit49c3e4a50819bfea464dca4290f9f61eac9540e2 (patch)
tree5a523bbb2e7034b5ce9be3430a3a02fc2fad7e4d /mlir/utils
parent5a7014c390918ea73684b37616d7a4f61d9b2f17 (diff)
downloadbcm5719-llvm-49c3e4a50819bfea464dca4290f9f61eac9540e2.tar.gz
bcm5719-llvm-49c3e4a50819bfea464dca4290f9f61eac9540e2.zip
Add floating-point comparison operations to SPIR-V dialect.
Use the existing SPV_LogicalOp specification to add the floating-point comparison operations (both ordered and unordered versions). To make it easier to import the op-definitions automatically modify the dialect generation script to update the different .td files based on whether the operation is an arithmetic op, logical op, etc. Also allow specification of multiple opcodes with define_inst.sh. Since this reuses the SPV_LogicalOp framework, no tests specific to the floating point comparison ops are added with this CL. PiperOrigin-RevId: 266561634
Diffstat (limited to 'mlir/utils')
-rwxr-xr-xmlir/utils/spirv/define_inst.sh34
-rwxr-xr-xmlir/utils/spirv/gen_spirv_dialect.py101
2 files changed, 92 insertions, 43 deletions
diff --git a/mlir/utils/spirv/define_inst.sh b/mlir/utils/spirv/define_inst.sh
index 49b5e8df880..55e2fa0ed9b 100755
--- a/mlir/utils/spirv/define_inst.sh
+++ b/mlir/utils/spirv/define_inst.sh
@@ -17,20 +17,38 @@
# Script for defining a new op using SPIR-V spec from the Internet.
#
# Run as:
-# ./define_inst.sh <opname>
+# ./define_inst.sh <inst_category> (<opname>)*
-# For example:
-# ./define_inst.sh OpIAdd
-#
-# If <opname> is missing, this script updates existing ones.
+# <inst_category> is required. It can be one of
+# (Op|ArithmeticOp|LogicalOp|ControlFlowOp|StructureOp). Based on the
+# inst_category the file SPIRV<inst_category>s.td is updated with the
+# instruction definition. If <opname> is missing, this script updates existing
+# ones in SPIRV<inst_category>s.td
+# For example:
+# ./define_inst.sh ArithmeticOp OpIAdd
+# ./define_inst.sh LogicalOp OpFOrdEqual
set -e
-new_op=$1
+inst_category=$1
+
+case $inst_category in
+ Op | ArithmeticOp | LogicalOp | ControlFlowOp | StructureOp)
+ ;;
+ *)
+ echo "Usage : " $0 " <inst_category> (<opname>)*"
+ echo "<inst_category> must be one of " \
+ "(Op|ArithmeticOp|LogicalOp|ControlFlowOp|StructureOp)"
+ exit 1;
+ ;;
+esac
+
+shift
current_file="$(readlink -f "$0")"
current_dir="$(dirname "$current_file")"
python3 ${current_dir}/gen_spirv_dialect.py \
- --op-td-path ${current_dir}/../../include/mlir/Dialect/SPIRV/SPIRVOps.td \
- --new-inst "${new_op}"
+ --op-td-path \
+ ${current_dir}/../../include/mlir/Dialect/SPIRV/SPIRV${inst_category}s.td \
+ --inst-category $inst_category --new-inst "$@"
diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py
index e34945d40e7..e74a40e16d8 100755
--- a/mlir/utils/spirv/gen_spirv_dialect.py
+++ b/mlir/utils/spirv/gen_spirv_dialect.py
@@ -360,7 +360,7 @@ def map_spec_operand_to_ods_argument(operand):
return '{}:${}'.format(arg_type, name)
-def get_op_definition(instruction, doc, existing_info):
+def get_op_definition(instruction, doc, existing_info, inst_category):
"""Generates the TableGen op definition for the given SPIR-V instruction.
Arguments:
@@ -372,19 +372,22 @@ def get_op_definition(instruction, doc, existing_info):
Returns:
- A string containing the TableGen op definition
"""
- fmt_str = 'def SPV_{opname}Op : SPV_Op<"{opname}", [{traits}]> {{\n'\
- ' let summary = {summary};\n\n'\
- ' let description = [{{\n'\
- '{description}\n\n'\
- ' ### Custom assembly form\n'\
- '{assembly}'\
- '}}];\n\n'\
- ' let arguments = (ins{args});\n\n'\
- ' let results = (outs{results});\n'\
- '{extras}'\
+ fmt_str = ('def SPV_{opname}Op : '
+ 'SPV_{inst_category}<"{opname}"{category_args}[{traits}]> '
+ '{{\n let summary = {summary};\n\n let description = '
+ '[{{\n{description}\n\n ### Custom assembly '
+ 'form\n{assembly}}}];\n')
+ if inst_category == 'Op':
+ fmt_str +='\n let arguments = (ins{args});\n\n'\
+ ' let results = (outs{results});\n\n'
+
+ fmt_str +='{extras}'\
'}}\n'
opname = instruction['opname'][2:]
+ category_args = existing_info.get('category_args', None)
+ if category_args is None:
+ category_args = ', '
summary, description = doc.split('\n', 1)
wrapper = textwrap.TextWrapper(
@@ -438,6 +441,8 @@ def get_op_definition(instruction, doc, existing_info):
return fmt_str.format(
opname=opname,
+ category_args=category_args,
+ inst_category=inst_category,
traits=existing_info.get('traits', ''),
summary=summary,
description=description,
@@ -447,6 +452,29 @@ def get_op_definition(instruction, doc, existing_info):
extras=existing_info.get('extras', ''))
+def get_string_between(base, start, end):
+ """Extracts a substring with a specified start and end from a string.
+
+ Arguments:
+ - base: string to extract from.
+ - start: string to use as the start of the substring.
+ - end: string to use as the end of the substring.
+
+ Returns:
+ - The substring if found
+ - The part of the base after end of the substring. Is the base string itself
+ if the substring wasnt found.
+ """
+ split = base.split(start, 1)
+ if len(split) == 2:
+ rest = split[1].split(end, 1)
+ assert len(rest) == 2, \
+ 'cannot find end "{end}" while extracting substring '\
+ 'starting with {start}'.format(start=start, end=end)
+ return rest[0].rstrip(end), rest[1]
+ return '', split[0]
+
+
def extract_td_op_info(op_def):
"""Extracts potentially manually specified sections in op's definition.
@@ -461,39 +489,32 @@ def extract_td_op_info(op_def):
assert len(opname) == 1, 'more than one ops in the same section!'
opname = opname[0]
+ # Get category_args
+ op_tmpl_params = op_def.split('<', 1)[1].split('>', 1)[0]
+ opstringname, rest = get_string_between(op_tmpl_params, '"', '"')
+ category_args = rest.split('[', 1)[0]
+
# Get traits
- op_tmpl_params = op_def.split('<', 1)[1].split('>', 1)[0].split(', ', 1)
- if len(op_tmpl_params) == 1:
- traits = ''
- else:
- traits = op_tmpl_params[1].strip('[]')
+ traits, _ = get_string_between(rest, '[', ']')
# Get custom assembly form
- rest = op_def.split('### Custom assembly form\n')
- assert len(rest) == 2, \
- '{}: cannot find "### Custom assembly form"'.format(opname)
- rest = rest[1].split(' let arguments = (ins')
- assert len(rest) == 2, '{}: cannot find arguments'.format(opname)
- assembly = rest[0].rstrip('}];\n')
+ assembly, rest = get_string_between(op_def, '### Custom assembly form\n',
+ '}];\n')
# Get arguments
- rest = rest[1].split(' let results = (outs')
- assert len(rest) == 2, '{}: cannot find results'.format(opname)
- args = rest[0].rstrip(');\n')
+ args, rest = get_string_between(rest, ' let arguments = (ins', ');\n')
# Get results
- rest = rest[1].split(');', 1)
- assert len(rest) == 2, \
- '{}: cannot find ");" ending results'.format(opname)
- results = rest[0]
+ results, rest = get_string_between(rest, ' let results = (outs', ');\n')
- extras = rest[1].strip(' }\n')
+ extras = rest.strip(' }\n')
if extras:
extras = '\n {}\n'.format(extras)
return {
# Prefix with 'Op' to make it consistent with SPIR-V spec
'opname': 'Op{}'.format(opname),
+ 'category_args': category_args,
'traits': traits,
'assembly': assembly,
'arguments': args,
@@ -502,7 +523,8 @@ def extract_td_op_info(op_def):
}
-def update_td_op_definitions(path, instructions, docs, filter_list):
+def update_td_op_definitions(path, instructions, docs, filter_list,
+ inst_category):
"""Updates SPIRVOps.td with newly generated op definition.
Arguments:
@@ -541,7 +563,7 @@ def update_td_op_definitions(path, instructions, docs, filter_list):
inst for inst in instructions if inst['opname'] == opname)
op_defs.append(
get_op_definition(instruction, docs[opname],
- op_info_dict.get(opname, {})))
+ op_info_dict.get(opname, {}), inst_category))
# Substitute the old op definitions
op_defs = [header] + op_defs + [footer]
@@ -588,7 +610,16 @@ if __name__ == '__main__':
dest='new_inst',
type=str,
default=None,
- help='SPIR-V instruction to be added to SPIRVOps.td')
+ nargs='*',
+ help='SPIR-V instruction to be added to ops file')
+ cli_parser.add_argument(
+ '--inst-category',
+ dest='inst_category',
+ type=str,
+ default='Op',
+ help='SPIR-V instruction category used for choosing '\
+ 'a suitable .td file and TableGen common base '\
+ 'class to define this op')
args = cli_parser.parse_args()
@@ -608,9 +639,9 @@ if __name__ == '__main__':
# Define new op
if args.new_inst is not None:
assert args.op_td_path is not None
- filter_list = [args.new_inst] if args.new_inst else []
docs = get_spirv_doc_from_html_spec()
- update_td_op_definitions(args.op_td_path, instructions, docs, filter_list)
+ update_td_op_definitions(args.op_td_path, instructions, docs, args.new_inst,
+ args.inst_category)
print('Done. Note that this script just generates a template; ', end='')
print('please read the spec and update traits, arguments, and ', end='')
print('results accordingly.')
OpenPOWER on IntegriCloud