mirror of
https://github.com/progval/Limnoria.git
synced 2025-04-28 05:51:16 -05:00
Use asyncio to manage the driver loop
This allows us to remove the custom shoehorned select() handling without changing the driver design too much.
This commit is contained in:
parent
0ed743bb8e
commit
f164ac7fbe
@ -39,9 +39,10 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import errno
|
import errno
|
||||||
import threading
|
|
||||||
import select
|
import select
|
||||||
import socket
|
import socket
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
|
||||||
import ipaddress
|
import ipaddress
|
||||||
|
|
||||||
@ -58,9 +59,8 @@ except:
|
|||||||
class SSLError(Exception):
|
class SSLError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
||||||
_instances = []
|
|
||||||
_selecting = threading.Lock()
|
|
||||||
def __init__(self, irc):
|
def __init__(self, irc):
|
||||||
assert irc is not None
|
assert irc is not None
|
||||||
self.irc = irc
|
self.irc = irc
|
||||||
@ -108,26 +108,19 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
|||||||
# hasn't finished yet. We'll keep track of how many we get.
|
# hasn't finished yet. We'll keep track of how many we get.
|
||||||
if e is None or e.args[0] != 11 or self.eagains > 120:
|
if e is None or e.args[0] != 11 or self.eagains > 120:
|
||||||
drivers.log.disconnect(self.currentServer, e)
|
drivers.log.disconnect(self.currentServer, e)
|
||||||
if self in self._instances:
|
|
||||||
self._instances.remove(self)
|
|
||||||
try:
|
try:
|
||||||
self.conn.close()
|
self.conn.close()
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
self.connected = False
|
self.connected = False
|
||||||
if self.irc is None:
|
if self.irc is None:
|
||||||
# This driver is dead already, but we're still running because
|
|
||||||
# of select() running in an other driver's thread that started
|
|
||||||
# before this one died and stil holding a reference to this
|
|
||||||
# instance.
|
|
||||||
# Just return, and we should never be called again.
|
|
||||||
return
|
return
|
||||||
self.scheduleReconnect()
|
self.scheduleReconnect()
|
||||||
else:
|
else:
|
||||||
log.debug('Got EAGAIN, current count: %s.', self.eagains)
|
log.debug('Got EAGAIN, current count: %s.', self.eagains)
|
||||||
self.eagains += 1
|
self.eagains += 1
|
||||||
|
|
||||||
def _sendIfMsgs(self):
|
async def _sendIfMsgs(self):
|
||||||
if not self.connected:
|
if not self.connected:
|
||||||
return
|
return
|
||||||
if not self.zombie:
|
if not self.zombie:
|
||||||
@ -137,71 +130,34 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
|||||||
del msgs[-1]
|
del msgs[-1]
|
||||||
self.outbuffer += ''.join(map(str, msgs))
|
self.outbuffer += ''.join(map(str, msgs))
|
||||||
if self.outbuffer:
|
if self.outbuffer:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
try:
|
try:
|
||||||
if minisix.PY2:
|
await loop.sock_sendall(self.conn, self.outbuffer.encode())
|
||||||
sent = self.conn.send(self.outbuffer)
|
self.outbuffer = ''
|
||||||
else:
|
|
||||||
sent = self.conn.send(self.outbuffer.encode())
|
|
||||||
self.outbuffer = self.outbuffer[sent:]
|
|
||||||
self.eagains = 0
|
self.eagains = 0
|
||||||
except socket.error as e:
|
except socket.error as e:
|
||||||
self._handleSocketError(e)
|
self._handleSocketError(e)
|
||||||
if self.zombie and not self.outbuffer:
|
if self.zombie and not self.outbuffer:
|
||||||
self._reallyDie()
|
self._reallyDie()
|
||||||
|
|
||||||
@classmethod
|
async def run(self):
|
||||||
def _select(cls):
|
|
||||||
timeout = conf.supybot.drivers.poll()
|
|
||||||
try:
|
|
||||||
if not cls._selecting.acquire(blocking=False):
|
|
||||||
# there's already a thread running this code, abort.
|
|
||||||
return
|
|
||||||
for inst in cls._instances:
|
|
||||||
# Do not use a list comprehension here, we have to edit the list
|
|
||||||
# and not to reassign it.
|
|
||||||
if not inst.connected or \
|
|
||||||
(minisix.PY3 and inst.conn._closed) or \
|
|
||||||
(minisix.PY2 and
|
|
||||||
inst.conn._sock.__class__ is socket._closedsocket):
|
|
||||||
cls._instances.remove(inst)
|
|
||||||
elif inst.conn.fileno() == -1:
|
|
||||||
inst.reconnect()
|
|
||||||
if not cls._instances:
|
|
||||||
return
|
|
||||||
rlist, wlist, xlist = select.select([x.conn for x in cls._instances],
|
|
||||||
[], [], timeout)
|
|
||||||
for instance in cls._instances:
|
|
||||||
if instance.conn in rlist:
|
|
||||||
instance._read()
|
|
||||||
except select.error as e:
|
|
||||||
if e.args[0] != errno.EINTR:
|
|
||||||
# 'Interrupted system call'
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
cls._selecting.release()
|
|
||||||
for instance in cls._instances:
|
|
||||||
if instance.irc and not instance.irc.zombie:
|
|
||||||
instance._sendIfMsgs()
|
|
||||||
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
now = time.time()
|
now = time.time()
|
||||||
if self.nextReconnectTime is not None and now > self.nextReconnectTime:
|
if self.nextReconnectTime is not None and now > self.nextReconnectTime:
|
||||||
self.reconnect()
|
self.reconnect()
|
||||||
elif self.writeCheckTime is not None and now > self.writeCheckTime:
|
elif self.writeCheckTime is not None and now > self.writeCheckTime:
|
||||||
self._checkAndWriteOrReconnect()
|
self._checkAndWriteOrReconnect()
|
||||||
if not self.connected:
|
if not self.connected:
|
||||||
# We sleep here because otherwise, if we're the only driver, we'll
|
|
||||||
# spin at 100% CPU while we're disconnected.
|
|
||||||
time.sleep(conf.supybot.drivers.poll())
|
|
||||||
return
|
return
|
||||||
self._sendIfMsgs()
|
await asyncio.gather(
|
||||||
self._select()
|
self._sendIfMsgs(),
|
||||||
|
self._read(),
|
||||||
|
)
|
||||||
|
|
||||||
def _read(self):
|
async def _read(self):
|
||||||
"""Called by _select() when we can read data."""
|
"""Called by _select() when we can read data."""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
try:
|
try:
|
||||||
new_data = self.conn.recv(1024)
|
new_data = await loop.sock_recv(self.conn, 1024)
|
||||||
if not new_data:
|
if not new_data:
|
||||||
# Socket was closed
|
# Socket was closed
|
||||||
self._handleSocketError(None)
|
self._handleSocketError(None)
|
||||||
@ -241,7 +197,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
|||||||
self._handleSocketError(e)
|
self._handleSocketError(e)
|
||||||
return
|
return
|
||||||
if self.irc and not self.irc.zombie:
|
if self.irc and not self.irc.zombie:
|
||||||
self._sendIfMsgs()
|
await self._sendIfMsgs()
|
||||||
|
|
||||||
def connect(self, **kwargs):
|
def connect(self, **kwargs):
|
||||||
self.reconnect(reset=False, **kwargs)
|
self.reconnect(reset=False, **kwargs)
|
||||||
@ -252,8 +208,6 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
|||||||
if self.connected:
|
if self.connected:
|
||||||
self.onDisconnect()
|
self.onDisconnect()
|
||||||
drivers.log.reconnect(self.irc.network)
|
drivers.log.reconnect(self.irc.network)
|
||||||
if self in self._instances:
|
|
||||||
self._instances.remove(self)
|
|
||||||
try:
|
try:
|
||||||
self.conn.shutdown(socket.SHUT_RDWR)
|
self.conn.shutdown(socket.SHUT_RDWR)
|
||||||
except: # "Transport endpoint not connected"
|
except: # "Transport endpoint not connected"
|
||||||
@ -314,12 +268,10 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
|||||||
return
|
return
|
||||||
# We allow more time for the connect here, since it might take longer.
|
# We allow more time for the connect here, since it might take longer.
|
||||||
# At least 10 seconds.
|
# At least 10 seconds.
|
||||||
self.conn.settimeout(max(10, conf.supybot.drivers.poll()*10))
|
|
||||||
try:
|
try:
|
||||||
# Connect before SSL, otherwise SSL is disabled if we use SOCKS.
|
# Connect before SSL, otherwise SSL is disabled if we use SOCKS.
|
||||||
# See http://stackoverflow.com/q/16136916/539465
|
# See http://stackoverflow.com/q/16136916/539465
|
||||||
self.conn.connect(
|
self.conn.connect((address, self.currentServer.port))
|
||||||
(address, self.currentServer.port))
|
|
||||||
if network_config.ssl() or \
|
if network_config.ssl() or \
|
||||||
self.currentServer.force_tls_verification:
|
self.currentServer.force_tls_verification:
|
||||||
self.starttls()
|
self.starttls()
|
||||||
@ -343,8 +295,6 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
|||||||
'<http://docs.limnoria.net/en/latest/use/faq.html#how-to-make-a-connection-secure>')
|
'<http://docs.limnoria.net/en/latest/use/faq.html#how-to-make-a-connection-secure>')
|
||||||
% self.irc.network)
|
% self.irc.network)
|
||||||
|
|
||||||
conf.supybot.drivers.poll.addCallback(self.setTimeout)
|
|
||||||
self.setTimeout()
|
|
||||||
self.connected = True
|
self.connected = True
|
||||||
self.resetDelay()
|
self.resetDelay()
|
||||||
except socket.error as e:
|
except socket.error as e:
|
||||||
@ -361,13 +311,6 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
|||||||
drivers.log.connectError(self.currentServer, e)
|
drivers.log.connectError(self.currentServer, e)
|
||||||
self.scheduleReconnect()
|
self.scheduleReconnect()
|
||||||
return
|
return
|
||||||
self._instances.append(self)
|
|
||||||
|
|
||||||
def setTimeout(self):
|
|
||||||
try:
|
|
||||||
self.conn.settimeout(conf.supybot.drivers.poll())
|
|
||||||
except Exception:
|
|
||||||
drivers.log.exception('Could not set socket timeout:')
|
|
||||||
|
|
||||||
def _checkAndWriteOrReconnect(self):
|
def _checkAndWriteOrReconnect(self):
|
||||||
self.writeCheckTime = None
|
self.writeCheckTime = None
|
||||||
@ -393,9 +336,6 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
|||||||
self.nextReconnectTime = when
|
self.nextReconnectTime = when
|
||||||
|
|
||||||
def die(self):
|
def die(self):
|
||||||
if self in self._instances:
|
|
||||||
self._instances.remove(self)
|
|
||||||
conf.supybot.drivers.poll.removeCallback(self.setTimeout)
|
|
||||||
self.zombie = True
|
self.zombie = True
|
||||||
if self.nextReconnectTime is not None:
|
if self.nextReconnectTime is not None:
|
||||||
self.nextReconnectTime = None
|
self.nextReconnectTime = None
|
||||||
|
@ -35,6 +35,7 @@ Contains various drivers (network, file, and otherwise) for using IRC objects.
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
import socket
|
import socket
|
||||||
|
import asyncio
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
from .. import conf, ircdb, ircmsgs, ircutils, log as supylog, utils
|
from .. import conf, ircdb, ircmsgs, ircutils, log as supylog, utils
|
||||||
@ -146,15 +147,59 @@ def remove(name):
|
|||||||
"""Removes the driver with the given name from the loop."""
|
"""Removes the driver with the given name from the loop."""
|
||||||
_deadDrivers.add(name)
|
_deadDrivers.add(name)
|
||||||
|
|
||||||
|
|
||||||
|
def _loop_exception_handler(loop, context):
|
||||||
|
if 'Task was destroyed but it is pending' in context['message']:
|
||||||
|
# This happens when cancelling tasks because of KeyboardInterrupt.
|
||||||
|
# There is really nothing we can do about this, so it's
|
||||||
|
# pointless to log it.
|
||||||
|
return
|
||||||
|
log.error('Exception in drivers loop: %s', context['message'])
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
"""Runs the whole driver loop."""
|
"""Runs the whole driver loop."""
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
loop.set_exception_handler(_loop_exception_handler)
|
||||||
|
driver_names = []
|
||||||
|
futures = []
|
||||||
|
coroutines = [] # Used to cleanup on shutdown to avoid warnings
|
||||||
for (name, driver) in _drivers.items():
|
for (name, driver) in _drivers.items():
|
||||||
|
if name not in _deadDrivers:
|
||||||
|
try:
|
||||||
|
coroutine = driver.run()
|
||||||
|
future = asyncio.ensure_future(coroutine, loop=loop)
|
||||||
|
except Exception:
|
||||||
|
log.exception('Exception in drivers.run for driver %s:', name)
|
||||||
|
continue
|
||||||
|
driver_names.append(name)
|
||||||
|
coroutines.append(coroutine)
|
||||||
|
futures.append(future)
|
||||||
|
|
||||||
|
gather_task = asyncio.gather(*futures, return_exceptions=True, loop=loop)
|
||||||
|
timeout_gather_task = asyncio.wait_for(
|
||||||
|
gather_task,
|
||||||
|
timeout=conf.supybot.drivers.poll())
|
||||||
|
coroutines.append(timeout_gather_task)
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(timeout_gather_task)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
# Cleanup all the objects so they don't throw warnings.
|
||||||
|
gather_task.cancel()
|
||||||
|
for future in futures:
|
||||||
|
future.cancel()
|
||||||
|
for coroutine in coroutines:
|
||||||
|
coroutine.close()
|
||||||
|
raise
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
for (name, future) in zip(driver_names, futures):
|
||||||
try:
|
try:
|
||||||
if name not in _deadDrivers:
|
future.result() # Raises an exception if driver.run() did.
|
||||||
driver.run()
|
|
||||||
except:
|
except:
|
||||||
log.exception('Uncaught exception in in drivers.run:')
|
log.exception('Uncaught exception in drivers.run:')
|
||||||
_deadDrivers.add(name)
|
_deadDrivers.add(name)
|
||||||
|
|
||||||
for name in _deadDrivers:
|
for name in _deadDrivers:
|
||||||
try:
|
try:
|
||||||
driver = _drivers[name]
|
driver = _drivers[name]
|
||||||
|
@ -137,7 +137,7 @@ class Schedule(drivers.IrcDriver):
|
|||||||
|
|
||||||
removePeriodicEvent = removeEvent
|
removePeriodicEvent = removeEvent
|
||||||
|
|
||||||
def run(self):
|
async def run(self):
|
||||||
if len(drivers._drivers) == 1 and not world.testing:
|
if len(drivers._drivers) == 1 and not world.testing:
|
||||||
log.error('Schedule is the only remaining driver, '
|
log.error('Schedule is the only remaining driver, '
|
||||||
'why do we continue to live?')
|
'why do we continue to live?')
|
||||||
|
@ -31,6 +31,7 @@
|
|||||||
from supybot.test import *
|
from supybot.test import *
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
import asyncio
|
||||||
|
|
||||||
import supybot.schedule as schedule
|
import supybot.schedule as schedule
|
||||||
|
|
||||||
@ -43,6 +44,11 @@ class FakeSchedule(schedule.Schedule):
|
|||||||
def name(self):
|
def name(self):
|
||||||
return 'FakeSchedule'
|
return 'FakeSchedule'
|
||||||
|
|
||||||
|
|
||||||
|
def run(sched):
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
loop.run_until_complete(sched.run())
|
||||||
|
|
||||||
class TestSchedule(SupyTestCase):
|
class TestSchedule(SupyTestCase):
|
||||||
def testSchedule(self):
|
def testSchedule(self):
|
||||||
sched = FakeSchedule()
|
sched = FakeSchedule()
|
||||||
@ -55,14 +61,14 @@ class TestSchedule(SupyTestCase):
|
|||||||
sched.addEvent(add10, time.time() + 3)
|
sched.addEvent(add10, time.time() + 3)
|
||||||
sched.addEvent(add1, time.time() + 1)
|
sched.addEvent(add1, time.time() + 1)
|
||||||
timeFastForward(1.2)
|
timeFastForward(1.2)
|
||||||
sched.run()
|
run(sched)
|
||||||
self.assertEqual(i[0], 1)
|
self.assertEqual(i[0], 1)
|
||||||
timeFastForward(1.9)
|
timeFastForward(1.9)
|
||||||
sched.run()
|
run(sched)
|
||||||
self.assertEqual(i[0], 11)
|
self.assertEqual(i[0], 11)
|
||||||
|
|
||||||
sched.addEvent(add10, time.time() + 3, 'test')
|
sched.addEvent(add10, time.time() + 3, 'test')
|
||||||
sched.run()
|
run(sched)
|
||||||
self.assertEqual(i[0], 11)
|
self.assertEqual(i[0], 11)
|
||||||
sched.removeEvent('test')
|
sched.removeEvent('test')
|
||||||
self.assertEqual(i[0], 11)
|
self.assertEqual(i[0], 11)
|
||||||
@ -77,10 +83,10 @@ class TestSchedule(SupyTestCase):
|
|||||||
n = sched.addEvent(inc, time.time() + 1)
|
n = sched.addEvent(inc, time.time() + 1)
|
||||||
sched.rescheduleEvent(n, time.time() + 3)
|
sched.rescheduleEvent(n, time.time() + 3)
|
||||||
timeFastForward(1.2)
|
timeFastForward(1.2)
|
||||||
sched.run()
|
run(sched)
|
||||||
self.assertEqual(i[0], 0)
|
self.assertEqual(i[0], 0)
|
||||||
timeFastForward(2)
|
timeFastForward(2)
|
||||||
sched.run()
|
run(sched)
|
||||||
self.assertEqual(i[0], 1)
|
self.assertEqual(i[0], 1)
|
||||||
|
|
||||||
def testPeriodic(self):
|
def testPeriodic(self):
|
||||||
@ -90,20 +96,20 @@ class TestSchedule(SupyTestCase):
|
|||||||
i[0] += 1
|
i[0] += 1
|
||||||
n = sched.addPeriodicEvent(inc, 1, name='test_periodic')
|
n = sched.addPeriodicEvent(inc, 1, name='test_periodic')
|
||||||
timeFastForward(0.6)
|
timeFastForward(0.6)
|
||||||
sched.run() # 0.6
|
run(sched) # 0.6
|
||||||
self.assertEqual(i[0], 1)
|
self.assertEqual(i[0], 1)
|
||||||
timeFastForward(0.6)
|
timeFastForward(0.6)
|
||||||
sched.run() # 1.2
|
run(sched) # 1.2
|
||||||
self.assertEqual(i[0], 2)
|
self.assertEqual(i[0], 2)
|
||||||
timeFastForward(0.6)
|
timeFastForward(0.6)
|
||||||
sched.run() # 1.8
|
run(sched) # 1.8
|
||||||
self.assertEqual(i[0], 2)
|
self.assertEqual(i[0], 2)
|
||||||
timeFastForward(0.6)
|
timeFastForward(0.6)
|
||||||
sched.run() # 2.4
|
run(sched) # 2.4
|
||||||
self.assertEqual(i[0], 3)
|
self.assertEqual(i[0], 3)
|
||||||
sched.removePeriodicEvent(n)
|
sched.removePeriodicEvent(n)
|
||||||
timeFastForward(1)
|
timeFastForward(1)
|
||||||
sched.run() # 3.4
|
run(sched) # 3.4
|
||||||
self.assertEqual(i[0], 3)
|
self.assertEqual(i[0], 3)
|
||||||
|
|
||||||
def testCountedPeriodic(self):
|
def testCountedPeriodic(self):
|
||||||
@ -113,19 +119,19 @@ class TestSchedule(SupyTestCase):
|
|||||||
i[0] += 1
|
i[0] += 1
|
||||||
n = sched.addPeriodicEvent(inc, 1, name='test_periodic', count=3)
|
n = sched.addPeriodicEvent(inc, 1, name='test_periodic', count=3)
|
||||||
timeFastForward(0.6)
|
timeFastForward(0.6)
|
||||||
sched.run() # 0.6
|
run(sched) # 0.6
|
||||||
self.assertEqual(i[0], 1)
|
self.assertEqual(i[0], 1)
|
||||||
timeFastForward(0.6)
|
timeFastForward(0.6)
|
||||||
sched.run() # 1.2
|
run(sched) # 1.2
|
||||||
self.assertEqual(i[0], 2)
|
self.assertEqual(i[0], 2)
|
||||||
timeFastForward(0.6)
|
timeFastForward(0.6)
|
||||||
sched.run() # 1.8
|
run(sched) # 1.8
|
||||||
self.assertEqual(i[0], 2)
|
self.assertEqual(i[0], 2)
|
||||||
timeFastForward(0.6)
|
timeFastForward(0.6)
|
||||||
sched.run() # 2.4
|
run(sched) # 2.4
|
||||||
self.assertEqual(i[0], 3)
|
self.assertEqual(i[0], 3)
|
||||||
timeFastForward(1)
|
timeFastForward(1)
|
||||||
sched.run() # 3.4
|
run(sched) # 3.4
|
||||||
self.assertEqual(i[0], 3)
|
self.assertEqual(i[0], 3)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user