#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.

import errno
import os
import socket
import ssl
import stat
import struct
import sys
import time

from gunicorn import util

PLATFORM = sys.platform


class BaseSocket:

    def __init__(self, address, conf, log, fd=None):
        self.log = log
        self.conf = conf

        self.cfg_addr = address
        if fd is None:
            sock = socket.socket(self.FAMILY, socket.SOCK_STREAM)
            bound = False
        else:
            sock = socket.fromfd(fd, self.FAMILY, socket.SOCK_STREAM)
            os.close(fd)
            bound = True

        self.sock = self.set_options(sock, bound=bound)

    def __str__(self):
        return "<socket %d>" % self.sock.fileno()

    def __getattr__(self, name):
        return getattr(self.sock, name)

    def set_options(self, sock, bound=False):
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        if (self.conf.reuse_port
                and hasattr(socket, 'SO_REUSEPORT')):  # pragma: no cover
            try:
                sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
            except OSError as err:
                if err.errno not in (errno.ENOPROTOOPT, errno.EINVAL):
                    raise
        if not bound:
            self.bind(sock)
        sock.setblocking(0)

        # make sure that the socket can be inherited
        if hasattr(sock, "set_inheritable"):
            sock.set_inheritable(True)

        sock.listen(self.conf.backlog)
        return sock

    def bind(self, sock):
        sock.bind(self.cfg_addr)

    def close(self):
        if self.sock is None:
            return

        try:
            self.sock.close()
        except OSError as e:
            self.log.info("Error while closing socket %s", str(e))

        self.sock = None

    def get_backlog(self):
        return -1


class TCPSocket(BaseSocket):

    FAMILY = socket.AF_INET

    def __str__(self):
        if self.conf.is_ssl:
            scheme = "https"
        else:
            scheme = "http"

        addr = self.sock.getsockname()
        return "%s://%s:%d" % (scheme, addr[0], addr[1])

    def set_options(self, sock, bound=False):
        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
        return super().set_options(sock, bound=bound)

    if PLATFORM == "linux":
        def get_backlog(self):
            if self.sock:
                # tcp_info struct from include/uapi/linux/tcp.h
                fmt = 'B' * 8 + 'I' * 24
                try:
                    tcp_info_struct = self.sock.getsockopt(socket.IPPROTO_TCP,
                                                           socket.TCP_INFO, 104)
                    # 12 is tcpi_unacked
                    return struct.unpack(fmt, tcp_info_struct)[12]
                except (AttributeError, OSError):
                    pass
            return 0
    else:
        def get_backlog(self):
            return -1


class TCP6Socket(TCPSocket):

    FAMILY = socket.AF_INET6

    def __str__(self):
        (host, port, _, _) = self.sock.getsockname()
        return "http://[%s]:%d" % (host, port)


class UnixSocket(BaseSocket):

    FAMILY = socket.AF_UNIX

    def __init__(self, addr, conf, log, fd=None):
        if fd is None:
            try:
                st = os.stat(addr)
            except OSError as e:
                if e.args[0] != errno.ENOENT:
                    raise
            else:
                if stat.S_ISSOCK(st.st_mode):
                    os.remove(addr)
                else:
                    raise ValueError("%r is not a socket" % addr)
        super().__init__(addr, conf, log, fd=fd)

    def __str__(self):
        return "unix:%s" % self.cfg_addr

    def bind(self, sock):
        old_umask = os.umask(self.conf.umask)
        sock.bind(self.cfg_addr)
        util.chown(self.cfg_addr, self.conf.uid, self.conf.gid)
        os.umask(old_umask)


def _sock_type(addr):
    if isinstance(addr, tuple):
        if util.is_ipv6(addr[0]):
            sock_type = TCP6Socket
        else:
            sock_type = TCPSocket
    elif isinstance(addr, (str, bytes)):
        sock_type = UnixSocket
    else:
        raise TypeError("Unable to create socket from: %r" % addr)
    return sock_type


def create_sockets(conf, log, fds=None):
    """
    Create a new socket for the configured addresses or file descriptors.

    If a configured address is a tuple then a TCP socket is created.
    If it is a string, a Unix socket is created. Otherwise, a TypeError is
    raised.
    """
    listeners = []

    # get it only once
    addr = conf.address
    fdaddr = [bind for bind in addr if isinstance(bind, int)]
    if fds:
        fdaddr += list(fds)
    laddr = [bind for bind in addr if not isinstance(bind, int)]

    # check ssl config early to raise the error on startup
    # only the certfile is needed since it can contains the keyfile
    if conf.certfile and not os.path.exists(conf.certfile):
        raise ValueError('certfile "%s" does not exist' % conf.certfile)

    if conf.keyfile and not os.path.exists(conf.keyfile):
        raise ValueError('keyfile "%s" does not exist' % conf.keyfile)

    # sockets are already bound
    if fdaddr:
        for fd in fdaddr:
            sock = socket.fromfd(fd, socket.AF_UNIX, socket.SOCK_STREAM)
            sock_name = sock.getsockname()
            sock_type = _sock_type(sock_name)
            listener = sock_type(sock_name, conf, log, fd=fd)
            listeners.append(listener)

        return listeners

    # no sockets is bound, first initialization of gunicorn in this env.
    for addr in laddr:
        sock_type = _sock_type(addr)
        sock = None
        for i in range(5):
            try:
                sock = sock_type(addr, conf, log)
            except OSError as e:
                if e.args[0] == errno.EADDRINUSE:
                    log.error("Connection in use: %s", str(addr))
                if e.args[0] == errno.EADDRNOTAVAIL:
                    log.error("Invalid address: %s", str(addr))
                msg = "connection to {addr} failed: {error}"
                log.error(msg.format(addr=str(addr), error=str(e)))
                if i < 5:
                    log.debug("Retrying in 1 second.")
                    time.sleep(1)
            else:
                break

        if sock is None:
            log.error("Can't connect to %s", str(addr))
            sys.exit(1)

        listeners.append(sock)

    return listeners


def close_sockets(listeners, unlink=True):
    for sock in listeners:
        sock_name = sock.getsockname()
        sock.close()
        if unlink and _sock_type(sock_name) is UnixSocket:
            os.unlink(sock_name)


def _get_alpn_protocols(conf):
    """Get ALPN protocol list from configuration.

    Returns list of ALPN protocol identifiers based on http_protocols setting.
    Returns empty list if HTTP/2 is not configured or available.
    """
    from gunicorn.config import ALPN_PROTOCOL_MAP

    http_protocols = conf.http_protocols
    if not http_protocols:
        return []

    # Only configure ALPN if h2 is in the protocol list
    if "h2" not in http_protocols:
        return []

    # Check if h2 library is available
    from gunicorn.http2 import is_http2_available
    if not is_http2_available():
        return []

    # Map to ALPN identifiers, maintaining preference order
    alpn_protocols = []
    for proto in http_protocols:
        if proto in ALPN_PROTOCOL_MAP:
            alpn_protocols.append(ALPN_PROTOCOL_MAP[proto])
    return alpn_protocols


def ssl_context(conf):
    def default_ssl_context_factory():
        context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH, cafile=conf.ca_certs)
        context.load_cert_chain(certfile=conf.certfile, keyfile=conf.keyfile)
        context.verify_mode = conf.cert_reqs
        if conf.ciphers:
            context.set_ciphers(conf.ciphers)

        # Configure ALPN for HTTP/2 if enabled
        alpn_protocols = _get_alpn_protocols(conf)
        if alpn_protocols:
            context.set_alpn_protocols(alpn_protocols)

        return context

    return conf.ssl_context(conf, default_ssl_context_factory)


def ssl_wrap_socket(sock, conf):
    return ssl_context(conf).wrap_socket(sock,
                                         server_side=True,
                                         suppress_ragged_eofs=conf.suppress_ragged_eofs,
                                         do_handshake_on_connect=conf.do_handshake_on_connect)


def get_negotiated_protocol(ssl_socket):
    """Get the negotiated ALPN protocol from an SSL socket.

    Returns:
        str: The negotiated protocol name ('h2', 'http/1.1', etc.)
             or None if no protocol was negotiated.
    """
    if not isinstance(ssl_socket, ssl.SSLSocket):
        return None

    try:
        return ssl_socket.selected_alpn_protocol()
    except (AttributeError, ssl.SSLError):
        return None


def is_http2_negotiated(ssl_socket):
    """Check if HTTP/2 was negotiated on an SSL socket.

    Returns:
        bool: True if HTTP/2 was negotiated via ALPN.
    """
    protocol = get_negotiated_protocol(ssl_socket)
    return protocol == "h2"
