Now supports ambiguous stations.

When multiple stations are found, the user is presented with buttons to
chosse from. The arguments entered before will be preserved
This commit is contained in:
JuliusFreudenberger 2019-10-17 20:19:37 +02:00
parent 0d92c0abe6
commit 7ee9e2a9f1
2 changed files with 68 additions and 8 deletions

61
bot.py
View file

@ -1,8 +1,8 @@
import logging import logging
import requests import requests
import re import re
from telegram import InlineQueryResultArticle, InputTextMessageContent from telegram import InlineQueryResultArticle, InputTextMessageContent, InlineKeyboardButton, InlineKeyboardMarkup
from telegram.ext import Updater, CommandHandler, InlineQueryHandler from telegram.ext import Updater, CommandHandler, InlineQueryHandler, CallbackQueryHandler
from exceptions import * from exceptions import *
@ -25,7 +25,15 @@ class Query:
request_tmp = ' '.join(args) request_tmp = ' '.join(args)
argument_names = re.findall(r' to | in | times ', request_tmp) argument_names = re.findall(r' to | in | times ', request_tmp)
arguments = re.split(r' to | in | times ', request_tmp) arguments = re.split(r' to | in | times ', request_tmp)
self.station_id = search_station(arguments[0]) if not arguments[0].isdigit():
reply = search_station(arguments[0])
if len(reply) == 1:
self.station_id = reply[0]['stationId']
else:
raise MultipleStationsFoundError(request_tmp, arguments[0], reply)
else:
self.station_id = arguments[0]
if ' to ' in argument_names: if ' to ' in argument_names:
self.destination = arguments[argument_names.index(' to ') + 1] self.destination = arguments[argument_names.index(' to ') + 1]
if ' in ' in argument_names: if ' in ' in argument_names:
@ -34,16 +42,51 @@ class Query:
self.departure_count = int(arguments[argument_names.index(' times ') + 1]) self.departure_count = int(arguments[argument_names.index(' times ') + 1])
def build_menu(buttons,
n_cols,
header_buttons=None,
footer_buttons=None):
menu = [buttons[i:i + n_cols] for i in range(0, len(buttons), n_cols)]
if header_buttons:
menu.insert(0, [header_buttons])
if footer_buttons:
menu.append([footer_buttons])
return menu
def reply_multiple_stations(message, message_text, queried_station, station_list):
button_list = []
for station in station_list:
button_list.append(InlineKeyboardButton(station['fullName'],
callback_data="/vvs " + message_text
.replace(queried_station, station['stationId'])))
reply_markup = InlineKeyboardMarkup(build_menu(button_list, n_cols=2))
message.reply_text("Multiple stations found:", reply_markup=reply_markup)
def handle_multiple_stations_reply(update, context):
query = parse_station(update.callback_query.data.split(' ')[1:])
departures = get_vvs_departures(query)
for reply in departures:
context.bot.send_message(update.effective_chat['id'], reply)
def search_station(query): def search_station(query):
request = requests.get("https://efa-api.asw.io/api/v1/station/?search=" + query) request = requests.get("https://efa-api.asw.io/api/v1/station/?search=" + query)
if request.status_code != 200 or (request.status_code == 200 and len(request.text) <= 2): if request.status_code != 200:
return -1 raise ServerCommunicationError
if request.status_code == 200 and len(request.text) <= 2:
raise StationNotFoundError
else: else:
return request.json()[0]['stationId'] return request.json()
def handle_vvs(update, context): def handle_vvs(update, context):
try:
query = parse_station(context.args) query = parse_station(context.args)
except MultipleStationsFoundError as error:
reply_multiple_stations(update.message, error.message_text, error.queried_station, error.station_list)
return
departures = get_vvs_departures(query) departures = get_vvs_departures(query)
for reply in departures: for reply in departures:
@ -53,7 +96,12 @@ def handle_vvs(update, context):
def parse_station(args): def parse_station(args):
if len(args) == 0: if len(args) == 0:
raise NoArgError raise NoArgError
try:
query = Query(args) query = Query(args)
except StationNotFoundError:
raise
except MultipleStationsFoundError:
raise
if query.station_id == -1: if query.station_id == -1:
raise StationNotFoundError raise StationNotFoundError
return query return query
@ -120,6 +168,7 @@ inline_station_search_handler = InlineQueryHandler(inline_station_search)
dispatcher.add_handler(inline_station_search_handler) dispatcher.add_handler(inline_station_search_handler)
dispatcher.add_handler(CommandHandler('vvs', handle_vvs)) dispatcher.add_handler(CommandHandler('vvs', handle_vvs))
dispatcher.add_error_handler(error_callback) dispatcher.add_error_handler(error_callback)
dispatcher.add_handler(CallbackQueryHandler(handle_multiple_stations_reply))
updater.start_polling() updater.start_polling()
updater.idle() updater.idle()

View file

@ -8,3 +8,14 @@ class StationNotFoundError(Exception):
class ServerCommunicationError(Exception): class ServerCommunicationError(Exception):
pass pass
class MultipleStationsFoundError(Exception):
message_text = ''
queried_station = ''
station_list = []
def __init__(self, message_text, queried_station, station_list):
self.message_text = message_text
self.queried_station = queried_station
self.station_list = station_list