Improve EPG mapping and connection normalization

This commit is contained in:
2026-04-10 19:35:46 +02:00
parent 3ac86ddce4
commit 5735d5190b
3 changed files with 113 additions and 10 deletions

View File

@@ -4,6 +4,7 @@ import logging
from datetime import timedelta, datetime, timezone from datetime import timedelta, datetime, timezone
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import re import re
from urllib.parse import urlparse
import aiohttp import aiohttp
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@@ -49,7 +50,9 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator):
self.websession = async_get_clientsession(hass) self.websession = async_get_clientsession(hass)
self._access_token: str | None = None self._access_token: str | None = None
self.channel_map: dict = {} self.channel_map: dict = {}
self.channel_map_by_id: dict[str, dict] = {}
self.stream_catalog: dict[str, dict] = {} self.stream_catalog: dict[str, dict] = {}
self.stream_catalog_by_id: dict[int, dict] = {}
self.stream_catalog_count: int = 0 self.stream_catalog_count: int = 0
self.stream_catalog_active_count: int = 0 self.stream_catalog_active_count: int = 0
self._last_catalog_refresh: datetime | None = None self._last_catalog_refresh: datetime | None = None
@@ -62,8 +65,22 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator):
def base_url(self) -> str: def base_url(self) -> str:
"""Get the base URL for API calls.""" """Get the base URL for API calls."""
data = self.config_entry.data data = self.config_entry.data
protocol = "https" if data.get("ssl", False) else "http" host_raw = str(data.get("host", "")).strip()
return f"{protocol}://{data['host']}:{data['port']}" ssl_enabled = bool(data.get("ssl", False))
port = int(data.get("port", 443 if ssl_enabled else 80))
if "://" in host_raw:
parsed = urlparse(host_raw)
host = parsed.hostname or host_raw
if parsed.port:
port = parsed.port
if parsed.scheme in ("http", "https"):
ssl_enabled = parsed.scheme == "https"
else:
host = host_raw
protocol = "https" if ssl_enabled else "http"
return f"{protocol}://{host}:{port}"
async def _get_new_token(self) -> None: async def _get_new_token(self) -> None:
"""Get a new access token using username and password.""" """Get a new access token using username and password."""
@@ -127,14 +144,22 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator):
def _normalize_stream_name(self, name: str) -> str: def _normalize_stream_name(self, name: str) -> str:
"""Normalize stream name for fuzzy cross-endpoint matching.""" """Normalize stream name for fuzzy cross-endpoint matching."""
return slugify(re.sub(r"^\w+:\s*|\s+HD$", "", name or "", flags=re.IGNORECASE)) cleaned = name or ""
cleaned = re.sub(r"^\s*[A-Z]{2,4}\s*[|:-]\s*", "", cleaned)
cleaned = re.sub(r"\s*\[[^\]]+\]\s*$", "", cleaned)
cleaned = re.sub(
r"\s+(U?F?HD|SD|4K|HEVC|H265|H264)\b", "", cleaned, flags=re.IGNORECASE
)
cleaned = re.sub(r"\s+\([^\)]*\)\s*$", "", cleaned)
cleaned = re.sub(r"^\w+:\s*", "", cleaned, flags=re.IGNORECASE)
return slugify(cleaned)
async def _async_fetch_stream_catalog(self) -> None: async def _async_fetch_stream_catalog(self) -> None:
"""Fetch stream list from Dispatcharr streams endpoint.""" """Fetch stream list from Dispatcharr streams endpoint."""
all_streams: list[dict] = [] all_streams: list[dict] = []
page = 1 page = 1
page_size = 500 page_size = 500
max_pages = 40 max_pages = 1000
while page <= max_pages: while page <= max_pages:
payload = await self.api_request( payload = await self.api_request(
@@ -152,12 +177,24 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator):
break break
page += 1 page += 1
if page > max_pages:
_LOGGER.warning(
"Stopped stream catalog fetch at max_pages=%d (fetched=%d)",
max_pages,
len(all_streams),
)
catalog: dict[str, dict] = {} catalog: dict[str, dict] = {}
catalog_by_id: dict[int, dict] = {}
active_count = 0 active_count = 0
for stream in all_streams: for stream in all_streams:
if not isinstance(stream, dict): if not isinstance(stream, dict):
continue continue
stream_id = stream.get("id")
if isinstance(stream_id, int):
catalog_by_id[stream_id] = stream
name = stream.get("name") name = stream.get("name")
if not name: if not name:
continue continue
@@ -182,6 +219,7 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator):
catalog[normalized] = stream catalog[normalized] = stream
self.stream_catalog = catalog self.stream_catalog = catalog
self.stream_catalog_by_id = catalog_by_id
self.stream_catalog_count = len(all_streams) self.stream_catalog_count = len(all_streams)
self.stream_catalog_active_count = active_count self.stream_catalog_active_count = active_count
self._last_catalog_refresh = datetime.now(timezone.utc) self._last_catalog_refresh = datetime.now(timezone.utc)
@@ -192,6 +230,12 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator):
return None return None
return self.stream_catalog.get(self._normalize_stream_name(stream_name)) return self.stream_catalog.get(self._normalize_stream_name(stream_name))
def _get_catalog_stream_by_id(self, stream_id: int | None) -> dict | None:
"""Lookup stream catalog row by numeric stream id."""
if not isinstance(stream_id, int):
return None
return self.stream_catalog_by_id.get(stream_id)
async def async_populate_channel_map_from_xml(self): async def async_populate_channel_map_from_xml(self):
"""Fetch the XML file once to build a reliable map of channels.""" """Fetch the XML file once to build a reliable map of channels."""
_LOGGER.info("Populating Dispatcharr channel map from XML file...") _LOGGER.info("Populating Dispatcharr channel map from XML file...")
@@ -207,6 +251,7 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator):
try: try:
root = ET.fromstring(xml_string) root = ET.fromstring(xml_string)
self.channel_map = {} self.channel_map = {}
self.channel_map_by_id = {}
for channel in root.iterfind("channel"): for channel in root.iterfind("channel"):
display_name = channel.findtext("display-name") display_name = channel.findtext("display-name")
channel_id = channel.get("id") channel_id = channel.get("id")
@@ -215,11 +260,13 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator):
if display_name and channel_id: if display_name and channel_id:
slug_name = slugify(display_name) slug_name = slugify(display_name)
self.channel_map[slug_name] = { details = {
"id": channel_id, "id": channel_id,
"name": display_name, "name": display_name,
"logo_url": icon_url, "logo_url": icon_url,
} }
self.channel_map[slug_name] = details
self.channel_map_by_id[channel_id] = details
if not self.channel_map: if not self.channel_map:
raise ConfigEntryNotReady( raise ConfigEntryNotReady(
@@ -305,6 +352,8 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator):
details = self._get_channel_details_from_stream_name(stream_name) details = self._get_channel_details_from_stream_name(stream_name)
enriched_stream = stream.copy() enriched_stream = stream.copy()
catalog_stream = self._get_catalog_stream_by_id(stream.get("stream_id"))
if not catalog_stream:
catalog_stream = self._get_catalog_stream(stream_name) catalog_stream = self._get_catalog_stream(stream_name)
if catalog_stream: if catalog_stream:
enriched_stream["catalog_stream_id"] = catalog_stream.get("id") enriched_stream["catalog_stream_id"] = catalog_stream.get("id")
@@ -320,6 +369,10 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator):
if not enriched_stream.get("logo_url"): if not enriched_stream.get("logo_url"):
enriched_stream["logo_url"] = catalog_stream.get("logo_url") enriched_stream["logo_url"] = catalog_stream.get("logo_url")
catalog_tvg_id = catalog_stream.get("tvg_id")
if not details and catalog_tvg_id:
details = self.channel_map_by_id.get(catalog_tvg_id)
if details: if details:
xmltv_id = details["id"] xmltv_id = details["id"]
enriched_stream["xmltv_id"] = xmltv_id enriched_stream["xmltv_id"] = xmltv_id
@@ -336,12 +389,30 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator):
episode_num_tag = program.find( episode_num_tag = program.find(
"episode-num[@system='onscreen']" "episode-num[@system='onscreen']"
) )
icon_tag = program.find("icon")
icon_url = (
icon_tag.get("src")
if icon_tag is not None
else None
)
duration_seconds = int(
(stop_time - start_time).total_seconds()
)
categories = [
cat.text
for cat in program.findall("category")
if cat.text
]
enriched_stream["program"] = { enriched_stream["program"] = {
"title": program.findtext("title"), "title": program.findtext("title"),
"description": program.findtext("desc"), "description": program.findtext("desc"),
"start_time": start_time.isoformat(), "start_time": start_time.isoformat(),
"end_time": stop_time.isoformat(), "end_time": stop_time.isoformat(),
"duration_seconds": duration_seconds,
"duration_minutes": round(duration_seconds / 60),
"subtitle": program.findtext("sub-title"), "subtitle": program.findtext("sub-title"),
"image": icon_url,
"categories": categories,
"episode_num": episode_num_tag.text "episode_num": episode_num_tag.text
if episode_num_tag is not None if episode_num_tag is not None
else None, else None,

View File

@@ -1,7 +1,9 @@
"""Config flow for Dispatcharr Sensor integration.""" """Config flow for Dispatcharr Sensor integration."""
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any from typing import Any
from urllib.parse import urlparse
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
@@ -14,17 +16,34 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
{ {
vol.Required("host"): str, vol.Required("host"): str,
vol.Required("port", default=9191): int, vol.Required("port", default=9191): int,
vol.Optional("ssl",default=False): bool, vol.Optional("ssl", default=False): bool,
vol.Required("username"): str, vol.Required("username"): str,
vol.Required("password"): str, vol.Required("password"): str,
} }
) )
class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Dispatcharr Sensor.""" """Handle a config flow for Dispatcharr Sensor."""
VERSION = 1 VERSION = 1
def _normalize_connection_input(self, user_input: dict[str, Any]) -> dict[str, Any]:
"""Normalize host/port/ssl from host field."""
data = dict(user_input)
host_raw = str(data.get("host", "")).strip()
if "://" in host_raw:
parsed = urlparse(host_raw)
if parsed.hostname:
data["host"] = parsed.hostname
if parsed.port:
data["port"] = parsed.port
if parsed.scheme in ("http", "https"):
data["ssl"] = parsed.scheme == "https"
return data
async def async_step_user( async def async_step_user(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> FlowResult:
@@ -34,4 +53,5 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
step_id="user", data_schema=STEP_USER_DATA_SCHEMA step_id="user", data_schema=STEP_USER_DATA_SCHEMA
) )
return self.async_create_entry(title="Dispatcharr", data=user_input) normalized = self._normalize_connection_input(user_input)
return self.async_create_entry(title="Dispatcharr", data=normalized)

View File

@@ -128,21 +128,33 @@ class DispatcharrStreamMediaPlayer(CoordinatorEntity, MediaPlayerEntity):
stream_data = self.coordinator.data[self._stream_id] stream_data = self.coordinator.data[self._stream_id]
program_data = stream_data.get("program") or {} program_data = stream_data.get("program") or {}
stream_stats = stream_data.get("stream_stats") or {} stream_stats = stream_data.get("stream_stats") or {}
channel_name = stream_data.get("channel_name")
self._attr_state = STATE_PLAYING self._attr_state = STATE_PLAYING
self._attr_app_name = "Dispatcharr" self._attr_app_name = "Dispatcharr"
self._attr_entity_picture = stream_data.get("logo_url") self._attr_entity_picture = stream_data.get("logo_url")
self._attr_media_image_url = program_data.get("image") or stream_data.get(
"logo_url"
)
self._attr_media_content_type = MediaType.TVSHOW self._attr_media_content_type = MediaType.TVSHOW
self._attr_media_series_title = program_data.get("title") self._attr_media_series_title = program_data.get("title")
self._attr_media_title = program_data.get("subtitle") or program_data.get( self._attr_media_title = program_data.get("subtitle") or program_data.get(
"title" "title"
) )
if channel_name:
self._attr_name = channel_name
# Extra attributes # Extra attributes
self._attr_extra_state_attributes = { self._attr_extra_state_attributes = {
"channel_number": stream_data.get("xmltv_id"), "channel_number": stream_data.get("xmltv_id"),
"channel_name": stream_data.get("channel_name"), "channel_name": channel_name,
"program_description": program_data.get("description"), "program_description": program_data.get("description"),
"program_start_time": program_data.get("start_time"),
"program_end_time": program_data.get("end_time"),
"program_duration_seconds": program_data.get("duration_seconds"),
"program_duration_minutes": program_data.get("duration_minutes"),
"program_image": program_data.get("image"),
"program_categories": program_data.get("categories"),
"clients": stream_data.get("client_count"), "clients": stream_data.get("client_count"),
"catalog_current_viewers": stream_data.get("catalog_current_viewers"), "catalog_current_viewers": stream_data.get("catalog_current_viewers"),
"catalog_stream_id": stream_data.get("catalog_stream_id"), "catalog_stream_id": stream_data.get("catalog_stream_id"),