diff --git a/bot.py b/bot.py index 405a060..345e333 100644 --- a/bot.py +++ b/bot.py @@ -1,8 +1,8 @@ import logging import requests import re -from telegram import InlineQueryResultArticle, InputTextMessageContent -from telegram.ext import Updater, CommandHandler, InlineQueryHandler +from telegram import InlineQueryResultArticle, InputTextMessageContent, InlineKeyboardButton, InlineKeyboardMarkup +from telegram.ext import Updater, CommandHandler, InlineQueryHandler, CallbackQueryHandler from exceptions import * @@ -25,7 +25,15 @@ class Query: request_tmp = ' '.join(args) argument_names = re.findall(r' to | in | times ', request_tmp) arguments = re.split(r' to | in | times ', request_tmp) - self.station_id = search_station(arguments[0]) + if not arguments[0].isdigit(): + reply = search_station(arguments[0]) + if len(reply) == 1: + self.station_id = reply[0]['stationId'] + else: + raise MultipleStationsFoundError(request_tmp, arguments[0], reply) + else: + self.station_id = arguments[0] + if ' to ' in argument_names: self.destination = arguments[argument_names.index(' to ') + 1] if ' in ' in argument_names: @@ -34,16 +42,51 @@ class Query: self.departure_count = int(arguments[argument_names.index(' times ') + 1]) +def build_menu(buttons, + n_cols, + header_buttons=None, + footer_buttons=None): + menu = [buttons[i:i + n_cols] for i in range(0, len(buttons), n_cols)] + if header_buttons: + menu.insert(0, [header_buttons]) + if footer_buttons: + menu.append([footer_buttons]) + return menu + + +def reply_multiple_stations(message, message_text, queried_station, station_list): + button_list = [] + for station in station_list: + button_list.append(InlineKeyboardButton(station['fullName'], + callback_data="/vvs " + message_text + .replace(queried_station, station['stationId']))) + reply_markup = InlineKeyboardMarkup(build_menu(button_list, n_cols=2)) + message.reply_text("Multiple stations found:", reply_markup=reply_markup) + + +def handle_multiple_stations_reply(update, context): + query = parse_station(update.callback_query.data.split(' ')[1:]) + departures = get_vvs_departures(query) + for reply in departures: + context.bot.send_message(update.effective_chat['id'], reply) + + def search_station(query): request = requests.get("https://efa-api.asw.io/api/v1/station/?search=" + query) - if request.status_code != 200 or (request.status_code == 200 and len(request.text) <= 2): - return -1 + if request.status_code != 200: + raise ServerCommunicationError + if request.status_code == 200 and len(request.text) <= 2: + raise StationNotFoundError else: - return request.json()[0]['stationId'] + return request.json() def handle_vvs(update, context): - query = parse_station(context.args) + try: + query = parse_station(context.args) + except MultipleStationsFoundError as error: + reply_multiple_stations(update.message, error.message_text, error.queried_station, error.station_list) + return departures = get_vvs_departures(query) for reply in departures: @@ -53,7 +96,12 @@ def handle_vvs(update, context): def parse_station(args): if len(args) == 0: raise NoArgError - query = Query(args) + try: + query = Query(args) + except StationNotFoundError: + raise + except MultipleStationsFoundError: + raise if query.station_id == -1: raise StationNotFoundError return query @@ -120,6 +168,7 @@ inline_station_search_handler = InlineQueryHandler(inline_station_search) dispatcher.add_handler(inline_station_search_handler) dispatcher.add_handler(CommandHandler('vvs', handle_vvs)) dispatcher.add_error_handler(error_callback) +dispatcher.add_handler(CallbackQueryHandler(handle_multiple_stations_reply)) updater.start_polling() updater.idle() diff --git a/exceptions.py b/exceptions.py index bab4dba..4f028c0 100644 --- a/exceptions.py +++ b/exceptions.py @@ -8,3 +8,14 @@ class StationNotFoundError(Exception): class ServerCommunicationError(Exception): pass + + +class MultipleStationsFoundError(Exception): + message_text = '' + queried_station = '' + station_list = [] + + def __init__(self, message_text, queried_station, station_list): + self.message_text = message_text + self.queried_station = queried_station + self.station_list = station_list