mirror of
https://github.com/progval/Limnoria.git
synced 2025-04-25 12:31:04 -05:00
Use asyncio to manage the driver loop
This allows us to remove the custom shoehorned select() handling without changing the driver design too much.
This commit is contained in:
parent
0ed743bb8e
commit
f164ac7fbe
@ -39,9 +39,10 @@ import os
|
||||
import sys
|
||||
import time
|
||||
import errno
|
||||
import threading
|
||||
import select
|
||||
import socket
|
||||
import asyncio
|
||||
import threading
|
||||
|
||||
import ipaddress
|
||||
|
||||
@ -58,9 +59,8 @@ except:
|
||||
class SSLError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
||||
_instances = []
|
||||
_selecting = threading.Lock()
|
||||
def __init__(self, irc):
|
||||
assert irc is not None
|
||||
self.irc = irc
|
||||
@ -108,26 +108,19 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
||||
# hasn't finished yet. We'll keep track of how many we get.
|
||||
if e is None or e.args[0] != 11 or self.eagains > 120:
|
||||
drivers.log.disconnect(self.currentServer, e)
|
||||
if self in self._instances:
|
||||
self._instances.remove(self)
|
||||
try:
|
||||
self.conn.close()
|
||||
except:
|
||||
pass
|
||||
self.connected = False
|
||||
if self.irc is None:
|
||||
# This driver is dead already, but we're still running because
|
||||
# of select() running in an other driver's thread that started
|
||||
# before this one died and stil holding a reference to this
|
||||
# instance.
|
||||
# Just return, and we should never be called again.
|
||||
return
|
||||
self.scheduleReconnect()
|
||||
else:
|
||||
log.debug('Got EAGAIN, current count: %s.', self.eagains)
|
||||
self.eagains += 1
|
||||
|
||||
def _sendIfMsgs(self):
|
||||
async def _sendIfMsgs(self):
|
||||
if not self.connected:
|
||||
return
|
||||
if not self.zombie:
|
||||
@ -137,71 +130,34 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
||||
del msgs[-1]
|
||||
self.outbuffer += ''.join(map(str, msgs))
|
||||
if self.outbuffer:
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
if minisix.PY2:
|
||||
sent = self.conn.send(self.outbuffer)
|
||||
else:
|
||||
sent = self.conn.send(self.outbuffer.encode())
|
||||
self.outbuffer = self.outbuffer[sent:]
|
||||
await loop.sock_sendall(self.conn, self.outbuffer.encode())
|
||||
self.outbuffer = ''
|
||||
self.eagains = 0
|
||||
except socket.error as e:
|
||||
self._handleSocketError(e)
|
||||
if self.zombie and not self.outbuffer:
|
||||
self._reallyDie()
|
||||
|
||||
@classmethod
|
||||
def _select(cls):
|
||||
timeout = conf.supybot.drivers.poll()
|
||||
try:
|
||||
if not cls._selecting.acquire(blocking=False):
|
||||
# there's already a thread running this code, abort.
|
||||
return
|
||||
for inst in cls._instances:
|
||||
# Do not use a list comprehension here, we have to edit the list
|
||||
# and not to reassign it.
|
||||
if not inst.connected or \
|
||||
(minisix.PY3 and inst.conn._closed) or \
|
||||
(minisix.PY2 and
|
||||
inst.conn._sock.__class__ is socket._closedsocket):
|
||||
cls._instances.remove(inst)
|
||||
elif inst.conn.fileno() == -1:
|
||||
inst.reconnect()
|
||||
if not cls._instances:
|
||||
return
|
||||
rlist, wlist, xlist = select.select([x.conn for x in cls._instances],
|
||||
[], [], timeout)
|
||||
for instance in cls._instances:
|
||||
if instance.conn in rlist:
|
||||
instance._read()
|
||||
except select.error as e:
|
||||
if e.args[0] != errno.EINTR:
|
||||
# 'Interrupted system call'
|
||||
raise
|
||||
finally:
|
||||
cls._selecting.release()
|
||||
for instance in cls._instances:
|
||||
if instance.irc and not instance.irc.zombie:
|
||||
instance._sendIfMsgs()
|
||||
|
||||
|
||||
def run(self):
|
||||
async def run(self):
|
||||
now = time.time()
|
||||
if self.nextReconnectTime is not None and now > self.nextReconnectTime:
|
||||
self.reconnect()
|
||||
elif self.writeCheckTime is not None and now > self.writeCheckTime:
|
||||
self._checkAndWriteOrReconnect()
|
||||
if not self.connected:
|
||||
# We sleep here because otherwise, if we're the only driver, we'll
|
||||
# spin at 100% CPU while we're disconnected.
|
||||
time.sleep(conf.supybot.drivers.poll())
|
||||
return
|
||||
self._sendIfMsgs()
|
||||
self._select()
|
||||
await asyncio.gather(
|
||||
self._sendIfMsgs(),
|
||||
self._read(),
|
||||
)
|
||||
|
||||
def _read(self):
|
||||
async def _read(self):
|
||||
"""Called by _select() when we can read data."""
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
new_data = self.conn.recv(1024)
|
||||
new_data = await loop.sock_recv(self.conn, 1024)
|
||||
if not new_data:
|
||||
# Socket was closed
|
||||
self._handleSocketError(None)
|
||||
@ -241,7 +197,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
||||
self._handleSocketError(e)
|
||||
return
|
||||
if self.irc and not self.irc.zombie:
|
||||
self._sendIfMsgs()
|
||||
await self._sendIfMsgs()
|
||||
|
||||
def connect(self, **kwargs):
|
||||
self.reconnect(reset=False, **kwargs)
|
||||
@ -252,8 +208,6 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
||||
if self.connected:
|
||||
self.onDisconnect()
|
||||
drivers.log.reconnect(self.irc.network)
|
||||
if self in self._instances:
|
||||
self._instances.remove(self)
|
||||
try:
|
||||
self.conn.shutdown(socket.SHUT_RDWR)
|
||||
except: # "Transport endpoint not connected"
|
||||
@ -314,12 +268,10 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
||||
return
|
||||
# We allow more time for the connect here, since it might take longer.
|
||||
# At least 10 seconds.
|
||||
self.conn.settimeout(max(10, conf.supybot.drivers.poll()*10))
|
||||
try:
|
||||
# Connect before SSL, otherwise SSL is disabled if we use SOCKS.
|
||||
# See http://stackoverflow.com/q/16136916/539465
|
||||
self.conn.connect(
|
||||
(address, self.currentServer.port))
|
||||
self.conn.connect((address, self.currentServer.port))
|
||||
if network_config.ssl() or \
|
||||
self.currentServer.force_tls_verification:
|
||||
self.starttls()
|
||||
@ -343,8 +295,6 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
||||
'<http://docs.limnoria.net/en/latest/use/faq.html#how-to-make-a-connection-secure>')
|
||||
% self.irc.network)
|
||||
|
||||
conf.supybot.drivers.poll.addCallback(self.setTimeout)
|
||||
self.setTimeout()
|
||||
self.connected = True
|
||||
self.resetDelay()
|
||||
except socket.error as e:
|
||||
@ -361,13 +311,6 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
||||
drivers.log.connectError(self.currentServer, e)
|
||||
self.scheduleReconnect()
|
||||
return
|
||||
self._instances.append(self)
|
||||
|
||||
def setTimeout(self):
|
||||
try:
|
||||
self.conn.settimeout(conf.supybot.drivers.poll())
|
||||
except Exception:
|
||||
drivers.log.exception('Could not set socket timeout:')
|
||||
|
||||
def _checkAndWriteOrReconnect(self):
|
||||
self.writeCheckTime = None
|
||||
@ -393,9 +336,6 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
||||
self.nextReconnectTime = when
|
||||
|
||||
def die(self):
|
||||
if self in self._instances:
|
||||
self._instances.remove(self)
|
||||
conf.supybot.drivers.poll.removeCallback(self.setTimeout)
|
||||
self.zombie = True
|
||||
if self.nextReconnectTime is not None:
|
||||
self.nextReconnectTime = None
|
||||
|
@ -35,6 +35,7 @@ Contains various drivers (network, file, and otherwise) for using IRC objects.
|
||||
|
||||
import time
|
||||
import socket
|
||||
import asyncio
|
||||
from collections import namedtuple
|
||||
|
||||
from .. import conf, ircdb, ircmsgs, ircutils, log as supylog, utils
|
||||
@ -146,15 +147,59 @@ def remove(name):
|
||||
"""Removes the driver with the given name from the loop."""
|
||||
_deadDrivers.add(name)
|
||||
|
||||
|
||||
def _loop_exception_handler(loop, context):
|
||||
if 'Task was destroyed but it is pending' in context['message']:
|
||||
# This happens when cancelling tasks because of KeyboardInterrupt.
|
||||
# There is really nothing we can do about this, so it's
|
||||
# pointless to log it.
|
||||
return
|
||||
log.error('Exception in drivers loop: %s', context['message'])
|
||||
|
||||
def run():
|
||||
"""Runs the whole driver loop."""
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.set_exception_handler(_loop_exception_handler)
|
||||
driver_names = []
|
||||
futures = []
|
||||
coroutines = [] # Used to cleanup on shutdown to avoid warnings
|
||||
for (name, driver) in _drivers.items():
|
||||
if name not in _deadDrivers:
|
||||
try:
|
||||
coroutine = driver.run()
|
||||
future = asyncio.ensure_future(coroutine, loop=loop)
|
||||
except Exception:
|
||||
log.exception('Exception in drivers.run for driver %s:', name)
|
||||
continue
|
||||
driver_names.append(name)
|
||||
coroutines.append(coroutine)
|
||||
futures.append(future)
|
||||
|
||||
gather_task = asyncio.gather(*futures, return_exceptions=True, loop=loop)
|
||||
timeout_gather_task = asyncio.wait_for(
|
||||
gather_task,
|
||||
timeout=conf.supybot.drivers.poll())
|
||||
coroutines.append(timeout_gather_task)
|
||||
try:
|
||||
loop.run_until_complete(timeout_gather_task)
|
||||
except KeyboardInterrupt:
|
||||
# Cleanup all the objects so they don't throw warnings.
|
||||
gather_task.cancel()
|
||||
for future in futures:
|
||||
future.cancel()
|
||||
for coroutine in coroutines:
|
||||
coroutine.close()
|
||||
raise
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
for (name, future) in zip(driver_names, futures):
|
||||
try:
|
||||
if name not in _deadDrivers:
|
||||
driver.run()
|
||||
future.result() # Raises an exception if driver.run() did.
|
||||
except:
|
||||
log.exception('Uncaught exception in in drivers.run:')
|
||||
log.exception('Uncaught exception in drivers.run:')
|
||||
_deadDrivers.add(name)
|
||||
|
||||
for name in _deadDrivers:
|
||||
try:
|
||||
driver = _drivers[name]
|
||||
|
@ -137,7 +137,7 @@ class Schedule(drivers.IrcDriver):
|
||||
|
||||
removePeriodicEvent = removeEvent
|
||||
|
||||
def run(self):
|
||||
async def run(self):
|
||||
if len(drivers._drivers) == 1 and not world.testing:
|
||||
log.error('Schedule is the only remaining driver, '
|
||||
'why do we continue to live?')
|
||||
|
@ -31,6 +31,7 @@
|
||||
from supybot.test import *
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
import supybot.schedule as schedule
|
||||
|
||||
@ -43,6 +44,11 @@ class FakeSchedule(schedule.Schedule):
|
||||
def name(self):
|
||||
return 'FakeSchedule'
|
||||
|
||||
|
||||
def run(sched):
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.run_until_complete(sched.run())
|
||||
|
||||
class TestSchedule(SupyTestCase):
|
||||
def testSchedule(self):
|
||||
sched = FakeSchedule()
|
||||
@ -55,14 +61,14 @@ class TestSchedule(SupyTestCase):
|
||||
sched.addEvent(add10, time.time() + 3)
|
||||
sched.addEvent(add1, time.time() + 1)
|
||||
timeFastForward(1.2)
|
||||
sched.run()
|
||||
run(sched)
|
||||
self.assertEqual(i[0], 1)
|
||||
timeFastForward(1.9)
|
||||
sched.run()
|
||||
run(sched)
|
||||
self.assertEqual(i[0], 11)
|
||||
|
||||
sched.addEvent(add10, time.time() + 3, 'test')
|
||||
sched.run()
|
||||
run(sched)
|
||||
self.assertEqual(i[0], 11)
|
||||
sched.removeEvent('test')
|
||||
self.assertEqual(i[0], 11)
|
||||
@ -77,10 +83,10 @@ class TestSchedule(SupyTestCase):
|
||||
n = sched.addEvent(inc, time.time() + 1)
|
||||
sched.rescheduleEvent(n, time.time() + 3)
|
||||
timeFastForward(1.2)
|
||||
sched.run()
|
||||
run(sched)
|
||||
self.assertEqual(i[0], 0)
|
||||
timeFastForward(2)
|
||||
sched.run()
|
||||
run(sched)
|
||||
self.assertEqual(i[0], 1)
|
||||
|
||||
def testPeriodic(self):
|
||||
@ -90,20 +96,20 @@ class TestSchedule(SupyTestCase):
|
||||
i[0] += 1
|
||||
n = sched.addPeriodicEvent(inc, 1, name='test_periodic')
|
||||
timeFastForward(0.6)
|
||||
sched.run() # 0.6
|
||||
run(sched) # 0.6
|
||||
self.assertEqual(i[0], 1)
|
||||
timeFastForward(0.6)
|
||||
sched.run() # 1.2
|
||||
run(sched) # 1.2
|
||||
self.assertEqual(i[0], 2)
|
||||
timeFastForward(0.6)
|
||||
sched.run() # 1.8
|
||||
run(sched) # 1.8
|
||||
self.assertEqual(i[0], 2)
|
||||
timeFastForward(0.6)
|
||||
sched.run() # 2.4
|
||||
run(sched) # 2.4
|
||||
self.assertEqual(i[0], 3)
|
||||
sched.removePeriodicEvent(n)
|
||||
timeFastForward(1)
|
||||
sched.run() # 3.4
|
||||
run(sched) # 3.4
|
||||
self.assertEqual(i[0], 3)
|
||||
|
||||
def testCountedPeriodic(self):
|
||||
@ -113,19 +119,19 @@ class TestSchedule(SupyTestCase):
|
||||
i[0] += 1
|
||||
n = sched.addPeriodicEvent(inc, 1, name='test_periodic', count=3)
|
||||
timeFastForward(0.6)
|
||||
sched.run() # 0.6
|
||||
run(sched) # 0.6
|
||||
self.assertEqual(i[0], 1)
|
||||
timeFastForward(0.6)
|
||||
sched.run() # 1.2
|
||||
run(sched) # 1.2
|
||||
self.assertEqual(i[0], 2)
|
||||
timeFastForward(0.6)
|
||||
sched.run() # 1.8
|
||||
run(sched) # 1.8
|
||||
self.assertEqual(i[0], 2)
|
||||
timeFastForward(0.6)
|
||||
sched.run() # 2.4
|
||||
run(sched) # 2.4
|
||||
self.assertEqual(i[0], 3)
|
||||
timeFastForward(1)
|
||||
sched.run() # 3.4
|
||||
run(sched) # 3.4
|
||||
self.assertEqual(i[0], 3)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user