BUG-197: improve session ID tracking
[bgpcep.git] / pcep / pcepy / peer / pce.py
1 # PCE and its handlers
2
3 # Copyright (c) 2012,2013 Cisco Systems, Inc. and others.  All rights reserved.
4 #
5 # This program and the accompanying materials are made available under the
6 # terms of the Eclipse Public License v1.0 which accompanies this distribution,
7 # and is available at http://www.eclipse.org/legal/epl-v10.html
8
9
10 import time
11 import socket as _socket
12
13 from . import base
14 from . import lsp as _lsp
15 from pcepy import session as _session
16 from pcepy import message as _message
17
18 class Pce(base.Peer):
19     """A simulated Path Computation Element"""
20
21     CONFIG_SERVER_CONFIG = 'pce.server_config'
22     CONFIG_SESSION_CONFIG = 'pce.session_config'
23
24     def __init__(self, name, context):
25         super(Pce, self).__init__(name, context)
26         self._servers = list()
27
28     def _get_active(self):
29         return bool(self._servers) or super(Pce, self)._get_active()
30
31     def _create_handlers(self):
32         super(Pce, self)._create_handlers()
33         for handler_class in (Listener, Reporter,):
34             self.add_handler(self._create_handler(handler_class))
35
36     def create_server(self, node):
37         """Bind to a server socket specified by node and put it on the bus."""
38         key = Pce.CONFIG_SERVER_CONFIG
39         server_config = self[key]
40         if server_config:
41             server_config = dict(server_config)
42         else:
43             server_config = dict()
44         server_node_config = self[key, node.name]
45         if server_node_config:
46             server_config.update(server_node_config)
47
48         if not node.port:
49             node = node.with_port(_session.PCEP_PORT)
50         server = _session.PcepServer(self, node, server_config)
51         self._servers.append(server)
52         self.context.bus.add(server)
53         return server
54
55     def create_session(self, local, socket, remote):
56         """Create PCEP session from socket and put it on the bus."""
57         key = Pce.CONFIG_SESSION_CONFIG
58         session_config = self[key]
59         if session_config:
60             session_config = dict(session_config)
61         else:
62             session_config = dict()
63         session_remote_config = self[key, remote.name]
64         if session_remote_config:
65             session_config.update(session_remote_config)
66
67         pcep_session = _session.PcepAccept(
68             self, local, socket, remote, session_config
69         )
70         self._sessions.append(pcep_session)
71         self.context.bus.add(pcep_session)
72         return pcep_session
73
74     def shutdown(self):
75         "Ask all server sessions to close. Then close all incoming sessions"
76         for server in self._servers:
77             server.closing = True
78         super(Pce, self).shutdown()
79
80
81 class Listener(base.Handler):
82     """Manage incoming connections on PceServer sockets"""
83
84     CONFIG_TIMEOUT = 'pcep_server.timeout'
85     STATE_TIMEOUT = '_listener.timeout'
86
87     def on_open(self, peer, eventargs):
88         server = eventargs['session']
89         if not server.is_server():
90             return
91         if Listener.CONFIG_TIMEOUT not in server:
92             timeout = peer[Listener.CONFIG_TIMEOUT]
93             if timeout:
94                 server[Listener.STATE_TIMEOUT] = _session.resolve_timeout(timeout)
95
96     def on_close(self, peer, eventargs):
97         server = eventargs['session']
98         if not server.is_server():
99             return
100         peer._servers.remove(server)
101
102     def on_connection(self, peer, eventargs):
103         # TODO: restrict addresses by config
104         base._LOGGER.debug('Creating accepted session on peer %s' % peer)
105         server = eventargs['server']
106         socket = eventargs['socket']
107         address, port = eventargs['address'][:2]
108         address = peer.context.address_from(address)
109         node = peer.context.get_node(_session.Node.ROLE_PCC,
110             address=address, port=port
111         )
112         session = peer.create_session(server.local, socket, node)
113         base._LOGGER.debug('Created accepted session %s' % session)
114         del server[Listener.CONFIG_TIMEOUT] #TODO: only if not expecting more
115
116     def on_timeout(self, peer, eventargs):
117         session = eventargs['session']
118         now = eventargs['now']
119
120         timeout = session[Listener.STATE_TIMEOUT]
121         if not timeout or timeout > now:
122             return
123
124         del session[Listener.STATE_TIMEOUT]
125         peer.emit('on_socket_error', session=session,
126             error=_socket.timeout('%s: timed out' % session)
127         )
128         session.closing = True
129
130     def timeout(self, session):
131         if session.is_server():
132             return session[Listener.STATE_TIMEOUT]
133
134
135 class Requester(base.Handler):
136     """Manage reception of PCReq messages."""
137     # May be implemented later
138     pass
139
140
141 class Reporter(base.Handler):
142     """Manage reception of PCRpt messages.
143
144     Manage state database (lsp.StateReports). Await specific state reports.
145
146     The Reporter emits these events to its peer:
147         on_state_report(session, lsp, report, new):
148             A [new] state report has arrived; called before adding to statedb.
149
150         on_synchronized(session, statedb):
151             State synchronization has completed and recorded in lsp.Reports.
152
153         on_await_report(session, key, arrived):
154             State report for key has arrived (with report) or timed out (None).
155     """
156
157     STATE_STATEDB = '_reporter.statedb'
158     STATE_AWAITED = '_reporter.awaited'
159     STATE_STATE = '_reporter.state'
160
161     # Values for STATE_STATE
162     RS_NONE, RS_SYNCING, RS_AVOID, RS_SYNCED = range(0, 4)
163
164     def on_session_open(self, peer, eventargs):
165         session = eventargs['session']
166         if session[base.Opener.STATE_PCEPTYPE] == base.Opener.PCEPTYPE_STATELESS:
167             session[Reporter.STATE_STATE] = Reporter.RS_NONE
168             return
169         statedb = self._get_statedb(peer, session)
170         session[Reporter.STATE_STATEDB] = statedb
171         session[Reporter.STATE_AWAITED] = self._get_awaited(peer, session)
172
173         local_open = session[base.Opener.STATE_LOCAL_OPEN]
174         remote_open = session[base.Opener.STATE_REMOTE_OPEN]
175         avoid = statedb.can_avoid(pce_open=local_open, pcc_open=remote_open)
176         if avoid:
177             base._LOGGER.info('Session "%s" has valid database version "%s"'
178                 % (session, statedb.version)
179             )
180             state = Reporter.RS_AVOID
181         else:
182             state = Reporter.RS_SYNCING
183         session[Reporter.STATE_STATE] = state
184
185     def on_message(self, peer, eventargs):
186         session = eventargs['session']
187         statedb = session[Reporter.STATE_STATEDB]
188         if statedb is None:
189             return
190         message = eventargs['message']
191         if not isinstance(message, _message.PCRpt):
192             return
193         awaited = session[Reporter.STATE_AWAITED]
194         use_dbv = session[base.Opener.STATE_USE_DBV]
195         state = session[Reporter.STATE_STATE]
196
197         for report in message.poll('report'):
198             report_lsp = report.poll('lsp')
199             if report_lsp is None:
200                 peer.make_pcep_error(
201                     origin = self,
202                     session = session,
203                     cause = report,
204                     send = _message.code.Error.MandatoryObjectMissing_LSP,
205                     closing = False
206                 )
207                 continue
208             lsp = statedb[report_lsp.lsp_id]
209             if lsp is None:
210                 new = True
211                 name = report_lsp.get(_message.tlv.LspSymbolicName)
212                 if name is None:
213                     base._LOGGER.error('New LSP "%s" missing name in "%s"'
214                         % (report_lsp, session)
215                     )
216                     # FIXME: should be a PCEP error
217                 else:
218                     name = name.lsp_name
219                 lsp = _lsp.Lsp(name=name, lsp_id=report_lsp.lsp_id)
220             else:
221                 new = False
222
223             peer.emit('on_state_report', session=session,
224                 lsp = lsp, report = report, new = new
225             )
226
227             try:
228                 statedb.get_version(report_lsp, use_dbv)
229             except ValueError as value_error:
230                 peer.make_pcep_error(
231                     origin = self,
232                     session = session,
233                     cause = (report, value_error),
234                     send = _message.code.Error.MandatoryObjectMissing_DBV,
235                     closing = True
236                 )
237                 return
238
239             if new:
240                 statedb.add(lsp)
241             lsp.report = report
242
243             for key in awaited.match(report):
244                 peer.emit('on_await_report', session=session,
245                     key = key, arrived = report
246                 )
247
248             if not report_lsp.synchronize:
249                 if state != Reporter.RS_SYNCED:
250                     state = Reporter.RS_SYNCED
251                     session[Reporter.STATE_STATE] = state
252                     peer.emit('on_synchronized', session=session,
253                         statedb = statedb
254                     )
255             elif state == Reporter.RS_AVOID:
256                 base._LOGGER.warning('Session "%s": synchronization not avoided'
257                     % session
258                 )
259                 state = Reporter.RS_SYNCING
260                 session[Reporter.STATE_STATE] = state
261             else:
262                 base._LOGGER.error('Session "%s": already synchronized'
263                     % session
264                 )
265
266     def on_timeout(self, peer, eventargs):
267         session = eventargs['session']
268         now = eventargs['now']
269         awaited = session[Reporter.STATE_AWAITED]
270         if awaited is None:
271             return
272         outs = awaited.out(now)
273         for out in outs:
274             peer.emit('on_await_report', session=session,
275                 key = out, arrived = None
276             )
277
278     def timeout(self, session):
279         awaited = session[Reporter.STATE_AWAITED]
280         return None if awaited is None else awaited.timeout
281
282     def _get_statedb(self, peer, session):
283         """Create state database for session"""
284         return _lsp.StateDb(self, peer, session)
285
286     def _get_awaited(self, peer, session):
287         """Create awaited report database for session"""
288         return _lsp.Awaited(self, session)
289
290     def await(self, session, criteria, timeout=None):
291         """Watch for a [set of] state reports satisfying criteria."""
292         awaited = session[Reporter.STATE_AWAITED]
293         for key, criterion in criteria.items():
294             awaited.add(self._get_await(session, key, criterion, timeout))
295
296     def _get_await(self, session, key, criterion, timeout=None):
297         """Transform a criterium into an Await object."""
298         return _lsp.Await(key, criterion, criterion.get('timeout', timeout))
299