Source code for mpcq.client

from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from datetime import date, datetime
from typing import Any, Iterable, Literal, Sequence, cast, no_type_check

import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
from adam_core.observations import ADESObservations
from adam_core.time import Timestamp
from astropy.time import Time, TimeDelta
from google.cloud import bigquery

from .observations import CrossMatchedMPCObservations, MPCObservations
from .orbits import MPCOrbits, MPCPrimaryObjects
from .submissions import (
    MPCSubmissionHistory,
    MPCSubmissionResults,
    infer_submission_time,
)

METERS_PER_ARCSECONDS = 30.87
MAX_CROSSMATCH_INPUT_ROWS_PER_QUERY = 500
ObservationColumnMode = Literal["minimal", "ades", "full"]
OrbitColumnMode = Literal["minimal", "full"]

# Small default payload intended for most query/list use cases.
OBSERVATION_COLUMNS_MINIMAL = [
    "obsid",
    "provid",
    "permid",
    "obstime",
    "ra",
    "dec",
    "stn",
    "mag",
    "band",
    "status",
]

# ADES-compatible payload including expanded ADES fields.
OBSERVATION_COLUMNS_ADES = [
    "obsid",
    "obssubid",
    "trksub",
    "trkid",
    "provid",
    "permid",
    "submission_id",
    "submission_block_id",
    "obs80",
    "status",
    "ref",
    "mode",
    "stn",
    "trx",
    "rcv",
    "sys",
    "ctr",
    "pos1",
    "pos2",
    "pos3",
    "poscov11",
    "poscov12",
    "poscov13",
    "poscov22",
    "poscov23",
    "poscov33",
    "prog",
    "obstime",
    "ra",
    "dec",
    "rastar",
    "decstar",
    "obscenter",
    "deltara",
    "deltadec",
    "dist",
    "pa",
    "rmsra",
    "rmsdec",
    "rmsdist",
    "rmspa",
    "rmscorr",
    "delay",
    "rmsdelay",
    "doppler",
    "rmsdoppler",
    "astcat",
    "mag",
    "rmsmag",
    "band",
    "photcat",
    "photap",
    "nucmag",
    "logsnr",
    "seeing",
    "exp",
    "rmsfit",
    "com",
    "frq",
    "disc",
    "subfrm",
    "subfmt",
    "prectime",
    "precra",
    "precdec",
    "unctime",
    "notes",
    "remarks",
    "deprecated",
    "localuse",
    "nstars",
    "prev_desig",
    "prev_ref",
    "rmstime",
    "trkmpc",
    "designation_asterisk",
    "obstime_text",
]

ORBIT_COLUMNS_MINIMAL = [
    "provid",
    "epoch",
    "q",
    "e",
    "i",
    "node",
    "argperi",
    "peri_time",
    "h",
    "g",
]


def _iso_utc(col: pa.ChunkedArray) -> list[str]:
    """Convert a timestamp column to ISO-8601 UTC strings.

    BigQuery returns TIMESTAMP as Arrow timestamp[us, tz=UTC]. Casting to string yields
    'YYYY-MM-DD HH:MM:SS.ffffffZ'. Replace the space with 'T' for ISO-8601.
    """
    arr = pc.replace_substring(col.cast(pa.string()), " ", "T").combine_chunks()
    return cast(list[str], arr.to_pylist())


@no_type_check
def _mask_nonfinite(values: Any) -> Any:
    return np.ma.masked_array(values, mask=np.isnan(values))


def _escape_sql_string(value: str) -> str:
    return value.replace("'", "''")


def _normalize_string_value(value: Any) -> str:
    return str(value).strip()


def _sql_string_list(values: Sequence[Any]) -> str:
    return ", ".join(
        [f"'{_escape_sql_string(_normalize_string_value(value))}'" for value in values]
    )


@dataclass(frozen=True)
class Where:
    column: str
    op: Literal[
        "=",
        "!=",
        "<",
        "<=",
        ">",
        ">=",
        "between",
        "in",
        "is null",
        "is not null",
        "startswith",
        "endswith",
        "contains",
        "istartswith",
        "iendswith",
        "icontains",
    ]
    value: Any | Sequence[Any] | tuple[Any, Any] | None = None


def _normalize_columns(
    columns: list[str] | str | None,
    mode_columns: list[str],
    all_columns: Iterable[str],
    required: Iterable[str],
) -> list[str]:
    if columns is None:
        selected = list(mode_columns)
    elif columns == "*":
        selected = list(all_columns)
    else:
        selected = list(columns)

    # Always include required metadata columns
    for col in required:
        if col not in selected:
            selected.append(col)
    return selected


def _build_where_clause(
    filters: list[Where] | None,
    valid_columns: set[str],
    param_prefix: str,
) -> tuple[
    str,
    list[bigquery.ScalarQueryParameter | bigquery.ArrayQueryParameter],
]:
    if not filters:
        return "", []

    params: list[bigquery.ScalarQueryParameter | bigquery.ArrayQueryParameter] = []
    clauses: list[str] = []
    param_index = 0

    for f in filters:
        col = f.column
        if col not in valid_columns:
            raise ValueError(f"Invalid column in where: {col}")

        op = f.op.lower()
        pname = f"{param_prefix}{param_index}"

        if op in {"=", "!=", "<", "<=", ">", ">="}:
            if f.value is None:
                raise ValueError(f"Operator {f.op} requires a non-null value for {col}")
            ptype, pvalue = _to_bq_param(f.value)
            params.append(bigquery.ScalarQueryParameter(pname, ptype, pvalue))
            clauses.append(f"{col} {f.op} @{pname}")
            param_index += 1
        elif op == "between":
            if not isinstance(f.value, tuple) or len(f.value) != 2:
                raise ValueError("between requires a (low, high) tuple")
            pname1 = f"{param_prefix}{param_index}"
            pname2 = f"{param_prefix}{param_index + 1}"
            ptype1, pvalue1 = _to_bq_param(f.value[0])
            ptype2, pvalue2 = _to_bq_param(f.value[1])
            params.append(bigquery.ScalarQueryParameter(pname1, ptype1, pvalue1))
            params.append(bigquery.ScalarQueryParameter(pname2, ptype2, pvalue2))
            clauses.append(f"{col} BETWEEN @{pname1} AND @{pname2}")
            param_index += 2
        elif op == "in":
            if not isinstance(f.value, (list, tuple)):
                raise ValueError("in requires a list/tuple value")
            if len(f.value) == 0:
                raise ValueError("in requires a non-empty list/tuple value")
            pname = f"{param_prefix}{param_index}"
            array_type, _ = _to_bq_param(f.value[0])
            array_values = [_to_bq_param(v)[1] for v in f.value]
            params.append(bigquery.ArrayQueryParameter(pname, array_type, array_values))
            clauses.append(f"{col} IN UNNEST(@{pname})")
            param_index += 1
        elif op in {"is null", "is not null"}:
            clauses.append(f"{col} {op.upper()}")
        elif op in {"startswith", "endswith", "contains"}:
            if f.value is None:
                raise ValueError(f"Operator {f.op} requires a non-null value for {col}")
            params.append(bigquery.ScalarQueryParameter(pname, "STRING", str(f.value)))
            func = {
                "startswith": "STARTS_WITH",
                "endswith": "ENDS_WITH",
                "contains": "STRPOS",
            }[op]
            if func == "STRPOS":
                clauses.append(f"STRPOS(CAST({col} AS STRING), CAST(@{pname} AS STRING)) > 0")
            else:
                clauses.append(f"{func}(CAST({col} AS STRING), CAST(@{pname} AS STRING))")
            param_index += 1
        elif op in {"istartswith", "iendswith", "icontains"}:
            if f.value is None:
                raise ValueError(f"Operator {f.op} requires a non-null value for {col}")
            params.append(bigquery.ScalarQueryParameter(pname, "STRING", str(f.value).lower()))
            lowered = f"LOWER(CAST({col} AS STRING))"
            if op == "icontains":
                clauses.append(f"STRPOS({lowered}, CAST(@{pname} AS STRING)) > 0")
            elif op == "istartswith":
                clauses.append(f"STARTS_WITH({lowered}, CAST(@{pname} AS STRING))")
            else:
                clauses.append(f"ENDS_WITH({lowered}, CAST(@{pname} AS STRING))")
            param_index += 1
        else:
            raise ValueError(f"Unsupported operator: {f.op}")

    return ("WHERE " + " AND ".join(clauses), params)


def _to_bq_param(value: Any) -> tuple[str, Any]:
    if isinstance(value, Time):
        return "TIMESTAMP", value.to_datetime()
    if isinstance(value, datetime):
        return "TIMESTAMP", value
    if isinstance(value, date):
        return "DATE", value
    if isinstance(value, bool):
        return "BOOL", value
    if isinstance(value, int):
        return "INT64", value
    if isinstance(value, float):
        return "FLOAT64", value
    return "STRING", str(value)


class MPCClient(ABC):
    @abstractmethod
    def query_observations(
        self,
        provids: list[str] | None = None,
        columns: list[str] | str | None = None,
        column_mode: ObservationColumnMode = "minimal",
        where: list[Where] | None = None,
        limit: int | None = None,
        dedupe: bool = True,
    ) -> MPCObservations:
        """
        Query the MPC database for the observations and associated data for the given
        provisional designations.

        Parameters
        ----------
        provids : List[str] | None
            List of provisional designations to query. Optional.
        columns : list[str] | str | None
            Explicit subset of columns, "*" for full schema, or None to use column_mode.
        column_mode : Literal["minimal", "ades", "full"]
            Default column bundle when columns is None.
        where : list[Where] | None
            Additional filters using allowed operators.
        limit : int | None
            Limit the number of rows returned. Required if both provids and where are None.
        dedupe : bool
            If True, use SELECT DISTINCT to deduplicate expanded-identifier joins.

        Returns
        -------
        observations : MPCObservations
            The observations and associated data for the given provisional designations.
        """
        pass

    @abstractmethod
    def query_orbits(
        self,
        provids: list[str] | None = None,
        columns: list[str] | str | None = None,
        column_mode: OrbitColumnMode = "minimal",
        where: list[Where] | None = None,
        limit: int | None = None,
        dedupe: bool = True,
    ) -> MPCOrbits:
        """
        Query the MPC database for the orbits and associated data for the given
        provisional designations.

        Parameters
        ----------
        provids : List[str] | None
            List of provisional designations to query. Optional.
        columns : list[str] | str | None
            Explicit subset of columns, "*" for full schema, or None to use column_mode.
        column_mode : Literal["minimal", "full"]
            Default column bundle when columns is None.
        where : list[Where] | None
            Additional filters using allowed operators.
        limit : int | None
            Limit the number of rows returned. Required if both provids and where are None.
        dedupe : bool
            If True, use SELECT DISTINCT to deduplicate expanded-identifier joins.

        Returns
        -------
        orbits : MPCOrbits
            The orbits and associated data for the given provisional designations.
        """
        pass

    @abstractmethod
    def query_submission_info(self, submission_ids: list[str]) -> MPCSubmissionResults:
        """
        Query for observation status and mapping (observation ID to trksub, provid, etc.) for a
        given list of submission IDs.

        Parameters
        ----------
        submission_ids : List[str]
            List of submission IDs to query.

        Returns
        -------
        submission_info : MPCSubmissionResults
            The observation status and mapping for the given submission IDs.
        """
        pass

    @abstractmethod
    def query_submission_history(self, provids: list[str]) -> MPCSubmissionHistory:
        """
        Query for submission history for a given list of provisional designations.

        Parameters
        ----------
        provids : List[str]
            List of provisional designations to query.

        Returns
        -------
        submission_history : MPCSubmissionHistory
            The submission history for the given provisional designations.
        """
        pass

    @abstractmethod
    def query_primary_objects(self, provids: list[str]) -> MPCPrimaryObjects:
        """
        Query the MPC database for the primary objects and associated data for the given
        provisional designations.

        Parameters
        ----------
        provids : List[str]
            List of provisional designations to query.

        Returns
        -------
        primary_objects : MPCPrimaryObjects
            The primary objects and associated data for the given provisional designations.
        """
        pass

    @abstractmethod
    def cross_match_observations(
        self,
        ades_observations: ADESObservations,
        obstime_tolerance_seconds: int = 30,
        arcseconds_tolerance: float = 2.0,
    ) -> CrossMatchedMPCObservations:
        """
        Cross-match the given ADES observations with the MPC observations.

        Parameters
        ----------
        ades_observations : ADESObservations
            The ADES observations to cross-match.
        obstime_tolerance_seconds : int, optional
            Time tolerance in seconds for matching observations.
        arcseconds_tolerance : float, optional
            Angular separation tolerance in arcseconds.

        Returns
        -------
        cross_matched_mpc_observations : CrossMatchedMPCObservations
            The MPC observations that match the given ADES observations.
        """
        pass

    @abstractmethod
    def find_duplicates(
        self,
        provid: str,
        obstime_tolerance_seconds: int = 30,
        arcseconds_tolerance: float = 2.0,
    ) -> CrossMatchedMPCObservations:
        """
        Find duplicates in the MPC observations for a given object by comparing
        observations against each other using time and position tolerances.

        Parameters
        ----------
        provid : str
            The provisional designation to check for duplicates.
        obstime_tolerance_seconds : int, optional
            Time tolerance in seconds for matching observations.
        arcseconds_tolerance : float, optional
            Angular separation tolerance in arcseconds.

        Returns
        -------
        cross_matched_mpc_observations : CrossMatchedMPCObservations
            The MPC observations that are potential duplicates, with separation
            information included.
        """
        pass


[docs] class BigQueryMPCClient(MPCClient): def __init__( self, dataset_id: str, **kwargs: Any, ) -> None: self.client = bigquery.Client(**kwargs) self.dataset_id = dataset_id
[docs] def query_observations( self, provids: list[str] | None = None, columns: list[str] | str | None = None, column_mode: ObservationColumnMode = "minimal", where: list[Where] | None = None, limit: int | None = None, dedupe: bool = True, ) -> MPCObservations: """ Query the MPC database for the observations and associated data for the given provisional designations. Parameters ---------- provids : List[str] | None List of provisional designations to query. Optional. columns : list[str] | str | None Explicit subset of columns, "*" for full schema, or None to use column_mode. column_mode : Literal["minimal", "ades", "full"] Default column bundle when columns is None. where : list[Where] | None Additional filters using allowed operators. limit : int | None Limit the number of rows returned. Required if both provids and where are None. dedupe : bool If True, use SELECT DISTINCT to deduplicate expanded-identifier joins. Returns ------- observations : MPCObservations The observations and associated data for the given provisional designations. """ # Validation for no filters if provids is None and where is None and limit is None: raise ValueError("limit is required when neither provids nor where are provided") all_columns = list(MPCObservations.schema.names) all_column_set = set(all_columns) required_cols = ["requested_provid", "primary_designation"] mode_map: dict[ObservationColumnMode, list[str]] = { "minimal": [c for c in OBSERVATION_COLUMNS_MINIMAL if c in all_column_set], "ades": [c for c in OBSERVATION_COLUMNS_ADES if c in all_column_set], "full": list(all_columns), } if column_mode not in mode_map: raise ValueError(f"Unsupported observation column_mode: {column_mode}") selected_cols = _normalize_columns( columns, mode_map[column_mode], all_columns, required_cols ) # Build optional WHERE where_sql, params = _build_where_clause(where, all_column_set, "p_") # Build base query parts with_clause = "" from_clause = "" order_by = "ORDER BY obs_sbn.obstime ASC" if provids is not None: provids_str = _sql_string_list(provids) with_clause = f""" WITH requested_provids AS ( SELECT provid FROM UNNEST(ARRAY[{provids_str}]) AS provid ), requested_identifiers AS ( SELECT rp.provid AS requested_provid, CASE WHEN ni.permid IS NOT NULL THEN ni.permid ELSE ci.unpacked_primary_provisional_designation END AS primary_designation, ci.unpacked_primary_provisional_designation AS primary_provid, ci_alt.unpacked_secondary_provisional_designation AS secondary_provid, ni.permid AS numbered_permid FROM requested_provids AS rp LEFT JOIN `{self.dataset_id}.public_current_identifications` AS ci ON ci.unpacked_secondary_provisional_designation = rp.provid LEFT JOIN `{self.dataset_id}.public_current_identifications` AS ci_alt ON ci.unpacked_primary_provisional_designation = ci_alt.unpacked_primary_provisional_designation LEFT JOIN `{self.dataset_id}.public_numbered_identifications` AS ni ON ci.unpacked_primary_provisional_designation = ni.unpacked_primary_provisional_designation ), candidate_matches AS ( SELECT ri.requested_provid, ri.primary_designation, obs_sbn.* FROM requested_identifiers AS ri JOIN `{self.dataset_id}.public_obs_sbn` AS obs_sbn ON obs_sbn.provid = ri.primary_provid UNION ALL SELECT ri.requested_provid, ri.primary_designation, obs_sbn.* FROM requested_identifiers AS ri JOIN `{self.dataset_id}.public_obs_sbn` AS obs_sbn ON obs_sbn.provid = ri.secondary_provid UNION ALL SELECT ri.requested_provid, ri.primary_designation, obs_sbn.* FROM requested_identifiers AS ri JOIN `{self.dataset_id}.public_obs_sbn` AS obs_sbn ON obs_sbn.permid = ri.numbered_permid ) """ from_clause = "FROM candidate_matches AS obs_sbn" order_by = "ORDER BY requested_provid ASC, obs_sbn.obstime ASC" else: with_clause = f""" WITH candidate_matches AS ( SELECT obs_sbn.provid AS requested_provid, obs_sbn.provid AS primary_designation, obs_sbn.* FROM `{self.dataset_id}.public_obs_sbn` AS obs_sbn ) """ from_clause = "FROM candidate_matches AS obs_sbn" # Build SELECT list, prepend metadata select_list = [] if "requested_provid" in selected_cols: select_list.append("obs_sbn.requested_provid AS requested_provid") if "primary_designation" in selected_cols: select_list.append("obs_sbn.primary_designation AS primary_designation") for col in selected_cols: if col in {"requested_provid", "primary_designation"}: continue if col in {"all_pub_ref", "datastream_metadata"}: select_list.append(f"TO_JSON_STRING(obs_sbn.{col}) AS {col}") else: select_list.append(f"obs_sbn.{col}") select_sql = ",\n ".join(select_list) limit_sql = f"LIMIT {int(limit)}" if limit is not None else "" select_keyword = "SELECT DISTINCT" if dedupe else "SELECT" query = f""" {with_clause} {select_keyword} {select_sql} {from_clause} {where_sql} {order_by} {limit_sql}; """ job_config = bigquery.QueryJobConfig(query_parameters=params) results = self.client.query(query, job_config=job_config).result() table = results.to_arrow(progress_bar_type="tqdm", create_bqstorage_client=True) obstime_iso = _iso_utc(table["obstime"]) if "obstime" in table.column_names else None created_at_iso = ( _iso_utc(table["created_at"]) if "created_at" in table.column_names else None ) updated_at_iso = ( _iso_utc(table["updated_at"]) if "updated_at" in table.column_names else None ) # Ensure time-like columns cast when present; fill missing required schema columns kwargs: dict[str, Any] = {} for name in MPCObservations.schema.names: if name in table.column_names: if name in {"obstime", "created_at", "updated_at"}: if name == "obstime" and obstime_iso is not None: val = Timestamp.from_iso8601(obstime_iso, scale="utc") elif name == "created_at" and created_at_iso is not None: val = Timestamp.from_iso8601(created_at_iso, scale="utc") elif name == "updated_at" and updated_at_iso is not None: val = Timestamp.from_iso8601(updated_at_iso, scale="utc") else: continue else: val = table[name] kwargs[name] = val return MPCObservations.from_kwargs(**kwargs)
[docs] def all_orbits(self) -> MPCOrbits: """ Query the MPC database for all orbits and associated data. Returns ------- orbits : MPCOrbits The orbits and associated data for all objects in the MPC database. """ query = f""" SELECT mpc_orbits.id, mpc_orbits.unpacked_primary_provisional_designation AS provid, mpc_orbits.epoch_mjd, mpc_orbits.q, mpc_orbits.e, mpc_orbits.i, mpc_orbits.node, mpc_orbits.argperi, mpc_orbits.peri_time, mpc_orbits.q_unc, mpc_orbits.e_unc, mpc_orbits.i_unc, mpc_orbits.node_unc, mpc_orbits.argperi_unc, mpc_orbits.peri_time_unc, mpc_orbits.a1, mpc_orbits.a2, mpc_orbits.a3, mpc_orbits.h, mpc_orbits.g, mpc_orbits.created_at, mpc_orbits.updated_at FROM `{self.dataset_id}.public_mpc_orbits` AS mpc_orbits ORDER BY mpc_orbits.epoch_mjd ASC; """ query_job = self.client.query(query) results = query_job.result() table = results.to_arrow(progress_bar_type="tqdm", create_bqstorage_client=True) created_at_iso = ( _iso_utc(table["created_at"]) if "created_at" in table.column_names else None ) updated_at_iso = ( _iso_utc(table["updated_at"]) if "updated_at" in table.column_names else None ) epoch_ts = Timestamp.from_mjd(table["epoch_mjd"], scale="tt") return MPCOrbits.from_kwargs( # Note, since we didn't request a specific provid we use the one MPC provides requested_provid=table["provid"], id=table["id"], provid=table["provid"], epoch=epoch_ts, q=table["q"], e=table["e"], i=table["i"], node=table["node"], argperi=table["argperi"], peri_time=table["peri_time"], q_unc=table["q_unc"], e_unc=table["e_unc"], i_unc=table["i_unc"], node_unc=table["node_unc"], argperi_unc=table["argperi_unc"], peri_time_unc=table["peri_time_unc"], a1=table["a1"], a2=table["a2"], a3=table["a3"], h=table["h"], g=table["g"], created_at=Timestamp.from_iso8601(created_at_iso, scale="utc") if created_at_iso is not None else None, updated_at=Timestamp.from_iso8601(updated_at_iso, scale="utc") if updated_at_iso is not None else None, )
[docs] def query_orbits( self, provids: list[str] | None = None, columns: list[str] | str | None = None, column_mode: OrbitColumnMode = "minimal", where: list[Where] | None = None, limit: int | None = None, dedupe: bool = True, ) -> MPCOrbits: """ Query the MPC database for the orbits and associated data for the given provisional designations. Parameters ---------- provids : List[str] | None List of provisional designations to query. Optional. columns : list[str] | str | None Explicit subset of columns, "*" for full schema, or None to use column_mode. column_mode : Literal["minimal", "full"] Default column bundle when columns is None. where : list[Where] | None Additional filters using allowed operators. limit : int | None Limit the number of rows returned. Required if both provids and where are None. dedupe : bool If True, use SELECT DISTINCT to deduplicate expanded-identifier joins. Returns ------- orbits : MPCOrbits The orbits and associated data for the given provisional designations. """ if provids is None and where is None and limit is None: raise ValueError("limit is required when neither provids nor where are provided") all_columns = list(MPCOrbits.schema.names) all_column_set = set(all_columns) required_cols = ["requested_provid", "primary_designation", "provid", "epoch"] mode_map: dict[OrbitColumnMode, list[str]] = { "minimal": [c for c in ORBIT_COLUMNS_MINIMAL if c in all_column_set], "full": list(all_columns), } if column_mode not in mode_map: raise ValueError(f"Unsupported orbit column_mode: {column_mode}") selected_cols = _normalize_columns( columns, mode_map[column_mode], all_columns, required_cols ) select_list = [] if "requested_provid" in selected_cols: select_list.append("rp.provid AS requested_provid") if "primary_designation" in selected_cols: select_list.append( "CASE WHEN ni.permid IS NOT NULL THEN ni.permid ELSE ci.unpacked_primary_provisional_designation END AS primary_designation" ) if "provid" in selected_cols: select_list.append("mpc_orbits.unpacked_primary_provisional_designation AS provid") if "epoch" in selected_cols: select_list.append("mpc_orbits.epoch_mjd") # Remaining columns for col in selected_cols: if col in {"requested_provid", "primary_designation", "provid", "epoch"}: continue if col in {"mpc_orb_jsonb", "datastream_metadata"}: select_list.append(f"TO_JSON_STRING(mpc_orbits.{col}) AS {col}") else: select_list.append(f"mpc_orbits.{col}") select_sql = ",\n ".join(select_list) where_sql, params = _build_where_clause(where, all_column_set, "o_") with_requested = "" join_condition = "" order_by = "ORDER BY mpc_orbits.epoch_mjd ASC" if provids is not None: provids_str = _sql_string_list(provids) with_requested = f""" WITH requested_provids AS ( SELECT provid FROM UNNEST(ARRAY[{provids_str}]) AS provid ) """ join_condition = f""" FROM requested_provids AS rp LEFT JOIN `{self.dataset_id}.public_current_identifications` AS ci ON ci.unpacked_secondary_provisional_designation = rp.provid LEFT JOIN `{self.dataset_id}.public_current_identifications` AS ci_alt ON ci.unpacked_primary_provisional_designation = ci_alt.unpacked_primary_provisional_designation LEFT JOIN `{self.dataset_id}.public_numbered_identifications` AS ni ON ci.unpacked_primary_provisional_designation = ni.unpacked_primary_provisional_designation LEFT JOIN `{self.dataset_id}.public_mpc_orbits` AS mpc_orbits ON ci.unpacked_primary_provisional_designation = mpc_orbits.unpacked_primary_provisional_designation """ order_by = "ORDER BY requested_provid ASC, mpc_orbits.epoch_mjd ASC" else: join_condition = f""" FROM `{self.dataset_id}.public_mpc_orbits` AS mpc_orbits """ if "requested_provid" in selected_cols: select_list[0] = ( "mpc_orbits.unpacked_primary_provisional_designation AS requested_provid" ) if "primary_designation" in selected_cols: idx = 1 if "requested_provid" in selected_cols else 0 select_list[idx] = ( "mpc_orbits.unpacked_primary_provisional_designation AS primary_designation" ) select_sql = ",\n ".join(select_list) limit_sql = f"LIMIT {int(limit)}" if limit is not None else "" select_keyword = "SELECT DISTINCT" if dedupe else "SELECT" query = f""" {with_requested} {select_keyword} {select_sql} {join_condition} {where_sql} {order_by} {limit_sql}; """ job_config = bigquery.QueryJobConfig(query_parameters=params) results = self.client.query(query, job_config=job_config).result() table = results.to_arrow(progress_bar_type="tqdm", create_bqstorage_client=True) created_at_iso = ( _iso_utc(table["created_at"]) if "created_at" in table.column_names else None ) updated_at_iso = ( _iso_utc(table["updated_at"]) if "updated_at" in table.column_names else None ) fitting_datetime_iso = ( _iso_utc(table["fitting_datetime"]) if "fitting_datetime" in table.column_names else None ) # Handle NULL values in the epoch_mjd column: ideally # we should have the Timestamp class be able to handle this mjd_array = table["epoch_mjd"].to_numpy(zero_copy_only=False) mjds = _mask_nonfinite(mjd_array) epoch = Time(mjds, format="mjd", scale="tt") # Build kwargs dynamically kwargs: dict[str, Any] = {} for name in MPCOrbits.schema.names: if name in table.column_names: if name == "epoch": kwargs[name] = Timestamp.from_iso8601(epoch.isot, scale="tt") elif name in {"created_at", "updated_at", "fitting_datetime"}: if name == "created_at": if created_at_iso is not None: kwargs[name] = Timestamp.from_iso8601(created_at_iso, scale="utc") else: continue elif name == "updated_at": if updated_at_iso is not None: kwargs[name] = Timestamp.from_iso8601(updated_at_iso, scale="utc") else: continue else: if fitting_datetime_iso is not None: kwargs[name] = Timestamp.from_iso8601(fitting_datetime_iso, scale="utc") else: continue else: kwargs[name] = table[name] # Always ensure provid and epoch are present if requested if ( "provid" in MPCOrbits.schema.names and "provid" not in kwargs and "provid" in table.column_names ): kwargs["provid"] = table["provid"] if "epoch" in MPCOrbits.schema.names and "epoch" not in kwargs: kwargs["epoch"] = Timestamp.from_iso8601(epoch.isot, scale="tt") return MPCOrbits.from_kwargs(**kwargs)
[docs] def query_submission_info(self, submission_ids: list[str]) -> MPCSubmissionResults: """ Query for observation status and mapping (observation ID to trksub, provid, etc.) for a given list of submission IDs. Parameters ---------- submission_ids : list[str] List of submission IDs to query. Returns ------- submission_info : MPCSubmissionResults The observation status and mapping for the given submission IDs. """ submission_ids_str = _sql_string_list(submission_ids) query = f""" WITH requested_submission_ids AS ( SELECT submission_id FROM UNNEST(ARRAY[{submission_ids_str}]) AS submission_id ) SELECT DISTINCT sb.submission_id AS requested_submission_id, obs_sbn.obsid, obs_sbn.obssubid, obs_sbn.trksub, CASE WHEN ni.permid IS NOT NULL THEN ni.permid ELSE ci.unpacked_primary_provisional_designation END AS primary_designation, obs_sbn.permid, obs_sbn.provid, obs_sbn.submission_id, obs_sbn.status FROM requested_submission_ids AS sb LEFT JOIN `{self.dataset_id}.public_obs_sbn` AS obs_sbn ON sb.submission_id = obs_sbn.submission_id LEFT JOIN `{self.dataset_id}.public_current_identifications` AS ci ON ci.unpacked_secondary_provisional_designation = obs_sbn.provid OR ci.unpacked_primary_provisional_designation = obs_sbn.provid LEFT JOIN `{self.dataset_id}.public_numbered_identifications` AS ni ON obs_sbn.permid = ni.permid ORDER BY requested_submission_id ASC, obs_sbn.obsid ASC; """ query_job = self.client.query(query) results = query_job.result() table = results.to_arrow(progress_bar_type="tqdm", create_bqstorage_client=True) return MPCSubmissionResults.from_pyarrow(table)
[docs] def query_submission_history(self, provids: list[str]) -> MPCSubmissionHistory: """ Query for submission history for a given list of provisional designations. Parameters ---------- provids : list[str] List of provisional designations to query. Returns ------- submission_history : MPCSubmissionHistory The submission history for the given provisional designations. """ provids_str = _sql_string_list(provids) query = f""" WITH requested_provids AS ( SELECT provid FROM UNNEST(ARRAY[{provids_str}]) AS provid ), requested_identifiers AS ( SELECT rp.provid AS requested_provid, CASE WHEN ni.permid IS NOT NULL THEN ni.permid ELSE ci.unpacked_primary_provisional_designation END AS primary_designation, ci.unpacked_primary_provisional_designation AS primary_provid, ci_alt.unpacked_secondary_provisional_designation AS secondary_provid, ni.permid AS numbered_permid FROM requested_provids AS rp LEFT JOIN `{self.dataset_id}.public_current_identifications` AS ci ON ci.unpacked_secondary_provisional_designation = rp.provid LEFT JOIN `{self.dataset_id}.public_current_identifications` AS ci_alt ON ci.unpacked_primary_provisional_designation = ci_alt.unpacked_primary_provisional_designation LEFT JOIN `{self.dataset_id}.public_numbered_identifications` AS ni ON ci.unpacked_primary_provisional_designation = ni.unpacked_primary_provisional_designation ), candidate_obs AS ( SELECT ri.requested_provid, ri.primary_designation, obs_sbn.obsid, obs_sbn.obstime, obs_sbn.submission_id FROM requested_identifiers AS ri JOIN `{self.dataset_id}.public_obs_sbn` AS obs_sbn ON obs_sbn.provid = ri.primary_provid UNION ALL SELECT ri.requested_provid, ri.primary_designation, obs_sbn.obsid, obs_sbn.obstime, obs_sbn.submission_id FROM requested_identifiers AS ri JOIN `{self.dataset_id}.public_obs_sbn` AS obs_sbn ON obs_sbn.provid = ri.secondary_provid UNION ALL SELECT ri.requested_provid, ri.primary_designation, obs_sbn.obsid, obs_sbn.obstime, obs_sbn.submission_id FROM requested_identifiers AS ri JOIN `{self.dataset_id}.public_obs_sbn` AS obs_sbn ON obs_sbn.permid = ri.numbered_permid ) SELECT DISTINCT requested_provid, primary_designation, obsid, obstime, submission_id FROM candidate_obs ORDER BY requested_provid ASC, obstime ASC; """ query_job = self.client.query(query) results = query_job.result() # Convert the results to a PyArrow table table = results.to_arrow(progress_bar_type="tqdm", create_bqstorage_client=True) table = ( table.group_by(["requested_provid", "primary_designation", "submission_id"]) .aggregate([("obsid", "count_distinct"), ("obstime", "min"), ("obstime", "max")]) .sort_by([("primary_designation", "ascending"), ("submission_id", "ascending")]) .rename_columns( [ "requested_provid", "primary_designation", "submission_id", "num_obs", "first_obs_time", "last_obs_time", ] ) ) # Create array that tracks the index of each row table = table.append_column("idx", pa.array(np.arange(len(table)))) # Find the first and last index of each group (first and last submission) # and append boolean columns to the table first_last_idx = table.group_by(["primary_designation"], use_threads=False).aggregate( [("idx", "first"), ("idx", "last")] ) first = np.zeros(len(table), dtype=bool) last = np.zeros(len(table), dtype=bool) first[first_last_idx["idx_first"].to_numpy(zero_copy_only=False)] = True last[first_last_idx["idx_last"].to_numpy(zero_copy_only=False)] = True table = table.append_column("first_submission", pa.array(first)) table = table.append_column("last_submission", pa.array(last)) # Calculate the arc length of each submission start_times = Time(table["first_obs_time"].to_numpy(zero_copy_only=False), scale="utc") end_times = Time(table["last_obs_time"].to_numpy(zero_copy_only=False), scale="utc") arc_length = end_times.utc.mjd - start_times.utc.mjd return MPCSubmissionHistory.from_kwargs( requested_provid=table["requested_provid"], primary_designation=table["primary_designation"], submission_id=table["submission_id"], submission_time=infer_submission_time( table["submission_id"].to_numpy(zero_copy_only=False), end_times.utc.isot, ), first_submission=table["first_submission"], last_submission=table["last_submission"], num_obs=table["num_obs"], first_obs_time=Timestamp.from_iso8601(start_times.utc.isot, scale="utc"), last_obs_time=Timestamp.from_iso8601(end_times.utc.isot, scale="utc"), arc_length=arc_length, )
[docs] def query_primary_objects(self, provids: list[str]) -> MPCPrimaryObjects: """ Query the MPC database for the primary objects and associated data for the given provisional designations. Parameters ---------- provids : list[str] List of provisional designations to query. Returns ------- primary_objects : MPCPrimaryObjects The primary objects and associated data for the given provisional designations. """ provids_str = _sql_string_list(provids) query = f"""WITH requested_provids AS ( SELECT provid FROM UNNEST(ARRAY[{provids_str}]) AS provid ) SELECT DISTINCT rp.provid AS requested_provid, CASE WHEN ni.permid IS NOT NULL THEN ni.permid ELSE ci.unpacked_primary_provisional_designation END AS primary_designation, po.unpacked_primary_provisional_designation as provid, po.created_at, po.updated_at FROM requested_provids AS rp LEFT JOIN `{self.dataset_id}.public_current_identifications` AS ci ON ci.unpacked_secondary_provisional_designation = rp.provid LEFT JOIN `{self.dataset_id}.public_current_identifications` AS ci_alt ON ci.unpacked_primary_provisional_designation = ci_alt.unpacked_primary_provisional_designation LEFT JOIN `{self.dataset_id}.public_numbered_identifications` AS ni ON ci.unpacked_primary_provisional_designation = ni.unpacked_primary_provisional_designation LEFT JOIN `{self.dataset_id}.public_primary_objects` AS po ON ci.unpacked_primary_provisional_designation = po.unpacked_primary_provisional_designation ORDER BY requested_provid ASC; """ query_job = self.client.query(query) results = query_job.result() table = results.to_arrow(progress_bar_type="tqdm", create_bqstorage_client=True) created_at = Time( table["created_at"].to_numpy(zero_copy_only=False), format="datetime64", scale="utc", ) updated_at = Time( table["updated_at"].to_numpy(zero_copy_only=False), format="datetime64", scale="utc", ) return MPCPrimaryObjects.from_kwargs( requested_provid=table["requested_provid"], primary_designation=table["primary_designation"], provid=table["provid"], created_at=Timestamp.from_iso8601(created_at.utc.isot, scale="utc"), updated_at=Timestamp.from_iso8601(updated_at.utc.isot, scale="utc"), )
[docs] def cross_match_observations( self, ades_observations: ADESObservations, obstime_tolerance_seconds: int = 30, arcseconds_tolerance: float = 2.0, ) -> CrossMatchedMPCObservations: """ Cross-match the given ADES observations with the MPC observations. Parameters ---------- ades_observations : ADESObservations The ADES observations to cross-match. obstime_tolerance_seconds : float, optional Time tolerance in seconds for matching observations. arcseconds_tolerance : float, optional Angular separation tolerance in arcseconds. Returns ------- cross_matched_mpc_observations : CrossMatchedMPCObservations The MPC observations that match the given ADES observations. """ # We use the ADESObservation.obssubid as the unique identifier # to track the cross-match requests. assert pc.all(pc.invert(pc.is_null(ades_observations.obsSubID))).as_py() # Convert arcseconds to meters at Earth's surface (approximate) meters_tolerance = arcseconds_tolerance * METERS_PER_ARCSECONDS coarse_dec_tolerance_deg = arcseconds_tolerance / 3600.0 input_rows = [] for obsSubID, obsTime, ra, dec, stn in zip( ades_observations.obsSubID.to_numpy(zero_copy_only=False), ades_observations.obsTime.to_astropy(), ades_observations.ra.to_numpy(zero_copy_only=False), ades_observations.dec.to_numpy(zero_copy_only=False), ades_observations.stn.to_numpy(zero_copy_only=False), ): obstime_iso = obsTime.utc.isot input_rows.append( { "id": _normalize_string_value(obsSubID), "stn": _normalize_string_value(stn), "ra": float(ra), "dec": float(dec), "obstime_iso": obstime_iso, "month_bucket": obstime_iso[:7], } ) if len(input_rows) == 0: return CrossMatchedMPCObservations.empty() # Keep bounds tight to preserve partition pruning and prevent a single # wide-spanning request from scanning large historical ranges. bucketed_rows = defaultdict(list) for row in input_rows: bucketed_rows[row["month_bucket"]].append(row) result_tables = [] for month_key in sorted(bucketed_rows.keys()): month_rows = bucketed_rows[month_key] for start in range(0, len(month_rows), MAX_CROSSMATCH_INPUT_ROWS_PER_QUERY): batch_rows = month_rows[start : start + MAX_CROSSMATCH_INPUT_ROWS_PER_QUERY] min_obstime = min(row["obstime_iso"] for row in batch_rows) max_obstime = max(row["obstime_iso"] for row in batch_rows) min_bound = ( Time(min_obstime, format="isot", scale="utc") - TimeDelta(obstime_tolerance_seconds, format="sec") ).utc.isot max_bound = ( Time(max_obstime, format="isot", scale="utc") + TimeDelta(obstime_tolerance_seconds, format="sec") ).utc.isot station_literals = ", ".join( [ f"'{_escape_sql_string(stn)}'" for stn in sorted({row["stn"] for row in batch_rows}) ] ) struct_entries = [ ( "STRUCT(" f"'{_escape_sql_string(row['id'])}' AS id, " f"'{_escape_sql_string(row['stn'])}' AS stn, " f"{row['ra']} AS ra, " f"{row['dec']} AS dec, " f"TIMESTAMP('{row['obstime_iso']}') AS obstime" ")" ) for row in batch_rows ] struct_str = ",\n ".join(struct_entries) matching_query = f""" WITH input_observations AS ( SELECT id, stn, ra, dec, obstime, ST_GEOGPOINT(ra, dec) AS input_geo FROM UNNEST([ {struct_str} ]) ), candidate_observations AS ( SELECT obs.obsid, obs.trksub, obs.provid, obs.permid, obs.submission_id, obs.obssubid, obs.obstime, obs.ra, obs.dec, obs.rmsra, obs.rmsdec, obs.mag, obs.rmsmag, obs.band, obs.stn, obs.updated_at, obs.created_at, obs.status, SAFE_CAST(obs.ra AS FLOAT64) AS ra_f64, SAFE_CAST(obs.dec AS FLOAT64) AS dec_f64, ST_GEOGPOINT(SAFE_CAST(obs.ra AS FLOAT64), SAFE_CAST(obs.dec AS FLOAT64)) AS obs_geo FROM `{self.dataset_id}.public_obs_sbn` AS obs WHERE obs.stn IN ({station_literals}) AND obs.obstime BETWEEN TIMESTAMP('{min_bound}') AND TIMESTAMP('{max_bound}') ) SELECT input.id AS input_id, ST_DISTANCE(obs.obs_geo, input.input_geo) AS separation_meters, TIMESTAMP_DIFF(obs.obstime, input.obstime, SECOND) AS separation_seconds, obs.obsid, obs.trksub, obs.provid, obs.permid, obs.submission_id, obs.obssubid, obs.obstime, obs.ra, obs.dec, obs.rmsra, obs.rmsdec, obs.mag, obs.rmsmag, obs.band, obs.stn, obs.updated_at, obs.created_at, obs.status FROM input_observations AS input JOIN candidate_observations AS obs ON obs.stn = input.stn AND obs.obs_geo IS NOT NULL AND obs.obstime BETWEEN TIMESTAMP_SUB(input.obstime, INTERVAL {obstime_tolerance_seconds} SECOND) AND TIMESTAMP_ADD(input.obstime, INTERVAL {obstime_tolerance_seconds} SECOND) AND obs.dec_f64 BETWEEN input.dec - {coarse_dec_tolerance_deg} AND input.dec + {coarse_dec_tolerance_deg} AND obs.ra_f64 BETWEEN input.ra - ({coarse_dec_tolerance_deg} / GREATEST(0.1, COS(input.dec * ACOS(-1) / 180.0))) AND input.ra + ({coarse_dec_tolerance_deg} / GREATEST(0.1, COS(input.dec * ACOS(-1) / 180.0))) WHERE ST_DISTANCE(obs.obs_geo, input.input_geo) <= {meters_tolerance} ORDER BY input_id, separation_meters, separation_seconds """ result_table = ( self.client.query(matching_query) .result() .to_arrow(progress_bar_type="tqdm", create_bqstorage_client=True) ) if len(result_table) > 0: result_tables.append(result_table) if len(result_tables) == 0: return CrossMatchedMPCObservations.empty() table = pa.concat_tables(result_tables) if len(result_tables) > 1 else result_tables[0] table = table.combine_chunks() obstime_iso = _iso_utc(table["obstime"]) created_at_iso = _iso_utc(table["created_at"]) updated_at_iso = _iso_utc(table["updated_at"]) separation_arcseconds = ( table["separation_meters"].to_numpy(zero_copy_only=False) / METERS_PER_ARCSECONDS ) return CrossMatchedMPCObservations.from_kwargs( request_id=table["input_id"], separation_arcseconds=separation_arcseconds, separation_seconds=table["separation_seconds"], mpc_observations=MPCObservations.from_kwargs( obsid=table["obsid"], trksub=table["trksub"], provid=table["provid"], permid=table["permid"], submission_id=table["submission_id"], obssubid=table["obssubid"], obstime=Timestamp.from_iso8601(obstime_iso, scale="utc"), ra=table["ra"], dec=table["dec"], rmsra=table["rmsra"], rmsdec=table["rmsdec"], mag=table["mag"], rmsmag=table["rmsmag"], band=table["band"], stn=table["stn"], updated_at=Timestamp.from_iso8601(updated_at_iso, scale="utc"), created_at=Timestamp.from_iso8601(created_at_iso, scale="utc"), status=table["status"], ), )
[docs] def find_duplicates( self, provid: str, obstime_tolerance_seconds: int = 30, arcseconds_tolerance: float = 2.0, ) -> CrossMatchedMPCObservations: meters_tolerance = arcseconds_tolerance * METERS_PER_ARCSECONDS coarse_dec_tolerance_deg = arcseconds_tolerance / 3600.0 provid = _escape_sql_string(_normalize_string_value(provid)) query = f""" WITH obs AS ( SELECT obsid, stn, ra, dec, obstime, created_at, updated_at, trksub, provid, permid, submission_id, obssubid, rmsra, rmsdec, mag, rmsmag, band, status, SAFE_CAST(ra AS FLOAT64) AS ra_f64, SAFE_CAST(dec AS FLOAT64) AS dec_f64, ST_GEOGPOINT(SAFE_CAST(ra AS FLOAT64), SAFE_CAST(dec AS FLOAT64)) AS geo FROM `{self.dataset_id}.public_obs_sbn` WHERE provid = '{provid}' ) SELECT a.obsid AS input_id, b.obsid, ST_DISTANCE(a.geo, b.geo) AS separation_meters, TIMESTAMP_DIFF(b.obstime, a.obstime, SECOND) AS separation_seconds, b.trksub, b.provid, b.permid, b.submission_id, b.obssubid, b.obstime, b.ra, b.dec, b.rmsra, b.rmsdec, b.mag, b.rmsmag, b.band, b.stn, b.created_at, b.updated_at, b.status FROM obs a JOIN obs b ON b.stn = a.stn -- Same station AND a.obsid < b.obsid -- Avoid self-matches and duplicates AND a.geo IS NOT NULL AND b.geo IS NOT NULL AND b.obstime BETWEEN TIMESTAMP_SUB(a.obstime, INTERVAL {obstime_tolerance_seconds} SECOND) AND TIMESTAMP_ADD(a.obstime, INTERVAL {obstime_tolerance_seconds} SECOND) AND b.dec_f64 BETWEEN a.dec_f64 - {coarse_dec_tolerance_deg} AND a.dec_f64 + {coarse_dec_tolerance_deg} AND b.ra_f64 BETWEEN a.ra_f64 - ({coarse_dec_tolerance_deg} / GREATEST(0.1, COS(a.dec_f64 * ACOS(-1) / 180.0))) AND a.ra_f64 + ({coarse_dec_tolerance_deg} / GREATEST(0.1, COS(a.dec_f64 * ACOS(-1) / 180.0))) AND ST_DISTANCE(a.geo, b.geo) <= {meters_tolerance} ORDER BY a.obsid, separation_meters """ # Execute query and get results results = ( self.client.query(query) .result() .to_arrow(progress_bar_type="tqdm", create_bqstorage_client=True) ) if len(results) == 0: return CrossMatchedMPCObservations.empty() # Convert timestamps to ISO strings (no astropy) obstime_iso = pc.binary_join_element_wise( pc.replace_substring(results["obstime"].cast(pa.string()), " ", "T"), pa.scalar("Z"), pa.scalar(""), ) created_at_iso = pc.binary_join_element_wise( pc.replace_substring(results["created_at"].cast(pa.string()), " ", "T"), pa.scalar("Z"), pa.scalar(""), ) updated_at_iso = pc.binary_join_element_wise( pc.replace_substring(results["updated_at"].cast(pa.string()), " ", "T"), pa.scalar("Z"), pa.scalar(""), ) # Convert meters to arcseconds separation_arcseconds = ( results["separation_meters"].to_numpy(zero_copy_only=False) / METERS_PER_ARCSECONDS ) return CrossMatchedMPCObservations.from_kwargs( request_id=results["input_id"].cast(pa.string()), separation_arcseconds=separation_arcseconds, separation_seconds=results["separation_seconds"], mpc_observations=MPCObservations.from_kwargs( obsid=results["obsid"], trksub=results["trksub"], provid=results["provid"], permid=results["permid"], submission_id=results["submission_id"], obssubid=results["obssubid"], obstime=Timestamp.from_iso8601(obstime_iso, scale="utc"), ra=results["ra"], dec=results["dec"], rmsra=results["rmsra"], rmsdec=results["rmsdec"], mag=results["mag"], rmsmag=results["rmsmag"], band=results["band"], stn=results["stn"], updated_at=Timestamp.from_iso8601(updated_at_iso, scale="utc"), created_at=Timestamp.from_iso8601(created_at_iso, scale="utc"), status=results["status"], ), )