| | import collections |
| | import copy |
| | import itertools |
| | import os |
| |
|
| | import asdl |
| | import attr |
| | import networkx as nx |
| |
|
| | from seq2struct import ast_util |
| | from seq2struct.utils import registry |
| |
|
| |
|
| | def bimap(first, second): |
| | return {f: s for f, s in zip(first, second)}, {s: f for f, s in zip(first, second)} |
| |
|
| |
|
| | def filter_nones(d): |
| | return {k: v for k, v in d.items() if v is not None and v != []} |
| |
|
| |
|
| | def join(iterable, delimiter): |
| | it = iter(iterable) |
| | yield next(it) |
| | for x in it: |
| | yield delimiter |
| | yield x |
| |
|
| |
|
| | def intersperse(delimiter, seq): |
| | return itertools.islice( |
| | itertools.chain.from_iterable( |
| | zip(itertools.repeat(delimiter), seq)), 1, None) |
| |
|
| |
|
| | @registry.register('grammar', 'spider') |
| | class SpiderLanguage: |
| |
|
| | root_type = 'sql' |
| |
|
| | def __init__( |
| | self, |
| | output_from=False, |
| | use_table_pointer=False, |
| | include_literals=True, |
| | include_columns=True, |
| | end_with_from=False, |
| | clause_order=None, |
| | infer_from_conditions=False, |
| | factorize_sketch=0): |
| |
|
| | |
| | custom_primitive_type_checkers = {} |
| | self.pointers = set() |
| | if use_table_pointer: |
| | custom_primitive_type_checkers['table'] = lambda x: isinstance(x, int) |
| | self.pointers.add('table') |
| | self.include_columns = include_columns |
| | if include_columns: |
| | custom_primitive_type_checkers['column'] = lambda x: isinstance(x, int) |
| | self.pointers.add('column') |
| |
|
| | |
| | self.factorize_sketch = factorize_sketch |
| | if self.factorize_sketch == 0: |
| | asdl_file = "Spider.asdl" |
| | elif self.factorize_sketch == 1: |
| | asdl_file = "Spider_f1.asdl" |
| | elif self.factorize_sketch == 2: |
| | asdl_file = "Spider_f2.asdl" |
| | else: |
| | raise NotImplementedError |
| | self.ast_wrapper = ast_util.ASTWrapper( |
| | asdl.parse( |
| | os.path.join( |
| | os.path.dirname(os.path.abspath(__file__)), |
| | asdl_file)), |
| | custom_primitive_type_checkers=custom_primitive_type_checkers) |
| | if not use_table_pointer: |
| | self.ast_wrapper.singular_types['Table'].fields[0].type = 'int' |
| | if not include_columns: |
| | col_unit_fields = self.ast_wrapper.singular_types['col_unit'].fields |
| | assert col_unit_fields[1].name == 'col_id' |
| | del col_unit_fields[1] |
| |
|
| | |
| | self.include_literals = include_literals |
| | if not self.include_literals: |
| | if self.factorize_sketch == 0: |
| | limit_field = self.ast_wrapper.singular_types['sql'].fields[6] |
| | else: |
| | limit_field = self.ast_wrapper.singular_types['sql_orderby'].fields[1] |
| | assert limit_field.name == 'limit' |
| | limit_field.opt = False |
| | limit_field.type = 'singleton' |
| |
|
| | |
| | self.output_from = output_from |
| | self.end_with_from = end_with_from |
| | self.clause_order = clause_order |
| | self.infer_from_conditions = infer_from_conditions |
| | if self.clause_order: |
| | |
| | assert factorize_sketch == 2 |
| | sql_fields = self.ast_wrapper.product_types['sql'].fields |
| | letter2field = { k:v for k, v in zip("SFWGOI", sql_fields)} |
| | new_sql_fields = [letter2field[k] for k in self.clause_order] |
| | self.ast_wrapper.product_types['sql'].fields = new_sql_fields |
| | else: |
| | if not self.output_from: |
| | sql_fields = self.ast_wrapper.product_types['sql'].fields |
| | assert sql_fields[1].name == 'from' |
| | del sql_fields[1] |
| | else: |
| | sql_fields = self.ast_wrapper.product_types['sql'].fields |
| | assert sql_fields[1].name == "from" |
| | if self.end_with_from: |
| | sql_fields.append(sql_fields[1]) |
| | del sql_fields[1] |
| |
|
| |
|
| | def parse(self, code, section): |
| | return self.parse_sql(code) |
| |
|
| | def unparse(self, tree, item): |
| | unparser = SpiderUnparser(self.ast_wrapper, item.schema, self.factorize_sketch) |
| | return unparser.unparse_sql(tree) |
| |
|
| | @classmethod |
| | def tokenize_field_value(cls, field_value): |
| | if isinstance(field_value, bytes): |
| | field_value_str = field_value.encode('latin1') |
| | elif isinstance(field_value, str): |
| | field_value_str = field_value |
| | else: |
| | field_value_str = str(field_value) |
| | if field_value_str[0] == '"' and field_value_str[-1] == '"': |
| | field_value_str = field_value_str[1:-1] |
| | |
| | return [field_value_str] |
| |
|
| | |
| | |
| | |
| |
|
| | def parse_val(self, val): |
| | if isinstance(val, str): |
| | if not self.include_literals: |
| | return {'_type': 'Terminal'} |
| | return { |
| | '_type': 'String', |
| | 's': val, |
| | } |
| | elif isinstance(val, list): |
| | return { |
| | '_type': 'ColUnit', |
| | 'c': self.parse_col_unit(val), |
| | } |
| | elif isinstance(val, float): |
| | if not self.include_literals: |
| | return {'_type': 'Terminal'} |
| | return { |
| | '_type': 'Number', |
| | 'f': val, |
| | } |
| | elif isinstance(val, dict): |
| | return { |
| | '_type': 'ValSql', |
| | 's': self.parse_sql(val), |
| | } |
| | else: |
| | raise ValueError(val) |
| |
|
| | def parse_col_unit(self, col_unit): |
| | agg_id, col_id, is_distinct = col_unit |
| | result = { |
| | '_type': 'col_unit', |
| | 'agg_id': {'_type': self.AGG_TYPES_F[agg_id]}, |
| | 'is_distinct': is_distinct, |
| | } |
| | if self.include_columns: |
| | result['col_id'] = col_id |
| | return result |
| |
|
| | def parse_val_unit(self, val_unit): |
| | unit_op, col_unit1, col_unit2 = val_unit |
| | result = { |
| | '_type': self.UNIT_TYPES_F[unit_op], |
| | 'col_unit1': self.parse_col_unit(col_unit1), |
| | } |
| | if unit_op != 0: |
| | result['col_unit2'] = self.parse_col_unit(col_unit2) |
| | return result |
| |
|
| | def parse_table_unit(self, table_unit): |
| | table_type, value = table_unit |
| | if table_type == 'sql': |
| | return { |
| | '_type': 'TableUnitSql', |
| | 's': self.parse_sql(value), |
| | } |
| | elif table_type == 'table_unit': |
| | return { |
| | '_type': 'Table', |
| | 'table_id': value, |
| | } |
| | else: |
| | raise ValueError(table_type) |
| |
|
| | def parse_cond(self, cond, optional=False): |
| | if optional and not cond: |
| | return None |
| |
|
| | if len(cond) > 1: |
| | return { |
| | '_type': self.LOGIC_OPERATORS_F[cond[1]], |
| | 'left': self.parse_cond(cond[:1]), |
| | 'right': self.parse_cond(cond[2:]), |
| | } |
| |
|
| | (not_op, op_id, val_unit, val1, val2), = cond |
| | result = { |
| | '_type': self.COND_TYPES_F[op_id], |
| | 'val_unit': self.parse_val_unit(val_unit), |
| | 'val1': self.parse_val(val1), |
| | } |
| | if op_id == 1: |
| | result['val2'] = self.parse_val(val2) |
| | if not_op: |
| | result = { |
| | '_type': 'Not', |
| | 'c': result, |
| | } |
| | return result |
| |
|
| | def parse_sql(self, sql, optional=False): |
| | if optional and sql is None: |
| | return None |
| | if self.factorize_sketch == 0: |
| | return filter_nones({ |
| | '_type': 'sql', |
| | 'select': self.parse_select(sql['select']), |
| | 'where': self.parse_cond(sql['where'], optional=True), |
| | 'group_by': [self.parse_col_unit(u) for u in sql['groupBy']], |
| | 'order_by': self.parse_order_by(sql['orderBy']), |
| | 'having': self.parse_cond(sql['having'], optional=True), |
| | 'limit': sql['limit'] if self.include_literals else (sql['limit'] is not None), |
| | 'intersect': self.parse_sql(sql['intersect'], optional=True), |
| | 'except': self.parse_sql(sql['except'], optional=True), |
| | 'union': self.parse_sql(sql['union'], optional=True), |
| | **({ |
| | 'from': self.parse_from(sql['from'], self.infer_from_conditions), |
| | } if self.output_from else {}) |
| | }) |
| | elif self.factorize_sketch == 1: |
| | return filter_nones({ |
| | '_type': 'sql', |
| | 'select': self.parse_select(sql['select']), |
| | **({ |
| | 'from': self.parse_from(sql['from'], self.infer_from_conditions), |
| | } if self.output_from else {}), |
| | 'sql_where': filter_nones({ |
| | '_type': 'sql_where', |
| | 'where': self.parse_cond(sql['where'], optional=True), |
| | 'sql_groupby': filter_nones({ |
| | '_type': 'sql_groupby', |
| | 'group_by': [self.parse_col_unit(u) for u in sql['groupBy']], |
| | 'having': filter_nones({ |
| | '_type': 'having', |
| | 'having': self.parse_cond(sql['having'], optional=True), |
| | }), |
| | 'sql_orderby': filter_nones({ |
| | '_type': 'sql_orderby', |
| | 'order_by': self.parse_order_by(sql['orderBy']), |
| | 'limit': filter_nones({ |
| | '_type': 'limit', |
| | 'limit': sql['limit'] if self.include_literals else (sql['limit'] is not None), |
| | }), |
| | 'sql_ieu': filter_nones({ |
| | '_type': 'sql_ieu', |
| | 'intersect': self.parse_sql(sql['intersect'], optional=True), |
| | 'except': self.parse_sql(sql['except'], optional=True), |
| | 'union': self.parse_sql(sql['union'], optional=True), |
| | }) |
| | }) |
| | }) |
| | }) |
| | }) |
| | elif self.factorize_sketch == 2: |
| | return filter_nones({ |
| | '_type': 'sql', |
| | 'select': self.parse_select(sql['select']), |
| | **({ |
| | 'from': self.parse_from(sql['from'], self.infer_from_conditions), |
| | } if self.output_from else {}), |
| | "sql_where": filter_nones({ |
| | '_type': 'sql_where', |
| | 'where': self.parse_cond(sql['where'], optional=True), |
| | }), |
| | "sql_groupby": filter_nones({ |
| | '_type': 'sql_groupby', |
| | 'group_by': [self.parse_col_unit(u) for u in sql['groupBy']], |
| | 'having': self.parse_cond(sql['having'], optional=True), |
| | }), |
| | "sql_orderby": filter_nones({ |
| | '_type': 'sql_orderby', |
| | 'order_by': self.parse_order_by(sql['orderBy']), |
| | 'limit': sql['limit'] if self.include_literals else (sql['limit'] is not None), |
| | }), |
| | 'sql_ieu': filter_nones({ |
| | '_type': 'sql_ieu', |
| | 'intersect': self.parse_sql(sql['intersect'], optional=True), |
| | 'except': self.parse_sql(sql['except'], optional=True), |
| | 'union': self.parse_sql(sql['union'], optional=True), |
| | }) |
| | }) |
| |
|
| | def parse_select(self, select): |
| | is_distinct, aggs = select |
| | return { |
| | '_type': 'select', |
| | 'is_distinct': is_distinct, |
| | 'aggs': [self.parse_agg(agg) for agg in aggs], |
| | } |
| |
|
| | def parse_agg(self, agg): |
| | agg_id, val_unit = agg |
| | return { |
| | '_type': 'agg', |
| | 'agg_id': {'_type': self.AGG_TYPES_F[agg_id]}, |
| | 'val_unit': self.parse_val_unit(val_unit), |
| | } |
| |
|
| | def parse_from(self, from_, infer_from_conditions=False): |
| | return filter_nones({ |
| | '_type': 'from', |
| | 'table_units': [ |
| | self.parse_table_unit(u) for u in from_['table_units']], |
| | 'conds': self.parse_cond(from_['conds'], optional=True) \ |
| | if not infer_from_conditions else None, |
| | }) |
| |
|
| | def parse_order_by(self, order_by): |
| | if not order_by: |
| | return None |
| |
|
| | order, val_units = order_by |
| | return { |
| | '_type': 'order_by', |
| | 'order': {'_type': self.ORDERS_F[order]}, |
| | 'val_units': [self.parse_val_unit(v) for v in val_units] |
| | } |
| |
|
| | |
| | |
| | |
| |
|
| | COND_TYPES_F, COND_TYPES_B = bimap( |
| | |
| | |
| | range(1, 10), |
| | ('Between', 'Eq', 'Gt', 'Lt', 'Ge', 'Le', 'Ne', 'In', 'Like')) |
| |
|
| | UNIT_TYPES_F, UNIT_TYPES_B = bimap( |
| | |
| | range(5), |
| | ('Column', 'Minus', 'Plus', 'Times', 'Divide')) |
| |
|
| | AGG_TYPES_F, AGG_TYPES_B = bimap( |
| | range(6), |
| | ('NoneAggOp', 'Max', 'Min', 'Count', 'Sum', 'Avg')) |
| |
|
| | ORDERS_F, ORDERS_B = bimap( |
| | ('asc', 'desc'), |
| | ('Asc', 'Desc')) |
| |
|
| | LOGIC_OPERATORS_F, LOGIC_OPERATORS_B = bimap( |
| | ('and', 'or'), |
| | ('And', 'Or')) |
| |
|
| | @attr.s |
| | class SpiderUnparser: |
| | ast_wrapper = attr.ib() |
| | schema = attr.ib() |
| | factorize_sketch = attr.ib(default=0) |
| |
|
| | UNIT_TYPES_B = { |
| | 'Minus': '-', |
| | 'Plus': '+', |
| | 'Times': '*', |
| | 'Divide': '/', |
| | } |
| | COND_TYPES_B = { |
| | 'Between': 'BETWEEN', |
| | 'Eq': '=', |
| | 'Gt': '>', |
| | 'Lt': '<', |
| | 'Ge': '>=', |
| | 'Le': '<=', |
| | 'Ne': '!=', |
| | 'In': 'IN', |
| | 'Like': 'LIKE' |
| | } |
| |
|
| | @classmethod |
| | def conjoin_conds(cls, conds): |
| | if not conds: |
| | return None |
| | if len(conds) == 1: |
| | return conds[0] |
| | return {'_type': 'And', 'left': conds[0], 'right': cls.conjoin_conds(conds[1:])} |
| | |
| | @classmethod |
| | def linearize_cond(cls, cond): |
| | if cond['_type'] in ('And', 'Or'): |
| | conds, keywords = cls.linearize_cond(cond['right']) |
| | return [cond['left']] + conds, [cond['_type']] + keywords |
| | else: |
| | return [cond], [] |
| |
|
| | def unparse_val(self, val): |
| | if val['_type'] == 'Terminal': |
| | return "'terminal'" |
| | if val['_type'] == 'String': |
| | return val['s'] |
| | if val['_type'] == 'ColUnit': |
| | return self.unparse_col_unit(val['c']) |
| | if val['_type'] == 'Number': |
| | return str(val['f']) |
| | if val['_type'] == 'ValSql': |
| | return '({})'.format(self.unparse_sql(val['s'])) |
| |
|
| | def unparse_col_unit(self, col_unit): |
| | if 'col_id' in col_unit: |
| | column = self.schema.columns[col_unit['col_id']] |
| | if column.table is None: |
| | column_name = column.orig_name |
| | else: |
| | column_name = '{}.{}'.format(column.table.orig_name, column.orig_name) |
| | else: |
| | column_name = 'some_col' |
| |
|
| | if col_unit['is_distinct']: |
| | column_name = 'DISTINCT {}'.format(column_name) |
| | agg_type = col_unit['agg_id']['_type'] |
| | if agg_type == 'NoneAggOp': |
| | return column_name |
| | else: |
| | return '{}({})'.format(agg_type, column_name) |
| |
|
| | def unparse_val_unit(self, val_unit): |
| | if val_unit['_type'] == 'Column': |
| | return self.unparse_col_unit(val_unit['col_unit1']) |
| | col1 = self.unparse_col_unit(val_unit['col_unit1']) |
| | col2 = self.unparse_col_unit(val_unit['col_unit2']) |
| | return '{} {} {}'.format(col1, self.UNIT_TYPES_B[val_unit['_type']], col2) |
| |
|
| | |
| | |
| |
|
| | def unparse_cond(self, cond, negated=False): |
| | if cond['_type'] == 'And': |
| | assert not negated |
| | return '{} AND {}'.format( |
| | self.unparse_cond(cond['left']), self.unparse_cond(cond['right'])) |
| | elif cond['_type'] == 'Or': |
| | assert not negated |
| | return '{} OR {}'.format( |
| | self.unparse_cond(cond['left']), self.unparse_cond(cond['right'])) |
| | elif cond['_type'] == 'Not': |
| | return self.unparse_cond(cond['c'], negated=True) |
| | elif cond['_type'] == 'Between': |
| | tokens = [self.unparse_val_unit(cond['val_unit'])] |
| | if negated: |
| | tokens.append('NOT') |
| | tokens += [ |
| | 'BETWEEN', |
| | self.unparse_val(cond['val1']), |
| | 'AND', |
| | self.unparse_val(cond['val2']), |
| | ] |
| | return ' '.join(tokens) |
| | tokens = [self.unparse_val_unit(cond['val_unit'])] |
| | if negated: |
| | tokens.append('NOT') |
| | tokens += [self.COND_TYPES_B[cond['_type']], self.unparse_val(cond['val1'])] |
| | return ' '.join(tokens) |
| | |
| | def refine_from(self, tree): |
| | """ |
| | 1) Inferring tables from columns predicted |
| | 2) Mix them with the predicted tables if any |
| | 3) Inferring conditions based on tables |
| | """ |
| | tree = dict(tree) |
| |
|
| | |
| | if "from" in tree and tree["from"]["table_units"][0]["_type"] == 'TableUnitSql': |
| | for table_unit in tree["from"]["table_units"]: |
| | subquery_tree = table_unit["s"] |
| | self.refine_from(subquery_tree) |
| | return |
| | |
| | |
| | predicted_from_table_ids = set() |
| | if "from" in tree: |
| | table_unit_set = [] |
| | for table_unit in tree["from"]["table_units"]: |
| | if table_unit["table_id"] not in predicted_from_table_ids: |
| | predicted_from_table_ids.add(table_unit["table_id"]) |
| | table_unit_set.append(table_unit) |
| | tree["from"]["table_units"] = table_unit_set |
| |
|
| | |
| | candidate_column_ids = set(self.ast_wrapper.find_all_descendants_of_type( |
| | tree, 'column', lambda field: field.type != 'sql')) |
| | candidate_columns = [self.schema.columns[i] for i in candidate_column_ids] |
| | must_in_from_table_ids = set( |
| | column.table.id for column in candidate_columns if column.table is not None) |
| |
|
| | |
| | all_from_table_ids = must_in_from_table_ids.union(predicted_from_table_ids) |
| | if not all_from_table_ids: |
| | |
| | all_from_table_ids = {0} |
| | |
| | covered_tables = set() |
| | candidate_table_ids = sorted(all_from_table_ids) |
| | start_table_id = candidate_table_ids[0] |
| | conds = [] |
| | for table_id in candidate_table_ids[1:]: |
| | if table_id in covered_tables: |
| | continue |
| | try: |
| | path = nx.shortest_path( |
| | self.schema.foreign_key_graph, source=start_table_id, target=table_id) |
| | except (nx.NetworkXNoPath, nx.NodeNotFound): |
| | covered_tables.add(table_id) |
| | continue |
| | |
| | for source_table_id, target_table_id in zip(path, path[1:]): |
| | if target_table_id in covered_tables: |
| | continue |
| | all_from_table_ids.add(target_table_id) |
| | col1, col2 = self.schema.foreign_key_graph[source_table_id][target_table_id]['columns'] |
| | conds.append({ |
| | '_type': 'Eq', |
| | 'val_unit': { |
| | '_type': 'Column', |
| | 'col_unit1': { |
| | '_type': 'col_unit', |
| | 'agg_id': {'_type': 'NoneAggOp'}, |
| | 'col_id': col1, |
| | 'is_distinct': False |
| | }, |
| | }, |
| | 'val1': { |
| | '_type': 'ColUnit', |
| | 'c': { |
| | '_type': 'col_unit', |
| | 'agg_id': {'_type': 'NoneAggOp'}, |
| | 'col_id': col2, |
| | 'is_distinct': False |
| | } |
| | } |
| | }) |
| | table_units = [{'_type': 'Table', 'table_id': i} for i in sorted(all_from_table_ids)] |
| |
|
| | tree['from'] = { |
| | '_type': 'from', |
| | 'table_units': table_units, |
| | } |
| | cond_node = self.conjoin_conds(conds) |
| | if cond_node is not None: |
| | tree['from']['conds'] = cond_node |
| |
|
| |
|
| | def unparse_sql(self, tree): |
| | self.refine_from(tree) |
| |
|
| | result = [ |
| | |
| | self.unparse_select(tree['select']), |
| | |
| | self.unparse_from(tree['from']), |
| | ] |
| |
|
| | def find_subtree(_tree, name): |
| | if self.factorize_sketch == 0: |
| | return _tree, _tree |
| | elif name in _tree: |
| | if self.factorize_sketch == 1: |
| | return _tree[name], _tree[name] |
| | elif self.factorize_sketch == 2: |
| | return _tree, _tree[name] |
| | else: |
| | raise NotImplementedError |
| |
|
| | tree, target_tree = find_subtree(tree, "sql_where") |
| | |
| | if 'where' in target_tree: |
| | result += [ |
| | 'WHERE', |
| | self.unparse_cond(target_tree['where']) |
| | ] |
| |
|
| | tree, target_tree = find_subtree(tree, "sql_groupby") |
| | |
| | if 'group_by' in target_tree: |
| | result += [ |
| | 'GROUP BY', |
| | ', '.join(self.unparse_col_unit(c) for c in target_tree['group_by']) |
| | ] |
| |
|
| | tree, target_tree = find_subtree(tree, "sql_orderby") |
| | |
| | if 'order_by' in target_tree: |
| | result.append(self.unparse_order_by(target_tree['order_by'])) |
| |
|
| | tree, target_tree = find_subtree(tree, "sql_groupby") |
| | |
| | if 'having' in target_tree: |
| | result += ['HAVING', self.unparse_cond(target_tree['having'])] |
| |
|
| | tree, target_tree = find_subtree(tree, "sql_orderby") |
| | |
| | if 'limit' in target_tree: |
| | if isinstance(target_tree['limit'], bool): |
| | if target_tree['limit']: |
| | result += ['LIMIT', '1'] |
| | else: |
| | result += ['LIMIT', str(target_tree['limit'])] |
| |
|
| | tree, target_tree = find_subtree(tree, "sql_ieu") |
| | |
| | if 'intersect' in target_tree: |
| | result += ['INTERSECT', self.unparse_sql(target_tree['intersect'])] |
| | |
| | if 'except' in target_tree: |
| | result += ['EXCEPT', self.unparse_sql(target_tree['except'])] |
| | |
| | if 'union' in target_tree: |
| | result += ['UNION', self.unparse_sql(target_tree['union'])] |
| | |
| | return ' '.join(result) |
| |
|
| | def unparse_select(self, select): |
| | tokens = ['SELECT'] |
| | if select['is_distinct']: |
| | tokens.append('DISTINCT') |
| | tokens.append(', '.join(self.unparse_agg(agg) for agg in select.get('aggs', []))) |
| | return ' '.join(tokens) |
| |
|
| | def unparse_agg(self, agg): |
| | unparsed_val_unit = self.unparse_val_unit(agg['val_unit']) |
| | agg_type = agg['agg_id']['_type'] |
| | if agg_type == 'NoneAggOp': |
| | return unparsed_val_unit |
| | else: |
| | return '{}({})'.format(agg_type, unparsed_val_unit) |
| |
|
| | def unparse_from(self, from_): |
| | if 'conds' in from_: |
| | all_conds, keywords = self.linearize_cond(from_['conds']) |
| | else: |
| | all_conds, keywords = [], [] |
| | assert all(keyword == 'And' for keyword in keywords) |
| |
|
| | cond_indices_by_table = collections.defaultdict(set) |
| | tables_involved_by_cond_idx = collections.defaultdict(set) |
| | for i, cond in enumerate(all_conds): |
| | for column in self.ast_wrapper.find_all_descendants_of_type(cond, 'column'): |
| | table = self.schema.columns[column].table |
| | if table is None: |
| | continue |
| | cond_indices_by_table[table.id].add(i) |
| | tables_involved_by_cond_idx[i].add(table.id) |
| |
|
| | output_table_ids = set() |
| | output_cond_indices = set() |
| | tokens = ['FROM'] |
| | for i, table_unit in enumerate(from_.get('table_units', [])): |
| | if i > 0: |
| | tokens += ['JOIN'] |
| |
|
| | if table_unit['_type'] == 'TableUnitSql': |
| | tokens.append('({})'.format(self.unparse_sql(table_unit['s']))) |
| | elif table_unit['_type'] == 'Table': |
| | table_id = table_unit['table_id'] |
| | tokens += [self.schema.tables[table_id].orig_name] |
| | output_table_ids.add(table_id) |
| |
|
| | |
| | conds_to_output = [] |
| | for cond_idx in sorted(cond_indices_by_table[table_id]): |
| | if cond_idx in output_cond_indices: |
| | continue |
| | if tables_involved_by_cond_idx[cond_idx] <= output_table_ids: |
| | conds_to_output.append(all_conds[cond_idx]) |
| | output_cond_indices.add(cond_idx) |
| | if conds_to_output: |
| | tokens += ['ON'] |
| | tokens += list(intersperse( |
| | 'AND', |
| | (self.unparse_cond(cond) for cond in conds_to_output))) |
| | return ' '.join(tokens) |
| |
|
| | def unparse_order_by(self, order_by): |
| | return 'ORDER BY {} {}'.format( |
| | ', '.join(self.unparse_val_unit(v) for v in order_by['val_units']), |
| | order_by['order']['_type']) |
| |
|