Files
2025-07-12 10:10:42 +02:00

351 lines
13 KiB
Python

from __future__ import annotations
import base64
import typing
from twisted import logger
from twisted.internet import protocol
from twisted.internet import ssl as tls
from twisted.words.protocols import irc
if typing.TYPE_CHECKING:
import configparser
from twisted.internet import abstract, base, posixbase
from twisted.python import failure
IRC_DEBUG = False
class CommitBot(irc.IRCClient):
log = logger.Logger()
factory: CommitBotFactory | None = None
transport: abstract.FileDescriptor | None = None
def __init__(self) -> None:
self.pendingChannels : set[str] = set()
self.channelsJoined : set[str] = set()
self.timeout : base.DelayedCall | None = None
self.saslData : str | None = None
def notify(self, tag: str, message: str) -> None:
for channel in self.channelsJoined:
filt = self.filters.get(channel)
if filt and tag not in filt:
continue
self.msg(channel, message)
def lineReceived(self, line: bytes) -> typing.Any: # noqa: ANN401
if IRC_DEBUG:
self.log.debug('<<< {line!r}', line=line)
return super().lineReceived(line)
def _reallySendLine(self, line: str | bytes) -> typing.Any: # noqa: ANN401
if IRC_DEBUG:
self.log.debug('>>> {line!r}', line=line)
return super()._reallySendLine(line)
def connectionMade(self) -> None:
assert(self.factory)
config = self.factory.config
# Variables are not annotated in Twisted and mypy erroneously set them to None
self.password = config['server_password'] or None # type: ignore[assignment]
self.nickname = config['nick']
self.username = config['username']
self.realname = config['realname'] # type: ignore[assignment]
self.channels = [cp.rsplit(',', maxsplit=1) for cp in config['channels'].split()]
self.filters : dict[str, list[str]] = {}
for cp in self.channels:
filt = config.get(f'filter.{cp[0]}')
if filt is None:
continue
self.filters[cp[0]] = [tag for tag in filt.split(' ') if tag]
self.log.info("Connected")
if (config['sasl_user'] and
config['sasl_pass']):
# Begin capability negotiation: request sasl
self.log.info("Checking for SASL authentication support")
self.sendLine('CAP REQ :sasl')
super().connectionMade()
def connectionLost(self, reason: failure.Failure = protocol.connectionDone) -> None:
assert(self.factory)
self.factory.protocolDead(self)
super().connectionLost(reason)
self.log.info("Disconnected")
# callbacks for events
def irc_CAP(self, _prefix: str, params: list[str]) -> None:
subcommand, params = params[1].upper(), params[2:]
if subcommand == 'NAK':
# No SASL available
self.log.info("SASL authentication not supported")
# End negotiation to finish the register
self.sendLine('CAP END')
return
if subcommand == 'ACK':
extensions = params[-1].lower().split(' ')
if 'sasl' not in extensions:
# No SASL although we requested it: should never happen
# End negotiation to finish the register
self.sendLine('CAP END')
# SASL enabled: authenticate
self.log.info("SASL authentication started")
self.saslData = ''
self.sendLine('AUTHENTICATE PLAIN')
assert(self.transport)
reactor = typing.cast('base.ReactorBase', self.transport.reactor)
self.timeout = reactor.callLater(10, self._saslTimedout)
def irc_AUTHENTICATE(self, _prefix: str, params: list[str]) -> None:
if len(params) > 1:
return
data = params[0]
assert(self.saslData is not None)
self.saslData = self.saslData + data
SASL_MAX_BLOCK_SIZE : typing.Final[int] = 400
if len(data) == SASL_MAX_BLOCK_SIZE:
# More data expected
return
data = self.saslData
self.saslData = ''
data = base64.b64decode(data) if data != '+' else b''
out = self.saslAuthenticate(data)
out = base64.b64encode(out).decode('ascii') if out else '+'
while len(out) > SASL_MAX_BLOCK_SIZE:
part, out = out[:SASL_MAX_BLOCK_SIZE], out[SASL_MAX_BLOCK_SIZE:]
self.sendLine('AUTHENTICATE ' + part)
if out:
self.sendLine('AUTHENTICATE ' + out)
else:
self.sendLine('AUTHENTICATE +')
def irc_903(self, _prefix: str, _params: list[str]) -> None: # RPL_SASLSUCCESS
self.log.info("SASL authentication succeeded")
# End negotiation to finish the register
self._saslEnd()
def irc_904(self, _prefix: str, _params: list[str]) -> None: # ERR_SASLFAIL
self.log.warn("SASL authentication failed: closing")
assert(self.transport)
self.transport.loseConnection()
def irc_905(self, _prefix: str, _params: list[str]) -> None: # ERR_SASLTOOLONG
self.log.warn("SASL authentication data was too long")
# End negotiation to finish the register
self._saslEnd()
def irc_906(self, _prefix: str, _params: list[str]) -> None: # ERR_SASLABORTED
self.log.warn("SASL authentication aborted: closing")
assert(self.transport)
self.transport.loseConnection()
irc_907 = irc_903 # ERR_SASLALREADY
def _saslEnd(self) -> None:
if self.timeout:
self.timeout.cancel()
self.timeout = None
# End negotiation to finish the register
self.sendLine('CAP END')
def _saslTimedout(self) -> None:
if self.timeout is None:
return
self.timeout = None
self.log.warn("SASL authentication timed out")
# End negotiation to finish the register
self.sendLine('CAP END')
def signedOn(self) -> None:
"""Called when bot has successfully signed on to server."""
self.log.info("Signed on")
assert(self.factory)
self.factory.protocolReady(self)
self.multiJoin(self.channels)
def alterCollidedNick(self, nickname: str) -> str:
assert(self.transport)
reactor = typing.cast('base.ReactorBase', self.transport.reactor)
reactor.callLater(30, self.setNick, nickname)
return nickname + '_'
def join(self, channel: str, key: str | None = None) -> None:
assert(self.timeout is None)
super().join(channel, key)
self.pendingChannels.add(channel)
reactor = typing.cast('base.ReactorBase', self.transport.reactor)
self.timeout = reactor.callLater(30, self._joinTimeout)
def multiJoin(self, channels: list[str | tuple[str, str]]) -> None:
assert(self.timeout is None)
# MAX_COMMAND_LENGTH - JOIN - \r\n
# spaces are included in the length count
max_length = irc.MAX_COMMAND_LENGTH - 4 - 2
channels = list(channels)
while len(channels):
ln = 0
channels_list = []
keys_list = []
while len(channels):
channel = channels.pop()
if channel is str:
channel = (channel, '')
if not channel:
continue
if len(channel) != 2:
channel = (channel[0], '')
if channel[0][0] != '#':
channel = ('#' + channel[0], channel[1])
ln += len(channel[0]) + len(channel[1]) + 2 # preceding space or comma
if ln > max_length:
channels.append(channel)
break
channels_list.append(channel[0])
keys_list.append(channel[1])
self.pendingChannels.add(channel[0])
if all(not k for k in keys_list):
self.sendLine("JOIN {}".format(','.join(channels_list)))
else:
self.sendLine("JOIN {} {}".format(','.join(channels_list), ','.join(keys_list)))
reactor = typing.cast('base.ReactorBase', self.transport.reactor)
self.timeout = reactor.callLater(30, self._joinTimeout)
def joined(self, channel: str) -> None:
"""This will get called when the bot joins the channel."""
self.log.info("Channel {channel} joined", channel=channel)
self.channelsJoined.add(channel)
self.pendingChannels.discard(channel)
if not self.pendingChannels and self.timeout:
self.timeout.cancel()
self.timeout = None
def joinError(self, channel: str | None, reason: str) -> None:
self.log.info("Failed to join {channel}: {reason}", channel=channel, reason=reason)
self.pendingChannels.discard(channel)
if not self.pendingChannels and self.timeout:
self.timeout.cancel()
self.timeout = None
def _joinTimeout(self) -> None:
if self.timeout is None:
return
self.timeout = None
for channel in self.pendingChannels:
self.log.info("Failed to join {channel}: {reason}", channel=channel, reason="Timeout")
self.pendingChannels.clear()
def irc_ERR_INVITEONLYCHAN(self, _prefix: str, params: list[str]) -> None:
_nickname, channel, reason = params
self.joinError(channel, reason)
def irc_ERR_CHANNELISFULL(self, _prefix: str, params: list[str]) -> None:
_nickname, channel, reason = params
self.joinError(channel, reason)
def irc_ERR_BANNEDFROMCHAN(self, _prefix: str, params: list[str]) -> None:
_nickname, channel, reason = params
self.joinError(channel, reason)
def irc_ERR_BADCHANNELKEY(self, _prefix: str, params: list[str]) -> None:
_nickname, channel, reason = params
self.joinError(channel, reason)
def irc_ERR_TOOMANYCHANNELS(self, _prefix: str, params: list[str]) -> None:
_nickname, channel, reason = params
self.joinError(channel, reason)
def kickedFrom(self, channel: str, kicker: str, message: str) -> None:
self.log.info("Kicked from {channel} by {kicker} (reason: {message})",
channel=channel, kicker=kicker, message=message)
self.channelsJoined.discard(channel)
if message == 'die':
assert(self.transport)
reactor = typing.cast('base.ReactorBase', self.transport.reactor)
reactor.stop()
return
key = None
for c in self.channels:
if (c[0] != channel and
c[0] != channel[1:]):
continue
if len(c) > 1:
key = c[1]
break
self.join(channel, key)
def saslAuthenticate(self, data: bytes) -> bytes:
if data != b'':
return b''
assert(self.factory)
username = self.factory.config['sasl_user'].encode('utf-8')
password = self.factory.config['sasl_pass'].encode('utf-8')
return b'\x00'.join((username, username, password))
def ctcpQuery(self, user: str, channel: str, messages: list[tuple[str, str]]) -> None:
# Ignore CTCP
pass
class CommitBotFactory(protocol.ReconnectingClientFactory):
protocol = CommitBot
def __init__(self, config: configparser.SectionProxy) -> None:
self.config = config
self.protocols : set[CommitBot] = set()
def connect(self, reactor: posixbase.PosixReactorBase) -> None:
if self.config.getboolean('tls'):
ca_path = self.config.get('ca_path')
if ca_path:
with open(ca_path, 'rb') as ca:
ca_data = ca.read()
authority = tls.Certificate.loadPEM(ca_data)
else:
authority = None
options = tls.optionsForClientTLS(self.config['server'], authority)
reactor.connectSSL(self.config['server'], int(self.config['port']), self, options)
else:
reactor.connectTCP(self.config['server'], int(self.config['port']), self)
def protocolReady(self, protocol: CommitBot) -> None:
self.protocols.add(protocol)
self.resetDelay()
def protocolDead(self, protocol: CommitBot) -> None:
# Don't use remove: if we don't get past the welcome, we will never add the protocol
self.protocols.discard(protocol)
def notify(self, *args: typing.Any, **kwargs: typing.Any) -> None: # noqa: ANN401
for p in self.protocols:
p.notify(*args, **kwargs)
def create(config: configparser.SectionProxy, reactor: posixbase.PosixReactorBase) -> CommitBotFactory:
f = CommitBotFactory(config)
f.connect(reactor)
return f