Use asyncio to manage the driver loop

This allows us to remove the custom shoehorned select() handling
without changing the driver design too much.
This commit is contained in:
Valentin Lorentz 2021-07-31 14:31:12 +02:00
parent 0ed743bb8e
commit f164ac7fbe
4 changed files with 87 additions and 96 deletions

View File

@ -39,9 +39,10 @@ import os
import sys import sys
import time import time
import errno import errno
import threading
import select import select
import socket import socket
import asyncio
import threading
import ipaddress import ipaddress
@ -58,9 +59,8 @@ except:
class SSLError(Exception): class SSLError(Exception):
pass pass
class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
_instances = []
_selecting = threading.Lock()
def __init__(self, irc): def __init__(self, irc):
assert irc is not None assert irc is not None
self.irc = irc self.irc = irc
@ -108,26 +108,19 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
# hasn't finished yet. We'll keep track of how many we get. # hasn't finished yet. We'll keep track of how many we get.
if e is None or e.args[0] != 11 or self.eagains > 120: if e is None or e.args[0] != 11 or self.eagains > 120:
drivers.log.disconnect(self.currentServer, e) drivers.log.disconnect(self.currentServer, e)
if self in self._instances:
self._instances.remove(self)
try: try:
self.conn.close() self.conn.close()
except: except:
pass pass
self.connected = False self.connected = False
if self.irc is None: if self.irc is None:
# This driver is dead already, but we're still running because
# of select() running in an other driver's thread that started
# before this one died and stil holding a reference to this
# instance.
# Just return, and we should never be called again.
return return
self.scheduleReconnect() self.scheduleReconnect()
else: else:
log.debug('Got EAGAIN, current count: %s.', self.eagains) log.debug('Got EAGAIN, current count: %s.', self.eagains)
self.eagains += 1 self.eagains += 1
def _sendIfMsgs(self): async def _sendIfMsgs(self):
if not self.connected: if not self.connected:
return return
if not self.zombie: if not self.zombie:
@ -137,71 +130,34 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
del msgs[-1] del msgs[-1]
self.outbuffer += ''.join(map(str, msgs)) self.outbuffer += ''.join(map(str, msgs))
if self.outbuffer: if self.outbuffer:
loop = asyncio.get_event_loop()
try: try:
if minisix.PY2: await loop.sock_sendall(self.conn, self.outbuffer.encode())
sent = self.conn.send(self.outbuffer) self.outbuffer = ''
else:
sent = self.conn.send(self.outbuffer.encode())
self.outbuffer = self.outbuffer[sent:]
self.eagains = 0 self.eagains = 0
except socket.error as e: except socket.error as e:
self._handleSocketError(e) self._handleSocketError(e)
if self.zombie and not self.outbuffer: if self.zombie and not self.outbuffer:
self._reallyDie() self._reallyDie()
@classmethod async def run(self):
def _select(cls):
timeout = conf.supybot.drivers.poll()
try:
if not cls._selecting.acquire(blocking=False):
# there's already a thread running this code, abort.
return
for inst in cls._instances:
# Do not use a list comprehension here, we have to edit the list
# and not to reassign it.
if not inst.connected or \
(minisix.PY3 and inst.conn._closed) or \
(minisix.PY2 and
inst.conn._sock.__class__ is socket._closedsocket):
cls._instances.remove(inst)
elif inst.conn.fileno() == -1:
inst.reconnect()
if not cls._instances:
return
rlist, wlist, xlist = select.select([x.conn for x in cls._instances],
[], [], timeout)
for instance in cls._instances:
if instance.conn in rlist:
instance._read()
except select.error as e:
if e.args[0] != errno.EINTR:
# 'Interrupted system call'
raise
finally:
cls._selecting.release()
for instance in cls._instances:
if instance.irc and not instance.irc.zombie:
instance._sendIfMsgs()
def run(self):
now = time.time() now = time.time()
if self.nextReconnectTime is not None and now > self.nextReconnectTime: if self.nextReconnectTime is not None and now > self.nextReconnectTime:
self.reconnect() self.reconnect()
elif self.writeCheckTime is not None and now > self.writeCheckTime: elif self.writeCheckTime is not None and now > self.writeCheckTime:
self._checkAndWriteOrReconnect() self._checkAndWriteOrReconnect()
if not self.connected: if not self.connected:
# We sleep here because otherwise, if we're the only driver, we'll
# spin at 100% CPU while we're disconnected.
time.sleep(conf.supybot.drivers.poll())
return return
self._sendIfMsgs() await asyncio.gather(
self._select() self._sendIfMsgs(),
self._read(),
)
def _read(self): async def _read(self):
"""Called by _select() when we can read data.""" """Called by _select() when we can read data."""
loop = asyncio.get_event_loop()
try: try:
new_data = self.conn.recv(1024) new_data = await loop.sock_recv(self.conn, 1024)
if not new_data: if not new_data:
# Socket was closed # Socket was closed
self._handleSocketError(None) self._handleSocketError(None)
@ -241,7 +197,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
self._handleSocketError(e) self._handleSocketError(e)
return return
if self.irc and not self.irc.zombie: if self.irc and not self.irc.zombie:
self._sendIfMsgs() await self._sendIfMsgs()
def connect(self, **kwargs): def connect(self, **kwargs):
self.reconnect(reset=False, **kwargs) self.reconnect(reset=False, **kwargs)
@ -252,8 +208,6 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
if self.connected: if self.connected:
self.onDisconnect() self.onDisconnect()
drivers.log.reconnect(self.irc.network) drivers.log.reconnect(self.irc.network)
if self in self._instances:
self._instances.remove(self)
try: try:
self.conn.shutdown(socket.SHUT_RDWR) self.conn.shutdown(socket.SHUT_RDWR)
except: # "Transport endpoint not connected" except: # "Transport endpoint not connected"
@ -314,12 +268,10 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
return return
# We allow more time for the connect here, since it might take longer. # We allow more time for the connect here, since it might take longer.
# At least 10 seconds. # At least 10 seconds.
self.conn.settimeout(max(10, conf.supybot.drivers.poll()*10))
try: try:
# Connect before SSL, otherwise SSL is disabled if we use SOCKS. # Connect before SSL, otherwise SSL is disabled if we use SOCKS.
# See http://stackoverflow.com/q/16136916/539465 # See http://stackoverflow.com/q/16136916/539465
self.conn.connect( self.conn.connect((address, self.currentServer.port))
(address, self.currentServer.port))
if network_config.ssl() or \ if network_config.ssl() or \
self.currentServer.force_tls_verification: self.currentServer.force_tls_verification:
self.starttls() self.starttls()
@ -343,8 +295,6 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
'<http://docs.limnoria.net/en/latest/use/faq.html#how-to-make-a-connection-secure>') '<http://docs.limnoria.net/en/latest/use/faq.html#how-to-make-a-connection-secure>')
% self.irc.network) % self.irc.network)
conf.supybot.drivers.poll.addCallback(self.setTimeout)
self.setTimeout()
self.connected = True self.connected = True
self.resetDelay() self.resetDelay()
except socket.error as e: except socket.error as e:
@ -361,13 +311,6 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
drivers.log.connectError(self.currentServer, e) drivers.log.connectError(self.currentServer, e)
self.scheduleReconnect() self.scheduleReconnect()
return return
self._instances.append(self)
def setTimeout(self):
try:
self.conn.settimeout(conf.supybot.drivers.poll())
except Exception:
drivers.log.exception('Could not set socket timeout:')
def _checkAndWriteOrReconnect(self): def _checkAndWriteOrReconnect(self):
self.writeCheckTime = None self.writeCheckTime = None
@ -393,9 +336,6 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
self.nextReconnectTime = when self.nextReconnectTime = when
def die(self): def die(self):
if self in self._instances:
self._instances.remove(self)
conf.supybot.drivers.poll.removeCallback(self.setTimeout)
self.zombie = True self.zombie = True
if self.nextReconnectTime is not None: if self.nextReconnectTime is not None:
self.nextReconnectTime = None self.nextReconnectTime = None

View File

@ -35,6 +35,7 @@ Contains various drivers (network, file, and otherwise) for using IRC objects.
import time import time
import socket import socket
import asyncio
from collections import namedtuple from collections import namedtuple
from .. import conf, ircdb, ircmsgs, ircutils, log as supylog, utils from .. import conf, ircdb, ircmsgs, ircutils, log as supylog, utils
@ -146,15 +147,59 @@ def remove(name):
"""Removes the driver with the given name from the loop.""" """Removes the driver with the given name from the loop."""
_deadDrivers.add(name) _deadDrivers.add(name)
def _loop_exception_handler(loop, context):
if 'Task was destroyed but it is pending' in context['message']:
# This happens when cancelling tasks because of KeyboardInterrupt.
# There is really nothing we can do about this, so it's
# pointless to log it.
return
log.error('Exception in drivers loop: %s', context['message'])
def run(): def run():
"""Runs the whole driver loop.""" """Runs the whole driver loop."""
loop = asyncio.new_event_loop()
loop.set_exception_handler(_loop_exception_handler)
driver_names = []
futures = []
coroutines = [] # Used to cleanup on shutdown to avoid warnings
for (name, driver) in _drivers.items(): for (name, driver) in _drivers.items():
if name not in _deadDrivers:
try:
coroutine = driver.run()
future = asyncio.ensure_future(coroutine, loop=loop)
except Exception:
log.exception('Exception in drivers.run for driver %s:', name)
continue
driver_names.append(name)
coroutines.append(coroutine)
futures.append(future)
gather_task = asyncio.gather(*futures, return_exceptions=True, loop=loop)
timeout_gather_task = asyncio.wait_for(
gather_task,
timeout=conf.supybot.drivers.poll())
coroutines.append(timeout_gather_task)
try:
loop.run_until_complete(timeout_gather_task)
except KeyboardInterrupt:
# Cleanup all the objects so they don't throw warnings.
gather_task.cancel()
for future in futures:
future.cancel()
for coroutine in coroutines:
coroutine.close()
raise
except asyncio.TimeoutError:
pass
for (name, future) in zip(driver_names, futures):
try: try:
if name not in _deadDrivers: future.result() # Raises an exception if driver.run() did.
driver.run()
except: except:
log.exception('Uncaught exception in in drivers.run:') log.exception('Uncaught exception in drivers.run:')
_deadDrivers.add(name) _deadDrivers.add(name)
for name in _deadDrivers: for name in _deadDrivers:
try: try:
driver = _drivers[name] driver = _drivers[name]

View File

@ -137,7 +137,7 @@ class Schedule(drivers.IrcDriver):
removePeriodicEvent = removeEvent removePeriodicEvent = removeEvent
def run(self): async def run(self):
if len(drivers._drivers) == 1 and not world.testing: if len(drivers._drivers) == 1 and not world.testing:
log.error('Schedule is the only remaining driver, ' log.error('Schedule is the only remaining driver, '
'why do we continue to live?') 'why do we continue to live?')

View File

@ -31,6 +31,7 @@
from supybot.test import * from supybot.test import *
import time import time
import asyncio
import supybot.schedule as schedule import supybot.schedule as schedule
@ -43,6 +44,11 @@ class FakeSchedule(schedule.Schedule):
def name(self): def name(self):
return 'FakeSchedule' return 'FakeSchedule'
def run(sched):
loop = asyncio.new_event_loop()
loop.run_until_complete(sched.run())
class TestSchedule(SupyTestCase): class TestSchedule(SupyTestCase):
def testSchedule(self): def testSchedule(self):
sched = FakeSchedule() sched = FakeSchedule()
@ -55,14 +61,14 @@ class TestSchedule(SupyTestCase):
sched.addEvent(add10, time.time() + 3) sched.addEvent(add10, time.time() + 3)
sched.addEvent(add1, time.time() + 1) sched.addEvent(add1, time.time() + 1)
timeFastForward(1.2) timeFastForward(1.2)
sched.run() run(sched)
self.assertEqual(i[0], 1) self.assertEqual(i[0], 1)
timeFastForward(1.9) timeFastForward(1.9)
sched.run() run(sched)
self.assertEqual(i[0], 11) self.assertEqual(i[0], 11)
sched.addEvent(add10, time.time() + 3, 'test') sched.addEvent(add10, time.time() + 3, 'test')
sched.run() run(sched)
self.assertEqual(i[0], 11) self.assertEqual(i[0], 11)
sched.removeEvent('test') sched.removeEvent('test')
self.assertEqual(i[0], 11) self.assertEqual(i[0], 11)
@ -77,10 +83,10 @@ class TestSchedule(SupyTestCase):
n = sched.addEvent(inc, time.time() + 1) n = sched.addEvent(inc, time.time() + 1)
sched.rescheduleEvent(n, time.time() + 3) sched.rescheduleEvent(n, time.time() + 3)
timeFastForward(1.2) timeFastForward(1.2)
sched.run() run(sched)
self.assertEqual(i[0], 0) self.assertEqual(i[0], 0)
timeFastForward(2) timeFastForward(2)
sched.run() run(sched)
self.assertEqual(i[0], 1) self.assertEqual(i[0], 1)
def testPeriodic(self): def testPeriodic(self):
@ -90,20 +96,20 @@ class TestSchedule(SupyTestCase):
i[0] += 1 i[0] += 1
n = sched.addPeriodicEvent(inc, 1, name='test_periodic') n = sched.addPeriodicEvent(inc, 1, name='test_periodic')
timeFastForward(0.6) timeFastForward(0.6)
sched.run() # 0.6 run(sched) # 0.6
self.assertEqual(i[0], 1) self.assertEqual(i[0], 1)
timeFastForward(0.6) timeFastForward(0.6)
sched.run() # 1.2 run(sched) # 1.2
self.assertEqual(i[0], 2) self.assertEqual(i[0], 2)
timeFastForward(0.6) timeFastForward(0.6)
sched.run() # 1.8 run(sched) # 1.8
self.assertEqual(i[0], 2) self.assertEqual(i[0], 2)
timeFastForward(0.6) timeFastForward(0.6)
sched.run() # 2.4 run(sched) # 2.4
self.assertEqual(i[0], 3) self.assertEqual(i[0], 3)
sched.removePeriodicEvent(n) sched.removePeriodicEvent(n)
timeFastForward(1) timeFastForward(1)
sched.run() # 3.4 run(sched) # 3.4
self.assertEqual(i[0], 3) self.assertEqual(i[0], 3)
def testCountedPeriodic(self): def testCountedPeriodic(self):
@ -113,19 +119,19 @@ class TestSchedule(SupyTestCase):
i[0] += 1 i[0] += 1
n = sched.addPeriodicEvent(inc, 1, name='test_periodic', count=3) n = sched.addPeriodicEvent(inc, 1, name='test_periodic', count=3)
timeFastForward(0.6) timeFastForward(0.6)
sched.run() # 0.6 run(sched) # 0.6
self.assertEqual(i[0], 1) self.assertEqual(i[0], 1)
timeFastForward(0.6) timeFastForward(0.6)
sched.run() # 1.2 run(sched) # 1.2
self.assertEqual(i[0], 2) self.assertEqual(i[0], 2)
timeFastForward(0.6) timeFastForward(0.6)
sched.run() # 1.8 run(sched) # 1.8
self.assertEqual(i[0], 2) self.assertEqual(i[0], 2)
timeFastForward(0.6) timeFastForward(0.6)
sched.run() # 2.4 run(sched) # 2.4
self.assertEqual(i[0], 3) self.assertEqual(i[0], 3)
timeFastForward(1) timeFastForward(1)
sched.run() # 3.4 run(sched) # 3.4
self.assertEqual(i[0], 3) self.assertEqual(i[0], 3)