#!/usr/bin/env python3 """ see: https://www.postgresqltutorial.com/postgresql-python/connect/ Options: add # add to database remove # remove from database active # set status to active inactive # set status to inactive info # get host info from database check # check live cert for and update db entry if in db checkall # check live cert for all active hosts critical # list all certs with status 'critical' renew # renew a host Dhya Thu, 03 Oct 2024 14:32:32 -0700 """ import argparse import concurrent.futures import re import shutil import ssl import sys from dataclasses import dataclass from datetime import datetime from datetime import timezone from socket import socket, AF_INET, SOCK_STREAM from typing import Final import psycopg from cryptography import x509 from cryptography.x509.oid import ExtensionOID from cryptography.x509.oid import NameOID from psycopg import sql # import os # os.chdir("/usr/local/lib/certmanager") from config import config DT_FMT: Final[str] = '%F %T %Z' def fqdn_type(host_name): """ Used by argparse to validate fqdn input Makes sure supplied host name is a valid fqdn """ if not 1 < len(host_name) < 253: return False parts = host_name.split('.') fqdn = re. \ compile(r'^[a-z0-9]([a-z-0-9-]{0,61}[a-z0-9])?$', re.IGNORECASE) if all(fqdn.match(part) for part in parts): return host_name print("Error: \"", host_name, "\" is not a valid hostname.", sep='') sys.exit(1) @dataclass class Host: """ Class to encapsulate a host All fields are optional except {host_name}, e.g.: h1 = Host("example.com") """ host_name: str # non-default argument must come first host_id: int = None # primary key, autoincremented common_name: str = None issuer: str = None not_valid_before: datetime = None not_valid_after: datetime = None active: bool = False # a host can have active or inactive status checked: datetime = None # date of most recent cert check check_status: str = None # status of most recent cert check check_err: str = None # error (if any) of most recent cert check renewed: datetime = None # datetime of most recent cert renewal renew_status: str = None # status of most recent cert renewal renew_out: str = None # output of most recent cert renewal renew_err: str = None # error output of most recent cert renewal @classmethod def from_row( cls, *, host_id, host_name, common_name, issuer, not_valid_before, not_valid_after, active, checked, check_status, check_err, renewed, renew_status, renew_out, renew_err ): """ Row factory to process psycopg SQL values into Host class """ return cls( host_id=host_id, host_name=host_name, common_name=common_name, issuer=issuer, not_valid_before=not_valid_before, not_valid_after=not_valid_after, active=active, checked=checked, check_status=check_status, check_err=check_err, renewed=renewed, renew_status=renew_status, renew_out=renew_out, renew_err=renew_err ) @classmethod def row_factory(cls, cursor): """use cursor""" columns = [column.name for column in cursor.description] def make_row(values): row = dict(zip(columns, values)) return cls.from_row(**row) return make_row @dataclass class San: """ Class to encapsulate Subject Alternative Names data """ name: str # non-default argument must come first host_id: int = None # primary key, autoincremented status: str = None # status of most recent cert check @classmethod def from_row( cls, *, host_id, name, status ): """ Row factory to process psycopg SQL values into SAN class """ return cls( host_id=host_id, name=name, status=status ) @classmethod def row_factory(cls, cursor): """use cursor""" columns = [column.name for column in cursor.description] def make_row(values): row = dict(zip(columns, values)) return cls.from_row(**row) return make_row def db_retrieve(query): """ Retrieves SSL certificate data """ params = config() with psycopg.connect(**params) as conn: with conn.cursor() as cur: # next two lines are for debuging db queries # print(query.as_string(conn)) # return cur.execute(query) return cur.fetchall() def db_execute(query): """ Execute SQL statement """ params = config() with psycopg.connect(**params) as conn: with conn.cursor() as cur: # next lines are for debuging db queries # print(query.as_string(conn)) # sys.exit(0) # return cur.execute(query) conn.commit() def is_managed(host_name): """ Validate that a host is managed by this application, returns boolean """ # TODO: why is this called multiple times? params = config() with psycopg.connect(**params) as conn: with conn.cursor() as cur: query = ("SELECT EXISTS(SELECT 1 FROM main " f"WHERE host_name='{host_name}')") cur.execute(query) if result := cur.fetchone()[0]: return result return False def cert_error_db_update(host, err): """ Update database with cert check error data """ query = sql.SQL( """ UPDATE {table} SET checked = {f1}, check_status = {f2}, check_err = {f3} WHERE host_name = {f4} """ ).format( # .Identifier = PGSQL identifier, e.g. table & column names, uses # double quotes # .Literal = PGSQL literal, e.g. field values, uses single quotes table=sql.Identifier('main'), f1=sql.Literal(datetime.now(timezone.utc)), f2=sql.Literal('failure'), f3=sql.Literal(str(err)), f4=sql.Literal(host) ) db_execute(query) def get_sans(host): """ Retreive host's SANs from database """ # print(f"get sans: {host}") # SELECT 'name' FROM 'san' WHERE 'host_id' = (select host_id from main where host_name = {host}) params = config() with psycopg.connect(**params) as conn: with conn.cursor() as cur: query = ("SELECT * FROM san WHERE host_id = " f"(SELECT host_id FROM main WHERE host_name = '{host}')" ) # next two lines are for debuging db queries # print(f"get_sans query: {query}") # return cur.execute(query) return cur.fetchone() def cert_db_update(host, cert): """ Update database host record with cert info """ # TODO: process SANs from cert query = sql.SQL( """ UPDATE {table} SET not_valid_before = {f1}, not_valid_after = {f2}, common_name = {f3}, issuer = {f4}, checked = {f5}, check_status = {f6} WHERE host_name = {f7} """ ).format( # .Identifier = PGSQL identifier, e.g. table & column names, uses # double quotes # .Literal = PGSQL literal, e.g. field values, uses single quotes table=sql.Identifier('main'), f1=sql.Literal( cert.not_valid_before_utc.replace(tzinfo=timezone.utc)), f2=sql.Literal( cert.not_valid_after_utc.replace(tzinfo=timezone.utc)), f3=sql.Literal( cert.subject.get_attributes_for_oid( NameOID.COMMON_NAME)[0].value), f4=sql.Literal( cert.issuer.get_attributes_for_oid( NameOID.ORGANIZATION_NAME)[0].value), f5=sql.Literal(datetime.now(timezone.utc)), f6=sql.Literal('success'), f7=sql.Literal(host) ) db_execute(query) managed_sans = get_sans(host) san = cert. \ extensions. \ get_extension_for_oid(ExtensionOID. SUBJECT_ALTERNATIVE_NAME) cert_sans = san.value.get_values_for_type(x509.DNSName) if managed_sans is not None: for _ in managed_sans: pass # This will require a separete UPDATE/INSERT for dns_names table # print("DNS Names:", dnames) # INSERT INTO san (host_id, name) # VALUES ((select host_id from main where host_name = 'workin.com'), '*.workin.com'); class GetCertException(Exception): """ Custom exception thrown by get_cert() Raise with: raise get_cert_exception("Exception text") """ def get_cert(host_to_check): """ Check live certificate for host """ context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) context.check_hostname = False context.verify_mode = ssl.CERT_NONE try: sock = socket(AF_INET, SOCK_STREAM, 0) sock.settimeout(5.0) # 5 second timeout ssock = context.wrap_socket(sock, server_hostname=host_to_check) ssock.connect((host_to_check, 443)) pem_data = ssl.DER_cert_to_PEM_cert(ssock.getpeercert(True)) cert = x509.load_pem_x509_certificate(str.encode(pem_data)) # EOF occurred in violation of protocol (_ssl.c:1123) # [Errno -2] Name or service not known # [Errno -3] Temporary failure in name resolution # [Errno -5] No address associated with hostname # [Errno 110] Connection timed out # [Errno 111] Connection refused # [Errno 113] No route to host except Exception as err: if is_managed(host_to_check): cert_error_db_update(host_to_check, err) raise GetCertException(err) # Handle self-signed certificates which lack common name/org. name if not cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME): err_string = "Certificate has no Common Name" if not cert.issuer.get_attributes_for_oid(NameOID.ORGANIZATION_NAME): err_string = "Certificate has no Common Name and no Organization Name" if is_managed(host_to_check): cert_error_db_update(host_to_check, err_string) raise GetCertException(err_string) return cert, ssock def handle_add(args): """ Add a host """ if is_managed(args.host): print("Host \'", args.host, "\' is already managed by this application.") return query = sql.SQL( "INSERT INTO {table} (host_name) VALUES ({f1})" ).format( table=sql.Identifier('main'), f1=sql.Literal(args.host) ) db_execute(query) def handle_remove(args): """ Remove a host """ if not is_managed(args.host): print("Host \'", args.host, "\' is not managed by this application.") return query = sql.SQL( "DELETE FROM {table} WHERE {host_name} = {f1}" ).format( table=sql.Identifier('main'), host_name=sql.Identifier('host_name'), f1=sql.Literal(args.host) ) db_execute(query) def handle_list_all(args): """ Handle list-all List all hosts and their check status """ query = sql.SQL( "SELECT {host_name}, {check_status} FROM {table}" ).format( host_name=sql.Identifier('host_name'), check_status=sql.Identifier('check_status'), table=sql.Identifier('main') ) if len(db_retrieve(query)) == 0: print("No hosts found in the database") else: result = db_retrieve(query) rlist = list(result) rlist.sort() for i in rlist: print("{:<36} Check status:".format(i[0].decode("utf-8")), i[1].decode("utf-8")) def handle_list_active(args): """ Handle list-active List hosts marked as active """ query = sql.SQL( """ SELECT {host_name} FROM {table} WHERE {active} IS TRUE """ ).format( host_name=sql.Identifier('host_name'), table=sql.Identifier('main'), active=sql.Identifier('active') ) actives = list(zip(*db_retrieve(query)))[0] alist = list(actives) alist.sort() for a in alist: print(a.decode()) def handle_active(args): """ Set host status to active """ if not is_managed(args.host): print(f"Host \'{args.host}\' is not managed by this application.", "Exiting.") return query = sql.SQL( """ UPDATE {table} SET active = {f1} WHERE host_name = {f2} """ ).format( table=sql.Identifier('main'), f1=sql.Literal('true'), f2=sql.Literal(args.host) ) db_execute(query) def handle_inactive(args): """ Set host status to inactive """ if not is_managed(args.host): print(f"Host \'{args.host}\' is not managed by this application.", "Exiting.") return query = sql.SQL( """ UPDATE {table} SET active = {f1} WHERE host_name = {f2} """ ).format( table=sql.Identifier('main'), f1=sql.Literal('false'), f2=sql.Literal(args.host) ) db_execute(query) def handle_info(args): """ Return database info on host """ # print("info", args.host) if not is_managed(args.host): print("Host \'", args.host, "\' is not managed by this application.") sys.exit(1) params = config() with psycopg.connect(**params) as conn: with conn.cursor(row_factory=Host.row_factory) as cur: query = sql.SQL( """ SELECT * FROM main WHERE host_name = %s """ ) cur.execute(query, (args.host,)) row = cur.fetchone() #sans = get_sans(args.host) # stringer = dt.strftime(DT_FMT) if dt else lambda _: None; print( f"Host name: {row.host_name} \n", f"Cert Common Name: {row.common_name} \n", f"Cert Issuer: {row.issuer} \n", "Cert Not Valid Before: ", row.not_valid_before.astimezone(timezone.utc).strftime(DT_FMT) if row.not_valid_before is not None else '', " (", row.not_valid_before.strftime(DT_FMT) if row.not_valid_before is not None else '', ")\n", "Cert Not Valid After: ", row.not_valid_after.astimezone(timezone.utc).strftime(DT_FMT) if row.not_valid_after is not None else '', " (", row.not_valid_after.strftime(DT_FMT) if row.not_valid_after is not None else '', ")\n", "Active: ", row.active, "\n", "Checked date: ", row.checked.strftime( DT_FMT) if row.checked is not None else '', "\n", "Check status: ", row.check_status, "\n", "Check error msg: ", row.check_err, "\n", "Renewed date: ", row.renewed.strftime( DT_FMT) if row.renewed is not None else '', "\n", sep='' ) # if sans is not None: # print("SANs: ", end="") # for n in sans: # print(f"{n} ", end="") # return # print(f"No SANs information for {args.host}") def handle_check(args): """ Check live certificate for host and refresh database entry """ # TODO: don't call is_managed() repeatedly here print("Getting live cert info for:", args.host) print(f"in handle_check(), is_managed={is_managed(args.host)}") if not is_managed(args.host): print("NOTICE: Host \'", args.host, "\' is not managed by this application.") try: cert, ssock = get_cert(args.host) except GetCertException as err: print("Error getting certificate:", err) sys.exit(1) # if not is_managed(args.host): print("SSL protocol version:", ssock.version()) print("Issuer:", cert.issuer.get_attributes_for_oid(NameOID.ORGANIZATION_NAME)[0].value) print("Subject CN:", cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value) print("Not valid before:", cert.not_valid_before_utc) print("Not valid after:", cert.not_valid_after_utc) san = cert.extensions.get_extension_for_oid( ExtensionOID.SUBJECT_ALTERNATIVE_NAME) dnames = san.value.get_values_for_type(x509.DNSName) print("DNS Names:") for name in dnames: print(" ", name) # return if is_managed(args.host): cert_db_update(args.host, cert) def handle_checkall(args): """ Handle --check-all Refreshes info for all certificates """ active_hosts_query = sql.SQL( """ SELECT {host_name} FROM {table} WHERE {active} IS TRUE """ ).format( host_name=sql.Identifier('host_name'), table=sql.Identifier('main'), active=sql.Identifier('active') ) hosts = [] for host in db_retrieve(active_hosts_query): hosts.append(host[0]) with concurrent.futures.ThreadPoolExecutor(max_workers=12) as executor: future_to_host = {executor.submit(get_cert, h): h for h in hosts} for future in concurrent.futures.as_completed(future_to_host): h = future_to_host[future] try: print("Getting live cert info for:", h) # get_cert() returns tuple[Certificate, SSLSocket] cert = future.result()[0] cert_db_update(h, cert) except GetCertException as err: print("Error getting certificate for: ", h, ": ", err, sep="") except Exception as exc: print('%r generated an exception: %s' % (h, exc)) # else: # print() def handle_critical(args): """ Handle --critical Return list of expiring and recently expired active hosts """ # print("critical", args) # return CRITICAL_INTERVAL: Final[str] = '14 DAY' query = sql.SQL( """ SELECT {host_name}, {nva} FROM {table} WHERE {nva} < now() + interval {interval} """ ).format( host_name=sql.Identifier('host_name'), table=sql.Identifier('main'), nva=sql.Identifier('not_valid_after'), interval=sql.Literal(CRITICAL_INTERVAL) ) if len(db_retrieve(query)) == 0: print("No hosts are listed as critical at this time") else: for i in db_retrieve(query): print("{:<30} Expires:".format(i[0].decode("utf-8"), i[1])) def handle_renew(args): """ Handle --renew Renews a certificate """ if certbot_cmd := shutil.which('certbot'): renew_cmd = [ f'{certbot_cmd}', '--cert-name', f'{args.host}', '--dns-rfc2136', '--dns-rfc2136-credentials', '/etc/letsencrypt/rfc2136.ini', '-d', f'{args.host}' ] # subprocess.run(certbotcmd, check=False) print("Not implemented yet") print("Renew command:", *renew_cmd) return True sys.stderr.write("\nThe certbot binary was not found in $PATH\n" "Please make sure certbot is installed.\n" "This program will now exit.\n") sys.exit(1) args = "" parser = argparse.ArgumentParser() # this is to deal with "AttributeError: 'Namespace' object has no attribute 'func'" issue parser.set_defaults(func=lambda _: parser.print_help()) subparsers = parser.add_subparsers(help='Functions') # "add" parser add_parser = subparsers.add_parser('add', help='add host') add_parser.add_argument('-active', required=False, action="store_true", help='set host status to active') add_parser.add_argument('-inactive', required=False, action="store_true", help='set host status to inactive') add_parser.add_argument('host', type=str, help='host name') add_parser.set_defaults(func=handle_add) # "remove" parser remove_parser = subparsers.add_parser('remove', help='remove host') remove_parser.add_argument('host', type=str, help='host name') remove_parser.set_defaults(func=handle_remove) # "list-all" parser listall_parser = subparsers.add_parser( 'list-all', help='list all hosts and their check status') listall_parser.set_defaults(func=handle_list_all) # "list-active" parser listactive_parser = subparsers.add_parser( 'list-active', help='list hosts with active status') listactive_parser.set_defaults(func=handle_list_active) # "active" parser active_parser = subparsers.add_parser('active', help='set host to active') active_parser.add_argument('host', type=str, help='host name') active_parser.set_defaults(func=handle_active) # "inactive" parser inactive_parser = subparsers.add_parser( 'inactive', help='set host to inactive') inactive_parser.add_argument('host', type=str, help='host name') inactive_parser.set_defaults(func=handle_inactive) # "info" parser info_parser = subparsers.add_parser('info', help='display info about host') info_parser.add_argument('host', type=str, help='host name') info_parser.set_defaults(func=handle_info) # "check" parser check_parser = subparsers.add_parser( 'check', help='check live SSL cert for host') check_parser.add_argument('host', type=str, help='host name') check_parser.set_defaults(func=handle_check) # "checkall" parser checkall_parser = subparsers.add_parser( 'check-all', help='check live SSL certs for all active hosts') checkall_parser.set_defaults(func=handle_checkall) # "critical" parser critical_parser = subparsers.add_parser( 'critical', help='list active hosts that are critical') critical_parser.set_defaults(func=handle_critical) # "renew" parser renew_parser = subparsers.add_parser('renew', help='renew SSL cert for host') renew_parser.add_argument('host', type=str, help='host name') renew_parser.set_defaults(func=handle_renew) args = parser.parse_args() if args.func: args.func(args)