From 5735d5190b892de6d5db873b2a9d912d3217bf90 Mon Sep 17 00:00:00 2001 From: Joren Date: Fri, 10 Apr 2026 19:35:46 +0200 Subject: [PATCH] Improve EPG mapping and connection normalization --- .../dispatcharr_sensor/__init__.py | 83 +++++++++++++++++-- .../dispatcharr_sensor/config_flow.py | 26 +++++- .../dispatcharr_sensor/media_player.py | 14 +++- 3 files changed, 113 insertions(+), 10 deletions(-) diff --git a/custom_components/dispatcharr_sensor/__init__.py b/custom_components/dispatcharr_sensor/__init__.py index cec5ccb..36897b2 100644 --- a/custom_components/dispatcharr_sensor/__init__.py +++ b/custom_components/dispatcharr_sensor/__init__.py @@ -4,6 +4,7 @@ import logging from datetime import timedelta, datetime, timezone import xml.etree.ElementTree as ET import re +from urllib.parse import urlparse import aiohttp from homeassistant.core import HomeAssistant @@ -49,7 +50,9 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator): self.websession = async_get_clientsession(hass) self._access_token: str | None = None self.channel_map: dict = {} + self.channel_map_by_id: dict[str, dict] = {} self.stream_catalog: dict[str, dict] = {} + self.stream_catalog_by_id: dict[int, dict] = {} self.stream_catalog_count: int = 0 self.stream_catalog_active_count: int = 0 self._last_catalog_refresh: datetime | None = None @@ -62,8 +65,22 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator): def base_url(self) -> str: """Get the base URL for API calls.""" data = self.config_entry.data - protocol = "https" if data.get("ssl", False) else "http" - return f"{protocol}://{data['host']}:{data['port']}" + host_raw = str(data.get("host", "")).strip() + ssl_enabled = bool(data.get("ssl", False)) + port = int(data.get("port", 443 if ssl_enabled else 80)) + + if "://" in host_raw: + parsed = urlparse(host_raw) + host = parsed.hostname or host_raw + if parsed.port: + port = parsed.port + if parsed.scheme in ("http", "https"): + ssl_enabled = parsed.scheme == "https" + else: + host = host_raw + + protocol = "https" if ssl_enabled else "http" + return f"{protocol}://{host}:{port}" async def _get_new_token(self) -> None: """Get a new access token using username and password.""" @@ -127,14 +144,22 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator): def _normalize_stream_name(self, name: str) -> str: """Normalize stream name for fuzzy cross-endpoint matching.""" - return slugify(re.sub(r"^\w+:\s*|\s+HD$", "", name or "", flags=re.IGNORECASE)) + cleaned = name or "" + cleaned = re.sub(r"^\s*[A-Z]{2,4}\s*[|:-]\s*", "", cleaned) + cleaned = re.sub(r"\s*\[[^\]]+\]\s*$", "", cleaned) + cleaned = re.sub( + r"\s+(U?F?HD|SD|4K|HEVC|H265|H264)\b", "", cleaned, flags=re.IGNORECASE + ) + cleaned = re.sub(r"\s+\([^\)]*\)\s*$", "", cleaned) + cleaned = re.sub(r"^\w+:\s*", "", cleaned, flags=re.IGNORECASE) + return slugify(cleaned) async def _async_fetch_stream_catalog(self) -> None: """Fetch stream list from Dispatcharr streams endpoint.""" all_streams: list[dict] = [] page = 1 page_size = 500 - max_pages = 40 + max_pages = 1000 while page <= max_pages: payload = await self.api_request( @@ -152,12 +177,24 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator): break page += 1 + if page > max_pages: + _LOGGER.warning( + "Stopped stream catalog fetch at max_pages=%d (fetched=%d)", + max_pages, + len(all_streams), + ) + catalog: dict[str, dict] = {} + catalog_by_id: dict[int, dict] = {} active_count = 0 for stream in all_streams: if not isinstance(stream, dict): continue + stream_id = stream.get("id") + if isinstance(stream_id, int): + catalog_by_id[stream_id] = stream + name = stream.get("name") if not name: continue @@ -182,6 +219,7 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator): catalog[normalized] = stream self.stream_catalog = catalog + self.stream_catalog_by_id = catalog_by_id self.stream_catalog_count = len(all_streams) self.stream_catalog_active_count = active_count self._last_catalog_refresh = datetime.now(timezone.utc) @@ -192,6 +230,12 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator): return None return self.stream_catalog.get(self._normalize_stream_name(stream_name)) + def _get_catalog_stream_by_id(self, stream_id: int | None) -> dict | None: + """Lookup stream catalog row by numeric stream id.""" + if not isinstance(stream_id, int): + return None + return self.stream_catalog_by_id.get(stream_id) + async def async_populate_channel_map_from_xml(self): """Fetch the XML file once to build a reliable map of channels.""" _LOGGER.info("Populating Dispatcharr channel map from XML file...") @@ -207,6 +251,7 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator): try: root = ET.fromstring(xml_string) self.channel_map = {} + self.channel_map_by_id = {} for channel in root.iterfind("channel"): display_name = channel.findtext("display-name") channel_id = channel.get("id") @@ -215,11 +260,13 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator): if display_name and channel_id: slug_name = slugify(display_name) - self.channel_map[slug_name] = { + details = { "id": channel_id, "name": display_name, "logo_url": icon_url, } + self.channel_map[slug_name] = details + self.channel_map_by_id[channel_id] = details if not self.channel_map: raise ConfigEntryNotReady( @@ -305,7 +352,9 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator): details = self._get_channel_details_from_stream_name(stream_name) enriched_stream = stream.copy() - catalog_stream = self._get_catalog_stream(stream_name) + catalog_stream = self._get_catalog_stream_by_id(stream.get("stream_id")) + if not catalog_stream: + catalog_stream = self._get_catalog_stream(stream_name) if catalog_stream: enriched_stream["catalog_stream_id"] = catalog_stream.get("id") enriched_stream["catalog_tvg_id"] = catalog_stream.get("tvg_id") @@ -320,6 +369,10 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator): if not enriched_stream.get("logo_url"): enriched_stream["logo_url"] = catalog_stream.get("logo_url") + catalog_tvg_id = catalog_stream.get("tvg_id") + if not details and catalog_tvg_id: + details = self.channel_map_by_id.get(catalog_tvg_id) + if details: xmltv_id = details["id"] enriched_stream["xmltv_id"] = xmltv_id @@ -336,12 +389,30 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator): episode_num_tag = program.find( "episode-num[@system='onscreen']" ) + icon_tag = program.find("icon") + icon_url = ( + icon_tag.get("src") + if icon_tag is not None + else None + ) + duration_seconds = int( + (stop_time - start_time).total_seconds() + ) + categories = [ + cat.text + for cat in program.findall("category") + if cat.text + ] enriched_stream["program"] = { "title": program.findtext("title"), "description": program.findtext("desc"), "start_time": start_time.isoformat(), "end_time": stop_time.isoformat(), + "duration_seconds": duration_seconds, + "duration_minutes": round(duration_seconds / 60), "subtitle": program.findtext("sub-title"), + "image": icon_url, + "categories": categories, "episode_num": episode_num_tag.text if episode_num_tag is not None else None, diff --git a/custom_components/dispatcharr_sensor/config_flow.py b/custom_components/dispatcharr_sensor/config_flow.py index f9074f2..2218c0b 100644 --- a/custom_components/dispatcharr_sensor/config_flow.py +++ b/custom_components/dispatcharr_sensor/config_flow.py @@ -1,7 +1,9 @@ """Config flow for Dispatcharr Sensor integration.""" + from __future__ import annotations import logging from typing import Any +from urllib.parse import urlparse import voluptuous as vol from homeassistant import config_entries from homeassistant.data_entry_flow import FlowResult @@ -14,17 +16,34 @@ STEP_USER_DATA_SCHEMA = vol.Schema( { vol.Required("host"): str, vol.Required("port", default=9191): int, - vol.Optional("ssl",default=False): bool, + vol.Optional("ssl", default=False): bool, vol.Required("username"): str, vol.Required("password"): str, } ) + class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): """Handle a config flow for Dispatcharr Sensor.""" VERSION = 1 + def _normalize_connection_input(self, user_input: dict[str, Any]) -> dict[str, Any]: + """Normalize host/port/ssl from host field.""" + data = dict(user_input) + host_raw = str(data.get("host", "")).strip() + + if "://" in host_raw: + parsed = urlparse(host_raw) + if parsed.hostname: + data["host"] = parsed.hostname + if parsed.port: + data["port"] = parsed.port + if parsed.scheme in ("http", "https"): + data["ssl"] = parsed.scheme == "https" + + return data + async def async_step_user( self, user_input: dict[str, Any] | None = None ) -> FlowResult: @@ -33,5 +52,6 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): return self.async_show_form( step_id="user", data_schema=STEP_USER_DATA_SCHEMA ) - - return self.async_create_entry(title="Dispatcharr", data=user_input) + + normalized = self._normalize_connection_input(user_input) + return self.async_create_entry(title="Dispatcharr", data=normalized) diff --git a/custom_components/dispatcharr_sensor/media_player.py b/custom_components/dispatcharr_sensor/media_player.py index b523cd7..60d2c3e 100644 --- a/custom_components/dispatcharr_sensor/media_player.py +++ b/custom_components/dispatcharr_sensor/media_player.py @@ -128,21 +128,33 @@ class DispatcharrStreamMediaPlayer(CoordinatorEntity, MediaPlayerEntity): stream_data = self.coordinator.data[self._stream_id] program_data = stream_data.get("program") or {} stream_stats = stream_data.get("stream_stats") or {} + channel_name = stream_data.get("channel_name") self._attr_state = STATE_PLAYING self._attr_app_name = "Dispatcharr" self._attr_entity_picture = stream_data.get("logo_url") + self._attr_media_image_url = program_data.get("image") or stream_data.get( + "logo_url" + ) self._attr_media_content_type = MediaType.TVSHOW self._attr_media_series_title = program_data.get("title") self._attr_media_title = program_data.get("subtitle") or program_data.get( "title" ) + if channel_name: + self._attr_name = channel_name # Extra attributes self._attr_extra_state_attributes = { "channel_number": stream_data.get("xmltv_id"), - "channel_name": stream_data.get("channel_name"), + "channel_name": channel_name, "program_description": program_data.get("description"), + "program_start_time": program_data.get("start_time"), + "program_end_time": program_data.get("end_time"), + "program_duration_seconds": program_data.get("duration_seconds"), + "program_duration_minutes": program_data.get("duration_minutes"), + "program_image": program_data.get("image"), + "program_categories": program_data.get("categories"), "clients": stream_data.get("client_count"), "catalog_current_viewers": stream_data.get("catalog_current_viewers"), "catalog_stream_id": stream_data.get("catalog_stream_id"),