#!/usr/bin/env python3

# SPDX-License-Identifier: GPL-2.0-or-later
#
# This file is part of osm2pgsql (https://osm2pgsql.org/).
#
# Copyright (C) 2025-2026 by the osm2pgsql developer community.
# For a full list of authors see the git log.
"""
Test runner for BDD-style integration tests.

See osm2pgsql manual for more information on osm2pgsql style testing.
"""
import logging
import sys
import tempfile
import math
import re
import os
import contextlib
import json
import datetime as dt
from decimal import Decimal
from subprocess import Popen, PIPE
from argparse import ArgumentParser, RawDescriptionHelpFormatter
from pathlib import Path
import importlib.util
import io
from importlib.machinery import SourceFileLoader

from behave import given, when, then, use_step_matcher, use_fixture, fixture
from behave.runner import ModelRunner, Context
from behave.formatter.base import StreamOpener
from behave.formatter.pretty import PrettyFormatter
from behave import runner_util
from behave.configuration import Configuration

LOG = logging.getLogger()

import psycopg
from psycopg import sql

use_step_matcher('re')

OBJECT_ORDER = {'n': 1, 'w': 2, 'r': 3}

def opl_sort(line):
    oid = line.split(' ', 1)[0]
    return OBJECT_ORDER[oid[0]], int(oid[1:])

#################### Replication mock ##############################

class ReplicationServerMock:

    def __init__(self, base_url, state_infos):
        self.expected_base_url = base_url
        self.state_infos = state_infos

    def __call__(self, base_url):
        assert base_url == self.expected_base_url,\
               f"Wrong replication service called. Expected '{self.expected_base_url}', got '{base_url}'"
        return self

    def get_state_info(self, seq=None, retries=2):
        assert self.state_infos, 'Replication mock not properly set up'
        if seq is None:
            return self.state_infos[-1]

        for info in self.state_infos:
            if info.sequence == seq:
                return info

        return None

    def timestamp_to_sequence(self, timestamp, balanced_search=False):
        assert self.state_infos, 'Replication mock not properly set up'

        if timestamp < self.state_infos[0].timestamp:
            return self.state_infos[0].sequence

        prev = self.state_infos[0]
        for info in self.state_infos:
            if timestamp >= prev.timestamp and timestamp < info.timestamp:
                return prev.sequence
            prev = info

        return prev.sequence

    def apply_diffs(self, handler, start_id, max_size=1024, idx="", simplify=True):
        if start_id > self.state_infos[-1].sequence:
            return None

        numdiffs = int((max_size + 1023)/1024)
        return min(self.state_infos[-1].sequence, start_id + numdiffs - 1)

# Replication module is optional
_replication_path = Path(__file__, '..', 'osm2pgsql-replication').resolve()

if _replication_path.is_file():
    _repfl_spec = importlib.util.spec_from_loader(
        'osm2pgsql_replication',
        SourceFileLoader('osm2pgsql_replication', str(_replication_path)))

    osm2pgsql_replication = importlib.util.module_from_spec(_repfl_spec)
    try:
        _repfl_spec.loader.exec_module(osm2pgsql_replication)
    except Exception as e:
        raise RuntimeError(f"osm2pgsql_replication script found but not readable.")

    from osmium.replication.server import OsmosisState
else:
    osm2pgsql_replication = None

#################### hooks #########################################

def hook_before_all(context):
    context.config.setup_logging(logging.INFO)

    # Feature check: table spaces
    if context.user_args.test_tablespace == 'auto':
        with context.connect_db('postgres') as conn:
            with conn.cursor() as cur:
                cur.execute("""SELECT spcname FROM pg_tablespace
                                WHERE spcname = 'tablespacetest'""")
                context.user_args.test_tablespace = cur.rowcount > 0
        LOG.info('Check if tablespaces are available: %s',
                 'yes' if context.user_args.test_tablespace else 'no')
    else:
        context.user_args.test_tablespace = context.user_args.test_tablespace == 'yes'

    # Test that osm2pgsql works.
    proc = Popen([context.user_args.osm2pgsql_binary, '--version'],
                 stdout=PIPE, stderr=PIPE)
    _, serr = proc.communicate()
    osm2pgsql_version = serr.decode('utf-8')
    if proc.returncode != 0:
        LOG.critical("Could not run osm2pgsql. Error:\n%s", serr)
        LOG.critical("osm2pgsql binary used: %s", context.user_args.osm2pgsql_binary)
        raise RuntimeError('Error running osm2pgsql')

    LOG.info('Check if proj is available: %s',
             'yes' if context.user_args.test_proj else 'no')

    # Feature check: proj
    if context.user_args.test_proj == 'auto':
        context.user_args.test_proj = 'Proj [disabled]' not in osm2pgsql_version
    else:
        context.user_args.test_proj = context.user_args.test_proj == 'yes'

    use_fixture(template_test_db, context)



def hook_before_scenario(context, scenario):
    if 'config.have_proj' in scenario.tags and not context.user_args.test_proj:
        scenario.skip("Generic proj library not configured.")

    context.db = use_fixture(test_db, context)
    context.import_file = None
    context.import_data = None
    context.osm2pgsql_params = {'-d': context.user_args.test_db}
    context.osm2pgsql_returncode = None
    context.workdir = use_fixture(working_directory, context)
    context.nodes = NodeStore()
    context.sql_statements = {}
    context.urlrequest_responses = {}
    if osm2pgsql_replication is not None:
        osm2pgsql_replication.ReplicationServer = None

        def _mock_urlopen(request):
            if not request.full_url in context.urlrequest_responses:
                raise urllib.error.URLError('Unknown URL')

            return contextlib.closing(io.BytesIO(context.urlrequest_responses[request.full_url].encode('utf-8')))

        osm2pgsql_replication.urlrequest.urlopen = _mock_urlopen



#################### fixtures ######################################

@fixture
def template_test_db(context, **kwargs):
    context.drop_db(context.user_args.template_test_db, recreate_template='default')

    with context.connect_db(context.user_args.template_test_db) as conn:
        conn.execute('CREATE EXTENSION postgis')
        conn.execute('CREATE EXTENSION hstore')

    yield context.user_args.template_test_db

    context.drop_db(context.user_args.template_test_db)


@fixture
def test_db(context, **kwargs):
    context.drop_db(recreate_template=context.user_args.template_test_db)

    with context.connect_db() as conn:
        yield conn

    if not context.user_args.keep_test_db:
        context.drop_db()


@fixture
def working_directory(context, **kwargs):
    with tempfile.TemporaryDirectory() as tmpdir:
        yield Path(tmpdir)

################### Node location creation #########################

class NodeStore:
    grid = {}

    def set_grid(self, lines, grid_step, origin_x, origin_y):
        self.grid = {}
        origin_y -= grid_step * (len(lines) - 1)

        ndigits = 1
        step = grid_step
        while step < 0:
            ndigits += 1
            step /= 10

        for y, line in enumerate(lines):
            for x, pt in enumerate(line):
                if pt.isdigit():
                    self.grid[int(pt)] = (round(origin_x + x * grid_step, ndigits),
                                          round(origin_y + y * grid_step, ndigits))

    def get_as_opl(self):
        return [f"n{i} x{x} y{y}" for i, (x, y) in self.grid.items()]

    def add_coordinates(self, lines):
        for line in lines:
            if line.startswith('n') and ' x' not in line:
                nid = int(line.split(' ', 1)[0][1:])
                assert nid in self.grid, \
                    f"OPL error. Node {nid} has no coordinates and is not in grid."
                x, y = self.grid[nid]
                yield f"{line} x{x} y{y}"
            else:
                yield line

    def parse_point(self, pt):
        pt = pt.strip()
        if ' ' in pt:
            return list(map(float, pt.split(' ', 1)))
        return self.grid[int(pt)]


################### ResultComparison ###############################

class ResultCompare:

    def __init__(self, heading, nodes):
        self.nodes = nodes
        if '!' in heading:
            self.name, self.fmt = heading.rsplit('!', 1)
            if self.fmt.startswith(':'):
                self.compare = self._intcompare_fmt
            elif self.fmt.startswith('~'):
                if self.fmt.endswith('%'):
                    rel_tol = float(self.fmt[1:-1]) / 100
                    self.compare = lambda exp, val: math.isclose(float(exp), val, rel_tol=rel_tol)
                else:
                    abs_tol = float(self.fmt[1:])
                    self.compare = lambda exp, val: math.isclose(float(exp), val, abs_tol=abs_tol)
            else:
                self.compare = getattr(self, f"_compare_{self.fmt}", None)
            assert self.compare is not None, f"Unknown formatter {self.fmt}"
        else:
            self.name = heading
            self.fmt = None
            self.compare = lambda exp, val: str(val) == exp

    def as_select(self):
        if self.fmt == 'geo':
            return f"ST_AsText({self.name})"

        return self.name

    def equals(self, expected, value):
        if expected == 'NULL':
            return value is None

        return self.compare(expected, value)

    def _intcompare_fmt(self, expected, value):
        return expected == f"{{{self.fmt}}}".format(value)

    def _compare_i(self, expected, value):
        return expected.lower() == str(value).lower()

    def _compare_re(self, expected, value):
        return re.fullmatch(expected, str(value)) is not None

    def _compare_substr(self, expected, value):
        return expected in str(value)

    def _compare_json(self, expected, value):
        return json.loads(expected) == value

    def _compare_geo(self, expected, value):
        m = re.fullmatch(r'([A-Z]+)\((.*)\)', value)

        return self._eq_geom(expected, m[1], m[2]) if m else False

    def _eq_geom(self, bdd_geom, pg_type, pg_coords):
        # MULTI* geometries
        if bdd_geom.startswith('[') and bdd_geom.endswith(']'):
            bdd_parts = bdd_geom[1:-1].split(';')
            pg_parts = pg_coords[1:-1].split('),(')
            return pg_type.startswith('MULTI') \
                    and len(bdd_parts) == len(pg_parts) \
                    and all(self._eq_geom(b.strip(), pg_type[5:], g.strip())
                            for b, g in zip(bdd_parts, pg_parts))

        # GEOMETRYCOLLECTIONS
        if bdd_geom.startswith('{') and bdd_geom.endswith('}'):
            bdd_parts = bdd_geom[1:-1].split(';')
            pg_parts = list(map(lambda s: re.fullmatch(r'([A-Z]+)\(([^A-Z]*)\)', s),
                                re.findall('[A-Z]+[^A-Z]+[^,A-Z]', pg_coords)))
            return pg_type.startswith('GEOMETRYCOLLECTION')\
                    and len(bdd_parts) == len(pg_parts)\
                    and all(g is not None and self._eq_geom(b.strip(), g[1], g[2])
                            for b, g in zip(bdd_parts, pg_parts))

        # POINT
        if ',' not in bdd_geom:
            return pg_type == 'POINT' and self._eq_point(bdd_geom, pg_coords)

        # LINESTRING
        if '(' not in bdd_geom:
            return pg_type == 'LINESTRING' \
                    and all(self._eq_point(b, p) for b, p
                            in zip((g.strip() for g in bdd_geom.split(',')),
                                   (g.strip() for g in pg_coords.split(','))))

        # POLYGON
        if pg_type != 'POLYGON':
            return False
        # Polygon comparison is tricky because the polygons don't necessarily
        # end at the same point or have the same winding order.
        # Brute force all possible variants of the expected polygon
        bdd_parts = re.findall(r'\([^)]+\)', bdd_geom)
        pg_parts = [g.strip() for g in pg_coords[1:-1].split('),(')]
        return len(bdd_parts) == len(pg_parts) \
                and all(self._eq_ring(*parts) for parts in zip(bdd_parts, pg_parts))

    def _eq_point(self, bdd_pt, pg_pt):
        exp_geom = self.nodes.parse_point(bdd_pt)
        pg_geom = list(map(float, pg_pt.split(' ')))
        return len(exp_geom) == len(pg_geom) \
                and all(math.isclose(e, p, rel_tol=0.000001) for e, p in zip(exp_geom, pg_geom))

    def _eq_ring(self, bdd_ring, pg_ring):
        bdd_pts = [g.strip() for g in bdd_ring[1:-1].split(',')]
        pg_pts = [g.strip() for g in pg_ring.split(',')]
        if bdd_pts[0] != bdd_pts[-1]:
            raise RuntimeError(f"Invalid polygon {bdd_geom}. "
                                "First and last point need to be the same")
        if len(bdd_pts) != len(pg_pts):
            return False

        for line in (bdd_pts[:-1], bdd_pts[-1:0:-1]):
            for i in range(len(line)):
                if all(self._eq_point(p1, p2) for p1, p2 in
                       zip(line[i:] + line[:i], pg_pts)):
                    return True

        return False

################### Steps: Database setup ##########################

@given("the database schema (?P<schema>.+)")
def create_db_schema(context, schema):
    with context.db.cursor() as cur:
        cur.execute("CREATE SCHEMA " + schema)


@when("deleting table (?P<table>.+)")
def delete_table(context, table):
    with context.db.cursor() as cur:
        cur.execute("DROP TABLE " + table)


################### Steps: OSM data ################################

@given("the input file '(?P<osm_file>.+)'")
def osm_set_import_file(context, osm_file):
    assert context.import_data is None, \
        "Import file cannot be used together with inline data."
    pfile = Path(osm_file)
    if pfile.is_absolute():
        context.import_file = pfile
    else:
        basedir = context.user_args.test_data_dir or Path(context.feature.filename).parent
        context.import_file = (basedir / osm_file).resolve()


@given("the OSM data")
def osm_define_data(context):
    assert context.import_file is None, \
        "Inline data cannot be used together with an import file."
    if context.text.strip():
        context.append_osm_data(context.nodes.add_coordinates(context.text.split('\n')))
    else:
        context.append_osm_data([])


@given("the OSM data format string")
def osm_define_data(context):
    assert context.import_file is None, \
        "Inline data cannot be used together with an import file."
    data = eval('f"""' + context.text + '"""')

    context.append_osm_data(context.nodes.add_coordinates(data.split('\n')))


@given("the (?P<step>[0-9.]+ )?grid(?: with origin (?P<origin_x>[0-9.-]+) (?P<origin_y>[0-9.-]+))?")
def osm_define_node_grid(context, step, origin_x, origin_y):
    step = float(step.strip()) if step else 0.1
    x = float(origin_x) if origin_x else 20.0
    y = float(origin_y) if origin_y else 20.0

    assert x > -180.0 and x < 180.0
    assert y > -90.0 and y < 90.0

    context.nodes.set_grid([context.table.headings] + [list(h) for h in context.table], step, x, y)
    context.append_osm_data(context.nodes.get_as_opl())

################### Steps: Style file ##############################

@given("the style file '(?P<style>.+)'")
def setup_style_file(context, style):
    sfile = Path(style)

    if sfile.is_absolute():
        assert sfile.is_file()
    elif context.user_args.style_data_dir is not None \
        and (context.user_args.style_data_dir / sfile).is_file():
            sfile = context.user_args.style_data_dir / sfile
    elif context.user_args.test_data_dir is not None \
        and (context.user_args.test_data_dir / sfile).is_file():
            sfile = context.user_args.test_data_dir / sfile
    else:
        sfile = Path(context.feature.filename).parent / sfile
        assert sfile.is_file()
    context.osm2pgsql_params['-S'] = str(sfile.resolve())


@given("the lua style")
def setup_style_inline(context):
    outfile = context.workdir / 'inline_style.lua'
    outfile.write_text(context.text)

    context.osm2pgsql_params['-S'] = str(outfile)


################### Steps: Running osm2pgsql #######################

@when(r"running osm2pgsql(?P<output> \w+)?(?: with parameters)?")
def execute_osm2pgsql(context, output):
    assert output in (' flex', ' pgsql', ' null', None)

    if output is not None:
        context.osm2pgsql_params['-O'] = output.strip()
    if output == ' pgsql' and '-S' not in context.osm2pgsql_params:
        context.osm2pgsql_params['-S'] = '{STYLE_DATA_DIR}/default.style'

    cmdline = [context.user_args.osm2pgsql_binary]

    test_dir = (context.user_args.test_data_dir or Path('.')).resolve()
    style_dir = (context.user_args.style_data_dir or Path('.')).resolve()
    def _template(param):
        return param.replace('{TEST_DATA_DIR}', str(test_dir))\
                    .replace('{STYLE_DATA_DIR}', str(style_dir))\
                    .replace('{TEST_DB}', context.user_args.test_db)

    if context.table:
        assert not any('<' in h for h in context.table.headings), \
            "Substition in the first line of a table are not supported."
        cmdline.extend(_template(h) for h in context.table.headings if h)
        for row in context.table:
            cmdline.extend(_template(c) for c in row if c)

    for k, v in context.osm2pgsql_params.items():
        if k not in cmdline:
            cmdline.extend((k, _template(v)))

    if not context.user_args.test_tablespace\
        and any(p.startswith('--tablespace') for p in cmdline):
        context.scenario.skip('tablespace tablespacetest not available')
        return

    if context.import_data is not None:
        data_stdin = '\n'.join(sorted(context.import_data.values(), key=opl_sort)).encode('utf-8')
        context.import_data = None
        cmdline.extend(('-r', 'opl', '-'))
    else:
        assert context.import_file is not None, "No input data given."
        cmdline.append(str(context.import_file))
        context.import_file = None
        data_stdin = None

    proc = Popen(cmdline, cwd=str(context.workdir),
                 stdin=PIPE, stdout=PIPE, stderr=PIPE)

    outdata = proc.communicate(input=data_stdin)
    context.osm2pgsql_cmdline = ' '.join(cmdline)
    context.osm2pgsql_outdata = [d.decode('utf-8').replace('\\n', '\n') for d in outdata]
    context.osm2pgsql_returncode = proc.returncode
    context.osm2pgsql_params = {'-d': context.user_args.test_db}

@when("running osm2pgsql-expire with parameters")
def execute_osm2pgsql_expire(context):

    cmdline = [context.user_args.osm2pgsql_binary + '-expire']

    test_dir = (context.user_args.test_data_dir or Path('.')).resolve()
    def _template(param):
        return param.replace('{TEST_DATA_DIR}', str(test_dir))

    if context.table:
        assert not any('<' in h for h in context.table.headings), \
            "Substition in the first line of a table are not supported."
        cmdline.extend(_template(h) for h in context.table.headings if h)
        for row in context.table:
            cmdline.extend(_template(c) for c in row if c)

    proc = Popen(cmdline, cwd=str(context.workdir),
                 stdout=PIPE, stderr=PIPE)

    outdata = proc.communicate()
    context.osm2pgsql_cmdline = ' '.join(cmdline)
    context.osm2pgsql_outdata = [d.decode('utf-8').replace('\\n', '\n') for d in outdata]
    context.osm2pgsql_returncode = proc.returncode

@then("execution is successful")
def osm2pgsql_check_success(context):
    assert context.osm2pgsql_returncode == 0, \
        f"osm2pgsql failed with error code {context.osm2pgsql_returncode}.\n"\
        f"Command line: {context.osm2pgsql_cmdline}\n"\
        f"Output:\n{context.osm2pgsql_outdata[0]}\n{context.osm2pgsql_outdata[1]}\n"

@then(r"execution fails(?: with return code (?P<expected>\d+))?")
def osm2pgsql_check_failure(context, expected):
    retcode = context.osm2pgsql_returncode
    assert retcode != 0, "osm2pgsql unexpectedly succeeded"

    if expected:
        assert retcode == int(expected), \
            f"osm2pgsql failed with return code {retcode} instead of {expected}\n"\
            f"Output:\n{context.osm2pgsql_outdata[0]}\n{context.osm2pgsql_outdata[1]}\n"

@then(r"the (?P<kind>\w+) output contains")
def check_program_output(context, kind):
    if kind == 'error':
        s = context.osm2pgsql_outdata[1]
    elif kind == 'standard':
        s = context.osm2pgsql_outdata[0]
    else:
        assert not "Expect one of error, standard"

    for line in context.text.split('\n'):
        line = line.strip()
        assert line in s,\
               f"Output '{line}' not found in {kind} output:\n{s}\n"

@then(r"the (?P<kind>\w+) output matches contents of (?P<result_file>.*)")
def check_program_output_against_file(context, kind, result_file):
    if kind == 'error':
        s = context.osm2pgsql_outdata[1]
    elif kind == 'standard':
        s = context.osm2pgsql_outdata[0]
    else:
        assert not "Expect one of error, standard"

    test_dir = (context.user_args.test_data_dir or Path('.')).resolve()
    with open(test_dir / result_file, 'r') as fd:
        expected = fd.read()
        assert s == expected, \
               f"Output does not match contents of {result_file}. Actual output:\n{s}"

################### Steps: Running Replication #####################

@given("the replication service at (?P<base_url>.*)")
def setup_replication_mock(context, base_url):
    if osm2pgsql_replication is None:
        context.scenario.skip("Replication binary not available. Skip.")
        return

    if context.table:
        state_infos = \
            [OsmosisState(int(row[0]),
             dt.datetime.strptime(row[1], '%Y-%m-%dT%H:%M:%SZ').replace(tzinfo=dt.timezone.utc))
             for row in context.table]
    else:
        state_infos = []
    osm2pgsql_replication.ReplicationServer = ReplicationServerMock(base_url, state_infos)


@given("the URL (?P<base_url>.*) returns")
def mock_url_response(context, base_url):
    context.urlrequest_responses[base_url] = context.text


@when("running osm2pgsql-replication")
def execute_osm2pgsql_replication(context):
    assert osm2pgsql_replication is not None
    assert osm2pgsql_replication.ReplicationServer is not None

    cmdline = []

    test_dir = (context.user_args.test_data_dir or Path('.')).resolve()
    style_dir = (context.user_args.style_data_dir or Path('.')).resolve()
    def _template(param):
        return param.replace('{TEST_DATA_DIR}', str(test_dir))\
                    .replace('{STYLE_DATA_DIR}', str(style_dir))\
                    .replace('{TEST_DB}', context.user_args.test_db)

    if context.table:
        assert not any('<' in h for h in context.table.headings), \
            "Substition in the first line of a table are not supported."
        cmdline.extend(_template(h) for h in context.table.headings if h)
        for row in context.table:
            cmdline.extend(_template(c) for c in row if c)

    if '-d' not in cmdline and '--database' not in cmdline:
        cmdline.extend(('-d', context.user_args.test_db))

    if cmdline[0] == 'update':
        cmdline.extend(('--osm2pgsql-cmd', context.user_args.osm2pgsql_binary))

        if '--' not in cmdline:
            cmdline.extend(('--', '-S', str(style_dir / 'default.style')))

    serr = io.StringIO()
    log_handler = logging.StreamHandler(serr)
    osm2pgsql_replication.LOG.addHandler(log_handler)
    with contextlib.redirect_stdout(io.StringIO()) as sout:
        context.osm2pgsql_returncode = osm2pgsql_replication.main(cmdline)
    osm2pgsql_replication.LOG.removeHandler(log_handler)
    context.osm2pgsql_outdata = [sout.getvalue(), serr.getvalue()]


################### Steps: Inspect database ########################

@given("the SQL statement (?P<sql>.+)")
def db_define_sql_statement(context, sql):
    context.sql_statements[sql] = context.text


@then("there are (?P<exists>no )?tables (?P<tables>.+)")
def db_table_existance(context, exists, tables):
    for table in tables.split(','):
        table = table.strip()
        if exists == 'no ':
            assert not context.table_exists(table), f"Table '{table}' unexpectedly found"
        else:
            assert context.table_exists(table), f"Table '{table}' not found"


@then("table (?P<table>.+) contains(?P<exact> exactly)?")
def db_check_table_content(context, table, exact):
    context.execute_steps("Then execution is successful")
    assert context.table_exists(table), f"Table {table} not found in database."
    context.check_select(sql.Identifier(*table.split('.', 1)), exact is not None)


@then("table (?P<table>.+) doesn't contain")
def db_check_table_content_negative(context, table):
    context.execute_steps("Then execution is successful")
    assert context.table_exists(table), f"Table {table} not found in database."
    context.check_select_not_contained(sql.Identifier(*table.split('.', 1)))


@then(r"table (?P<table>.+) has (?P<row_num>\d+) rows?")
def db_table_row_count(context, table, row_num):
    context.execute_steps("Then execution is successful")
    assert context.table_exists(table), f"Table {table} not found in database."

    query = sql.SQL("SELECT count(*) FROM {}").format(sql.Identifier(*table.split('.', 2)))
    for res in context.db.execute(query):
        assert res[0] == int(row_num),\
           f"Table {table}: expected {row_num} rows, got {res[0]}"


@then("statement (?P<stmt>.+) returns(?P<exact> exactly)?")
def db_check_table_content(context, stmt, exact):
    context.execute_steps("Then execution is successful")
    assert stmt in context.sql_statements
    context.check_select(sql.SQL(f"({context.sql_statements[stmt]}) _statement_sql"),
                         exact is not None)

################### Context ########################################

class Osm2pgsqlContext(Context):

    def __init__(self, runner, args):
        super().__init__(runner)
        self.user_args = args

    def connect_db(self, name=None):
        dbname = name or self.user_args.test_db
        return psycopg.connect(dbname=dbname, autocommit=True)

    def drop_db(self, name=None, recreate_template=None):
        db = sql.Identifier(name or self.user_args.test_db)
        with self.connect_db('postgres') as conn:
            conn.execute(sql.SQL('DROP DATABASE IF EXISTS {}').format(db))
            if recreate_template == 'default':
                conn.execute(sql.SQL('CREATE DATABASE {}').format(db))
            elif recreate_template:
                conn.execute(sql.SQL('CREATE DATABASE {} WITH TEMPLATE {}')
                                .format(db, sql.Identifier(recreate_template)))

    def append_osm_data(self, lines, include_untagged=None):
        if self.import_data is None:
            self.import_data = {}
        for line in lines:
            if (l := line.strip()):
                self.import_data[l.split(' ', 1)[0]] = l

    def table_exists(self, table):
        sql_params = table.split('.', 1) if '.' in table else ('public', table)
        sql = 'SELECT count(*) FROM {} WHERE schemaname = %s AND tablename = %s'

        for look_in in ('pg_tables', 'pg_views'):
            for res in self.db.execute(sql.format('pg_tables'), sql_params):
                if res[0] == 1:
                    return True
        return False

    def check_select(self, from_clause, exact):
        rows = [ResultCompare(h, self.nodes) for h in self.table.headings]
        lines = set(range(0, len(self.table.rows)))

        query = sql.SQL('SELECT {} FROM {}').format(
                sql.SQL(', '.join(f"({r.as_select()}) as c{i}" for i, r in enumerate(rows))),
                from_clause)

        table_content = ''
        unexpected_rows = []
        for row in self.db.execute(query):
            table_content += f"\n{row}"
            for i in lines:
                for attr, expected, value in zip(rows, self.table[i], row):
                    if not attr.equals(expected, value):
                        break
                else:
                    lines.remove(i)
                    break
            else:
                if exact:
                    unexpected_rows.append(str(row))

        assert not lines, \
               "Rows not found:\n" \
               + '\n'.join(str(self.table[i]) for i in lines) \
               + "\nTable content:\n" \
               + table_content

        assert not unexpected_rows, \
            "Unexpected rows found:\n" + '\n'.join(unexpected_rows)\
               + "\nTable content:\n" \
               + table_content

    def check_select_not_contained(self, from_clause):
        rows = [ResultCompare(h, self.nodes) for h in self.table.headings]
        lines = set(range(0, len(self.table.rows)))

        query = sql.SQL('SELECT {} FROM {}').format(
                sql.SQL(', '.join(f"({r.as_select()}) as c{i}" for i, r in enumerate(rows))),
                from_clause)

        table_content = ''
        matching_rows = []
        for row in self.db.execute(query):
            table_content += f"\n{row}"
            for i in lines:
                for attr, expected, value in zip(rows, self.table[i], row):
                    if not attr.equals(expected, value):
                        break
                else:
                    matching_rows.append(str(row))
                    break

        assert not matching_rows, \
            "Matching rows found:\n" + '\n'.join(matching_rows)\
               + "\nFull table content:\n" \
               + table_content


#################### runner and main ###############################

class Osm2pgsqlRunner(ModelRunner):

    def __init__(self, config, args):
        super().__init__(config)
        self.feature_locations = runner_util.collect_feature_locations(args.features)
        self.hooks = {
          'before_all' : hook_before_all,
          'before_scenario': hook_before_scenario
        }
        self.context = Osm2pgsqlContext(self, args)

    def run(self):
        features = runner_util.parse_features(self.feature_locations)
        self.features.extend(features)

        stream_opener = StreamOpener(stream=sys.stdout)
        self.formatters = [PrettyFormatter(stream_opener, self.config)]
        return self.run_model()


def get_parser():
    parser = ArgumentParser(description=__doc__,
                            prog='osm2pgsql-test-style',
                            formatter_class=RawDescriptionHelpFormatter)
    parser.add_argument('features', nargs='+',
                        help='Feature files or paths')
    parser.add_argument('--osm2pgsql-binary',
                        help='osm2pgsql binary to use for testing (default: osm2pgsql)')
    parser.add_argument('--test-data-dir', type=Path,
                        help='(optional) directory to search for test data')
    parser.add_argument('--style-data-dir', type=Path,
                        help='(optional) directory to search for style files')
    parser.add_argument('--test-db', default='osm2pgsql-test',
                        help='Name of database to use for testing (default: osm2pgsql-test)')
    parser.add_argument('--template-test-db', default='osm2pgsql-test-template',
                        help='Name of database to use for creating the template db '
                             '(default: osm2pgsql-test-template)')
    parser.add_argument('--keep-test-db', action='store_true',
                        help='Keep the test database around after tests are done')
    parser.add_argument('--test-tablespace', default='auto', choices=['yes', 'no', 'auto'],
                        help='Include tests requiring a tablespace')
    parser.add_argument('--test-proj', default='auto', choices=['yes', 'no', 'auto'],
                        help='Include tests requiring the proj library')

    return parser

def main(prog_args=None):
    parser = get_parser()
    try:
        args = parser.parse_args(args=prog_args)
    except SystemExit:
        return 1

    if args.osm2pgsql_binary is None:
        args.osm2pgsql_binary = 'osm2pgsql'
    else:
        args.osm2pgsql_binary = str(Path(args.osm2pgsql_binary).resolve())

    config = Configuration(command_args=[])
    config.show_skipped = False
    runner = Osm2pgsqlRunner(config, args)

    failed = runner.run()

    if runner.undefined_steps:
        LOG.error('Error in feature definition. The following steps are unknown:\n - '
                  + '\n - '.join(f"{s.keyword} {s.name}" for s in runner.undefined_steps))

    return int(failed)

if __name__ == '__main__':
    retcode = main()
    try:
        pass
    except Exception as ex:
        LOG.fatal("Exception during execution: %s", ex)
        retcode = 3

    sys.exit(retcode)
