Files
awesome-copilot/skills/sql-server-table-reconciliation/scripts/reconcile.py
T
Kaps 0e722d8f42 Add Skill: SQL Server table reconciliation using Arrow (#1736)
* Implement: Initial version of table reconciliation

* Refactor: Extracted inline script to scripts/reconcilliation.py

---------

Co-authored-by: Kapil Samant <kapilsamant@microsoft.com>
2026-05-19 10:55:57 +10:00

381 lines
13 KiB
Python

#!/usr/bin/env python3
"""SQL Server Table Reconciliation Script.
Compare identical tables across two SQL Server instances using mssql-python
driver and Apache Arrow. Detect missing rows, column mismatches, schema drift,
and produce a reconciliation report.
Usage:
python reconcile.py \
--source-server prod-server.database.windows.net \
--source-database ProdDB \
--target-server staging-server.database.windows.net \
--target-database StagingDB \
--tables "dbo.Orders,dbo.Items" \
--auth entra \
--output console \
--chunk-size 100000
Environment variables for credentials (when --auth sql):
MSSQL_USER - SQL Server username
MSSQL_PASSWORD - SQL Server password
"""
import argparse
import os
import sys
from getpass import getpass
import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc
from mssql_python import connect as mssql_connect
# --- Connection Setup ---
def connect(server, database, auth_mode, user=None, password=None):
"""Connect using mssql-python driver.
Reads credentials from env vars or prompts interactively. Never hardcodes."""
if auth_mode == "sql":
user = user or os.environ.get("MSSQL_USER") or input("Username: ")
password = password or os.environ.get("MSSQL_PASSWORD") or getpass("Password: ")
conn_str = (
f"Server={server};Database={database};"
f"UID={user};PWD={password};"
f"TrustServerCertificate=yes;Encrypt=yes"
)
else:
# Entra (Azure AD) authentication
conn_str = (
f"Server={server};Database={database};"
f"Authentication=ActiveDirectoryDefault;"
f"TrustServerCertificate=yes;Encrypt=yes"
)
return mssql_connect(conn_str)
# --- Table Discovery ---
def resolve_tables(conn, table_spec):
"""Resolve table spec to list of schema.table names.
Accepts: 'dbo.*', 'dbo.Orders,dbo.Items', or 'dbo.Orders'."""
tables = []
for spec in table_spec.split(","):
spec = spec.strip()
schema, tbl = spec.split(".")
if tbl == "*":
query = """
SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_SCHEMA = ? AND TABLE_TYPE = 'BASE TABLE'
ORDER BY TABLE_NAME
"""
cur = conn.cursor()
cur.execute(query, [schema])
rows = cur.arrow().to_pandas()
tables.extend(f"{schema}.{t}" for t in rows["TABLE_NAME"])
else:
tables.append(spec)
return tables
# --- Schema Comparison ---
def compare_schema(source_conn, target_conn, table):
"""Compare column names, types, nullability. Return drift report and common columns."""
query = """
SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, CHARACTER_MAXIMUM_LENGTH,
NUMERIC_PRECISION, NUMERIC_SCALE
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = ?
AND TABLE_NAME = ?
ORDER BY ORDINAL_POSITION
"""
schema_name, table_name = table.split(".")
src_cur = source_conn.cursor()
src_cur.execute(query, [schema_name, table_name])
source_schema = src_cur.arrow().to_pandas()
tgt_cur = target_conn.cursor()
tgt_cur.execute(query, [schema_name, table_name])
target_schema = tgt_cur.arrow().to_pandas()
src_cols = set(source_schema["COLUMN_NAME"])
tgt_cols = set(target_schema["COLUMN_NAME"])
drift = []
only_in_source = src_cols - tgt_cols
only_in_target = tgt_cols - src_cols
if only_in_source:
drift.append(f"Columns only in source: {sorted(only_in_source)}")
if only_in_target:
drift.append(f"Columns only in target: {sorted(only_in_target)}")
common_cols = sorted(src_cols & tgt_cols)
# Check type differences for common columns
src_types = source_schema.set_index("COLUMN_NAME")
tgt_types = target_schema.set_index("COLUMN_NAME")
for col in common_cols:
if col in src_types.index and col in tgt_types.index:
s = src_types.loc[col]
t = tgt_types.loc[col]
if s["DATA_TYPE"] != t["DATA_TYPE"]:
drift.append(
f" {col}: type {s['DATA_TYPE']} vs {t['DATA_TYPE']}"
)
return drift, common_cols
# --- Primary Key Detection ---
def detect_primary_key(conn, table):
"""Auto-detect PK columns from sys.index_columns."""
schema, tbl = table.split(".")
query = """
SELECT c.name
FROM sys.indexes i
JOIN sys.index_columns ic ON i.object_id = ic.object_id AND i.index_id = ic.index_id
JOIN sys.columns c ON ic.object_id = c.object_id AND ic.column_id = c.column_id
WHERE i.is_primary_key = 1
AND OBJECT_SCHEMA_NAME(i.object_id) = ?
AND OBJECT_NAME(i.object_id) = ?
ORDER BY ic.key_ordinal
"""
cur = conn.cursor()
cur.execute(query, [schema, tbl])
result = cur.arrow()
return result.column("name").to_pylist()
# --- Data Extraction (Arrow) ---
def extract_table(conn, table, pk_cols, chunk_size=100000):
"""Extract table data as Arrow Table, using Arrow columnar transfer."""
pk_order = ", ".join(pk_cols)
query = f"SELECT * FROM {table} ORDER BY {pk_order}"
cur = conn.cursor()
cur.execute(query)
return cur.arrow()
# --- Hash Pre-check (for large tables) ---
def extract_hashes(conn, table, pk_cols, compare_cols):
"""Extract PK + row hash for large table optimization."""
pk_select = ", ".join(pk_cols)
col_concat = ", ".join(compare_cols)
query = f"""
SELECT {pk_select},
HASHBYTES('SHA2_256', CONCAT_WS('|', {col_concat})) AS row_hash
FROM {table}
ORDER BY {pk_select}
"""
cur = conn.cursor()
cur.execute(query)
return cur.arrow()
# --- Comparison Logic ---
def reconcile(source_table, target_table, pk_cols, compare_cols):
"""Compare two Arrow tables.
1. Convert to pandas with PK as index
2. Identify missing/extra rows
3. Compare column values for matching rows
4. Handle NULL vs non-NULL (NULL == NULL is a match)
"""
src_df = source_table.to_pandas().set_index(pk_cols)
tgt_df = target_table.to_pandas().set_index(pk_cols)
# Missing/extra rows
src_keys = set(src_df.index.tolist() if len(pk_cols) > 1 else src_df.index)
tgt_keys = set(tgt_df.index.tolist() if len(pk_cols) > 1 else tgt_df.index)
missing_in_target = src_keys - tgt_keys
extra_in_target = tgt_keys - src_keys
common_keys = src_keys & tgt_keys
# Column-level mismatches on common rows
common_src = src_df.loc[src_df.index.isin(common_keys), compare_cols]
common_tgt = tgt_df.loc[tgt_df.index.isin(common_keys), compare_cols]
diff = common_src.compare(common_tgt, keep_shape=False)
return {
"missing_in_target": missing_in_target,
"extra_in_target": extra_in_target,
"mismatches": diff,
"total_source": len(src_df),
"total_target": len(tgt_df),
}
# --- Per-Table Pipeline ---
def reconcile_table(source_conn, target_conn, table, pk_override=None, columns=None,
chunk_size=100000):
"""Run full reconciliation for one table. Returns result dict."""
schema_drift, common_cols = compare_schema(source_conn, target_conn, table)
pk_cols = pk_override
if not pk_cols:
pk_cols = detect_primary_key(source_conn, table)
if not pk_cols:
pk_cols = detect_primary_key(target_conn, table)
if not pk_cols:
return {"table": table, "error": "No PK detected", "status": "SKIPPED"}
compare_cols = columns if columns else [c for c in common_cols if c not in pk_cols]
source_data = extract_table(source_conn, table, pk_cols, chunk_size)
target_data = extract_table(target_conn, table, pk_cols, chunk_size)
result = reconcile(source_data, target_data, pk_cols, compare_cols)
result["table"] = table
result["schema_drift"] = schema_drift
result["status"] = (
"PASS"
if not (result["missing_in_target"] or result["extra_in_target"] or len(result["mismatches"]))
else "FAIL"
)
return result
# --- Report Generation ---
def generate_report(all_results, output_format="console"):
"""Output per-table details + combined summary."""
for r in all_results:
print(f"\n--- {r['table']} ---")
if r.get("error"):
print(f" SKIPPED: {r['error']}")
continue
print(f" Source: {r['total_source']:,} Target: {r['total_target']:,}")
print(
f" Missing: {len(r['missing_in_target'])} "
f"Extra: {len(r['extra_in_target'])} "
f"Mismatches: {len(r['mismatches'])}"
)
print(
f" Result: {'✓ IDENTICAL' if r['status'] == 'PASS' else '✗ DIFFERENCES FOUND'}"
)
if r.get("schema_drift"):
print(" Schema drift:")
for d in r["schema_drift"]:
print(f" {d}")
# Summary
passed = sum(1 for r in all_results if r["status"] == "PASS")
failed = sum(1 for r in all_results if r["status"] == "FAIL")
skipped = sum(1 for r in all_results if r["status"] == "SKIPPED")
print(
f"\n=== Summary: {passed} passed, {failed} failed, "
f"{skipped} skipped / {len(all_results)} tables ==="
)
# Export if requested
if output_format == "csv":
rows = [
{
"table": r["table"],
"status": r["status"],
"source_rows": r.get("total_source", 0),
"target_rows": r.get("total_target", 0),
"missing": len(r.get("missing_in_target", [])),
"extra": len(r.get("extra_in_target", [])),
"mismatches": len(r.get("mismatches", [])),
}
for r in all_results
]
df = pd.DataFrame(rows)
df.to_csv("reconciliation_report.csv", index=False)
print("\nReport saved to reconciliation_report.csv")
elif output_format == "json":
import json
rows = [
{
"table": r["table"],
"status": r["status"],
"source_rows": r.get("total_source", 0),
"target_rows": r.get("total_target", 0),
"missing": len(r.get("missing_in_target", [])),
"extra": len(r.get("extra_in_target", [])),
"mismatches": len(r.get("mismatches", [])),
}
for r in all_results
]
with open("reconciliation_report.json", "w") as f:
json.dump(rows, f, indent=2)
print("\nReport saved to reconciliation_report.json")
# --- Main ---
def main():
parser = argparse.ArgumentParser(
description="Compare SQL Server tables across two instances."
)
parser.add_argument("--source-server", required=True, help="Source SQL Server host")
parser.add_argument("--source-database", required=True, help="Source database name")
parser.add_argument("--target-server", required=True, help="Target SQL Server host")
parser.add_argument("--target-database", required=True, help="Target database name")
parser.add_argument(
"--tables",
required=True,
help="Comma-separated schema.table names or schema.* wildcard",
)
parser.add_argument(
"--auth",
choices=["sql", "entra"],
default="sql",
help="Authentication mode (default: sql)",
)
parser.add_argument(
"--primary-key",
default=None,
help="Comma-separated PK column(s). Auto-detected if omitted.",
)
parser.add_argument(
"--columns",
default=None,
help="Comma-separated columns to compare. All non-PK columns if omitted.",
)
parser.add_argument(
"--chunk-size",
type=int,
default=100000,
help="Rows per batch for large tables (default: 100000)",
)
parser.add_argument(
"--output",
choices=["console", "csv", "json"],
default="console",
help="Output format (default: console)",
)
args = parser.parse_args()
pk_override = [c.strip() for c in args.primary_key.split(",")] if args.primary_key else None
columns = [c.strip() for c in args.columns.split(",")] if args.columns else None
print(f"Connecting to source: {args.source_server}/{args.source_database}")
source_conn = connect(args.source_server, args.source_database, args.auth)
print(f"Connecting to target: {args.target_server}/{args.target_database}")
target_conn = connect(args.target_server, args.target_database, args.auth)
tables = resolve_tables(source_conn, args.tables)
print(f"Tables to reconcile: {tables}")
results = []
for table in tables:
print(f"Reconciling {table}...")
results.append(
reconcile_table(
source_conn, target_conn, table,
pk_override=pk_override,
columns=columns,
chunk_size=args.chunk_size,
)
)
generate_report(results, output_format=args.output)
if __name__ == "__main__":
main()