Improve EPG mapping and connection normalization
This commit is contained in:
@@ -4,6 +4,7 @@ import logging
|
||||
from datetime import timedelta, datetime, timezone
|
||||
import xml.etree.ElementTree as ET
|
||||
import re
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
from homeassistant.core import HomeAssistant
|
||||
@@ -49,7 +50,9 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator):
|
||||
self.websession = async_get_clientsession(hass)
|
||||
self._access_token: str | None = None
|
||||
self.channel_map: dict = {}
|
||||
self.channel_map_by_id: dict[str, dict] = {}
|
||||
self.stream_catalog: dict[str, dict] = {}
|
||||
self.stream_catalog_by_id: dict[int, dict] = {}
|
||||
self.stream_catalog_count: int = 0
|
||||
self.stream_catalog_active_count: int = 0
|
||||
self._last_catalog_refresh: datetime | None = None
|
||||
@@ -62,8 +65,22 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator):
|
||||
def base_url(self) -> str:
|
||||
"""Get the base URL for API calls."""
|
||||
data = self.config_entry.data
|
||||
protocol = "https" if data.get("ssl", False) else "http"
|
||||
return f"{protocol}://{data['host']}:{data['port']}"
|
||||
host_raw = str(data.get("host", "")).strip()
|
||||
ssl_enabled = bool(data.get("ssl", False))
|
||||
port = int(data.get("port", 443 if ssl_enabled else 80))
|
||||
|
||||
if "://" in host_raw:
|
||||
parsed = urlparse(host_raw)
|
||||
host = parsed.hostname or host_raw
|
||||
if parsed.port:
|
||||
port = parsed.port
|
||||
if parsed.scheme in ("http", "https"):
|
||||
ssl_enabled = parsed.scheme == "https"
|
||||
else:
|
||||
host = host_raw
|
||||
|
||||
protocol = "https" if ssl_enabled else "http"
|
||||
return f"{protocol}://{host}:{port}"
|
||||
|
||||
async def _get_new_token(self) -> None:
|
||||
"""Get a new access token using username and password."""
|
||||
@@ -127,14 +144,22 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator):
|
||||
|
||||
def _normalize_stream_name(self, name: str) -> str:
|
||||
"""Normalize stream name for fuzzy cross-endpoint matching."""
|
||||
return slugify(re.sub(r"^\w+:\s*|\s+HD$", "", name or "", flags=re.IGNORECASE))
|
||||
cleaned = name or ""
|
||||
cleaned = re.sub(r"^\s*[A-Z]{2,4}\s*[|:-]\s*", "", cleaned)
|
||||
cleaned = re.sub(r"\s*\[[^\]]+\]\s*$", "", cleaned)
|
||||
cleaned = re.sub(
|
||||
r"\s+(U?F?HD|SD|4K|HEVC|H265|H264)\b", "", cleaned, flags=re.IGNORECASE
|
||||
)
|
||||
cleaned = re.sub(r"\s+\([^\)]*\)\s*$", "", cleaned)
|
||||
cleaned = re.sub(r"^\w+:\s*", "", cleaned, flags=re.IGNORECASE)
|
||||
return slugify(cleaned)
|
||||
|
||||
async def _async_fetch_stream_catalog(self) -> None:
|
||||
"""Fetch stream list from Dispatcharr streams endpoint."""
|
||||
all_streams: list[dict] = []
|
||||
page = 1
|
||||
page_size = 500
|
||||
max_pages = 40
|
||||
max_pages = 1000
|
||||
|
||||
while page <= max_pages:
|
||||
payload = await self.api_request(
|
||||
@@ -152,12 +177,24 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator):
|
||||
break
|
||||
page += 1
|
||||
|
||||
if page > max_pages:
|
||||
_LOGGER.warning(
|
||||
"Stopped stream catalog fetch at max_pages=%d (fetched=%d)",
|
||||
max_pages,
|
||||
len(all_streams),
|
||||
)
|
||||
|
||||
catalog: dict[str, dict] = {}
|
||||
catalog_by_id: dict[int, dict] = {}
|
||||
active_count = 0
|
||||
for stream in all_streams:
|
||||
if not isinstance(stream, dict):
|
||||
continue
|
||||
|
||||
stream_id = stream.get("id")
|
||||
if isinstance(stream_id, int):
|
||||
catalog_by_id[stream_id] = stream
|
||||
|
||||
name = stream.get("name")
|
||||
if not name:
|
||||
continue
|
||||
@@ -182,6 +219,7 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator):
|
||||
catalog[normalized] = stream
|
||||
|
||||
self.stream_catalog = catalog
|
||||
self.stream_catalog_by_id = catalog_by_id
|
||||
self.stream_catalog_count = len(all_streams)
|
||||
self.stream_catalog_active_count = active_count
|
||||
self._last_catalog_refresh = datetime.now(timezone.utc)
|
||||
@@ -192,6 +230,12 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator):
|
||||
return None
|
||||
return self.stream_catalog.get(self._normalize_stream_name(stream_name))
|
||||
|
||||
def _get_catalog_stream_by_id(self, stream_id: int | None) -> dict | None:
|
||||
"""Lookup stream catalog row by numeric stream id."""
|
||||
if not isinstance(stream_id, int):
|
||||
return None
|
||||
return self.stream_catalog_by_id.get(stream_id)
|
||||
|
||||
async def async_populate_channel_map_from_xml(self):
|
||||
"""Fetch the XML file once to build a reliable map of channels."""
|
||||
_LOGGER.info("Populating Dispatcharr channel map from XML file...")
|
||||
@@ -207,6 +251,7 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator):
|
||||
try:
|
||||
root = ET.fromstring(xml_string)
|
||||
self.channel_map = {}
|
||||
self.channel_map_by_id = {}
|
||||
for channel in root.iterfind("channel"):
|
||||
display_name = channel.findtext("display-name")
|
||||
channel_id = channel.get("id")
|
||||
@@ -215,11 +260,13 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator):
|
||||
|
||||
if display_name and channel_id:
|
||||
slug_name = slugify(display_name)
|
||||
self.channel_map[slug_name] = {
|
||||
details = {
|
||||
"id": channel_id,
|
||||
"name": display_name,
|
||||
"logo_url": icon_url,
|
||||
}
|
||||
self.channel_map[slug_name] = details
|
||||
self.channel_map_by_id[channel_id] = details
|
||||
|
||||
if not self.channel_map:
|
||||
raise ConfigEntryNotReady(
|
||||
@@ -305,7 +352,9 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator):
|
||||
details = self._get_channel_details_from_stream_name(stream_name)
|
||||
enriched_stream = stream.copy()
|
||||
|
||||
catalog_stream = self._get_catalog_stream(stream_name)
|
||||
catalog_stream = self._get_catalog_stream_by_id(stream.get("stream_id"))
|
||||
if not catalog_stream:
|
||||
catalog_stream = self._get_catalog_stream(stream_name)
|
||||
if catalog_stream:
|
||||
enriched_stream["catalog_stream_id"] = catalog_stream.get("id")
|
||||
enriched_stream["catalog_tvg_id"] = catalog_stream.get("tvg_id")
|
||||
@@ -320,6 +369,10 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator):
|
||||
if not enriched_stream.get("logo_url"):
|
||||
enriched_stream["logo_url"] = catalog_stream.get("logo_url")
|
||||
|
||||
catalog_tvg_id = catalog_stream.get("tvg_id")
|
||||
if not details and catalog_tvg_id:
|
||||
details = self.channel_map_by_id.get(catalog_tvg_id)
|
||||
|
||||
if details:
|
||||
xmltv_id = details["id"]
|
||||
enriched_stream["xmltv_id"] = xmltv_id
|
||||
@@ -336,12 +389,30 @@ class DispatcharrDataUpdateCoordinator(DataUpdateCoordinator):
|
||||
episode_num_tag = program.find(
|
||||
"episode-num[@system='onscreen']"
|
||||
)
|
||||
icon_tag = program.find("icon")
|
||||
icon_url = (
|
||||
icon_tag.get("src")
|
||||
if icon_tag is not None
|
||||
else None
|
||||
)
|
||||
duration_seconds = int(
|
||||
(stop_time - start_time).total_seconds()
|
||||
)
|
||||
categories = [
|
||||
cat.text
|
||||
for cat in program.findall("category")
|
||||
if cat.text
|
||||
]
|
||||
enriched_stream["program"] = {
|
||||
"title": program.findtext("title"),
|
||||
"description": program.findtext("desc"),
|
||||
"start_time": start_time.isoformat(),
|
||||
"end_time": stop_time.isoformat(),
|
||||
"duration_seconds": duration_seconds,
|
||||
"duration_minutes": round(duration_seconds / 60),
|
||||
"subtitle": program.findtext("sub-title"),
|
||||
"image": icon_url,
|
||||
"categories": categories,
|
||||
"episode_num": episode_num_tag.text
|
||||
if episode_num_tag is not None
|
||||
else None,
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""Config flow for Dispatcharr Sensor integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
import voluptuous as vol
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
@@ -14,17 +16,34 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required("host"): str,
|
||||
vol.Required("port", default=9191): int,
|
||||
vol.Optional("ssl",default=False): bool,
|
||||
vol.Optional("ssl", default=False): bool,
|
||||
vol.Required("username"): str,
|
||||
vol.Required("password"): str,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for Dispatcharr Sensor."""
|
||||
|
||||
VERSION = 1
|
||||
|
||||
def _normalize_connection_input(self, user_input: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Normalize host/port/ssl from host field."""
|
||||
data = dict(user_input)
|
||||
host_raw = str(data.get("host", "")).strip()
|
||||
|
||||
if "://" in host_raw:
|
||||
parsed = urlparse(host_raw)
|
||||
if parsed.hostname:
|
||||
data["host"] = parsed.hostname
|
||||
if parsed.port:
|
||||
data["port"] = parsed.port
|
||||
if parsed.scheme in ("http", "https"):
|
||||
data["ssl"] = parsed.scheme == "https"
|
||||
|
||||
return data
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
@@ -34,4 +53,5 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
step_id="user", data_schema=STEP_USER_DATA_SCHEMA
|
||||
)
|
||||
|
||||
return self.async_create_entry(title="Dispatcharr", data=user_input)
|
||||
normalized = self._normalize_connection_input(user_input)
|
||||
return self.async_create_entry(title="Dispatcharr", data=normalized)
|
||||
|
||||
@@ -128,21 +128,33 @@ class DispatcharrStreamMediaPlayer(CoordinatorEntity, MediaPlayerEntity):
|
||||
stream_data = self.coordinator.data[self._stream_id]
|
||||
program_data = stream_data.get("program") or {}
|
||||
stream_stats = stream_data.get("stream_stats") or {}
|
||||
channel_name = stream_data.get("channel_name")
|
||||
|
||||
self._attr_state = STATE_PLAYING
|
||||
self._attr_app_name = "Dispatcharr"
|
||||
self._attr_entity_picture = stream_data.get("logo_url")
|
||||
self._attr_media_image_url = program_data.get("image") or stream_data.get(
|
||||
"logo_url"
|
||||
)
|
||||
self._attr_media_content_type = MediaType.TVSHOW
|
||||
self._attr_media_series_title = program_data.get("title")
|
||||
self._attr_media_title = program_data.get("subtitle") or program_data.get(
|
||||
"title"
|
||||
)
|
||||
if channel_name:
|
||||
self._attr_name = channel_name
|
||||
|
||||
# Extra attributes
|
||||
self._attr_extra_state_attributes = {
|
||||
"channel_number": stream_data.get("xmltv_id"),
|
||||
"channel_name": stream_data.get("channel_name"),
|
||||
"channel_name": channel_name,
|
||||
"program_description": program_data.get("description"),
|
||||
"program_start_time": program_data.get("start_time"),
|
||||
"program_end_time": program_data.get("end_time"),
|
||||
"program_duration_seconds": program_data.get("duration_seconds"),
|
||||
"program_duration_minutes": program_data.get("duration_minutes"),
|
||||
"program_image": program_data.get("image"),
|
||||
"program_categories": program_data.get("categories"),
|
||||
"clients": stream_data.get("client_count"),
|
||||
"catalog_current_viewers": stream_data.get("catalog_current_viewers"),
|
||||
"catalog_stream_id": stream_data.get("catalog_stream_id"),
|
||||
|
||||
Reference in New Issue
Block a user