Compare commits

..

22 Commits

Author SHA1 Message Date
Richie a80de99175 adding math to bob 2026-04-12 10:08:23 -04:00
Richie 50d56a8a39 added config.toml to git ignore 2026-04-12 10:08:23 -04:00
Richie 30dc36588c updated BenchmarkConfig to have from_toml 2026-04-12 10:08:23 -04:00
Richie 68190901cb setup FinetuneConfig 2026-04-12 10:08:23 -04:00
Richie 275762843f deleted train.sh 2026-04-12 10:08:23 -04:00
Richie face93262f added containers dir 2026-04-12 10:08:23 -04:00
Richie ee34a0986b conveted to summarization_prompts 2026-04-12 10:08:23 -04:00
Richie e8b20bc7df moved renamed container.py to vllm_container.py 2026-04-12 10:08:23 -04:00
Richie 6c459985fa created working finetuing pipeline 2026-04-12 10:08:23 -04:00
Richie 20a204612f added data dir for traning 2026-04-12 10:08:23 -04:00
Richie 27b609052c updated spell check 2026-04-12 10:08:23 -04:00
Richie 20fb24e244 added storage pool 2026-04-12 10:08:23 -04:00
Richie 230ab1d7f6 added tiktoken 2026-04-12 10:08:23 -04:00
Richie 9ffaa1b755 added summarization_prompts.py to sore the prompts 2026-04-12 10:08:23 -04:00
Richie c6b4ed4814 added tools dir for on off scripts i used 2026-04-12 10:08:23 -04:00
Richie 88ceeb55a1 added batch_bill_summarizer.py
batch bill  summarizer sends a batch api call to gpt
2026-04-12 10:08:23 -04:00
Richie 6c57d74644 decreased root_pool/models snapshot life 2026-04-12 10:08:23 -04:00
Richie cb98090f95 added bill_token_compression.py
tested on sample size of 100 bills matching the distribution of our data
Compression saves ~11.5% on prompt tokens; completion/reasoning are roughly equal across the two sets.
prompt	completion	reasoning	total
compressed	349,460	157,110	112,128	506,570
uncompressed	394,948	154,710	110,080	549,658
delta	−45,488	+2,400	+2,048	−43,088
2026-04-12 10:08:23 -04:00
Richie 63cb48a3dd created main prompt bench 2026-04-12 10:08:23 -04:00
Richie 6f6d247d3e fixed sunshine.nix 2026-04-12 10:08:23 -04:00
Richie 6b63315579 converting bob to a server 2026-04-12 10:08:23 -04:00
Richie a093c72eb9 creating prompt_bench downloader 2026-04-12 10:08:23 -04:00
160 changed files with 2261 additions and 10633 deletions
+1 -1
View File
@@ -23,6 +23,6 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Build default package
run: "nixos-rebuild build --accept-flake-config --flake ./#${{ matrix.system }}"
run: "nixos-rebuild build --flake ./#${{ matrix.system }}"
- name: copy to nix-cache
run: nix copy --accept-flake-config --to unix:///host-nix/var/nix/daemon-socket/socket .#nixosConfigurations.${{ matrix.system }}.config.system.build.toplevel
+30
View File
@@ -0,0 +1,30 @@
name: fix_eval_warnings
on:
workflow_run:
workflows: ["build_systems"]
types: [completed]
jobs:
check-warnings:
if: >-
github.event.workflow_run.conclusion != 'cancelled' &&
github.event.workflow_run.head_branch == 'main' &&
(github.event.workflow_run.event == 'push' || github.event.workflow_run.event == 'schedule')
runs-on: self-hosted
permissions:
contents: write
pull-requests: write
steps:
- uses: actions/checkout@v4
- name: Fix eval warnings
env:
GH_TOKEN: ${{ secrets.GH_TOKEN_FOR_UPDATES }}
run: >-
nix develop .#devShells.x86_64-linux.default -c
python -m python.eval_warnings.main
--run-id "${{ github.event.workflow_run.id }}"
--repo "${{ github.repository }}"
--ollama-url "${{ secrets.OLLAMA_URL }}"
--run-url "${{ github.event.workflow_run.html_url }}"
+13 -7
View File
@@ -6,18 +6,24 @@ on:
jobs:
merge:
runs-on: self-hosted
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: merge_flake_lock_update
run: >-
nix develop .#devShells.x86_64-linux.default -c
python -m python.gitea_flake_lock merge
--repo "${{ github.repository }}"
run: |
pr_number=$(gh pr list --state open --author RichieCahill --label flake_lock_update --json number --jq '.[0].number')
echo "pr_number=$pr_number" >> $GITHUB_ENV
if [ -n "$pr_number" ]; then
gh pr merge "$pr_number" --rebase
else
echo "No open PR found with label flake_lock_update"
fi
env:
GITEA_TOKEN: ${{ secrets.GITEA_TOKEN }}
GITEA_URL: https://gitea.tmmworkshop.com
GITHUB_TOKEN: ${{ secrets.GH_TOKEN_FOR_UPDATES }}
+1 -1
View File
@@ -1,13 +1,13 @@
name: pytest
on:
workflow_dispatch:
push:
branches:
- main
pull_request:
branches:
- main
merge_group:
jobs:
pytest:
+11 -14
View File
@@ -6,21 +6,18 @@ on:
jobs:
lockfile:
runs-on: self-hosted
permissions:
actions: write
contents: write
pull-requests: write
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Nix
uses: DeterminateSystems/nix-installer-action@main
- name: Update flake.lock
run: nix flake update
- name: Create or update flake.lock PR
env:
GITEA_TOKEN: ${{ secrets.GITEA_TOKEN }}
GITEA_URL: https://gitea.tmmworkshop.com
run: >-
nix develop .#devShells.x86_64-linux.default -c
python -m python.gitea_flake_lock update
--repo "${{ github.repository }}"
uses: DeterminateSystems/update-flake-lock@main
with:
token: ${{ secrets.GH_TOKEN_FOR_UPDATES }}
pr-title: "Update flake.lock"
pr-labels: |
dependencies
automated
flake_lock_update
+3 -3
View File
@@ -170,6 +170,6 @@ test.*
frontend/dist/
frontend/node_modules/
# data from testing llms
data/*
.ebook_search_bm25
# data dir for training, validation, and testing
data/
config.toml
+1 -1
View File
@@ -40,6 +40,7 @@
"cgroupdriver",
"charliermarsh",
"Checkpointing",
"cloudflared",
"codellama",
"codezombiech",
"compactmode",
@@ -203,7 +204,6 @@
"peerconnection",
"PESKYFOX",
"PGID",
"pgvector",
"pipewire",
"pkgs",
"plugdev",
+2 -12
View File
@@ -23,10 +23,7 @@
boot = {
tmp.useTmpfs = true;
kernelPackages = lib.mkDefault pkgs.linuxPackages_6_12;
zfs = {
package = lib.mkDefault pkgs.zfs_2_4;
forceImportRoot = lib.mkDefault false;
};
zfs.package = lib.mkDefault pkgs.zfs_2_4;
};
hardware.enableRedistributableFirmware = true;
@@ -40,17 +37,10 @@
nixpkgs = {
overlays = builtins.attrValues outputs.overlays;
config = {
allowUnfree = true;
permittedInsecurePackages = [
"openssl-1.1.1w" # This is for discord-canary
];
};
config.allowUnfree = true;
};
services = {
dbus.implementation = "dbus";
# firmware update
fwupd.enable = true;
-1
View File
@@ -34,7 +34,6 @@ in
warn-dirty = false;
flake-registry = ""; # disable global flake registries
connect-timeout = 10;
download-buffer-size = 536870912;
fallback = true;
};
-256
View File
@@ -1,256 +0,0 @@
{
config,
lib,
pkgs,
...
}:
let
monitoringInterface = "ztwfunumly";
nodeTextfileDir = "/var/lib/prometheus-node-exporter-textfile";
mkProcessNameTemplate =
perPid: template: if perPid then "${template}:{{.PID}}:{{.StartTime}}" else template;
mkProcessMatchers = perPid: [
{
name = mkProcessNameTemplate perPid "{{.Username}}:{{.Matches.Module}}";
cmdline = [ "^/nix/store[^ ]*/bin/python[^ ]* -m (?P<Module>[^ ]+)" ];
}
{
name = mkProcessNameTemplate perPid "{{.Username}}:{{.Matches.Wrapped}}";
cmdline = [
"^/nix/store[^ ]*/bin/python[^ ]* /nix/store[^ ]*/bin/\\.?(?P<Wrapped>[^ /]+?)(?:-wrapped)?(?:\\s|$)"
];
}
{
name = mkProcessNameTemplate perPid "{{.Username}}:{{.Matches.Wrapped}}";
cmdline = [
"^/nix/store[^ ]*/bin/node /nix/store[^ ]*-(?P<Wrapped>[A-Za-z0-9._+-]+)-[0-9][^ /]*/"
];
}
{
name = mkProcessNameTemplate perPid "{{.Username}}:{{.Matches.Wrapped}}";
cmdline = [ "^/nix/store[^ ]*/(?:bin/|lib/[^ ]*/)?\\.?(?P<Wrapped>[^ /]+?)(?:-wrapped)?(?:\\s|$)" ];
}
{
name = mkProcessNameTemplate perPid "{{.Username}}:{{.ExeBase}}";
cmdline = [ ".+" ];
}
];
perPidConfig = pkgs.writeText "process-exporter-per-pid.yaml" (
builtins.toJSON {
process_names = mkProcessMatchers true;
}
);
zpoolLatencyScript = pkgs.writeShellScript "zpool-latency-exporter" ''
set -euo pipefail
out_dir=${lib.escapeShellArg nodeTextfileDir}
host=${lib.escapeShellArg config.networking.hostName}
tmp_file="$(mktemp "$out_dir/zpool.prom.XXXXXX")"
trap 'rm -f "$tmp_file"' EXIT
pools="$(zpool list -H -o name | paste -sd, -)"
cat >"$tmp_file" <<'EOF'
# HELP zpool_iostat_total_wait_read_ns Average total read wait time reported by zpool iostat.
# TYPE zpool_iostat_total_wait_read_ns gauge
# HELP zpool_iostat_total_wait_write_ns Average total write wait time reported by zpool iostat.
# TYPE zpool_iostat_total_wait_write_ns gauge
# HELP zpool_iostat_disk_wait_read_ns Average disk read wait time reported by zpool iostat.
# TYPE zpool_iostat_disk_wait_read_ns gauge
# HELP zpool_iostat_disk_wait_write_ns Average disk write wait time reported by zpool iostat.
# TYPE zpool_iostat_disk_wait_write_ns gauge
# HELP zpool_iostat_syncq_wait_read_ns Average synchronous queue read wait time reported by zpool iostat.
# TYPE zpool_iostat_syncq_wait_read_ns gauge
# HELP zpool_iostat_syncq_wait_write_ns Average synchronous queue write wait time reported by zpool iostat.
# TYPE zpool_iostat_syncq_wait_write_ns gauge
# HELP zpool_iostat_asyncq_wait_read_ns Average asynchronous queue read wait time reported by zpool iostat.
# TYPE zpool_iostat_asyncq_wait_read_ns gauge
# HELP zpool_iostat_asyncq_wait_write_ns Average asynchronous queue write wait time reported by zpool iostat.
# TYPE zpool_iostat_asyncq_wait_write_ns gauge
EOF
zpool iostat -Hplvy -y 1 1 | awk -F '\t' -v host="$host" -v pools="$pools" '
function esc(str, out) {
out = str
gsub(/\\/, "\\\\", out)
gsub(/"/, "\\\"", out)
return out
}
function emit(metric, pool, vdev, value) {
if (value == "" || value == "-") {
return
}
printf "%s{host=\"%s\",pool=\"%s\",vdev=\"%s\"} %s\n",
metric,
esc(host),
esc(pool),
esc(vdev),
value
}
BEGIN {
split(pools, pool_names, ",")
for (idx in pool_names) {
if (pool_names[idx] != "") {
known_pools[pool_names[idx]] = 1
}
}
}
NF == 0 {
next
}
{
row_name = $1
if (row_name in known_pools) {
current_pool = row_name
current_vdev = "_pool"
} else if (current_pool == "") {
next
} else {
current_vdev = row_name
}
emit("zpool_iostat_total_wait_read_ns", current_pool, current_vdev, $8)
emit("zpool_iostat_total_wait_write_ns", current_pool, current_vdev, $9)
emit("zpool_iostat_disk_wait_read_ns", current_pool, current_vdev, $10)
emit("zpool_iostat_disk_wait_write_ns", current_pool, current_vdev, $11)
emit("zpool_iostat_syncq_wait_read_ns", current_pool, current_vdev, $12)
emit("zpool_iostat_syncq_wait_write_ns", current_pool, current_vdev, $13)
emit("zpool_iostat_asyncq_wait_read_ns", current_pool, current_vdev, $14)
emit("zpool_iostat_asyncq_wait_write_ns", current_pool, current_vdev, $15)
}
' >>"$tmp_file"
mv "$tmp_file" "$out_dir/zpool.prom"
trap - EXIT
'';
in
{
networking.firewall.interfaces.${monitoringInterface}.allowedTCPPorts = [
9100
9134
9256
9257
9633
];
services.prometheus.exporters = {
node = {
enable = true;
enabledCollectors = [
"pressure"
"processes"
"systemd"
];
extraFlags = [ "--collector.textfile.directory=${nodeTextfileDir}" ];
};
process = {
enable = true;
user = "root";
group = "root";
settings.process_names = mkProcessMatchers false;
extraFlags = [
"-gather-smaps=false"
"-remove-empty-groups=true"
"-threads=false"
];
};
smartctl.enable = true;
zfs.enable = true;
};
programs.atop = {
enable = true;
atopService.enable = true;
atopRotateTimer.enable = true;
atopacctService.enable = true;
settings.interval = 30;
};
systemd = {
services = {
prometheus-process-pid-exporter = {
description = "Prometheus process exporter with per-PID naming";
wantedBy = [ "multi-user.target" ];
after = [ "network.target" ];
serviceConfig = {
ExecStart = ''
${pkgs.prometheus-process-exporter}/bin/process-exporter \
--web.listen-address 0.0.0.0:9257 \
--config.path ${perPidConfig} \
-children=false \
-gather-smaps=false \
-remove-empty-groups=true \
-threads=false
'';
User = "root";
Group = "root";
Restart = "always";
WorkingDirectory = "/tmp";
CapabilityBoundingSet = [ "" ];
DeviceAllow = [ "" ];
LockPersonality = true;
MemoryDenyWriteExecute = true;
NoNewPrivileges = true;
PrivateDevices = true;
PrivateTmp = true;
ProtectClock = true;
ProtectControlGroups = true;
ProtectHome = true;
ProtectHostname = true;
ProtectKernelLogs = true;
ProtectKernelModules = true;
ProtectKernelTunables = true;
ProtectSystem = "strict";
RemoveIPC = true;
RestrictAddressFamilies = [
"AF_INET"
"AF_INET6"
];
RestrictNamespaces = true;
RestrictRealtime = true;
RestrictSUIDSGID = true;
SystemCallArchitectures = "native";
UMask = "0077";
};
};
zpool-latency-exporter = {
description = "Exports ZFS latency metrics for node_exporter textfile collection";
after = [ "zfs-import.target" ];
requires = [ "zfs-import.target" ];
path = [
config.boot.zfs.package
pkgs.coreutils
pkgs.gawk
];
serviceConfig = {
Type = "oneshot";
ExecStart = zpoolLatencyScript;
};
};
};
timers.zpool-latency-exporter = {
wantedBy = [ "timers.target" ];
timerConfig = {
OnBootSec = "2m";
OnUnitActiveSec = "60s";
Unit = "zpool-latency-exporter.service";
};
};
tmpfiles.rules = [ "d ${nodeTextfileDir} 0755 root root - -" ];
};
}
+1 -1
View File
@@ -12,7 +12,7 @@
brain.id = "SSCGIPI-IV3VYKB-TRNIJE3-COV4T2H-CDBER7F-I2CGHYA-NWOEUDU-3T5QAAN"; # cspell:disable-line
ipad.id = "KI76T3X-SFUGV2L-VSNYTKR-TSIUV5L-SHWD3HE-GQRGRCN-GY4UFMD-CW6Z6AX"; # cspell:disable-line
jeeves.id = "ICRHXZW-ECYJCUZ-I4CZ64R-3XRK7CG-LL2HAAK-FGOHD22-BQA4AI6-5OAL6AG"; # cspell:disable-line
phone.id = "JPVQKQW-CFXOJXT-Q5G5F3H-QIDHDRE-GKHPTQB-GXZUQSP-U7FR7F7-INP3AAH"; # cspell:disable-line
phone.id = "TBRULKD-7DZPGGZ-F6LLB7J-MSO54AY-7KLPBIN-QOFK6PX-W2HBEWI-PHM2CQI"; # cspell:disable-line
rhapsody-in-green.id = "ASL3KC4-3XEN6PA-7BQBRKE-A7JXLI6-DJT43BY-Q4WPOER-7UALUAZ-VTPQ6Q4"; # cspell:disable-line
};
};
+1 -1
View File
@@ -4,7 +4,7 @@
flags = [ "--accept-flake-config" ];
randomizedDelaySec = "1h";
persistent = true;
flake = "git+https://gitea.tmmworkshop.com/richie/dotfiles?ref=main";
flake = "github:RichieCahill/dotfiles";
allowReboot = true;
dates = "Sat *-*-* 06:00:00";
};
-76
View File
@@ -1,76 +0,0 @@
# ZFS failed root import recovery
## Fast path
If the machine fails to boot because ZFS refuses to import `root_pool`:
### GRUB
1. At the bootloader menu, select the normal NixOS entry.
2. Press `e`.
3. Find the line that starts with `linux`.
4. Append this to the end of that line:
```text
zfs_force=1
```
5. Boot once with `Ctrl+x` or `F10`.
### systemd-boot
1. At the bootloader menu, highlight the normal NixOS entry.
2. Press `e`.
3. Append this to the end of the options line:
```text
zfs_force=1
```
4. Press `Enter` to boot once.
## After boot
Run:
```bash
sudo zpool status
sudo zpool import
journalctl -b | rg "ZFS|zfs|import|root_pool"
```
## Expected result
`sudo zpool status` should show `root_pool` as `ONLINE`.
## Reboot test
Run:
```bash
sudo reboot
```
Do not add `zfs_force=1` the second time.
## If it still fails
Boot once more with:
```text
zfs_force=1
```
Then run:
```bash
sudo zpool status -v
sudo zpool history | tail -n 50
journalctl -b | rg "ZFS|zfs|import|root_pool"
```
## Notes
- Root pool name is `root_pool`.
- This is a one-time recovery path after disk moves, controller changes, dirty exports, or interrupted imports.
- Some hosts also need the LUKS unlock USB key inserted before boot.
Generated
+26 -42
View File
@@ -8,11 +8,11 @@
},
"locked": {
"dir": "pkgs/firefox-addons",
"lastModified": 1781150628,
"narHash": "sha256-b4mp8l3qWuSCyYYo9HSngDtcB3PpecYiOXjULrjwwlw=",
"lastModified": 1773979456,
"narHash": "sha256-9kBMJ5IvxqNlkkj/swmE8uK1Sc7TL/LIRUI958m7uBM=",
"owner": "rycee",
"repo": "nur-expressions",
"rev": "753319310f4673a2dabbfab87482187b40bf9bac",
"rev": "81e28f47ac18d9e89513929c77e711e657b64851",
"type": "gitlab"
},
"original": {
@@ -29,11 +29,11 @@
]
},
"locked": {
"lastModified": 1781189114,
"narHash": "sha256-5inaamLgUMWy+MOBE9ChF9QAF1o/74LFuHkI0W/9rqc=",
"lastModified": 1774007980,
"narHash": "sha256-FOnZjElEI8pqqCvB6K/1JRHTE8o4rer8driivTpq2uo=",
"owner": "nix-community",
"repo": "home-manager",
"rev": "486595d2cf49cfcd649b58a284fa11ac0e34da22",
"rev": "9670de2921812bc4e0452f6e3efd8c859696c183",
"type": "github"
},
"original": {
@@ -43,15 +43,12 @@
}
},
"nixos-hardware": {
"inputs": {
"nixpkgs": "nixpkgs"
},
"locked": {
"lastModified": 1781168557,
"narHash": "sha256-LOnLQ2tpYF9gqIDDr3+j3DbpJJr/QCH6zPRT2GzEUOE=",
"lastModified": 1774018263,
"narHash": "sha256-HHYEwK1A22aSaxv2ibhMMkKvrDGKGlA/qObG4smrSqc=",
"owner": "nixos",
"repo": "nixos-hardware",
"rev": "6358ff76821101c178e3ab4919a62799bfe3652e",
"rev": "2d4b4717b2534fad5c715968c1cece04a172b365",
"type": "github"
},
"original": {
@@ -63,24 +60,27 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1767892417,
"narHash": "sha256-8bW3q88CEg2u4hSP66Vf4lpbLonHz7hqDNBMcCY7E9U=",
"rev": "3497aa5c9457a9d88d71fa93a4a8368816fbeeba",
"type": "tarball",
"url": "https://releases.nixos.org/nixos/unstable/nixos-26.05pre924538.3497aa5c9457/nixexprs.tar.xz"
"lastModified": 1773821835,
"narHash": "sha256-TJ3lSQtW0E2JrznGVm8hOQGVpXjJyXY2guAxku2O9A4=",
"owner": "nixos",
"repo": "nixpkgs",
"rev": "b40629efe5d6ec48dd1efba650c797ddbd39ace0",
"type": "github"
},
"original": {
"type": "tarball",
"url": "https://channels.nixos.org/nixos-unstable/nixexprs.tar.xz"
"owner": "nixos",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"nixpkgs-master": {
"locked": {
"lastModified": 1781229721,
"narHash": "sha256-ORvqDbb/LYxiJljGIejapjkc/kJbVote2N1WSb9W45I=",
"lastModified": 1774051532,
"narHash": "sha256-d3CGMweyYIcPuTj5BKq+1Lx4zwlgL31nVtN647tOZKo=",
"owner": "nixos",
"repo": "nixpkgs",
"rev": "173d0ad7a974f8543a9ab01d2271b2e290341b33",
"rev": "8620c0b5cc8fbe76502442181be1d0514bc3a1b7",
"type": "github"
},
"original": {
@@ -106,28 +106,12 @@
"type": "github"
}
},
"nixpkgs_2": {
"locked": {
"lastModified": 1781074563,
"narHash": "sha256-md8WlXOlfnIeHeOScMTTHFyf2d6iaTwPl2apR5EQ3P4=",
"owner": "nixos",
"repo": "nixpkgs",
"rev": "9ae611a455b90cf061d8f332b977e387bda8e1ca",
"type": "github"
},
"original": {
"owner": "nixos",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"root": {
"inputs": {
"firefox-addons": "firefox-addons",
"home-manager": "home-manager",
"nixos-hardware": "nixos-hardware",
"nixpkgs": "nixpkgs_2",
"nixpkgs": "nixpkgs",
"nixpkgs-master": "nixpkgs-master",
"nixpkgs-stable": "nixpkgs-stable",
"sops-nix": "sops-nix",
@@ -141,11 +125,11 @@
]
},
"locked": {
"lastModified": 1780547341,
"narHash": "sha256-Gq8KNx5A7hBB3uGJaj6eQfLDIz5YdLu92gqBcvHvoUo=",
"lastModified": 1773889674,
"narHash": "sha256-+ycaiVAk3MEshJTg35cBTUa0MizGiS+bgpYw/f8ohkg=",
"owner": "Mic92",
"repo": "sops-nix",
"rev": "9ed65852b6257fbeae4355bc24ecfea307ca759a",
"rev": "29b6519f3e0780452bca0ac0be4584f04ac16cc5",
"type": "github"
},
"original": {
+24
View File
@@ -0,0 +1,24 @@
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
lerna-debug.log*
node_modules
dist
dist-ssr
*.local
# Editor directories and files
.vscode/*
!.vscode/extensions.json
.idea
.DS_Store
*.suo
*.ntvs*
*.njsproj
*.sln
*.sw?
+2 -28
View File
@@ -17,41 +17,16 @@
python-env = final: _prev: {
my_python = final.python314.withPackages (
ps:
let
bm25s = ps.buildPythonPackage rec {
pname = "bm25s";
version = "0.3.9";
pyproject = true;
src = final.fetchPypi {
inherit pname version;
hash = "sha256-iVxnnZUrfeg1XttfPhpiCh4vKU0dQrkZvwghzOLi9Zc=";
};
build-system = [ ps.setuptools ];
dependencies = with ps; [
numpy
scipy
];
pythonImportsCheck = [ "bm25s" ];
};
in
with ps;
[
ps: with ps; [
alembic
apprise
apscheduler
beautifulsoup4
ebooklib
fastapi
fastapi-cli
httpx
huggingface-hub
mypy
numpy
orjson
pgvector
polars
psycopg
pydantic
@@ -65,7 +40,6 @@
scalene
sqlalchemy
sqlalchemy
bm25s
tenacity
textual
tiktoken
+14 -3
View File
@@ -12,6 +12,7 @@ dependencies = [
"alembic",
"apprise",
"apscheduler",
"huggingface-hub",
"httpx",
"python-multipart",
"polars",
@@ -26,7 +27,11 @@ dependencies = [
[project.scripts]
database = "python.database_cli:app"
van-inventory = "python.van_inventory.main:serve"
whisper-transcribe = "python.tools.whisper.transcribe:main"
prompt-bench = "python.prompt_bench.main:cli"
prompt-bench-download = "python.prompt_bench.downloader:cli"
finetune = "python.prompt_bench.finetune:cli"
finetune-container = "python.prompt_bench.finetune_container:cli"
build-finetune-dataset = "python.prompt_bench.build_finetune_dataset:cli"
[dependency-groups]
dev = [
@@ -51,7 +56,6 @@ lint.ignore = [
"COM812", # (TEMP) conflicts when used with the formatter
"ISC001", # (TEMP) conflicts when used with the formatter
"S603", # (PERM) This is known to cause a false positive
"S607", # (PERM) This is becoming a consistent annoyance
]
[tool.ruff.lint.per-file-ignores]
@@ -80,7 +84,14 @@ lint.ignore = [
"python/congress_tracker/**" = [
"TC003", # (perm) this creates issues because sqlalchemy uses these at runtime
]
"python/eval_warnings/**" = [
"S607", # (perm) gh and git are expected on PATH in the runner environment
]
"python/prompt_bench/**" = [
"FBT002", # (perm) typer requires boolean defaults for --flag/--no-flag options
"PLR0913", # (perm) typer CLIs naturally have many parameters
"S607", # (perm) docker and nvidia-smi are expected on PATH
]
"python/alembic/**" = [
"INP001", # (perm) this creates LSP issues for alembic
]
@@ -1,93 +0,0 @@
"""adding audiobook libreary metadata.
Revision ID: d7864d1ffc17
Revises: c8a794340928
Create Date: 2026-06-03 20:24:09.200837
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import sqlalchemy as sa
from alembic import op
from python.orm import RichieBase
if TYPE_CHECKING:
from collections.abc import Sequence
# revision identifiers, used by Alembic.
revision: str = "d7864d1ffc17"
down_revision: str | None = "c8a794340928"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
schema = RichieBase.schema_name
def upgrade() -> None:
"""Upgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"audiobook_author",
sa.Column("name", sa.String(), nullable=False),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("pk_audiobook_author")),
sa.UniqueConstraint("name", name=op.f("uq_audiobook_author_name")),
schema=schema,
)
op.create_table(
"audiobook_series",
sa.Column("name", sa.String(), nullable=False),
sa.Column("author_id", sa.Integer(), nullable=False),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.ForeignKeyConstraint(
["author_id"],
[f"{schema}.audiobook_author.id"],
name=op.f("fk_audiobook_series_author_id_audiobook_author"),
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_audiobook_series")),
sa.UniqueConstraint("author_id", "name", name=op.f("uq_audiobook_series_author_id")),
schema=schema,
)
op.create_table(
"audiobook",
sa.Column("title", sa.String(), nullable=False),
sa.Column("author_id", sa.Integer(), nullable=False),
sa.Column("series_id", sa.Integer(), nullable=True),
sa.Column("series_index", sa.Integer(), nullable=False),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.ForeignKeyConstraint(
["author_id"],
[f"{schema}.audiobook_author.id"],
name=op.f("fk_audiobook_author_id_audiobook_author"),
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["series_id"],
[f"{schema}.audiobook_series.id"],
name=op.f("fk_audiobook_series_id_audiobook_series"),
ondelete="SET NULL",
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_audiobook")),
schema=schema,
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("audiobook", schema=schema)
op.drop_table("audiobook_series", schema=schema)
op.drop_table("audiobook_author", schema=schema)
# ### end Alembic commands ###
@@ -1,200 +0,0 @@
"""add ebook search tables.
Revision ID: 2db132cace1a
Revises: b3c60cc5beb5
Create Date: 2026-06-10 22:10:54.379159
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import pgvector
import sqlalchemy as sa
from alembic import op
from python.orm import RichieBase
if TYPE_CHECKING:
from collections.abc import Sequence
# revision identifiers, used by Alembic.
revision: str = "2db132cace1a"
down_revision: str | None = "b3c60cc5beb5"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
schema = RichieBase.schema_name
def upgrade() -> None:
"""Upgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"ebook_embedding_model",
sa.Column("name", sa.String(), nullable=False),
sa.Column("dimension", sa.Integer(), nullable=False),
sa.Column("is_default", sa.Boolean(), nullable=False),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("pk_ebook_embedding_model")),
sa.UniqueConstraint("name", name=op.f("uq_ebook_embedding_model_name")),
schema=schema,
)
op.create_table(
"ebook_source",
sa.Column("title", sa.String(), nullable=False),
sa.Column("author", sa.String(), nullable=True),
sa.Column("language", sa.String(), nullable=True),
sa.Column("publisher", sa.String(), nullable=True),
sa.Column("identifier", sa.String(), nullable=True),
sa.Column("file_path", sa.String(), nullable=False),
sa.Column("file_sha256", sa.String(length=64), nullable=False),
sa.Column("file_mtime", sa.DateTime(timezone=True), nullable=False),
sa.Column("file_size", sa.BigInteger(), nullable=False),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("pk_ebook_source")),
sa.UniqueConstraint("file_path", name=op.f("uq_ebook_source_file_path")),
sa.UniqueConstraint("file_sha256", name=op.f("uq_ebook_source_file_sha256")),
schema=schema,
)
op.create_table(
"ebook_chapter",
sa.Column("source_id", sa.Integer(), nullable=False),
sa.Column("spine_index", sa.Integer(), nullable=False),
sa.Column("title", sa.String(), nullable=True),
sa.Column("href", sa.String(), nullable=True),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.ForeignKeyConstraint(
["source_id"],
[f"{schema}.ebook_source.id"],
name=op.f("fk_ebook_chapter_source_id_ebook_source"),
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_ebook_chapter")),
sa.UniqueConstraint("source_id", "spine_index", name=op.f("uq_ebook_chapter_source_id")),
schema=schema,
)
op.create_table(
"ebook_chunk",
sa.Column("source_id", sa.Integer(), nullable=False),
sa.Column("chapter_id", sa.Integer(), nullable=True),
sa.Column("chunk_index", sa.Integer(), nullable=False),
sa.Column("text", sa.String(), nullable=False),
sa.Column("token_start", sa.Integer(), nullable=False),
sa.Column("token_count", sa.Integer(), nullable=False),
sa.Column("page_label", sa.String(), nullable=True),
sa.Column("content_sha256", sa.String(length=64), nullable=False),
sa.Column("search_text", sa.String(), nullable=False),
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.ForeignKeyConstraint(
["chapter_id"],
[f"{schema}.ebook_chapter.id"],
name=op.f("fk_ebook_chunk_chapter_id_ebook_chapter"),
ondelete="SET NULL",
),
sa.ForeignKeyConstraint(
["source_id"],
[f"{schema}.ebook_source.id"],
name=op.f("fk_ebook_chunk_source_id_ebook_source"),
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_ebook_chunk")),
sa.UniqueConstraint("source_id", "chunk_index", name="uq_ebook_chunk_source_id_chunk_index"),
sa.UniqueConstraint("source_id", "content_sha256", name="uq_ebook_chunk_source_id_content_sha256"),
schema=schema,
)
op.create_table(
"ebook_chunk_embedding_1024",
sa.Column("chunk_id", sa.BigInteger(), nullable=False),
sa.Column("model_id", sa.Integer(), nullable=False),
sa.Column("embedding", pgvector.sqlalchemy.vector.VECTOR(dim=1024), nullable=False),
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.ForeignKeyConstraint(
["chunk_id"],
[f"{schema}.ebook_chunk.id"],
name=op.f("fk_ebook_chunk_embedding_1024_chunk_id_ebook_chunk"),
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["model_id"],
[f"{schema}.ebook_embedding_model.id"],
name=op.f("fk_ebook_chunk_embedding_1024_model_id_ebook_embedding_model"),
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_ebook_chunk_embedding_1024")),
sa.UniqueConstraint("chunk_id", "model_id", name=op.f("uq_ebook_chunk_embedding_1024_chunk_id")),
schema=schema,
)
op.create_table(
"ebook_chunk_embedding_2560",
sa.Column("chunk_id", sa.BigInteger(), nullable=False),
sa.Column("model_id", sa.Integer(), nullable=False),
sa.Column("embedding", pgvector.sqlalchemy.vector.VECTOR(dim=2560), nullable=False),
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.ForeignKeyConstraint(
["chunk_id"],
[f"{schema}.ebook_chunk.id"],
name=op.f("fk_ebook_chunk_embedding_2560_chunk_id_ebook_chunk"),
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["model_id"],
[f"{schema}.ebook_embedding_model.id"],
name=op.f("fk_ebook_chunk_embedding_2560_model_id_ebook_embedding_model"),
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_ebook_chunk_embedding_2560")),
sa.UniqueConstraint("chunk_id", "model_id", name=op.f("uq_ebook_chunk_embedding_2560_chunk_id")),
schema=schema,
)
op.create_table(
"ebook_chunk_embedding_4096",
sa.Column("chunk_id", sa.BigInteger(), nullable=False),
sa.Column("model_id", sa.Integer(), nullable=False),
sa.Column("embedding", pgvector.sqlalchemy.vector.VECTOR(dim=4096), nullable=False),
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.Column("updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
sa.ForeignKeyConstraint(
["chunk_id"],
[f"{schema}.ebook_chunk.id"],
name=op.f("fk_ebook_chunk_embedding_4096_chunk_id_ebook_chunk"),
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["model_id"],
[f"{schema}.ebook_embedding_model.id"],
name=op.f("fk_ebook_chunk_embedding_4096_model_id_ebook_embedding_model"),
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_ebook_chunk_embedding_4096")),
sa.UniqueConstraint("chunk_id", "model_id", name=op.f("uq_ebook_chunk_embedding_4096_chunk_id")),
schema=schema,
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("ebook_chunk_embedding_4096", schema=schema)
op.drop_table("ebook_chunk_embedding_2560", schema=schema)
op.drop_table("ebook_chunk_embedding_1024", schema=schema)
op.drop_table("ebook_chunk", schema=schema)
op.drop_table("ebook_chapter", schema=schema)
op.drop_table("ebook_source", schema=schema)
op.drop_table("ebook_embedding_model", schema=schema)
# ### end Alembic commands ###
@@ -1,63 +0,0 @@
"""updated series_index to float and added UniqueConstraint to audiobook and audiobook_author.
Revision ID: b3c60cc5beb5
Revises: d7864d1ffc17
Create Date: 2026-06-10 20:02:43.073725
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import sqlalchemy as sa
from alembic import op
from python.orm import RichieBase
if TYPE_CHECKING:
from collections.abc import Sequence
# revision identifiers, used by Alembic.
revision: str = "b3c60cc5beb5"
down_revision: str | None = "d7864d1ffc17"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
schema = RichieBase.schema_name
def upgrade() -> None:
"""Upgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"audiobook",
"series_index",
existing_type=sa.INTEGER(),
type_=sa.Float(),
existing_nullable=False,
schema=schema,
)
op.create_unique_constraint(
op.f("uq_audiobook_author_id"),
"audiobook",
["author_id", "series_id", "title"],
schema=schema,
postgresql_nulls_not_distinct=True,
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(op.f("uq_audiobook_author_id"), "audiobook", schema=schema, type_="unique")
op.alter_column(
"audiobook",
"series_index",
existing_type=sa.Float(),
type_=sa.INTEGER(),
existing_nullable=False,
schema=schema,
)
# ### end Alembic commands ###
+1 -1
View File
@@ -9,9 +9,9 @@ import typer
import uvicorn
from fastapi import FastAPI
from python.api.middleware import ZstdMiddleware
from python.api.routers import contact_router, views_router
from python.common import configure_logger
from python.fastapi_tools import ZstdMiddleware
from python.orm.common import get_postgres_engine
logger = logging.getLogger(__name__)
@@ -1,4 +1,4 @@
"""Zstd response compression middleware."""
"""Middleware for the FastAPI application."""
from compression import zstd
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
+1 -1
View File
@@ -9,7 +9,7 @@ from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from python.fastapi_tools.db import DbSession
from python.api.dependencies import DbSession
from python.orm.richie.contact import Contact, ContactRelationship, Need, RelationshipType
TEMPLATES_DIR = Path(__file__).parent.parent / "templates"
+1 -1
View File
@@ -9,7 +9,7 @@ from fastapi.templating import Jinja2Templates
from sqlalchemy import select
from sqlalchemy.orm import Session, selectinload
from python.fastapi_tools.db import DbSession
from python.api.dependencies import DbSession
from python.orm.richie.contact import Contact, ContactRelationship, Need, RelationshipType
TEMPLATES_DIR = Path(__file__).parent.parent / "templates"
-1
View File
@@ -1 +0,0 @@
"""EPUB search package."""
-57
View File
@@ -1,57 +0,0 @@
"""Grounded answer generation."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from python.ebook_search.llm_interface import request_chat_completion
if TYPE_CHECKING:
from python.ebook_search.config import EbookSearchConfig
from python.ebook_search.search import SearchResult
logger = logging.getLogger(__name__)
def answer_query(query: str, results: list[SearchResult], config: EbookSearchConfig) -> str:
"""Answer a question using only retrieved chunks."""
if not config.answer_enabled:
logger.info("ebook_answer_skipped_disabled")
return "Answer generation is disabled. Source chunks are shown below."
if not results:
logger.info("ebook_answer_skipped_no_results")
return "No relevant sources were found."
logger.info(
"ebook_answer_request_start base_url=%s model=%s sources=%s query_length=%s",
config.vllm_base_url,
config.chat_model,
len(results),
len(query),
)
context = "\n\n".join(
f"[{index}] {result.source_title}{' - ' + result.chapter_title if result.chapter_title else ''}\n{result.text}"
for index, result in enumerate(results, start=1)
)
content = request_chat_completion(
config,
[
{
"role": "system",
"content": (
"Answer only from the provided context. Cite sources with bracketed numbers like [1]. "
"If the context is insufficient, say so."
),
},
{"role": "user", "content": f"Question:\n{query}\n\nContext:\n{context}"},
],
)
logger.info(
"ebook_answer_request_complete model=%s answer_length=%s",
config.chat_model,
len(content),
)
return content or "The model returned an empty answer."
-1
View File
@@ -1 +0,0 @@
"""Web and external API adapters for EPUB search."""
-60
View File
@@ -1,60 +0,0 @@
"""Background BM25 refresh tasks for the web app."""
from __future__ import annotations
import logging
from threading import Timer
from typing import TYPE_CHECKING
from sqlalchemy.orm import Session
from python.ebook_search.bm25_corpus import load_bm25_corpus, refresh_bm25_corpus
if TYPE_CHECKING:
from fastapi import FastAPI
from sqlalchemy.engine import Engine
from python.ebook_search.config import EbookSearchConfig
logger = logging.getLogger(__name__)
def schedule_bm25_refresh(app: FastAPI) -> None:
"""Schedule a delayed BM25 corpus refresh, replacing any pending refresh."""
existing_timer = getattr(app.state, "bm25_refresh_timer", None)
if existing_timer is not None:
existing_timer.cancel()
timer = Timer(app.state.config.bm25_refresh_delay_seconds, refresh_bm25_for_app, args=(app,))
timer.daemon = True
timer.start()
app.state.bm25_refresh_timer = timer
logger.info(
"ebook_bm25_refresh_scheduled delay_seconds=%s",
app.state.config.bm25_refresh_delay_seconds,
)
def cancel_bm25_refresh(app: FastAPI) -> None:
"""Cancel any pending BM25 corpus refresh."""
existing_timer = getattr(app.state, "bm25_refresh_timer", None)
if existing_timer is not None:
existing_timer.cancel()
app.state.bm25_refresh_timer = None
logger.info("ebook_bm25_refresh_cancelled")
def refresh_bm25_for_app(app: FastAPI) -> None:
"""Refresh the BM25 corpus using the app engine and config."""
try:
refresh_bm25_for_engine(app.state.engine, app.state.config)
except Exception:
logger.exception("ebook_bm25_refresh_failed")
def refresh_bm25_for_engine(engine: Engine, config: EbookSearchConfig) -> None:
"""Refresh the BM25 corpus using a SQLAlchemy engine."""
with Session(engine) as session:
refresh_bm25_corpus(session, config)
load_bm25_corpus.cache_clear()
logger.info("ebook_bm25_corpus_cache_cleared_after_refresh")
-77
View File
@@ -1,77 +0,0 @@
"""FastAPI HTMX app for EPUB search."""
from __future__ import annotations
import logging
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Annotated
import typer
import uvicorn
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from sqlalchemy.orm import Session
from python.common import configure_logger
from python.ebook_search.api.bm25_tasks import cancel_bm25_refresh
from python.ebook_search.api.routes import admin_router, page_router, search_router
from python.ebook_search.api.web import STATIC_DIR
from python.ebook_search.bm25_corpus import ensure_bm25_corpus
from python.ebook_search.config import load_config
from python.orm.common import get_postgres_engine
if TYPE_CHECKING:
from collections.abc import AsyncIterator
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
"""Manage application startup and shutdown resources."""
logger.info("ebook_search_startup")
app.state.engine = get_postgres_engine(name="RICHIE", vector_engine=True)
with Session(app.state.engine) as session:
ensure_bm25_corpus(session, app.state.config)
try:
yield
finally:
logger.info("ebook_search_shutdown")
cancel_bm25_refresh(app)
app.state.engine.dispose()
def create_app() -> FastAPI:
"""Create the EPUB search web app."""
app = FastAPI(title="EPUB Search", lifespan=lifespan)
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
app.state.config = load_config()
logger.info(
"ebook_search_config_loaded top_k=%s embedding_model=%s rerank_enabled=%s answer_enabled=%s library_paths=%s",
app.state.config.top_k,
app.state.config.embedding_model,
app.state.config.rerank.enabled,
app.state.config.answer_enabled,
len(app.state.config.library_paths),
)
app.include_router(admin_router)
app.include_router(page_router)
app.include_router(search_router)
return app
def serve(
host: Annotated[str, typer.Option("--host", "-h", help="Host to bind to")] = "127.0.0.1",
port: Annotated[int, typer.Option("--port", "-p", help="Port to bind to")] = 8070,
log_level: Annotated[str, typer.Option("--log-level", "-l", help="Log level")] = "INFO",
) -> None:
"""Start the EPUB search server."""
configure_logger(log_level)
uvicorn.run(create_app(), host=host, port=port)
if __name__ == "__main__":
typer.run(serve)
@@ -1,11 +0,0 @@
"""EPUB search web route modules."""
from python.ebook_search.api.routes.admin import router as admin_router
from python.ebook_search.api.routes.page import router as page_router
from python.ebook_search.api.routes.search import router as search_router
__all__ = [
"admin_router",
"page_router",
"search_router",
]
-107
View File
@@ -1,107 +0,0 @@
"""Admin routes for the EPUB search web UI."""
from __future__ import annotations
import logging
from dataclasses import replace
from fastapi import APIRouter, Request
from fastapi.responses import HTMLResponse
from sqlalchemy.orm import Session
from python.ebook_search.api.bm25_tasks import schedule_bm25_refresh
from python.ebook_search.api.web import templates
from python.ebook_search.embeddings import embed_missing_chunks, embedding_model_stats
from python.ebook_search.ingest import ingest_configured_paths
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin")
EMBED_ALL_BATCH_SIZE = 32
@router.get("", response_class=HTMLResponse)
def admin(request: Request) -> HTMLResponse:
"""Render the admin page."""
with Session(request.app.state.engine) as session:
stats = embedding_model_stats(session)
logger.info("ebook_admin_page_loaded models=%s", len(stats))
return templates.TemplateResponse(request, "admin.html", {"config": request.app.state.config, "stats": stats})
@router.post("/scan", response_class=HTMLResponse)
def scan_library(request: Request) -> HTMLResponse:
"""Scan configured library paths for EPUB changes."""
try:
with Session(request.app.state.engine) as session:
count = ingest_configured_paths(session, request.app.state.config)
session.commit()
except Exception as error:
logger.exception("ebook_admin_scan_failed")
return templates.TemplateResponse(request, "partials/error.html", {"message": str(error)}, status_code=500)
logger.info("ebook_admin_scan_complete changed_files=%s", count)
if count > 0:
schedule_bm25_refresh(request.app)
return templates.TemplateResponse(request, "partials/admin_status.html", {"message": f"Indexed {count} EPUBs"})
@router.post("/embed-missing", response_class=HTMLResponse)
def embed_missing(request: Request) -> HTMLResponse:
"""Embed chunks missing vectors for the configured model."""
try:
with Session(request.app.state.engine) as session:
count = embed_missing_chunks(session, request.app.state.config)
session.commit()
except Exception as error:
logger.exception("ebook_admin_embed_missing_failed")
return templates.TemplateResponse(request, "partials/error.html", {"message": str(error)}, status_code=500)
logger.info("ebook_admin_embed_missing_complete chunks=%s", count)
return templates.TemplateResponse(
request,
"partials/admin_status.html",
{"message": f"Embedded {count} chunks"},
)
@router.post("/embed-all", response_class=HTMLResponse)
def embed_all(request: Request) -> HTMLResponse:
"""Embed all chunks missing vectors in fixed-size batches."""
total = 0
batches = 0
config = replace(request.app.state.config, embedding_batch_size=EMBED_ALL_BATCH_SIZE)
try:
with Session(request.app.state.engine) as session:
while True:
count = embed_missing_chunks(session, config)
if count == 0:
break
session.commit()
total += count
batches += 1
logger.info(
"ebook_admin_embed_all_batch_complete batch=%s chunks=%s total_chunks=%s",
batches,
count,
total,
)
except Exception as error:
logger.exception(
"ebook_admin_embed_all_failed batches=%s chunks=%s",
batches,
total,
)
return templates.TemplateResponse(
request,
"partials/error.html",
{"message": f"Embed all failed after {total} chunks in {batches} batches: {error}"},
status_code=500,
)
logger.info("ebook_admin_embed_all_complete batches=%s chunks=%s", batches, total)
return templates.TemplateResponse(
request,
"partials/admin_status.html",
{"message": f"Embedded {total} chunks in {batches} batches of {EMBED_ALL_BATCH_SIZE}"},
)
-57
View File
@@ -1,57 +0,0 @@
"""Page routes for the EPUB search web UI."""
from __future__ import annotations
import logging
from fastapi import APIRouter, Request
from fastapi.responses import HTMLResponse
from sqlalchemy import select
from sqlalchemy.orm import Session
from python.ebook_search.api.web import templates
from python.orm.richie import EbookSource
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get("/", response_class=HTMLResponse)
def index(request: Request) -> HTMLResponse:
"""Render the search page."""
return templates.TemplateResponse(request, "search.html", {"config": request.app.state.config})
@router.get("/books", response_class=HTMLResponse)
def books(request: Request) -> HTMLResponse:
"""Render the indexed books page."""
with Session(request.app.state.engine) as session:
sources = list(session.scalars(select(EbookSource).order_by(EbookSource.title)).all())
logger.info("ebook_books_page_loaded count=%s", len(sources))
return templates.TemplateResponse(request, "books.html", {"sources": sources})
@router.get("/books/{source_id}", response_class=HTMLResponse)
def book_detail(source_id: int, request: Request) -> HTMLResponse:
"""Render details for one indexed book."""
with Session(request.app.state.engine) as session:
source = session.get(EbookSource, source_id)
if source is not None:
chapter_count = len(source.chapters)
chunk_count = len(source.chunks)
else:
chapter_count = 0
chunk_count = 0
logger.info(
"ebook_book_detail_loaded source_id=%s found=%s chapters=%s chunks=%s",
source_id,
source is not None,
chapter_count,
chunk_count,
)
return templates.TemplateResponse(
request,
"book_detail.html",
{"chapter_count": chapter_count, "chunk_count": chunk_count, "source": source},
)
-58
View File
@@ -1,58 +0,0 @@
"""Search routes for the EPUB search web UI."""
from __future__ import annotations
import logging
from dataclasses import replace
from time import perf_counter
from typing import Annotated
from fastapi import APIRouter, Form, Request
from fastapi.responses import HTMLResponse
from python.ebook_search.answer import answer_query
from python.ebook_search.api.web import templates
from python.ebook_search.search import search_ebooks
from python.ebook_search.timing import runtime_step_from_start
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/search", response_class=HTMLResponse)
def search(
request: Request,
query: Annotated[str, Form()],
rerank: Annotated[str | None, Form()] = None,
) -> HTMLResponse:
"""Run a search and render HTMX results."""
try:
response = search_ebooks(request.app.state.engine, query, request.app.state.config, rerank=rerank == "true")
except Exception as error:
logger.exception("ebook_search_request_failed")
return templates.TemplateResponse(request, "partials/error.html", {"message": str(error)}, status_code=500)
answer_start = perf_counter()
if request.app.state.config.answer_enabled:
try:
answer = answer_query(query, response.results, request.app.state.config)
except RuntimeError as error:
logger.warning("ebook_answer_request_failed_falling_back error=%s", error)
answer = "Answer generation failed. Source chunks are still shown below."
else:
logger.info("ebook_answer_skipped_disabled")
answer = "Answer generation is disabled. Source chunks are shown below."
answer_step_name = "Answer generation" if request.app.state.config.answer_enabled else "Answer skipped"
response = replace(
response,
timings=(*response.timings, runtime_step_from_start(answer_step_name, answer_start)),
)
logger.info(
"ebook_search_request_complete results=%s rank_label=%s runtime_ms=%.1f",
len(response.results),
response.rank_label,
response.total_runtime_ms,
)
return templates.TemplateResponse(request, "partials/results.html", {"answer": answer, "response": response})
-140
View File
@@ -1,140 +0,0 @@
body {
margin: 0;
background: #f7f7f4;
color: #202124;
font-family: system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif;
}
main {
max-width: 960px;
margin: 0 auto;
padding: 24px;
}
nav {
display: flex;
gap: 12px;
align-items: center;
margin-bottom: 20px;
}
nav form {
margin: 0;
}
.actions {
display: flex;
flex-wrap: wrap;
gap: 12px;
margin-bottom: 24px;
}
textarea {
display: block;
width: 100%;
margin: 8px 0 12px;
}
button {
padding: 8px 14px;
}
.check {
display: inline-flex;
gap: 8px;
align-items: center;
margin-right: 12px;
}
.rank-label {
margin-top: 24px;
font-weight: 700;
}
.results {
padding-left: 24px;
}
.meta,
.scores,
.status {
color: #626a73;
}
.scores {
display: flex;
flex-wrap: wrap;
gap: 8px;
margin: 12px 0;
}
.scores div {
display: inline-flex;
gap: 4px;
align-items: baseline;
}
.scores dt {
font-weight: 700;
}
.scores dd {
margin: 0;
}
.runtime {
margin-top: 16px;
}
.timing-chart {
display: grid;
gap: 8px;
padding: 0;
list-style: none;
}
.timing-chart li {
display: grid;
grid-template-columns: minmax(150px, 1fr) minmax(160px, 2fr) auto auto;
gap: 8px;
align-items: center;
}
.timing-bar {
height: 10px;
overflow: hidden;
background: #e5e5df;
}
.timing-bar span {
display: block;
height: 100%;
background: #3767c8;
}
.timing-value,
.timing-remaining {
color: #626a73;
font-variant-numeric: tabular-nums;
}
table {
width: 100%;
border-collapse: collapse;
}
th,
td {
padding: 8px;
border-bottom: 1px solid #d8d8d2;
text-align: left;
}
th {
font-weight: 700;
}
.error {
color: #9f1d20;
font-weight: 700;
}
@@ -1,57 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>EPUB Admin</title>
<script src="https://unpkg.com/htmx.org@2.0.4"></script>
<link rel="stylesheet" href="/static/style.css">
</head>
<body>
<main>
<nav>
<a href="/">Search</a>
<a href="/books">Books</a>
<a href="/admin">Admin</a>
</nav>
<h1>Admin</h1>
<section id="admin-status"></section>
<section class="actions">
<form hx-post="/admin/scan" hx-target="#admin-status" hx-swap="innerHTML">
<button type="submit">Scan</button>
</form>
<form hx-post="/admin/embed-missing" hx-target="#admin-status" hx-swap="innerHTML">
<button type="submit">Embed</button>
</form>
<form hx-post="/admin/embed-all" hx-target="#admin-status" hx-swap="innerHTML">
<button type="submit">Embed all</button>
</form>
</section>
<section>
<h2>Embeddings</h2>
<table>
<thead>
<tr>
<th>Model</th>
<th>Dimensions</th>
<th>Embedded</th>
<th>Missing</th>
<th>Total chunks</th>
</tr>
</thead>
<tbody>
{% for item in stats %}
<tr>
<td>{{ item.model_name }}</td>
<td>{{ item.dimension }}</td>
<td>{{ item.embedded_chunks }}</td>
<td>{{ item.missing_chunks }}</td>
<td>{{ item.total_chunks }}</td>
</tr>
{% endfor %}
</tbody>
</table>
</section>
</main>
</body>
</html>
@@ -1,32 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{% if source %}{{ source.title }}{% else %}Book not found{% endif %}</title>
<link rel="stylesheet" href="/static/style.css">
</head>
<body>
<main>
<nav>
<a href="/">Search</a>
<a href="/books">Books</a>
<a href="/admin">Admin</a>
</nav>
{% if source %}
<h1>{{ source.title }}</h1>
<p class="meta">{{ source.author or "Unknown author" }}</p>
<dl>
<dt>File</dt>
<dd>{{ source.file_path }}</dd>
<dt>Chapters</dt>
<dd>{{ chapter_count }}</dd>
<dt>Chunks</dt>
<dd>{{ chunk_count }}</dd>
</dl>
{% else %}
<h1>Book not found</h1>
{% endif %}
</main>
</body>
</html>
@@ -1,31 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>EPUB Books</title>
<link rel="stylesheet" href="/static/style.css">
</head>
<body>
<main>
<nav>
<a href="/">Search</a>
<a href="/books">Books</a>
<a href="/admin">Admin</a>
</nav>
<h1>Books</h1>
{% if sources %}
<ol class="results">
{% for source in sources %}
<li>
<h2><a href="/books/{{ source.id }}">{{ source.title }}</a></h2>
<p class="meta">{{ source.author or "Unknown author" }}</p>
</li>
{% endfor %}
</ol>
{% else %}
<p>No EPUBs indexed.</p>
{% endif %}
</main>
</body>
</html>
@@ -1 +0,0 @@
<p class="status">{{ message }}</p>
@@ -1 +0,0 @@
<p class="error">{{ message }}</p>
@@ -1,74 +0,0 @@
<div class="rank-label">{{ response.rank_label }}</div>
{% if response.timings %}
<section class="runtime">
<h2>Runtime</h2>
<p class="meta">Total {{ "%.1f"|format(response.total_runtime_ms) }} ms</p>
<ol class="timing-chart">
{% set total = response.total_runtime_ms %}
{% set ns = namespace(remaining=total) %}
{% for step in response.timings %}
{% set width = (step.duration_ms / total * 100) if total else 0 %}
{% if step.counts_toward_total %}
{% set ns.remaining = ns.remaining - step.duration_ms %}
{% endif %}
<li>
<span class="timing-label">{{ step.name }}</span>
<span class="timing-bar"><span style="width: {{ "%.2f"|format(width) }}%"></span></span>
<span class="timing-value">{{ "%.1f"|format(step.duration_ms) }} ms</span>
<span class="timing-remaining">{{ "%.1f"|format([ns.remaining, 0]|max) }} ms left</span>
</li>
{% endfor %}
</ol>
</section>
{% endif %}
<section class="answer">
<h2>Answer</h2>
<p>{{ answer }}</p>
</section>
{% if response.results %}
<ol class="results">
{% for result in response.results %}
<li>
<h2>{{ result.source_title }}</h2>
<p class="meta">
{% if result.source_author %}{{ result.source_author }}{% endif %}
{% if result.chapter_title %} · {{ result.chapter_title }}{% endif %}
{% if result.page_label %} · page {{ result.page_label }}{% endif %}
</p>
<p>{{ result.text }}</p>
<dl class="scores">
<div>
<dt>final</dt>
<dd>{{ "%.3f"|format(result.score) }}</dd>
</div>
{% if result.rerank_score is not none %}
<div>
<dt>rerank</dt>
<dd>{{ "%.3f"|format(result.rerank_score) }}</dd>
</div>
{% endif %}
{% if result.vector_score is not none %}
<div>
<dt>vector cosine</dt>
<dd>{{ "%.3f"|format(result.vector_score) }}</dd>
</div>
{% endif %}
{% if result.bm25_score is not none %}
<div>
<dt>BM25</dt>
<dd>{{ "%.6f"|format(result.bm25_score) }}</dd>
</div>
{% endif %}
{% if result.fused_score is not none %}
<div>
<dt>RRF</dt>
<dd>{{ "%.3f"|format(result.fused_score) }}</dd>
</div>
{% endif %}
</dl>
</li>
{% endfor %}
</ol>
{% else %}
<p>No results.</p>
{% endif %}
@@ -1,30 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>EPUB Search</title>
<script src="https://unpkg.com/htmx.org@2.0.4"></script>
<link rel="stylesheet" href="/static/style.css">
</head>
<body>
<main>
<nav>
<a href="/">Search</a>
<a href="/books">Books</a>
<a href="/admin">Admin</a>
</nav>
<h1>EPUB Search</h1>
<form hx-post="/search" hx-target="#results" hx-swap="innerHTML">
<label for="query">Search</label>
<textarea id="query" name="query" rows="4" required></textarea>
<label class="check">
<input type="checkbox" name="rerank" value="true" {% if config.rerank.enabled %}checked{% endif %}>
Rerank
</label>
<button type="submit">Search</button>
</form>
<section id="results"></section>
</main>
</body>
</html>
-13
View File
@@ -1,13 +0,0 @@
"""Shared web UI resources for EPUB search."""
from __future__ import annotations
from pathlib import Path
from fastapi.templating import Jinja2Templates
PACKAGE_DIR = Path(__file__).resolve().parent
TEMPLATE_DIR = PACKAGE_DIR / "templates"
STATIC_DIR = PACKAGE_DIR / "static"
templates = Jinja2Templates(directory=TEMPLATE_DIR)
-237
View File
@@ -1,237 +0,0 @@
"""Persisted BM25 corpus management."""
from __future__ import annotations
import json
import logging
import shutil
import tempfile
from dataclasses import dataclass
from datetime import UTC, datetime
from functools import cache
from pathlib import Path
from typing import TYPE_CHECKING
import bm25s
from sqlalchemy import func, select, union_all
from python.orm.richie import EbookChapter, EbookChunk, EbookSource
if TYPE_CHECKING:
from sqlalchemy.orm import Session
from python.ebook_search.config import EbookSearchConfig
logger = logging.getLogger(__name__)
MANIFEST_NAME = "manifest.json"
REQUIRED_INDEX_FILES = frozenset(
{
"data.csc.index.npy",
"indices.csc.index.npy",
"indptr.csc.index.npy",
"params.index.json",
"vocab.index.json",
"corpus.jsonl",
}
)
@dataclass(frozen=True)
class BM25Manifest:
"""Metadata describing a persisted BM25 corpus."""
created_at: datetime
db_updated_at: datetime | None
chunk_count: int
@dataclass(frozen=True)
class BM25Corpus:
"""Loaded persisted BM25 corpus and retriever."""
retriever: object | None
records: tuple[dict[str, object], ...]
manifest: BM25Manifest
class BM25CorpusUnavailableError(RuntimeError):
"""Raised when the persisted BM25 corpus cannot be loaded."""
def bm25_index_path(config: EbookSearchConfig) -> Path:
"""Return the configured BM25 index path relative to the current working directory."""
path = Path(config.bm25_index_dir).expanduser()
if path.is_absolute():
return path
return Path.cwd() / path
def ensure_bm25_corpus(session: Session, config: EbookSearchConfig) -> None:
"""Create or refresh the persisted BM25 corpus when it is missing or stale."""
index_path = bm25_index_path(config)
manifest = read_bm25_manifest(index_path)
db_updated_at = corpus_last_updated_at(session)
if not bm25_index_exists(index_path, manifest):
logger.info("ebook_bm25_index_missing path=%s", index_path)
refresh_bm25_corpus(session, config, db_updated_at=db_updated_at)
return
if db_updated_at is not None and manifest is not None and manifest.created_at < db_updated_at:
logger.info(
"ebook_bm25_index_stale path=%s created_at=%s db_updated_at=%s",
index_path,
manifest.created_at.isoformat(),
db_updated_at.isoformat(),
)
refresh_bm25_corpus(session, config, db_updated_at=db_updated_at)
return
logger.info(
"ebook_bm25_index_current path=%s chunks=%s created_at=%s",
index_path,
manifest.chunk_count if manifest else 0,
manifest.created_at.isoformat() if manifest else None,
)
def refresh_bm25_corpus(
session: Session,
config: EbookSearchConfig,
*,
db_updated_at: datetime | None = None,
) -> BM25Manifest:
"""Rebuild and persist the BM25 corpus from the current database chunks."""
index_path = bm25_index_path(config)
records = fetch_bm25_corpus_records(session)
manifest = BM25Manifest(
created_at=datetime.now(tz=UTC),
db_updated_at=db_updated_at if db_updated_at is not None else corpus_last_updated_at(session),
chunk_count=len(records),
)
write_bm25_corpus(index_path, records, manifest)
logger.info(
"ebook_bm25_index_refreshed path=%s chunks=%s created_at=%s",
index_path,
manifest.chunk_count,
manifest.created_at.isoformat(),
)
return manifest
@cache
def load_bm25_corpus(config: EbookSearchConfig) -> BM25Corpus:
"""Load the BM25 corpus into memory once per process.
Background refresh tasks clear this cache after rebuilding the on-disk corpus.
"""
index_path = bm25_index_path(config)
logger.info("ebook_bm25_corpus_cache_load path=%s", index_path)
manifest = read_bm25_manifest(index_path)
if manifest is None or not bm25_index_exists(index_path, manifest):
msg = f"BM25 corpus is not available: {index_path}"
raise BM25CorpusUnavailableError(msg)
if manifest.chunk_count == 0:
return BM25Corpus(retriever=None, records=(), manifest=manifest)
retriever = bm25s.BM25.load(index_path, load_corpus=True, mmap=True)
records = tuple(dict(record) for record in retriever.corpus)
return BM25Corpus(retriever=retriever, records=records, manifest=manifest)
def score_bm25_corpus(query: str, corpus: BM25Corpus, *, limit: int) -> list[tuple[dict[str, object], float]]:
"""Score a query against a loaded BM25 corpus."""
if corpus.retriever is None or not corpus.records:
return []
k = min(limit, len(corpus.records))
documents, scores = corpus.retriever.retrieve(
bm25s.tokenize(query, show_progress=False),
corpus=list(corpus.records),
k=k,
show_progress=False,
)
results: list[tuple[dict[str, object], float]] = []
for document, score in zip(documents[0], scores[0], strict=True):
score_value = float(score)
if score_value <= 0:
continue
results.append((dict(document), score_value))
return results
def fetch_bm25_corpus_records(session: Session) -> list[dict[str, object]]:
"""Fetch BM25 corpus records from the database."""
statement = (
select(
EbookChunk.id.label("chunk_id"),
EbookChunk.text.label("text"),
EbookSource.title.label("source_title"),
EbookSource.author.label("source_author"),
EbookChapter.title.label("chapter_title"),
EbookChunk.page_label.label("page_label"),
EbookChunk.search_text.label("bm25_text"),
)
.select_from(EbookChunk)
.join(EbookSource, EbookSource.id == EbookChunk.source_id)
.outerjoin(EbookChapter, EbookChapter.id == EbookChunk.chapter_id)
.order_by(EbookChunk.id)
)
return [dict(row) for row in session.execute(statement).mappings()]
def corpus_last_updated_at(session: Session) -> datetime | None:
"""Return the latest source/chapter/chunk update timestamp relevant to BM25 text."""
update_times = union_all(
select(func.max(EbookSource.updated).label("updated")),
select(func.max(EbookChapter.updated).label("updated")),
select(func.max(EbookChunk.updated).label("updated")),
).subquery()
return session.scalar(select(func.max(update_times.c.updated)))
def write_bm25_corpus(index_path: Path, records: list[dict[str, object]], manifest: BM25Manifest) -> None:
"""Write a BM25 corpus and manifest atomically."""
index_path.parent.mkdir(parents=True, exist_ok=True)
temp_path = Path(tempfile.mkdtemp(prefix=f"{index_path.name}.", dir=index_path.parent))
try:
if records:
retriever = bm25s.BM25()
texts = [str(record["bm25_text"]) for record in records]
retriever.index(bm25s.tokenize(texts, show_progress=False), show_progress=False)
retriever.save(temp_path, corpus=records, show_progress=False)
write_bm25_manifest(temp_path, manifest)
if index_path.exists():
shutil.rmtree(index_path)
temp_path.rename(index_path)
except Exception:
shutil.rmtree(temp_path, ignore_errors=True)
raise
def read_bm25_manifest(index_path: Path) -> BM25Manifest | None:
"""Read the BM25 manifest if it exists and is valid."""
manifest_path = index_path / MANIFEST_NAME
if not manifest_path.exists():
return None
body = json.loads(manifest_path.read_text(encoding="utf-8"))
return BM25Manifest(
created_at=datetime.fromisoformat(str(body["created_at"])),
db_updated_at=datetime.fromisoformat(str(body["db_updated_at"])) if body.get("db_updated_at") else None,
chunk_count=int(body["chunk_count"]),
)
def write_bm25_manifest(index_path: Path, manifest: BM25Manifest) -> None:
"""Write the BM25 manifest to an index directory."""
body = {
"created_at": manifest.created_at.isoformat(),
"db_updated_at": manifest.db_updated_at.isoformat() if manifest.db_updated_at else None,
"chunk_count": manifest.chunk_count,
}
(index_path / MANIFEST_NAME).write_text(json.dumps(body, indent=2, sort_keys=True), encoding="utf-8")
def bm25_index_exists(index_path: Path, manifest: BM25Manifest | None) -> bool:
"""Return whether a usable persisted BM25 index exists."""
if manifest is None or not index_path.is_dir():
return False
if manifest.chunk_count == 0:
return True
return all((index_path / file_name).exists() for file_name in REQUIRED_INDEX_FILES)
-117
View File
@@ -1,117 +0,0 @@
"""Configuration for the EPUB search app."""
from __future__ import annotations
from dataclasses import dataclass
from os import getenv
def getenv_bool(name: str, *, default: bool) -> bool:
"""Read a boolean environment variable with a default fallback."""
value = getenv(name)
if value is None:
return default
return value.strip().lower() in {"1", "true", "yes", "on"}
def getenv_int(name: str, *, default: int) -> int:
"""Read an integer environment variable with a default fallback."""
value = getenv(name)
if value is None or not value.strip():
return default
return int(value)
@dataclass(frozen=True)
class RerankConfig:
"""vLLM reranker settings."""
enabled: bool = False
base_url: str = "http://192.168.90.25:8001"
model: str = "qwen3-reranker-06b"
candidates: int = 24
timeout_seconds: float = 30.0
@dataclass(frozen=True)
class EbookSearchConfig:
"""Runtime settings for EPUB search."""
rerank: RerankConfig
top_k: int = 12
library_paths: tuple[str, ...] = ()
vllm_base_url: str = "https://ollama.com/v1"
vllm_api_key: str = "not-needed"
chat_model: str = "deepseek-v4-flash"
answer_enabled: bool = True
embedding_base_url: str = "http://192.168.90.25:8000/v1"
embedding_api_key: str = "not-needed"
embedding_model: str = "qwen3-embedding-0.6b"
embedding_batch_size: int = 32
bm25_index_dir: str = ".ebook_search_bm25"
bm25_refresh_delay_seconds: int = 60
def load_rerank_config() -> RerankConfig:
"""Load reranker config from environment variables."""
return RerankConfig(
enabled=getenv_bool("EBOOK_SEARCH_RERANK_ENABLED", default=False),
base_url=getenv("EBOOK_SEARCH_RERANK_BASE_URL", "http://192.168.90.25:8001"),
model=getenv("EBOOK_SEARCH_RERANK_MODEL", "qwen3-reranker-06b"),
candidates=getenv_int("EBOOK_SEARCH_RERANK_CANDIDATES", default=24),
timeout_seconds=float(getenv_int("EBOOK_SEARCH_RERANK_TIMEOUT_SECONDS", default=30)),
)
def load_config() -> EbookSearchConfig:
"""Load EPUB search config from environment variables."""
return EbookSearchConfig(
rerank=load_rerank_config(),
top_k=getenv_int("EBOOK_SEARCH_TOP_K", default=12),
library_paths=library_paths_from_env(),
vllm_base_url=getenv("EBOOK_SEARCH_VLLM_BASE_URL", "https://ollama.com/v1"),
vllm_api_key=getenv("EBOOK_SEARCH_VLLM_API_KEY") or getenv("OLLAMA_API_KEY") or "not-needed",
chat_model=getenv("EBOOK_SEARCH_CHAT_MODEL", "deepseek-v4-flash"),
answer_enabled=getenv_bool("EBOOK_SEARCH_ANSWER_ENABLED", default=True),
embedding_base_url=getenv("EBOOK_SEARCH_EMBEDDING_BASE_URL", "http://192.168.90.25:8000/v1"),
embedding_api_key=getenv("EBOOK_SEARCH_EMBEDDING_API_KEY", "not-needed"),
embedding_model=normalize_embedding_model(),
embedding_batch_size=getenv_int("EBOOK_SEARCH_EMBEDDING_BATCH_SIZE", default=32),
bm25_index_dir=getenv("EBOOK_SEARCH_BM25_INDEX_DIR", ".ebook_search_bm25"),
bm25_refresh_delay_seconds=getenv_int("EBOOK_SEARCH_BM25_REFRESH_DELAY_SECONDS", default=60),
)
def normalize_embedding_model(default: str = "qwen3-embedding-0.6b") -> str:
"""Normalize supported embedding aliases to provider model names."""
aliases = {
"Qwen3-Embedding-0.6B": "qwen3-embedding-0.6b",
"Qwen3-Embedding-4B": "qwen3-embedding-4b",
"Qwen3-Embedding-8B": "qwen3-embedding-8b",
"Qwen/Qwen3-Embedding-0.6B": "qwen3-embedding-0.6b",
"Qwen/Qwen3-Embedding-4B": "qwen3-embedding-4b",
"Qwen/Qwen3-Embedding-8B": "qwen3-embedding-8b",
"qwen3-embedding:0.6b": "qwen3-embedding-0.6b",
"qwen3-embedding:4b": "qwen3-embedding-4b",
"qwen3-embedding:8b": "qwen3-embedding-8b",
"qwen3-embedding-0.6b": "qwen3-embedding-0.6b",
"qwen3-embedding-4b": "qwen3-embedding-4b",
"qwen3-embedding-8b": "qwen3-embedding-8b",
}
model = getenv("EBOOK_SEARCH_EMBEDDING_MODEL", default)
standard_model = aliases.get(model)
if standard_model is None:
error = f"Embedding model {model} is not supported. Supported models are {aliases.keys()}"
raise ValueError(error)
return standard_model
def library_paths_from_env() -> tuple[str, ...]:
"""Read configured EPUB library paths from the environment."""
value = getenv("EBOOK_SEARCH_LIBRARY_PATHS")
if value is None:
return ()
return tuple(path for path in value.split(":") if path)
-170
View File
@@ -1,170 +0,0 @@
"""Embedding model helpers."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING
from sqlalchemy import func, select
from sqlalchemy.dialects.postgresql import insert
from python.ebook_search.llm_interface import request_embeddings
from python.orm.richie import (
EbookChunk,
EbookChunkEmbedding1024,
EbookChunkEmbedding2560,
EbookChunkEmbedding4096,
EbookEmbeddingModel,
)
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from collections.abc import Sequence
from sqlalchemy.orm import Session
from python.ebook_search.config import EbookSearchConfig
MODEL_DIMENSIONS = {
"qwen3-embedding-0.6b": 1024,
"qwen3-embedding-4b": 2560,
"qwen3-embedding-8b": 4096,
}
def get_embedding_table(
dimension: int,
) -> type[EbookChunkEmbedding1024 | EbookChunkEmbedding2560 | EbookChunkEmbedding4096]:
"""Return the embedding table mapped to an embedding dimension."""
embedding_tables = {
1024: EbookChunkEmbedding1024,
2560: EbookChunkEmbedding2560,
4096: EbookChunkEmbedding4096,
}
table = embedding_tables.get(dimension)
if not table:
msg = f"Embedding dimension {dimension} is not supported"
raise ValueError(msg)
return table
@dataclass(frozen=True)
class EmbeddingModelStats:
"""Embedding coverage for one model."""
model_name: str
dimension: int
embedded_chunks: int
total_chunks: int
@property
def missing_chunks(self) -> int:
"""Return chunks missing this embedding model."""
return max(self.total_chunks - self.embedded_chunks, 0)
def embed_texts(texts: Sequence[str], config: EbookSearchConfig) -> list[list[float]]:
"""Embed text with the configured vLLM embedding model."""
logger.info(
"ebook_embed_request_start base_url=%s model=%s count=%s",
config.embedding_base_url,
config.embedding_model,
len(texts),
)
vectors = request_embeddings(texts, config)
expected_dimension = MODEL_DIMENSIONS[config.embedding_model]
for vector in vectors:
if len(vector) != expected_dimension:
msg = f"Expected {expected_dimension} dimensions, got {len(vector)}"
raise ValueError(msg)
logger.info(
"ebook_embed_request_complete model=%s count=%s dimension=%s",
config.embedding_model,
len(vectors),
expected_dimension,
)
return vectors
def embed_query(query: str, config: EbookSearchConfig) -> list[float]:
"""Embed a search query with the Qwen retrieval instruction."""
instructed_query = f"Instruct: Retrieve relevant passages for the query.\nQuery: {query}"
return embed_texts([instructed_query], config)[0]
def ensure_embedding_models(session: Session) -> None:
"""Ensure supported embedding model rows exist."""
for name, dimension in MODEL_DIMENSIONS.items():
existing = session.scalar(select(EbookEmbeddingModel).where(EbookEmbeddingModel.name == name))
if existing is None:
session.add(EbookEmbeddingModel(name=name, dimension=dimension, is_default=name == "qwen3-embedding-0.6b"))
logger.info("ebook_embedding_model_created model=%s dimension=%s", name, dimension)
session.flush()
def embedding_model_stats(session: Session) -> list[EmbeddingModelStats]:
"""Return embedding coverage counts for every supported model."""
total_chunks = session.scalar(select(func.count(EbookChunk.id))) or 0
models = {
model.name: model
for model in session.scalars(
select(EbookEmbeddingModel)
.where(EbookEmbeddingModel.name.in_(MODEL_DIMENSIONS))
.order_by(EbookEmbeddingModel.name)
)
}
stats: list[EmbeddingModelStats] = []
for model_name, dimension in MODEL_DIMENSIONS.items():
model = models.get(model_name)
embedded_chunks = 0
if model is not None:
table = get_embedding_table(dimension)
embedded_chunks = session.scalar(select(func.count(table.id)).where(table.model_id == model.id)) or 0
stats.append(
EmbeddingModelStats(
model_name=model_name,
dimension=dimension,
embedded_chunks=embedded_chunks,
total_chunks=total_chunks,
)
)
return stats
def embed_missing_chunks(session: Session, config: EbookSearchConfig) -> int:
"""Embed chunks missing embeddings for the configured model."""
ensure_embedding_models(session)
model = session.scalar(select(EbookEmbeddingModel).where(EbookEmbeddingModel.name == config.embedding_model))
if model is None:
supported_models = ", ".join(MODEL_DIMENSIONS)
msg = f"Unknown embedding model: {config.embedding_model}. Supported models: {supported_models}"
raise ValueError(msg)
table = get_embedding_table(model.dimension)
chunks = list(
session.scalars(
select(EbookChunk)
.outerjoin(table, (table.chunk_id == EbookChunk.id) & (table.model_id == model.id))
.where(table.id.is_(None))
.order_by(EbookChunk.id)
.limit(config.embedding_batch_size)
)
)
if not chunks:
logger.info("ebook_embed_missing_none model=%s", config.embedding_model)
return 0
logger.info("ebook_embed_missing_batch_start model=%s count=%s", config.embedding_model, len(chunks))
vectors = embed_texts([chunk.text for chunk in chunks], config)
rows = [
{"chunk_id": chunk.id, "model_id": model.id, "embedding": vector}
for chunk, vector in zip(chunks, vectors, strict=True)
]
statement = insert(table).values(rows).on_conflict_do_nothing(index_elements=["chunk_id", "model_id"])
session.execute(statement)
session.flush()
logger.info("ebook_embed_missing_batch_complete model=%s count=%s", config.embedding_model, len(rows))
return len(rows)
-95
View File
@@ -1,95 +0,0 @@
"""EPUB parsing helpers."""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import TYPE_CHECKING
from bs4 import BeautifulSoup
from ebooklib import ITEM_DOCUMENT, epub
if TYPE_CHECKING:
from pathlib import Path
WHITESPACE_RE = re.compile(r"\s+")
@dataclass(frozen=True)
class ParsedChapter:
"""Text extracted from one EPUB spine document."""
title: str | None
href: str | None
text: str
page_labels: tuple[str, ...]
@dataclass(frozen=True)
class ParsedEpub:
"""Parsed EPUB metadata and text."""
title: str
author: str | None
language: str | None
publisher: str | None
identifier: str | None
chapters: tuple[ParsedChapter, ...]
def parse_epub(path: Path) -> ParsedEpub:
"""Parse EPUB metadata and spine text."""
book = epub.read_epub(path)
chapters = []
for item in book.get_items_of_type(ITEM_DOCUMENT):
soup = BeautifulSoup(item.get_content(), "html.parser")
title = chapter_title(soup)
page_labels = tuple(extract_page_labels(soup))
text = clean_text(soup.get_text(" "))
if text:
chapters.append(ParsedChapter(title=title, href=item.get_name(), text=text, page_labels=page_labels))
return ParsedEpub(
title=metadata_value(book, "title") or path.stem,
author=metadata_value(book, "creator"),
language=metadata_value(book, "language"),
publisher=metadata_value(book, "publisher"),
identifier=metadata_value(book, "identifier"),
chapters=tuple(chapters),
)
def metadata_value(book: epub.EpubBook, name: str) -> str | None:
"""Return the first non-empty Dublin Core metadata value for a name."""
values = book.get_metadata("DC", name)
if not values:
return None
value = values[0][0]
return str(value).strip() or None
def chapter_title(soup: BeautifulSoup) -> str | None:
"""Extract the best available title from an EPUB document soup."""
heading = soup.find(["h1", "h2", "h3"])
if heading is None:
title = soup.find("title")
if title is None:
return None
return clean_text(title.get_text(" ")) or None
return clean_text(heading.get_text(" ")) or None
def extract_page_labels(soup: BeautifulSoup) -> list[str]:
"""Extract EPUB page-break labels from a document soup."""
labels: list[str] = []
for tag in soup.find_all(attrs={"epub:type": "pagebreak"}):
label = tag.get("title") or tag.get("aria-label") or tag.get_text(" ")
clean = clean_text(str(label))
if clean:
labels.append(clean)
return labels
def clean_text(text: str) -> str:
"""Normalize whitespace in extracted EPUB text."""
return WHITESPACE_RE.sub(" ", text).strip()
-190
View File
@@ -1,190 +0,0 @@
"""EPUB ingestion into Richie DB."""
from __future__ import annotations
import hashlib
import logging
from dataclasses import dataclass
from datetime import UTC, datetime
from pathlib import Path
from typing import TYPE_CHECKING
import tiktoken
from sqlalchemy import or_, select
from python.ebook_search.epub_parse import parse_epub
from python.orm.richie import EbookChapter, EbookChunk, EbookSource
logger = logging.getLogger(__name__)
DEFAULT_CHUNK_TOKENS = 700
DEFAULT_CHUNK_OVERLAP = 100
if TYPE_CHECKING:
from sqlalchemy.orm import Session
from python.ebook_search.config import EbookSearchConfig
from python.ebook_search.epub_parse import ParsedChapter
@dataclass(frozen=True)
class TextChunk:
"""A token-bounded chunk of text."""
text: str
token_start: int
token_count: int
def chunk_text(
text: str,
*,
chunk_tokens: int = DEFAULT_CHUNK_TOKENS,
overlap_tokens: int = DEFAULT_CHUNK_OVERLAP,
) -> list[TextChunk]:
"""Split text into overlapping token chunks."""
if chunk_tokens <= 0:
msg = "chunk_tokens must be positive"
raise ValueError(msg)
if overlap_tokens < 0 or overlap_tokens >= chunk_tokens:
msg = "overlap_tokens must be non-negative and smaller than chunk_tokens"
raise ValueError(msg)
encoding = tiktoken.get_encoding("cl100k_base")
tokens = encoding.encode(text)
if not tokens:
return []
chunks: list[TextChunk] = []
step = chunk_tokens - overlap_tokens
for start in range(0, len(tokens), step):
chunk = tokens[start : start + chunk_tokens]
if not chunk:
continue
chunks.append(
TextChunk(
text=encoding.decode(chunk).strip(),
token_start=start,
token_count=len(chunk),
)
)
if start + chunk_tokens >= len(tokens):
break
return [chunk for chunk in chunks if chunk.text]
def ingest_configured_paths(session: Session, config: EbookSearchConfig) -> int:
"""Ingest every EPUB found under configured library paths."""
count = 0
for library_path in config.library_paths:
path = Path(library_path).expanduser()
logger.info("ebook_ingest_path_start path=%s", path)
if path.is_file() and path.suffix.lower() == ".epub":
count += int(ingest_file(session, path))
elif path.is_dir():
for epub_path in sorted(path.rglob("*.epub")):
count += int(ingest_file(session, epub_path))
else:
logger.warning("ebook_ingest_path_missing path=%s", path)
logger.info("ebook_ingest_paths_complete changed_files=%s configured_paths=%s", count, len(config.library_paths))
return count
def ingest_file(session: Session, path: Path) -> bool:
"""Ingest one EPUB file. Return True when the database changed."""
resolved_path = path.expanduser().resolve()
logger.info("ebook_ingest_file_start path=%s", resolved_path)
file_hash = sha256_file(resolved_path)
existing = find_existing_source(session, resolved_path, file_hash)
if existing is not None and existing.file_sha256 == file_hash:
stat = resolved_path.stat()
existing.file_path = str(resolved_path)
existing.file_mtime = datetime.fromtimestamp(stat.st_mtime, tz=UTC)
existing.file_size = stat.st_size
session.flush()
logger.info("ebook_ingest_file_unchanged source_id=%s path=%s", existing.id, resolved_path)
return False
if existing is not None:
logger.info("ebook_ingest_file_replacing source_id=%s path=%s", existing.id, resolved_path)
session.delete(existing)
session.flush()
stat = resolved_path.stat()
parsed = parse_epub(resolved_path)
source = EbookSource(
title=parsed.title,
author=parsed.author,
language=parsed.language,
publisher=parsed.publisher,
identifier=parsed.identifier,
file_path=str(resolved_path),
file_sha256=file_hash,
file_mtime=datetime.fromtimestamp(stat.st_mtime, tz=UTC),
file_size=stat.st_size,
)
session.add(source)
session.flush()
chunk_index = 0
for spine_index, parsed_chapter in enumerate(parsed.chapters):
chapter = EbookChapter(
source_id=source.id,
spine_index=spine_index,
title=parsed_chapter.title,
href=parsed_chapter.href,
)
session.add(chapter)
session.flush()
chunk_index = add_chapter_chunks(session, source, chapter, parsed_chapter, chunk_index)
session.flush()
logger.info(
"ebook_ingest_file_complete source_id=%s path=%s chapters=%s chunks=%s",
source.id,
resolved_path,
len(parsed.chapters),
chunk_index,
)
return True
def find_existing_source(session: Session, path: Path, file_hash: str) -> EbookSource | None:
"""Find an existing source by canonical path or file hash."""
return session.scalar(
select(EbookSource).where(or_(EbookSource.file_path == str(path), EbookSource.file_sha256 == file_hash))
)
def add_chapter_chunks(
session: Session,
source: EbookSource,
chapter: EbookChapter,
parsed_chapter: ParsedChapter,
chunk_index: int,
) -> int:
"""Add chunk rows for one parsed chapter and return the next chunk index."""
page_label = parsed_chapter.page_labels[0] if parsed_chapter.page_labels else None
for text_chunk in chunk_text(parsed_chapter.text):
session.add(
EbookChunk(
source_id=source.id,
chapter_id=chapter.id,
chunk_index=chunk_index,
text=text_chunk.text,
token_start=text_chunk.token_start,
token_count=text_chunk.token_count,
page_label=page_label,
content_sha256=hashlib.sha256(text_chunk.text.encode()).hexdigest(),
search_text=f"{source.title} {source.author or ''} {chapter.title or ''} {text_chunk.text}",
)
)
chunk_index += 1
return chunk_index
def sha256_file(path: Path) -> str:
"""Calculate the SHA-256 digest for a file."""
digest = hashlib.sha256()
with path.open("rb") as file:
for block in iter(lambda: file.read(1024 * 1024), b""):
digest.update(block)
return digest.hexdigest()
-143
View File
@@ -1,143 +0,0 @@
"""LLM provider HTTP adapters."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
import httpx
if TYPE_CHECKING:
from collections.abc import Sequence
from python.ebook_search.config import EbookSearchConfig, RerankConfig
logger = logging.getLogger(__name__)
def auth_headers(api_key: str) -> dict[str, str]:
"""Build authorization headers when an API key is configured."""
if api_key == "not-needed":
return {}
return {"Authorization": f"Bearer {api_key}"}
def request_embeddings(texts: Sequence[str], config: EbookSearchConfig) -> list[list[float]]:
"""Request embeddings from the configured OpenAI-compatible endpoint."""
try:
response = httpx.post(
f"{config.embedding_base_url.rstrip('/')}/embeddings",
headers=auth_headers(config.embedding_api_key),
json={"model": config.embedding_model, "input": list(texts)},
timeout=60,
)
response.raise_for_status()
return embedding_vectors_from_response(response.json())
except (httpx.HTTPError, ValueError, KeyError, TypeError) as error:
logger.exception(
"ebook_embed_request_failed base_url=%s model=%s count=%s",
config.embedding_base_url,
config.embedding_model,
len(texts),
)
msg = f"Embedding request failed. base_url={config.embedding_base_url} model={config.embedding_model}"
raise RuntimeError(msg) from error
def embedding_vectors_from_response(body: object) -> list[list[float]]:
"""Extract embedding vectors from an OpenAI-compatible embedding response."""
if not isinstance(body, dict):
msg = "Embedding response is not an object"
raise TypeError(msg)
data = body["data"]
if not isinstance(data, list):
msg = "Embedding response data is not a list"
raise TypeError(msg)
vectors: list[list[float]] = []
for item in data:
if not isinstance(item, dict):
msg = "Embedding item is not an object"
raise TypeError(msg)
embedding = item["embedding"]
if not isinstance(embedding, list):
msg = "Embedding value is not a list"
raise TypeError(msg)
vectors.append([float(value) for value in embedding])
return vectors
def request_rerank(
query: str,
documents: Sequence[str],
config: RerankConfig,
) -> object | None:
"""Request rerank scores from the configured vLLM endpoint."""
payload = {
"model": config.model,
"query": query,
"documents": list(documents),
}
response = httpx.post(
f"{config.base_url.rstrip('/')}/rerank",
json=payload,
timeout=config.timeout_seconds,
)
response.raise_for_status()
try:
return response.json()
except ValueError:
logger.debug("ebook_rerank_response_invalid_json", extra={"response": response.text})
return None
def request_chat_completion(
config: EbookSearchConfig,
messages: Sequence[dict[str, str]],
) -> str:
"""Request a chat completion from the configured OpenAI-compatible endpoint."""
try:
response = httpx.post(
f"{config.vllm_base_url.rstrip('/')}/chat/completions",
headers=auth_headers(config.vllm_api_key),
json={
"model": config.chat_model,
"messages": list(messages),
"temperature": 0,
},
timeout=60,
)
response.raise_for_status()
return chat_content_from_response(response.json())
except (httpx.HTTPError, ValueError, KeyError, TypeError) as error:
msg = f"Chat request failed. base_url={config.vllm_base_url} model={config.chat_model}"
raise RuntimeError(msg) from error
def chat_content_from_response(body: object) -> str:
"""Extract text content from an OpenAI-compatible chat response."""
if not isinstance(body, dict):
msg = "Chat response is not an object"
raise TypeError(msg)
choices = body["choices"]
if not isinstance(choices, list) or not choices:
msg = "Chat response has no choices"
raise ValueError(msg)
first = choices[0]
if not isinstance(first, dict):
msg = "Chat choice is not an object"
raise TypeError(msg)
message = first["message"]
if not isinstance(message, dict):
msg = "Chat message is not an object"
raise TypeError(msg)
content = message.get("content") or ""
if not isinstance(content, str):
msg = "Chat content is not text"
raise TypeError(msg)
return content
-129
View File
@@ -1,129 +0,0 @@
"""vLLM-backed optional reranking."""
from __future__ import annotations
import logging
from dataclasses import dataclass, replace
from typing import TYPE_CHECKING
from python.ebook_search.llm_interface import request_rerank
if TYPE_CHECKING:
from python.ebook_search.config import RerankConfig
from python.ebook_search.search import SearchResult
logger = logging.getLogger(__name__)
RERANK_SCORE_WEIGHT = 0.7
HYBRID_SCORE_WEIGHT = 0.3
@dataclass(frozen=True)
class RerankResult:
"""A relevance score for one candidate chunk."""
chunk_id: int
score: float
def rerank_chunks(query: str, candidates: list[SearchResult], config: RerankConfig) -> list[SearchResult]:
"""Rerank candidates with a vLLM rerank endpoint."""
if not candidates:
return []
logger.info(
"ebook_rerank_request_start base_url=%s model=%s candidates=%s",
config.base_url,
config.model,
len(candidates),
)
scores = score_candidates(query, candidates, config)
results = sorted(
(
replace(
result,
score=final_rerank_score(result, scores[result.chunk_id].score, candidates),
rerank_score=scores[result.chunk_id].score,
)
for result in candidates
),
key=lambda result: result.score,
reverse=True,
)
logger.info(
"ebook_rerank_request_complete base_url=%s model=%s candidates=%s",
config.base_url,
config.model,
len(results),
)
return results
def score_candidates(
query: str,
candidates: list[SearchResult],
config: RerankConfig,
) -> dict[int, RerankResult]:
"""Score candidate chunks with the configured rerank API."""
body = request_rerank(query, [candidate.text for candidate in candidates], config)
if body is None:
return zero_rerank_scores(candidates)
scores = parse_vllm_scores(body, candidates)
for result in scores.values():
logger.debug("ebook_rerank_candidate_scored chunk_id=%s score=%s", result.chunk_id, result.score)
return scores
def parse_vllm_scores(body: object, candidates: list[SearchResult]) -> dict[int, RerankResult]:
"""Parse vLLM rerank scores into chunk-id keyed results."""
if not isinstance(body, dict):
logger.debug("ebook_rerank_response_not_object", extra={"response": body})
return zero_rerank_scores(candidates)
results = body.get("results") or body.get("data")
if not isinstance(results, list):
logger.debug("ebook_rerank_response_missing_results", extra={"response": body})
return zero_rerank_scores(candidates)
scores = zero_rerank_scores(candidates)
for item in results:
if not isinstance(item, dict):
continue
index = item.get("index")
score = item.get("relevance_score", item.get("score"))
if not isinstance(index, int) or index < 0 or index >= len(candidates):
continue
if not isinstance(score, int | float):
continue
chunk_id = candidates[index].chunk_id
scores[chunk_id] = RerankResult(chunk_id=chunk_id, score=clamp_score(float(score)))
return scores
def zero_rerank_scores(candidates: list[SearchResult]) -> dict[int, RerankResult]:
"""Return zero relevance scores for all candidate chunks."""
return {candidate.chunk_id: RerankResult(chunk_id=candidate.chunk_id, score=0.0) for candidate in candidates}
def clamp_score(score: float) -> float:
"""Clamp a rerank score into the supported 0.0 to 1.0 range."""
return min(max(score, 0.0), 1.0)
def final_rerank_score(result: SearchResult, rerank_score: float, candidates: list[SearchResult]) -> float:
"""Combine rerank relevance with normalized hybrid retrieval evidence."""
return (RERANK_SCORE_WEIGHT * rerank_score) + (HYBRID_SCORE_WEIGHT * normalized_hybrid_score(result, candidates))
def normalized_hybrid_score(result: SearchResult, candidates: list[SearchResult]) -> float:
"""Normalize a candidate hybrid score against the rerank candidate set."""
hybrid_scores = [
candidate.fused_score if candidate.fused_score is not None else candidate.score for candidate in candidates
]
low = min(hybrid_scores)
high = max(hybrid_scores)
if high == low:
return 1.0
score = result.fused_score if result.fused_score is not None else result.score
return (score - low) / (high - low)
-377
View File
@@ -1,377 +0,0 @@
"""Hybrid search orchestration."""
from __future__ import annotations
import logging
import re
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, replace
from typing import TYPE_CHECKING
from pgvector.sqlalchemy import Vector
from sqlalchemy import literal, select
from sqlalchemy.orm import Session
from python.ebook_search.bm25_corpus import (
load_bm25_corpus,
score_bm25_corpus,
)
from python.ebook_search.embeddings import MODEL_DIMENSIONS, embed_query, get_embedding_table
from python.ebook_search.rerank import rerank_chunks
from python.ebook_search.timing import RuntimeStep, timed_result
from python.orm.richie import (
EbookChapter,
EbookChunk,
EbookEmbeddingModel,
EbookSource,
)
if TYPE_CHECKING:
from collections.abc import Mapping
from sqlalchemy.engine import Engine
from python.ebook_search.config import EbookSearchConfig
logger = logging.getLogger(__name__)
BM25_CANDIDATE_LIMIT = 120
@dataclass(frozen=True)
class SearchResult:
"""One source chunk returned by search."""
chunk_id: int
text: str
source_title: str
score: float = 0.0
vector_score: float | None = None
bm25_score: float | None = None
fused_score: float | None = None
rerank_score: float | None = None
source_author: str | None = None
chapter_title: str | None = None
page_label: str | None = None
rank_source: str = "Hybrid"
@dataclass(frozen=True)
class SearchResponse:
"""Search output for the UI."""
query: str
results: list[SearchResult]
rank_label: str
timings: tuple[RuntimeStep, ...] = ()
@property
def total_runtime_ms(self) -> float:
"""Return total measured runtime for the response."""
return sum(step.duration_ms for step in self.timings if step.counts_toward_total)
@dataclass(frozen=True)
class RetrievalResponse:
"""Parallel retrieval output for vector and BM25 candidates."""
vector_results: list[SearchResult]
lexical_results: list[SearchResult]
timings: tuple[RuntimeStep, ...]
def search_ebooks(
engine: Engine,
query: str,
config: EbookSearchConfig,
*,
rerank: bool = False,
) -> SearchResponse:
"""Run hybrid vector/BM25 search and optional reranking."""
if not query.strip():
logger.info("ebook_search_empty_query")
return SearchResponse(query=query, results=[], rank_label="Hybrid")
logger.info("ebook_search_start query_length=%s rerank=%s", len(query), rerank)
timings: list[RuntimeStep] = []
bm25_query, timing = timed_result("BM25 query preparation", retrieval_query_from_text, query)
timings.append(timing)
retrieval, timing = timed_result(
"Hybrid retrieval",
parallel_retrieval,
engine,
query,
bm25_query,
config,
)
timings.extend(retrieval.timings)
timings.append(timing)
fused, timing = timed_result(
"Reciprocal rank fusion",
reciprocal_rank_fusion,
retrieval.vector_results,
retrieval.lexical_results,
)
timings.append(timing)
if config.rerank.enabled and rerank:
response, timing = timed_result("Rerank", apply_rerank, query, fused, config)
else:
response, timing = timed_result("Rerank skipped", skip_rerank, query, fused, config)
timings.append(timing)
response = replace(response, timings=tuple(timings))
logger.info(
"ebook_search_complete vector_candidates=%s lexical_candidates=%s "
"fused_candidates=%s returned=%s rank_label=%s runtime_ms=%.1f",
len(retrieval.vector_results),
len(retrieval.lexical_results),
len(fused),
len(response.results),
response.rank_label,
response.total_runtime_ms,
)
return response
def parallel_retrieval(
engine: Engine,
vector_query: str,
bm25_query: str,
config: EbookSearchConfig,
) -> RetrievalResponse:
"""Run vector and BM25 candidate retrieval concurrently with separate database sessions."""
with ThreadPoolExecutor(max_workers=2, thread_name_prefix="ebook-search") as executor:
vector_future = executor.submit(
timed_result,
"Embedding + vector search",
vector_candidates,
engine,
vector_query,
config,
)
bm25_future = executor.submit(
timed_result,
"BM25 search",
bm25_candidates,
bm25_query,
config,
)
vector_results, vector_timing = vector_future.result()
lexical_results, lexical_timing = bm25_future.result()
logger.info(
"ebook_parallel_retrieval_complete vector_candidates=%s lexical_candidates=%s",
len(vector_results),
len(lexical_results),
)
return RetrievalResponse(
vector_results=vector_results,
lexical_results=lexical_results,
timings=(
replace(vector_timing, counts_toward_total=False),
replace(lexical_timing, counts_toward_total=False),
),
)
def skip_rerank(
query: str,
candidates: list[SearchResult],
config: EbookSearchConfig,
) -> SearchResponse:
"""Return fused hybrid results without reranking."""
logger.info("ebook_rerank_skipped candidates=%s", len(candidates))
return SearchResponse(query=query, results=candidates[: config.top_k], rank_label="Hybrid")
def apply_rerank(
query: str,
candidates: list[SearchResult],
config: EbookSearchConfig,
) -> SearchResponse:
"""Rerank already-fused hybrid candidates."""
reranked = rerank_chunks(query, candidates[: config.rerank.candidates], config.rerank)
logger.info(
"ebook_rerank_complete input_candidates=%s returned=%s",
min(len(candidates), config.rerank.candidates),
len(reranked),
)
return SearchResponse(
query=query,
results=[replace(result, rank_source="Hybrid + rerank") for result in reranked[: config.top_k]],
rank_label="Hybrid + rerank",
)
def vector_candidates(engine: Engine, query: str, config: EbookSearchConfig) -> list[SearchResult]:
"""Return pgvector cosine candidates for a natural-language query."""
with Session(engine) as session:
model = session.scalar(select(EbookEmbeddingModel).where(EbookEmbeddingModel.name == config.embedding_model))
if model is None:
msg = f"Embedding model is not registered: {config.embedding_model}"
raise ValueError(msg)
expected_dimension = MODEL_DIMENSIONS[config.embedding_model]
if model.dimension != expected_dimension:
msg = f"Model row dimension {model.dimension} does not match configured dimension {expected_dimension}"
raise ValueError(msg)
embedding = embed_query(query, config)
limit = max(config.rerank.candidates, config.top_k) * 4
embedding_table = get_embedding_table(model.dimension)
embedding_param = literal(embedding, type_=Vector(model.dimension))
distance = embedding_table.embedding.op("<=>")(embedding_param)
score = (literal(1.0) - distance).label("score")
statement = (
select(
EbookChunk.id.label("chunk_id"),
EbookChunk.text.label("text"),
EbookSource.title.label("source_title"),
EbookSource.author.label("source_author"),
EbookChapter.title.label("chapter_title"),
EbookChunk.page_label.label("page_label"),
score,
)
.select_from(embedding_table)
.join(EbookChunk, EbookChunk.id == embedding_table.chunk_id)
.join(EbookSource, EbookSource.id == EbookChunk.source_id)
.outerjoin(EbookChapter, EbookChapter.id == EbookChunk.chapter_id)
.where(embedding_table.model_id == model.id)
.order_by(distance)
.limit(limit)
)
rows = session.execute(statement).mappings()
results = [search_result_from_row(row) for row in rows]
logger.info(
"ebook_vector_search_complete model=%s dimension=%s candidates=%s",
config.embedding_model,
model.dimension,
len(results),
)
return results
def bm25_candidates(query: str, config: EbookSearchConfig) -> list[SearchResult]:
"""Return BM25-ranked lexical candidates using the persisted corpus."""
corpus = load_bm25_corpus(config)
if not corpus.records:
logger.info("ebook_bm25_search_complete corpus=0 candidates=0")
return []
scored_records = score_bm25_corpus(query, corpus, limit=BM25_CANDIDATE_LIMIT)
results = [
replace(search_result_from_row(record), score=score, vector_score=None, bm25_score=score)
for record, score in scored_records
]
max_score = results[0].bm25_score if results else 0.0
logger.info(
"ebook_bm25_search_complete corpus=%s candidates=%s max_score=%.6f",
len(corpus.records),
len(results),
max_score,
)
return results
def reciprocal_rank_fusion(
vector_results: list[SearchResult],
lexical_results: list[SearchResult],
*,
rank_constant: int = 60,
) -> list[SearchResult]:
"""Fuse vector and lexical rankings with Reciprocal Rank Fusion."""
by_chunk: dict[int, SearchResult] = {}
scores: dict[int, float] = {}
vector_scores: dict[int, float] = {}
bm25_scores: dict[int, float] = {}
for rank, result in enumerate(vector_results, start=1):
by_chunk.setdefault(result.chunk_id, result)
vector_scores[result.chunk_id] = result.vector_score if result.vector_score is not None else result.score
scores[result.chunk_id] = scores.get(result.chunk_id, 0.0) + (1 / (rank_constant + rank))
for rank, result in enumerate(lexical_results, start=1):
by_chunk.setdefault(result.chunk_id, result)
bm25_scores[result.chunk_id] = result.bm25_score if result.bm25_score is not None else result.score
scores[result.chunk_id] = scores.get(result.chunk_id, 0.0) + (1 / (rank_constant + rank))
return sorted(
(
replace(
result,
score=scores[result.chunk_id],
vector_score=vector_scores.get(result.chunk_id),
bm25_score=bm25_scores.get(result.chunk_id),
fused_score=scores[result.chunk_id],
rank_source="Hybrid",
)
for result in by_chunk.values()
),
key=lambda result: result.score,
reverse=True,
)
def search_result_from_row(row: Mapping[str, object]) -> SearchResult:
"""Convert a database row mapping into a search result."""
return SearchResult(
chunk_id=int(row["chunk_id"]),
text=str(row["text"]),
source_title=str(row["source_title"]),
source_author=optional_str(row["source_author"]),
chapter_title=optional_str(row["chapter_title"]),
page_label=optional_str(row["page_label"]),
score=float(row["score"]) if "score" in row else 0.0,
vector_score=float(row["score"]) if "score" in row else None,
)
def optional_str(value: object) -> str | None:
"""Convert nullable database values to optional strings."""
if value is None:
return None
return str(value)
TOKEN_RE = re.compile(r"[A-Za-z0-9_]+")
def tokens(text_value: str) -> list[str]:
"""Extract tokens from a text value.
This is a simple approximation of the tokenization used by PostgreSQL's full-text search,
which is sufficient for BM25 candidate retrieval. It lowercases tokens and includes alphanumeric characters and
underscores.
"""
return [match.group(0).lower() for match in TOKEN_RE.finditer(text_value)]
QUERY_STOP_WORDS = {
"a",
"an",
"and",
"are",
"as",
"at",
"does",
"for",
"in",
"is",
"of",
"the",
"to",
"what",
"when",
"where",
"which",
"who",
"why",
}
def retrieval_query_from_text(query: str) -> str:
"""Remove generic question words while preserving entity and series terms."""
keywords = [token for token in tokens(query) if token not in QUERY_STOP_WORDS]
if not keywords:
return query
return " ".join(keywords)
-36
View File
@@ -1,36 +0,0 @@
"""Runtime timing helpers for EPUB search."""
from __future__ import annotations
from dataclasses import dataclass
from time import perf_counter
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Callable
@dataclass(frozen=True)
class RuntimeStep:
"""Elapsed runtime for one named search step."""
name: str
duration_ms: float
counts_toward_total: bool = True
def runtime_step_from_start(name: str, start_seconds: float) -> RuntimeStep:
"""Create a runtime step from a prior perf_counter timestamp."""
return RuntimeStep(name=name, duration_ms=(perf_counter() - start_seconds) * 1000)
def timed_result[T, **P](
name: str,
operation: Callable[P, T],
*args: P.args,
**kwargs: P.kwargs,
) -> tuple[T, RuntimeStep]:
"""Run an operation and return its result plus elapsed runtime."""
start_seconds = perf_counter()
result = operation(*args, **kwargs)
return result, runtime_step_from_start(name, start_seconds)
-6
View File
@@ -1,6 +0,0 @@
"""Reusable FastAPI tools."""
from python.fastapi_tools.db import DbSession, get_db
from python.fastapi_tools.zstd_middleware import ZstdMiddleware
__all__ = ["DbSession", "ZstdMiddleware", "get_db"]
-347
View File
@@ -1,347 +0,0 @@
"""Small Gitea API client for repository automation."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Self
from urllib.parse import quote
import httpx
DEFAULT_PAGE_SIZE = 100
EXPECTED_NO_CONTENT = 204
EXPECTED_CREATED = 201
EXPECTED_OK = 200
@dataclass(frozen=True)
class CreatedIssue:
"""Issue data returned by Gitea."""
number: int | None
html_url: str | None
title: str
@dataclass(frozen=True)
class PullRequest:
"""Pull request data returned by Gitea."""
number: int
title: str
html_url: str | None
labels: tuple[str, ...]
head_branch: str | None
base_branch: str | None
@dataclass(frozen=True)
class WorkflowJob:
"""Workflow job data returned by Gitea Actions."""
id: int
name: str
run_id: int | None
status: str | None
conclusion: str | None
class GiteaError(RuntimeError):
"""Raised when Gitea rejects an API request."""
def split_repo_name(repo: str) -> tuple[str, str]:
"""Split an owner/repo string into its parts."""
owner, separator, repo_name = repo.partition("/")
if not separator or not owner or not repo_name:
msg = f"Invalid repository name: {repo}"
raise ValueError(msg)
return owner, repo_name
class GiteaClient:
"""HTTP client for the subset of Gitea APIs used in this repository."""
def __init__(
self,
*,
base_url: str,
token: str,
timeout: int = 30,
transport: httpx.BaseTransport | None = None,
) -> None:
"""Initialize the Gitea client."""
self._client = httpx.Client(
base_url=base_url.rstrip("/"),
timeout=timeout,
headers={"Authorization": f"token {token}"},
transport=transport,
)
def create_issue(
self,
*,
owner: str,
repo: str,
title: str,
body: str,
labels: list[int] | None = None,
) -> CreatedIssue:
"""Create a Gitea issue."""
payload: dict[str, object] = {"title": title, "body": body, "labels": labels or []}
response = self._request(
"POST",
f"/api/v1/repos/{owner}/{repo}/issues",
expected_statuses={EXPECTED_CREATED},
json=payload,
)
data = response.json()
return CreatedIssue(
number=_optional_int(data.get("number")),
html_url=_optional_str(data.get("html_url")),
title=str(data.get("title", title)),
)
def resolve_label_ids(self, *, owner: str, repo: str, labels: list[str]) -> list[int]:
"""Resolve label names to Gitea label IDs."""
if not labels:
return []
available_labels: dict[str, int] = {}
page = 1
while True:
response = self._request(
"GET",
f"/api/v1/repos/{owner}/{repo}/labels",
params={"page": page, "limit": DEFAULT_PAGE_SIZE},
)
batch = response.json()
if not batch:
break
for label in batch:
label_name = str(label.get("name", ""))
label_id = _optional_int(label.get("id"))
if label_name and label_id is not None:
available_labels[label_name] = label_id
if len(batch) < DEFAULT_PAGE_SIZE:
break
page += 1
missing = [label for label in labels if label not in available_labels]
if missing:
missing_names = ", ".join(sorted(missing))
msg = f"Missing Gitea labels: {missing_names}"
raise GiteaError(msg)
return [available_labels[label] for label in labels]
def list_open_pull_requests(
self,
*,
owner: str,
repo: str,
labels: list[str] | None = None,
head: str | None = None,
) -> list[PullRequest]:
"""List open pull requests for a repository."""
expected_labels = set(labels or [])
pull_requests: list[PullRequest] = []
page = 1
while True:
response = self._request(
"GET",
f"/api/v1/repos/{owner}/{repo}/pulls",
params={"state": "open", "page": page, "limit": DEFAULT_PAGE_SIZE},
)
batch = response.json()
if not batch:
break
for item in batch:
pull_request = _pull_request_from_api(item)
if head and pull_request.head_branch != head:
continue
if expected_labels and not expected_labels.issubset(set(pull_request.labels)):
continue
pull_requests.append(pull_request)
if len(batch) < DEFAULT_PAGE_SIZE:
break
page += 1
return pull_requests
def create_pull_request(
self,
*,
owner: str,
repo: str,
title: str,
body: str,
head: str,
base: str,
labels: list[str] | None = None,
) -> PullRequest:
"""Create a pull request."""
payload: dict[str, object] = {
"title": title,
"body": body,
"head": head,
"base": base,
}
if labels:
payload["labels"] = self.resolve_label_ids(owner=owner, repo=repo, labels=labels)
response = self._request(
"POST",
f"/api/v1/repos/{owner}/{repo}/pulls",
expected_statuses={EXPECTED_CREATED},
json=payload,
)
return _pull_request_from_api(response.json())
def merge_pull_request(
self,
*,
owner: str,
repo: str,
number: int,
merge_method: str = "rebase",
head_commit_id: str | None = None,
delete_branch_after_merge: bool = False,
) -> None:
"""Merge a pull request."""
payload: dict[str, object] = {
"Do": merge_method,
"delete_branch_after_merge": delete_branch_after_merge,
}
if head_commit_id:
payload["head_commit_id"] = head_commit_id
self._request(
"POST",
f"/api/v1/repos/{owner}/{repo}/pulls/{number}/merge",
json=payload,
)
def dispatch_workflow(self, *, owner: str, repo: str, workflow_id: str, ref: str) -> None:
"""Trigger a workflow_dispatch run."""
workflow_path = quote(workflow_id, safe="")
self._request(
"POST",
f"/api/v1/repos/{owner}/{repo}/actions/workflows/{workflow_path}/dispatches",
expected_statuses={EXPECTED_OK, EXPECTED_NO_CONTENT},
json={"ref": ref},
)
def list_run_jobs(self, *, owner: str, repo: str, run_id: str | int) -> list[WorkflowJob]:
"""List workflow jobs for a specific run."""
jobs: list[WorkflowJob] = []
page = 1
while True:
response = self._request(
"GET",
f"/api/v1/repos/{owner}/{repo}/actions/jobs",
params={"page": page, "limit": DEFAULT_PAGE_SIZE},
)
payload = response.json()
batch = payload.get("jobs", [])
if not batch:
break
for item in batch:
if str(item.get("run_id")) != str(run_id):
continue
jobs.append(_workflow_job_from_api(item))
if len(batch) < DEFAULT_PAGE_SIZE:
break
page += 1
return jobs
def download_job_logs(self, *, owner: str, repo: str, job_id: int) -> str:
"""Download logs for a workflow job."""
response = self._request(
"GET",
f"/api/v1/repos/{owner}/{repo}/actions/jobs/{job_id}/logs",
)
return response.text
def close(self) -> None:
"""Close the underlying HTTP client."""
self._client.close()
def __enter__(self) -> Self:
"""Enter the context manager."""
return self
def __exit__(self, *args: object) -> None:
"""Close the HTTP client."""
self.close()
def _request(
self,
method: str,
path: str,
*,
expected_statuses: set[int] | None = None,
**kwargs: object,
) -> httpx.Response:
"""Send an HTTP request and validate the response status."""
response = self._client.request(method, path, **kwargs)
statuses = expected_statuses or {EXPECTED_OK}
if response.status_code not in statuses:
msg = f"Gitea request failed ({response.status_code}): {response.text}"
raise GiteaError(msg)
return response
def _pull_request_from_api(data: dict[str, object]) -> PullRequest:
"""Convert Gitea API pull-request data into a dataclass."""
number = _optional_int(data.get("number")) or _optional_int(data.get("index"))
if number is None:
msg = "Gitea pull request payload is missing a number"
raise GiteaError(msg)
labels = tuple(str(label.get("name", "")) for label in data.get("labels", []))
head = data.get("head", {})
base = data.get("base", {})
return PullRequest(
number=number,
title=str(data.get("title", "")),
html_url=_optional_str(data.get("html_url")),
labels=tuple(label for label in labels if label),
head_branch=_optional_str(head.get("ref")) or _optional_str(data.get("head_branch")),
base_branch=_optional_str(base.get("ref")) or _optional_str(data.get("base_branch")),
)
def _workflow_job_from_api(data: dict[str, object]) -> WorkflowJob:
"""Convert Gitea API workflow-job data into a dataclass."""
job_id = _optional_int(data.get("id"))
if job_id is None:
msg = "Gitea workflow job payload is missing an ID"
raise GiteaError(msg)
return WorkflowJob(
id=job_id,
name=str(data.get("name", "")),
run_id=_optional_int(data.get("run_id")),
status=_optional_str(data.get("status")),
conclusion=_optional_str(data.get("conclusion")),
)
def _optional_int(value: object) -> int | None:
"""Convert an API value to an integer when present."""
if value is None:
return None
return int(value)
def _optional_str(value: object) -> str | None:
"""Convert an API value to a string when present."""
if value is None:
return None
return str(value)
-148
View File
@@ -1,148 +0,0 @@
"""Automation helpers for flake.lock pull requests on Gitea."""
from __future__ import annotations
import subprocess
from os import getenv
from typing import Annotated
import typer
from python.gitea import GiteaClient, PullRequest, split_repo_name
DEFAULT_BASE_BRANCH = "main"
DEFAULT_BRANCH = "automation/update-flake-lock"
DEFAULT_GITEA_URL = "https://gitea.tmmworkshop.com"
PR_LABELS = ["dependencies", "automated", "flake_lock_update"]
PR_CHECK_WORKFLOWS = ["build_systems.yml", "treefmt.yml", "pytest.yml"]
PR_TITLE = "Update flake.lock"
PR_BODY = "Automated flake.lock update."
app = typer.Typer(add_completion=False)
def run_cmd(cmd: list[str], *, check: bool = True) -> subprocess.CompletedProcess[str]:
"""Run a subprocess command."""
return subprocess.run(cmd, capture_output=True, text=True, check=check)
def ensure_flake_lock_pull_request(
client: GiteaClient,
*,
owner: str,
repo: str,
branch: str,
base: str,
) -> PullRequest:
"""Return an existing flake.lock PR for the branch or create one."""
pull_requests = client.list_open_pull_requests(owner=owner, repo=repo, head=branch)
if pull_requests:
return pull_requests[0]
return client.create_pull_request(
owner=owner,
repo=repo,
title=PR_TITLE,
body=PR_BODY,
head=branch,
base=base,
labels=PR_LABELS,
)
def find_flake_lock_pull_request(client: GiteaClient, *, owner: str, repo: str) -> PullRequest | None:
"""Find the first open flake.lock pull request."""
pull_requests = client.list_open_pull_requests(owner=owner, repo=repo, labels=["flake_lock_update"])
if not pull_requests:
return None
return pull_requests[0]
def dispatch_pull_request_checks(client: GiteaClient, *, owner: str, repo: str, branch: str) -> None:
"""Dispatch the workflows that normally run for pull requests."""
for workflow in PR_CHECK_WORKFLOWS:
client.dispatch_workflow(owner=owner, repo=repo, workflow_id=workflow, ref=branch)
def has_worktree_changes() -> bool:
"""Return whether `flake.lock` has worktree changes."""
result = run_cmd(["git", "diff", "--quiet", "--", "flake.lock"], check=False)
return result.returncode != 0
def commit_flake_lock_update(*, branch: str) -> None:
"""Commit the updated lock file to the automation branch."""
run_cmd(["git", "config", "user.name", "gitea-actions[bot]"])
run_cmd(["git", "config", "user.email", "gitea-actions@tmmworkshop.com"])
run_cmd(["git", "checkout", "-B", branch])
run_cmd(["git", "add", "flake.lock"])
run_cmd(["git", "commit", "-m", "chore: update flake.lock"])
def push_branch(*, branch: str) -> None:
"""Push the automation branch to origin."""
run_cmd(["git", "push", "origin", f"HEAD:{branch}", "--force"])
def _required_gitea_token() -> str:
"""Read the required Gitea token from the environment."""
token = getenv("GITEA_TOKEN")
if token:
return token
msg = "GITEA_TOKEN environment variable is required"
raise RuntimeError(msg)
@app.command()
def update(
repo: Annotated[str, typer.Option("--repo", help="Gitea repository in owner/repo form")],
base: Annotated[str, typer.Option("--base", help="Base branch")] = DEFAULT_BASE_BRANCH,
branch: Annotated[str, typer.Option("--branch", help="Automation branch")] = DEFAULT_BRANCH,
) -> None:
"""Commit flake.lock changes and ensure a pull request exists."""
if not has_worktree_changes():
typer.echo("No flake.lock changes detected")
return
commit_flake_lock_update(branch=branch)
push_branch(branch=branch)
owner, repo_name = split_repo_name(repo)
with GiteaClient(
base_url=getenv("GITEA_URL", DEFAULT_GITEA_URL),
token=_required_gitea_token(),
) as client:
pull_request = ensure_flake_lock_pull_request(
client,
owner=owner,
repo=repo_name,
branch=branch,
base=base,
)
# We can remove this if Gitea fixes the following issue:
# https://github.com/go-gitea/gitea/issues/33963
dispatch_pull_request_checks(client, owner=owner, repo=repo_name, branch=branch)
typer.echo(pull_request.html_url or f"Pull request #{pull_request.number}")
@app.command()
def merge(
repo: Annotated[str, typer.Option("--repo", help="Gitea repository in owner/repo form")],
) -> None:
"""Merge the first open flake.lock pull request."""
owner, repo_name = split_repo_name(repo)
with GiteaClient(
base_url=getenv("GITEA_URL", DEFAULT_GITEA_URL),
token=_required_gitea_token(),
) as client:
pull_request = find_flake_lock_pull_request(client, owner=owner, repo=repo_name)
if not pull_request:
typer.echo("No open PR found with label flake_lock_update")
return
client.merge_pull_request(owner=owner, repo=repo_name, number=pull_request.number, merge_method="rebase")
typer.echo(f"Merged PR #{pull_request.number}")
if __name__ == "__main__":
app()
+2 -24
View File
@@ -31,24 +31,8 @@ def get_connection_info(name: str) -> tuple[str, str, str, str, str | None]:
return cast("tuple[str, str, str, str, str | None]", (database, host, port, username, password))
def get_postgres_engine(
*,
name: str = "POSTGRES",
pool_pre_ping: bool = True,
vector_engine: bool = False,
) -> Engine:
"""Create a SQLAlchemy engine from environment variables.
Args:
name (str, optional): The name of the environment variable prefix. Defaults to "POSTGRES".
pool_pre_ping (bool, optional): Whether to ping the database before each connection. Defaults to True.
This fixes the issue of trying to use a conection that has timed out on the database side.
vector_engine (bool, optional): Whether to use the vector search schema. Defaults to False.
This updates the search path the incldued the vecore types and operators.
Returns:
Engine: The SQLAlchemy engine.
"""
def get_postgres_engine(*, name: str = "POSTGRES", pool_pre_ping: bool = True) -> Engine:
"""Create a SQLAlchemy engine from environment variables."""
database, host, port, username, password = get_connection_info(name)
url = URL.create(
@@ -60,14 +44,8 @@ def get_postgres_engine(
database=database,
)
connect_args = {}
# There more better way to do this is with separate PG account and a dedicated vector schema for the vector types
if vector_engine:
connect_args["options"] = "-csearch_path=main,public"
return create_engine(
url=url,
pool_pre_ping=pool_pre_ping,
pool_recycle=1800,
connect_args=connect_args,
)
-20
View File
@@ -2,7 +2,6 @@
from __future__ import annotations
from python.orm.richie.audiobook import Audiobook, AudiobookAuthor, AudiobookSeries
from python.orm.richie.base import RichieBase, TableBase, TableBaseBig, TableBaseSmall
from python.orm.richie.contact import (
Contact,
@@ -11,30 +10,11 @@ from python.orm.richie.contact import (
Need,
RelationshipType,
)
from python.orm.richie.ebook import (
EbookChapter,
EbookChunk,
EbookChunkEmbedding1024,
EbookChunkEmbedding2560,
EbookChunkEmbedding4096,
EbookEmbeddingModel,
EbookSource,
)
__all__ = [
"Audiobook",
"AudiobookAuthor",
"AudiobookSeries",
"Contact",
"ContactNeed",
"ContactRelationship",
"EbookChapter",
"EbookChunk",
"EbookChunkEmbedding1024",
"EbookChunkEmbedding2560",
"EbookChunkEmbedding4096",
"EbookEmbeddingModel",
"EbookSource",
"Need",
"RelationshipType",
"RichieBase",
-55
View File
@@ -1,55 +0,0 @@
"""Audiobook catalog models."""
from __future__ import annotations
from sqlalchemy import ForeignKey, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from python.orm.richie.base import TableBase
class AudiobookAuthor(TableBase):
"""Canonical audiobook author."""
__tablename__ = "audiobook_author"
__table_args__ = (UniqueConstraint("name"),)
name: Mapped[str] = mapped_column(String, unique=True)
books: Mapped[list[Audiobook]] = relationship("Audiobook", back_populates="author")
series: Mapped[list[AudiobookSeries]] = relationship("AudiobookSeries", back_populates="author")
class AudiobookSeries(TableBase):
"""Canonical audiobook series."""
__tablename__ = "audiobook_series"
__table_args__ = (UniqueConstraint("author_id", "name"),)
name: Mapped[str] = mapped_column(String)
author_id: Mapped[int] = mapped_column(ForeignKey("main.audiobook_author.id", ondelete="CASCADE"))
author: Mapped[AudiobookAuthor] = relationship("AudiobookAuthor", back_populates="series")
books: Mapped[list[Audiobook]] = relationship("Audiobook", back_populates="series")
class Audiobook(TableBase):
"""Canonical audiobook title."""
__tablename__ = "audiobook"
__table_args__ = (
UniqueConstraint(
"author_id",
"series_id",
"title",
postgresql_nulls_not_distinct=True,
),
)
title: Mapped[str] = mapped_column(String)
author_id: Mapped[int] = mapped_column(ForeignKey("main.audiobook_author.id", ondelete="CASCADE"))
series_id: Mapped[int | None] = mapped_column(ForeignKey("main.audiobook_series.id", ondelete="SET NULL"))
series_index: Mapped[float] = mapped_column(default=0.0)
author: Mapped[AudiobookAuthor] = relationship("AudiobookAuthor", back_populates="books")
series: Mapped[AudiobookSeries | None] = relationship("AudiobookSeries", back_populates="books")
-130
View File
@@ -1,130 +0,0 @@
"""EPUB search models."""
from __future__ import annotations
from datetime import datetime
from pgvector.sqlalchemy import Vector
from sqlalchemy import BigInteger, Boolean, DateTime, ForeignKey, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from python.orm.richie.base import TableBase, TableBaseBig
class EbookSource(TableBase):
"""One indexed EPUB file."""
__tablename__ = "ebook_source"
__table_args__ = (
UniqueConstraint("file_path"),
UniqueConstraint("file_sha256"),
)
title: Mapped[str]
author: Mapped[str | None]
language: Mapped[str | None]
publisher: Mapped[str | None]
identifier: Mapped[str | None]
file_path: Mapped[str]
file_sha256: Mapped[str] = mapped_column(String(64))
file_mtime: Mapped[datetime] = mapped_column(DateTime(timezone=True))
file_size: Mapped[int] = mapped_column(BigInteger)
chapters: Mapped[list[EbookChapter]] = relationship(
"EbookChapter",
back_populates="source",
cascade="all, delete-orphan",
passive_deletes=True,
)
chunks: Mapped[list[EbookChunk]] = relationship(
"EbookChunk",
back_populates="source",
cascade="all, delete-orphan",
passive_deletes=True,
)
class EbookChapter(TableBase):
"""A chapter or spine document inside an EPUB."""
__tablename__ = "ebook_chapter"
__table_args__ = (UniqueConstraint("source_id", "spine_index"),)
source_id: Mapped[int] = mapped_column(ForeignKey("main.ebook_source.id", ondelete="CASCADE"))
spine_index: Mapped[int]
title: Mapped[str | None]
href: Mapped[str | None]
source: Mapped[EbookSource] = relationship("EbookSource", back_populates="chapters")
chunks: Mapped[list[EbookChunk]] = relationship(
"EbookChunk",
back_populates="chapter",
cascade="all, delete-orphan",
passive_deletes=True,
)
class EbookChunk(TableBaseBig):
"""A searchable text chunk."""
__tablename__ = "ebook_chunk"
__table_args__ = (
UniqueConstraint("source_id", "chunk_index", name="uq_ebook_chunk_source_id_chunk_index"),
UniqueConstraint("source_id", "content_sha256", name="uq_ebook_chunk_source_id_content_sha256"),
)
source_id: Mapped[int] = mapped_column(ForeignKey("main.ebook_source.id", ondelete="CASCADE"))
chapter_id: Mapped[int | None] = mapped_column(ForeignKey("main.ebook_chapter.id", ondelete="SET NULL"))
chunk_index: Mapped[int]
text: Mapped[str]
token_start: Mapped[int]
token_count: Mapped[int]
page_label: Mapped[str | None]
content_sha256: Mapped[str] = mapped_column(String(64))
search_text: Mapped[str]
source: Mapped[EbookSource] = relationship("EbookSource", back_populates="chunks")
chapter: Mapped[EbookChapter | None] = relationship("EbookChapter", back_populates="chunks")
class EbookEmbeddingModel(TableBase):
"""A supported embedding model."""
__tablename__ = "ebook_embedding_model"
name: Mapped[str] = mapped_column(String, unique=True)
dimension: Mapped[int]
is_default: Mapped[bool] = mapped_column(Boolean, default=False)
class EbookChunkEmbedding1024(TableBaseBig):
"""1024-dimensional chunk embedding."""
__tablename__ = "ebook_chunk_embedding_1024"
__table_args__ = (UniqueConstraint("chunk_id", "model_id"),)
chunk_id: Mapped[int] = mapped_column(ForeignKey("main.ebook_chunk.id", ondelete="CASCADE"))
model_id: Mapped[int] = mapped_column(ForeignKey("main.ebook_embedding_model.id", ondelete="CASCADE"))
embedding: Mapped[list[float]] = mapped_column(Vector(1024))
class EbookChunkEmbedding2560(TableBaseBig):
"""2560-dimensional chunk embedding."""
__tablename__ = "ebook_chunk_embedding_2560"
__table_args__ = (UniqueConstraint("chunk_id", "model_id"),)
chunk_id: Mapped[int] = mapped_column(ForeignKey("main.ebook_chunk.id", ondelete="CASCADE"))
model_id: Mapped[int] = mapped_column(ForeignKey("main.ebook_embedding_model.id", ondelete="CASCADE"))
embedding: Mapped[list[float]] = mapped_column(Vector(2560))
class EbookChunkEmbedding4096(TableBaseBig):
"""4096-dimensional chunk embedding."""
__tablename__ = "ebook_chunk_embedding_4096"
__table_args__ = (UniqueConstraint("chunk_id", "model_id"),)
chunk_id: Mapped[int] = mapped_column(ForeignKey("main.ebook_chunk.id", ondelete="CASCADE"))
model_id: Mapped[int] = mapped_column(ForeignKey("main.ebook_embedding_model.id", ondelete="CASCADE"))
embedding: Mapped[list[float]] = mapped_column(Vector(4096))
+25
View File
@@ -0,0 +1,25 @@
# Unsloth fine-tuning container for Qwen 3.5 4B on RTX 3090.
#
# Build:
# docker build -f python/prompt_bench/Dockerfile.finetune -t bill-finetune .
#
# Run:
# docker run --rm --device=nvidia.com/gpu=all --ipc=host \
# -v $(pwd)/output:/workspace/output \
# -v $(pwd)/output/finetune_dataset.jsonl:/workspace/dataset.jsonl:ro \
# -v /zfs/models/hf:/models \
# bill-finetune \
# --dataset /workspace/dataset.jsonl \
# --output-dir /workspace/output/qwen-bill-summarizer
FROM ghcr.io/unslothai/unsloth:latest
RUN pip install --no-cache-dir typer
WORKDIR /workspace
COPY python/prompt_bench/finetune.py python/prompt_bench/finetune.py
COPY python/prompt_bench/summarization_prompts.py python/prompt_bench/summarization_prompts.py
COPY python/prompt_bench/__init__.py python/prompt_bench/__init__.py
COPY python/__init__.py python/__init__.py
ENTRYPOINT ["python", "-m", "python.prompt_bench.finetune"]
+1
View File
@@ -0,0 +1 @@
"""Prompt benchmarking system for evaluating LLMs via vLLM."""
@@ -0,0 +1,233 @@
"""Submit an OpenAI Batch API bill-summarization job over compressed text.
Reads the first N bills from a CSV with a `text_content` column, compresses
each via `bill_token_compression.compress_bill_text`, builds a JSONL file of
summarization requests, and submits it as an asynchronous Batch API job
against `/v1/chat/completions`. Also writes a CSV of per-bill pre/post-
compression token counts.
"""
from __future__ import annotations
import csv
import json
import logging
import re
import sys
from os import getenv
from pathlib import Path
from typing import Annotated
import httpx
import typer
from tiktoken import Encoding, get_encoding
from python.prompt_bench.bill_token_compression import compress_bill_text
from python.prompt_bench.summarization_prompts import SUMMARIZATION_SYSTEM_PROMPT, SUMMARIZATION_USER_TEMPLATE
logger = logging.getLogger(__name__)
OPENAI_API_BASE = "https://api.openai.com/v1"
def load_bills(csv_path: Path, count: int = 0) -> list[tuple[str, str]]:
"""Return (bill_id, text_content) tuples with non-empty text.
If `count` is 0 or negative, all rows are returned.
"""
csv.field_size_limit(sys.maxsize)
bills: list[tuple[str, str]] = []
with csv_path.open(newline="", encoding="utf-8") as handle:
reader = csv.DictReader(handle)
for row in reader:
text_content = (row.get("text_content") or "").strip()
if not text_content:
continue
bill_id = row.get("bill_id") or row.get("id") or f"row-{len(bills)}"
version_code = row.get("version_code") or ""
unique_id = f"{bill_id}-{version_code}" if version_code else bill_id
bills.append((unique_id, text_content))
if count > 0 and len(bills) >= count:
break
return bills
def safe_filename(value: str) -> str:
"""Make a string safe for use as a filename or batch custom_id."""
return re.sub(r"[^A-Za-z0-9._-]+", "_", value).strip("_") or "unnamed"
def build_request(custom_id: str, model: str, bill_text: str) -> dict:
"""Build one OpenAI batch request line."""
return {
"custom_id": custom_id,
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": model,
"messages": [
{"role": "system", "content": SUMMARIZATION_SYSTEM_PROMPT},
{"role": "user", "content": SUMMARIZATION_USER_TEMPLATE.format(text_content=bill_text)},
],
},
}
def write_jsonl(path: Path, lines: list[dict]) -> None:
"""Write a list of dicts as JSONL."""
with path.open("w", encoding="utf-8") as handle:
for line in lines:
handle.write(json.dumps(line, ensure_ascii=False))
handle.write("\n")
def upload_file(client: httpx.Client, path: Path) -> str:
"""Upload a JSONL file to the OpenAI Files API and return its file id."""
with path.open("rb") as handle:
response = client.post(
f"{OPENAI_API_BASE}/files",
files={"file": (path.name, handle, "application/jsonl")},
data={"purpose": "batch"},
)
response.raise_for_status()
return response.json()["id"]
def prepare_requests(
bills: list[tuple[str, str]],
*,
model: str,
encoder: Encoding,
) -> tuple[list[dict], list[dict]]:
"""Build (request_lines, token_rows) from bills.
Each bill is compressed before being turned into a request line.
Each `token_rows` entry has chars + token counts for one bill so the caller
can write a per-bill CSV.
"""
request_lines: list[dict] = []
token_rows: list[dict] = []
for bill_id, text_content in bills:
raw_token_count = len(encoder.encode(text_content))
compressed_text = compress_bill_text(text_content)
compressed_token_count = len(encoder.encode(compressed_text))
token_rows.append(
{
"bill_id": bill_id,
"raw_chars": len(text_content),
"compressed_chars": len(compressed_text),
"raw_tokens": raw_token_count,
"compressed_tokens": compressed_token_count,
"token_ratio": (compressed_token_count / raw_token_count) if raw_token_count else None,
},
)
safe_id = safe_filename(bill_id)
request_lines.append(build_request(safe_id, model, compressed_text))
return request_lines, token_rows
def write_token_csv(path: Path, token_rows: list[dict]) -> tuple[int, int]:
"""Write per-bill token counts to CSV. Returns (raw_total, compressed_total)."""
with path.open("w", newline="", encoding="utf-8") as handle:
writer = csv.DictWriter(
handle,
fieldnames=["bill_id", "raw_chars", "compressed_chars", "raw_tokens", "compressed_tokens", "token_ratio"],
)
writer.writeheader()
writer.writerows(token_rows)
raw_total = sum(row["raw_tokens"] for row in token_rows)
compressed_total = sum(row["compressed_tokens"] for row in token_rows)
return raw_total, compressed_total
def create_batch(client: httpx.Client, input_file_id: str, description: str) -> dict:
"""Create a batch job and return its full response payload."""
response = client.post(
f"{OPENAI_API_BASE}/batches",
json={
"input_file_id": input_file_id,
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"metadata": {"description": description},
},
)
response.raise_for_status()
return response.json()
def main(
csv_path: Annotated[Path, typer.Option("--csv", help="Bills CSV path")] = Path("bills.csv"),
output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to write JSONL + metadata")] = Path(
"output/openai_batch",
),
model: Annotated[str, typer.Option(help="OpenAI model id")] = "gpt-5-mini",
count: Annotated[int, typer.Option(help="Max bills to process, 0 = all")] = 0,
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
) -> None:
"""Submit an OpenAI Batch job of compressed bill summaries."""
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
api_key = getenv("CLOSEDAI_TOKEN") or getenv("OPENAI_API_KEY")
if not api_key:
message = "Neither CLOSEDAI_TOKEN nor OPENAI_API_KEY is set"
raise typer.BadParameter(message)
if not csv_path.is_file():
message = f"CSV not found: {csv_path}"
raise typer.BadParameter(message)
output_dir.mkdir(parents=True, exist_ok=True)
logger.info("Loading %d bills from %s", count, csv_path)
bills = load_bills(csv_path, count)
if len(bills) < count:
logger.warning("Only %d bills available (requested %d)", len(bills), count)
encoder = get_encoding("o200k_base")
request_lines, token_rows = prepare_requests(bills, model=model, encoder=encoder)
token_csv_path = output_dir / "token_counts.csv"
raw_tokens_total, compressed_tokens_total = write_token_csv(token_csv_path, token_rows)
logger.info(
"Token counts: raw=%d compressed=%d ratio=%.3f -> %s",
raw_tokens_total,
compressed_tokens_total,
(compressed_tokens_total / raw_tokens_total) if raw_tokens_total else 0.0,
token_csv_path,
)
jsonl_path = output_dir / "requests.jsonl"
write_jsonl(jsonl_path, request_lines)
logger.info("Wrote %s (%d bills)", jsonl_path, len(request_lines))
headers = {"Authorization": f"Bearer {api_key}"}
with httpx.Client(headers=headers, timeout=httpx.Timeout(300.0)) as client:
logger.info("Uploading JSONL")
file_id = upload_file(client, jsonl_path)
logger.info("Uploaded: %s", file_id)
logger.info("Creating batch")
batch = create_batch(client, file_id, f"compressed bill summaries x{len(request_lines)} ({model})")
logger.info("Batch created: %s", batch["id"])
metadata = {
"model": model,
"count": len(bills),
"jsonl": str(jsonl_path),
"input_file_id": file_id,
"batch_id": batch["id"],
"raw_tokens_total": raw_tokens_total,
"compressed_tokens_total": compressed_tokens_total,
"batch": batch,
}
metadata_path = output_dir / "batch.json"
metadata_path.write_text(json.dumps(metadata, indent=2))
logger.info("Wrote metadata to %s", metadata_path)
def cli() -> None:
"""Typer entry point."""
typer.run(main)
if __name__ == "__main__":
cli()
@@ -0,0 +1,162 @@
"""Lossless-ish text compression for Congressional bill text."""
from __future__ import annotations
import re
STATES = (
"Alabama",
"Alaska",
"Arizona",
"Arkansas",
"California",
"Colorado",
"Connecticut",
"Delaware",
"Florida",
"Georgia",
"Hawaii",
"Idaho",
"Illinois",
"Indiana",
"Iowa",
"Kansas",
"Kentucky",
"Louisiana",
"Maine",
"Maryland",
"Massachusetts",
"Michigan",
"Minnesota",
"Mississippi",
"Missouri",
"Montana",
"Nebraska",
"Nevada",
"New Hampshire",
"New Jersey",
"New Mexico",
"New York",
"North Carolina",
"North Dakota",
"Ohio",
"Oklahoma",
"Oregon",
"Pennsylvania",
"Rhode Island",
"South Carolina",
"South Dakota",
"Tennessee",
"Texas",
"Utah",
"Vermont",
"Virginia",
"Washington",
"West Virginia",
"Wisconsin",
"Wyoming",
"Puerto Rico",
"Guam",
"American Samoa",
"District of Columbia",
"US Virgin Islands",
)
STATE_PATTERNS = [(re.compile(re.escape(state), re.IGNORECASE), state) for state in STATES]
def normalize_state_names(text: str) -> str:
"""Replace any casing of state names with title case."""
for pattern, replacement in STATE_PATTERNS:
text = pattern.sub(replacement, text)
return text
def strip_number_commas(text: str) -> str:
"""Remove commas from numeric thousands separators."""
return re.sub(r"(\d{1,3}(?:,\d{3})+)", lambda match: match.group().replace(",", ""), text)
def strip_horizontal_rules(text: str) -> str:
"""Remove ASCII horizontal-rule lines built from underscores, dashes, equals, or asterisks."""
return re.sub(r"^\s*[_\-=\*]{3,}\s*$", "", text, flags=re.MULTILINE)
def collapse_double_dashes(text: str) -> str:
"""Replace ``--`` em-dash stand-ins with a single space so they don't tokenize oddly."""
return text.replace("--", " ")
def collapse_inline_whitespace(text: str) -> str:
"""Collapse runs of horizontal whitespace (spaces, tabs) into a single space, leaving newlines intact."""
return re.sub(r"[^\S\n]+", " ", text)
def collapse_blank_lines(text: str) -> str:
"""Collapse three-or-more consecutive newlines down to a blank-line separator."""
return re.sub(r"\n{3,}", "\n\n", text)
def trim_line_edges(text: str) -> str:
"""Strip spaces immediately before and after newline characters on every line."""
text = re.sub(r" +\n", "\n", text)
return re.sub(r"\n +", "\n", text)
def shorten_section_markers(text: str) -> str:
"""Rewrite ``Sec. 12.`` style section headings as the more compact ``SEC 12``."""
return re.sub(r"(?i)sec\.\s*(\d+[a-zA-Z]?)\.", r"SEC \1", text)
def unwrap_parens(text: str) -> str:
"""Strip parentheses around short alphanumeric labels like ``(a)`` or ``(12)``."""
return re.sub(r"\(([a-zA-Z0-9]+)\)", r"\1", text)
def strip_typeset_quotes(text: str) -> str:
"""Remove the `` and '' typeset quote markers used in the GPO bill format."""
return text.replace("``", "").replace("''", "")
def normalize_usc_acronym(text: str) -> str:
"""Collapse ``U.S.C.`` to ``USC`` to save tokens on the common citation."""
return text.replace("U.S.C.", "USC")
def normalize_us_acronym(text: str) -> str:
"""Normalize the various ``U.S.``/``U. S.`` spellings to the bare ``US`` form."""
for acronym in ("U. S.", "u. s.", "U.S. ", "u.s. "):
text = text.replace(acronym, "US ")
return text
def collapse_ellipses(text: str) -> str:
"""Collapse runs of two-or-more periods (``...``, ``....``) down to a single period."""
return re.sub(r"\.{2,}", ".", text)
COMPRESSION_STEPS = (
strip_horizontal_rules,
collapse_double_dashes,
collapse_inline_whitespace,
collapse_blank_lines,
trim_line_edges,
shorten_section_markers,
unwrap_parens,
strip_typeset_quotes,
normalize_usc_acronym,
normalize_us_acronym,
strip_number_commas,
collapse_ellipses,
normalize_state_names,
)
def compress_bill_text(text: str) -> str:
"""Apply lossless-ish whitespace and boilerplate compression to bill text.
Runs every transform in :data:`COMPRESSION_STEPS` in order, then strips
leading/trailing whitespace from the final result.
"""
for step in COMPRESSION_STEPS:
text = step(text)
return text.strip()
+236
View File
@@ -0,0 +1,236 @@
"""Run two interactive OpenAI chat-completion sweeps over bill text.
Reads the first N bills from a CSV with a `text_content` column and sends two
sweeps through `/v1/chat/completions` concurrently — one with the raw bill
text, one with the compressed bill text. Each request's prompt is saved to
disk alongside the OpenAI response id so the prompts and responses can be
correlated later.
"""
from __future__ import annotations
import csv
import json
import logging
import re
import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from os import getenv
from pathlib import Path
from typing import Annotated
import httpx
import typer
from python.prompt_bench.bill_token_compression import compress_bill_text
from python.prompt_bench.summarization_prompts import SUMMARIZATION_SYSTEM_PROMPT, SUMMARIZATION_USER_TEMPLATE
logger = logging.getLogger(__name__)
OPENAI_API_BASE = "https://api.openai.com/v1"
DEFAULT_MODEL = "gpt-5.4-mini"
DEFAULT_COUNT = 100
SEED = 42
def load_bills(csv_path: Path, count: int) -> list[tuple[str, str]]:
"""Return up to `count` (bill_id, text_content) tuples with non-empty text."""
csv.field_size_limit(sys.maxsize)
bills: list[tuple[str, str]] = []
with csv_path.open(newline="", encoding="utf-8") as handle:
reader = csv.DictReader(handle)
for row in reader:
text_content = (row.get("text_content") or "").strip()
if not text_content:
continue
bill_id = row.get("bill_id") or row.get("id") or f"row-{len(bills)}"
version_code = row.get("version_code") or ""
unique_id = f"{bill_id}-{version_code}" if version_code else bill_id
bills.append((unique_id, text_content))
if len(bills) >= count:
break
return bills
def build_messages(bill_text: str) -> list[dict]:
"""Return the system + user message pair for a bill."""
return [
{"role": "system", "content": SUMMARIZATION_SYSTEM_PROMPT},
{"role": "user", "content": SUMMARIZATION_USER_TEMPLATE.format(text_content=bill_text)},
]
def safe_filename(value: str) -> str:
"""Make a string safe for use as a filename."""
return re.sub(r"[^A-Za-z0-9._-]+", "_", value).strip("_") or "unnamed"
def run_one_request(
client: httpx.Client,
*,
bill_id: str,
label: str,
bill_text: str,
model: str,
output_path: Path,
) -> tuple[bool, float, str | None]:
"""Send one chat-completion request and persist prompt + response.
Returns (success, elapsed_seconds, response_id).
"""
messages = build_messages(bill_text)
payload = {
"model": model,
"messages": messages,
"seed": SEED,
}
start = time.monotonic()
record: dict = {
"bill_id": bill_id,
"label": label,
"model": model,
"seed": SEED,
"input_chars": len(bill_text),
"messages": messages,
}
try:
response = client.post(f"{OPENAI_API_BASE}/chat/completions", json=payload)
response.raise_for_status()
body = response.json()
except httpx.HTTPStatusError as error:
elapsed = time.monotonic() - start
record["error"] = {
"status_code": error.response.status_code,
"body": error.response.text,
"elapsed_seconds": elapsed,
}
output_path.write_text(json.dumps(record, ensure_ascii=False, indent=2))
logger.exception("HTTP error for %s/%s after %.2fs", label, bill_id, elapsed)
return False, elapsed, None
except Exception as error:
elapsed = time.monotonic() - start
record["error"] = {"message": str(error), "elapsed_seconds": elapsed}
output_path.write_text(json.dumps(record, ensure_ascii=False, indent=2))
logger.exception("Failed: %s/%s after %.2fs", label, bill_id, elapsed)
return False, elapsed, None
elapsed = time.monotonic() - start
response_id = body.get("id")
record["response_id"] = response_id
record["elapsed_seconds"] = elapsed
record["usage"] = body.get("usage")
record["response"] = body
output_path.write_text(json.dumps(record, ensure_ascii=False, indent=2))
logger.info("Done: %s/%s id=%s in %.2fs", label, bill_id, response_id, elapsed)
return True, elapsed, response_id
def main(
csv_path: Annotated[Path, typer.Option("--csv", help="Bills CSV path")] = Path("bills.csv"),
output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to write per-request JSON")] = Path(
"output/openai_runs",
),
model: Annotated[str, typer.Option(help="OpenAI model id")] = DEFAULT_MODEL,
count: Annotated[int, typer.Option(help="Number of bills per set")] = DEFAULT_COUNT,
concurrency: Annotated[int, typer.Option(help="Concurrent in-flight requests")] = 16,
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
) -> None:
"""Run two interactive OpenAI sweeps (compressed + uncompressed) over bill text."""
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
api_key = getenv("CLOSEDAI_TOKEN") or getenv("OPENAI_API_KEY")
if not api_key:
message = "Neither CLOSEDAI_TOKEN nor OPENAI_API_KEY is set"
raise typer.BadParameter(message)
if not csv_path.is_file():
message = f"CSV not found: {csv_path}"
raise typer.BadParameter(message)
compressed_dir = output_dir / "compressed"
uncompressed_dir = output_dir / "uncompressed"
compressed_dir.mkdir(parents=True, exist_ok=True)
uncompressed_dir.mkdir(parents=True, exist_ok=True)
logger.info("Loading %d bills from %s", count, csv_path)
bills = load_bills(csv_path, count)
if len(bills) < count:
logger.warning("Only %d bills available (requested %d)", len(bills), count)
tasks: list[tuple[str, str, str, Path]] = []
for bill_id, text_content in bills:
filename = f"{safe_filename(bill_id)}.json"
tasks.append((bill_id, "compressed", compress_bill_text(text_content), compressed_dir / filename))
tasks.append((bill_id, "uncompressed", text_content, uncompressed_dir / filename))
logger.info("Submitting %d requests at concurrency=%d", len(tasks), concurrency)
headers = {"Authorization": f"Bearer {api_key}"}
completed = 0
failed = 0
index: list[dict] = []
wall_start = time.monotonic()
with (
httpx.Client(headers=headers, timeout=httpx.Timeout(300.0)) as client,
ThreadPoolExecutor(
max_workers=concurrency,
) as executor,
):
future_to_task = {
executor.submit(
run_one_request,
client,
bill_id=bill_id,
label=label,
bill_text=bill_text,
model=model,
output_path=output_path,
): (bill_id, label, output_path)
for bill_id, label, bill_text, output_path in tasks
}
for future in as_completed(future_to_task):
bill_id, label, output_path = future_to_task[future]
success, elapsed, response_id = future.result()
if success:
completed += 1
else:
failed += 1
index.append(
{
"bill_id": bill_id,
"label": label,
"response_id": response_id,
"elapsed_seconds": elapsed,
"success": success,
"path": str(output_path),
},
)
wall_elapsed = time.monotonic() - wall_start
summary = {
"model": model,
"count": len(bills),
"completed": completed,
"failed": failed,
"wall_seconds": wall_elapsed,
"concurrency": concurrency,
"results": index,
}
summary_path = output_dir / "summary.json"
summary_path.write_text(json.dumps(summary, indent=2))
logger.info(
"Done: completed=%d failed=%d wall=%.1fs summary=%s",
completed,
failed,
wall_elapsed,
summary_path,
)
def cli() -> None:
"""Typer entry point."""
typer.run(main)
if __name__ == "__main__":
cli()
@@ -0,0 +1 @@
"""Prompt benchmarking system for evaluating LLMs via vLLM."""
+165
View File
@@ -0,0 +1,165 @@
"""Docker container lifecycle management for Unsloth fine-tuning."""
from __future__ import annotations
import logging
import subprocess
from pathlib import Path
from typing import Annotated
import typer
from python.prompt_bench.containers.lib import check_gpu_free
logger = logging.getLogger(__name__)
CONTAINER_NAME = "bill-finetune"
FINETUNE_IMAGE = "bill-finetune:latest"
DOCKERFILE_PATH = "/home/richie/dotfiles/python/prompt_bench/Dockerfile.finetune"
DEFAULT_HF_CACHE = Path("/zfs/models/hf")
def build_image() -> None:
"""Build the fine-tuning Docker image."""
logger.info("Building fine-tuning image: %s", FINETUNE_IMAGE)
result = subprocess.run(
["docker", "build", "-f", DOCKERFILE_PATH, "-t", FINETUNE_IMAGE, "."],
text=True,
check=False,
)
if result.returncode != 0:
message = "Failed to build fine-tuning image"
raise RuntimeError(message)
logger.info("Image built: %s", FINETUNE_IMAGE)
def start_finetune(
*,
dataset_path: Path,
output_dir: Path,
hf_cache: Path = DEFAULT_HF_CACHE,
) -> None:
"""Run the fine-tuning container.
Args:
dataset_path: Host path to the fine-tuning JSONL dataset.
output_dir: Host path where the trained model will be saved.
hf_cache: Host path to HuggingFace model cache (bind-mounted to avoid re-downloading).
validation_split: Fraction of data held out for validation.
"""
dataset_path = dataset_path.resolve()
output_dir = output_dir.resolve()
if not dataset_path.is_file():
message = f"Dataset not found: {dataset_path}"
raise FileNotFoundError(message)
output_dir.mkdir(parents=True, exist_ok=True)
stop_finetune()
hf_cache = hf_cache.resolve()
hf_cache.mkdir(parents=True, exist_ok=True)
command = [
"docker",
"run",
"--name",
CONTAINER_NAME,
"--device=nvidia.com/gpu=all",
"--ipc=host",
"-v",
f"{hf_cache}:/root/.cache/huggingface",
"-v",
f"{output_dir}:/workspace/output/qwen-bill-summarizer",
"-v",
f"{dataset_path}:/workspace/dataset.jsonl:ro",
FINETUNE_IMAGE,
"--dataset",
"/workspace/dataset.jsonl",
"--output-dir",
"/workspace/output/qwen-bill-summarizer",
]
logger.info("Starting fine-tuning container")
logger.info(" Dataset: %s", dataset_path)
logger.info(" Output: %s", output_dir)
result = subprocess.run(command, text=True, check=False)
if result.returncode != 0:
message = f"Fine-tuning container exited with code {result.returncode}"
raise RuntimeError(message)
logger.info("Fine-tuning complete. Model saved to %s", output_dir)
def stop_finetune() -> None:
"""Stop and remove the fine-tuning container."""
logger.info("Stopping fine-tuning container")
subprocess.run(["docker", "stop", CONTAINER_NAME], capture_output=True, check=False)
subprocess.run(["docker", "rm", "-f", CONTAINER_NAME], capture_output=True, check=False)
def logs_finetune() -> str | None:
"""Return recent logs from the fine-tuning container, or None if not running."""
result = subprocess.run(
["docker", "logs", "--tail", "50", CONTAINER_NAME],
capture_output=True,
text=True,
check=False,
)
if result.returncode != 0:
return None
return result.stdout + result.stderr
app = typer.Typer(help="Fine-tuning container management.")
@app.command()
def build() -> None:
"""Build the fine-tuning Docker image."""
build_image()
@app.command()
def run(
dataset: Annotated[Path, typer.Option(help="Fine-tuning JSONL")] = Path(
"/home/richie/dotfiles/data/finetune_dataset.jsonl"
),
output_dir: Annotated[Path, typer.Option(help="Where to save the trained model")] = Path(
"/home/richie/dotfiles/data/output/qwen-bill-summarizer",
),
hf_cache: Annotated[Path, typer.Option(help="Host path to HuggingFace model cache")] = DEFAULT_HF_CACHE,
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
) -> None:
"""Run fine-tuning inside a Docker container."""
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
check_gpu_free()
start_finetune(
dataset_path=dataset,
output_dir=output_dir,
hf_cache=hf_cache,
)
@app.command()
def stop() -> None:
"""Stop and remove the fine-tuning container."""
stop_finetune()
@app.command()
def logs() -> None:
"""Show recent logs from the fine-tuning container."""
output = logs_finetune()
if output is None:
typer.echo("No running fine-tuning container found.")
raise typer.Exit(code=1)
typer.echo(output)
def cli() -> None:
"""Typer entry point."""
app()
if __name__ == "__main__":
cli()
+23
View File
@@ -0,0 +1,23 @@
from __future__ import annotations
import logging
import subprocess
logger = logging.getLogger(__name__)
def check_gpu_free() -> None:
"""Warn if GPU-heavy processes (e.g. Ollama) are running."""
result = subprocess.run(
["nvidia-smi", "--query-compute-apps=pid,process_name", "--format=csv,noheader"],
capture_output=True,
text=True,
check=False,
)
if result.returncode != 0:
logger.warning("Could not query GPU processes: %s", result.stderr.strip())
return
processes = result.stdout.strip()
if processes:
logger.warning("GPU processes detected:\n%s", processes)
logger.warning("Consider stopping Ollama (sudo systemctl stop ollama) before benchmarking")
+70
View File
@@ -0,0 +1,70 @@
"""Docker container lifecycle management for vLLM."""
from __future__ import annotations
import logging
import subprocess
logger = logging.getLogger(__name__)
CONTAINER_NAME = "vllm-bench"
VLLM_IMAGE = "vllm/vllm-openai:v0.19.0"
def start_vllm(
*,
model: str,
port: int,
model_dir: str,
gpu_memory_utilization: float,
) -> None:
"""Start a vLLM container serving the given model.
Args:
model: HuggingFace model directory name (relative to model_dir).
port: Host port to bind.
model_dir: Host path containing HuggingFace model directories.
gpu_memory_utilization: Fraction of GPU memory to use (0-1).
"""
command = [
"docker",
"run",
"-d",
"--name",
CONTAINER_NAME,
"--device=nvidia.com/gpu=all",
"--ipc=host",
"-v",
f"{model_dir}:/models",
"-p",
f"{port}:8000",
VLLM_IMAGE,
"--model",
f"/models/{model}",
"--served-model-name",
model,
"--gpu-memory-utilization",
str(gpu_memory_utilization),
"--max-model-len",
"4096",
]
logger.info("Starting vLLM container with model: %s", model)
stop_vllm()
result = subprocess.run(command, capture_output=True, text=True, check=False)
if result.returncode != 0:
msg = f"Failed to start vLLM container: {result.stderr.strip()}"
raise RuntimeError(msg)
logger.info("vLLM container started: %s", result.stdout.strip()[:12])
def stop_vllm() -> None:
"""Stop and remove the vLLM benchmark container."""
logger.info("Stopping vLLM container")
subprocess.run(["docker", "stop", CONTAINER_NAME], capture_output=True, check=False)
subprocess.run(["docker", "rm", "-f", CONTAINER_NAME], capture_output=True, check=False)
subprocess.run(
["docker", "network", "disconnect", "-f", "bridge", CONTAINER_NAME],
capture_output=True,
check=False,
)
logger.info("vLLM container stopped and removed")
+75
View File
@@ -0,0 +1,75 @@
"""HuggingFace model downloader."""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Annotated
import typer
from huggingface_hub import snapshot_download
from python.prompt_bench.models import BenchmarkConfig
logger = logging.getLogger(__name__)
def local_model_path(repo: str, model_dir: str) -> Path:
"""Return the local directory path for a HuggingFace repo."""
return Path(model_dir) / repo
def is_model_present(repo: str, model_dir: str) -> bool:
"""Check if a model has already been downloaded."""
path = local_model_path(repo, model_dir)
return path.exists() and any(path.iterdir())
def download_model(repo: str, model_dir: str) -> Path:
"""Download a HuggingFace model to the local model directory.
Skips the download if the model directory already exists and contains files.
"""
local_path = local_model_path(repo, model_dir)
if is_model_present(repo, model_dir):
logger.info("Model already exists: %s", local_path)
return local_path
logger.info("Downloading model: %s -> %s", repo, local_path)
snapshot_download(
repo_id=repo,
local_dir=str(local_path),
)
logger.info("Download complete: %s", repo)
return local_path
def download_all(config: BenchmarkConfig) -> None:
"""Download every model listed in the config, top to bottom."""
for repo in config.models:
download_model(repo, config.model_dir)
def main(
config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path("bench.toml"),
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
) -> None:
"""Download all models listed in the benchmark config."""
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
if not config.is_file():
message = f"Config file does not exist: {config}"
raise typer.BadParameter(message)
benchmark_config = BenchmarkConfig.from_toml(config)
download_all(benchmark_config)
def cli() -> None:
"""Typer entry point."""
typer.run(main)
if __name__ == "__main__":
cli()
+214
View File
@@ -0,0 +1,214 @@
"""Fine-tune Qwen 3.5 4B on bill summarization data using Unsloth.
Loads a ChatML-style JSONL dataset (system/user/assistant messages),
applies QLoRA with 4-bit quantization, and saves the merged model
in HuggingFace format. Designed for a single RTX 3090 (24GB).
Usage:
python -m python.prompt_bench.finetune \
--dataset output/finetune_dataset.jsonl \
--output-dir output/qwen-bill-summarizer
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Annotated
import tomllib
import typer
from unsloth import FastLanguageModel
from datasets import Dataset
from transformers import TrainingArguments
from trl import SFTTrainer
logger = logging.getLogger(__name__)
@dataclass
class LoraConfig:
"""LoRA adapter hyperparameters."""
rank: int
alpha: int
dropout: float
targets: list[str]
@dataclass
class TrainingConfig:
"""Training loop hyperparameters."""
learning_rate: float
epochs: int
batch_size: int
gradient_accumulation: int
max_seq_length: int
warmup_ratio: float
weight_decay: float
logging_steps: int
save_steps: int
@dataclass
class FinetuneConfig:
"""Top-level finetune configuration."""
base_model: str
lora: LoraConfig
training: TrainingConfig
@classmethod
def from_toml(cls, config_path: Path) -> FinetuneConfig:
"""Load finetune config from a TOML file."""
raw = tomllib.loads(config_path.read_text())["finetune"]
return cls(
base_model=raw["base_model"],
lora=LoraConfig(**raw["lora"]),
training=TrainingConfig(**raw["training"]),
)
def _messages_to_chatml(messages: list[dict]) -> str:
r"""Convert a message list to Qwen ChatML format.
Produces:
<|im_start|>system\n...\n<|im_end|>
<|im_start|>user\n...\n<|im_end|>
<|im_start|>assistant\n...\n<|im_end|>
"""
parts = []
for message in messages:
role = message["role"]
content = message["content"]
parts.append(f"<|im_start|>{role}\n{content}<|im_end|>")
return "\n".join(parts)
def load_dataset_from_jsonl(path: Path) -> Dataset:
"""Load a ChatML JSONL file into a HuggingFace Dataset.
Each line must have {"messages": [{"role": ..., "content": ...}, ...]}.
Pre-formats into a `text` column with the Qwen ChatML template applied,
which SFTTrainer consumes directly.
"""
records = []
with path.open(encoding="utf-8") as handle:
for raw_line in handle:
stripped = raw_line.strip()
if stripped:
entry = json.loads(stripped)
records.append({"text": _messages_to_chatml(entry["messages"])})
logger.info("Loaded %d examples from %s", len(records), path)
return Dataset.from_list(records)
def main(
dataset_path: Annotated[Path, typer.Option("--dataset", help="Fine-tuning JSONL")] = Path(
"output/finetune_dataset.jsonl",
),
validation_split: Annotated[float, typer.Option("--val-split", help="Fraction held out for validation")] = 0.1,
output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to save the merged model")] = Path(
"output/qwen-bill-summarizer",
),
config_path: Annotated[
Path,
typer.Option("--config", help="TOML config file"),
] = Path(__file__).parent / "config.toml",
save_gguf: Annotated[bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF")] = False,
) -> None:
"""Fine-tune Qwen 3.5 4B on bill summarization with Unsloth + QLoRA."""
logging.basicConfig(level="INFO", format="%(asctime)s %(levelname)s %(name)s: %(message)s")
if not dataset_path.is_file():
message = f"Dataset not found: {dataset_path}"
raise typer.BadParameter(message)
config = FinetuneConfig.from_toml(config_path)
logger.info("Loading base model: %s", config.base_model)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=config.base_model,
max_seq_length=config.training.max_seq_length,
load_in_4bit=True,
dtype=None,
)
logger.info("Applying LoRA (rank=%d, alpha=%d)", config.lora.rank, config.lora.alpha)
model = FastLanguageModel.get_peft_model(
model,
r=config.lora.rank,
lora_alpha=config.lora.alpha,
lora_dropout=config.lora.dropout,
target_modules=config.lora.targets,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=42,
)
full_dataset = load_dataset_from_jsonl(dataset_path)
split = full_dataset.train_test_split(test_size=validation_split, seed=42)
train_dataset = split["train"]
validation_dataset = split["test"]
logger.info("Split: %d train, %d validation", len(train_dataset), len(validation_dataset))
training_args = TrainingArguments(
output_dir=str(output_dir / "checkpoints"),
num_train_epochs=config.training.epochs,
per_device_train_batch_size=config.training.batch_size,
gradient_accumulation_steps=config.training.gradient_accumulation,
learning_rate=config.training.learning_rate,
warmup_ratio=config.training.warmup_ratio,
weight_decay=config.training.weight_decay,
lr_scheduler_type="cosine",
logging_steps=config.training.logging_steps,
save_steps=config.training.save_steps,
save_total_limit=3,
eval_strategy="steps",
eval_steps=config.training.save_steps,
load_best_model_at_end=True,
bf16=True,
optim="adamw_8bit",
seed=42,
report_to="none",
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=validation_dataset,
args=training_args,
max_seq_length=config.training.max_seq_length,
packing=True,
)
logger.info(
"Starting training: %d train, %d val, %d epochs",
len(train_dataset),
len(validation_dataset),
config.training.epochs,
)
trainer.train()
merged_path = str(output_dir / "merged")
logger.info("Saving merged model to %s", merged_path)
model.save_pretrained_merged(merged_path, tokenizer, save_method="merged_16bit")
if save_gguf:
gguf_path = str(output_dir / "gguf")
logger.info("Saving GGUF to %s", gguf_path)
model.save_pretrained_gguf(gguf_path, tokenizer, quantization_method="q4_k_m")
logger.info("Done! Model saved to %s", output_dir)
def cli() -> None:
"""Typer entry point."""
typer.run(main)
if __name__ == "__main__":
cli()
+215
View File
@@ -0,0 +1,215 @@
"""CLI entry point for the prompt benchmarking system."""
from __future__ import annotations
import json
import logging
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Annotated
import typer
from python.prompt_bench.containers.lib import check_gpu_free
from python.prompt_bench.containers.vllm import start_vllm, stop_vllm
from python.prompt_bench.downloader import is_model_present
from python.prompt_bench.models import BenchmarkConfig
from python.prompt_bench.vllm_client import VLLMClient
logger = logging.getLogger(__name__)
def discover_prompts(input_dir: Path) -> list[Path]:
"""Find all .txt files in the input directory."""
prompts = list(input_dir.glob("*.txt"))
if not prompts:
message = f"No .txt files found in {input_dir}"
raise FileNotFoundError(message)
return prompts
def _run_prompt(
client: VLLMClient,
prompt_path: Path,
*,
repo: str,
model_dir_name: str,
model_output: Path,
temperature: float,
) -> tuple[bool, float]:
"""Run a single prompt. Returns (success, elapsed_seconds)."""
filename = prompt_path.name
output_path = model_output / filename
start = time.monotonic()
try:
prompt_text = prompt_path.read_text()
response = client.complete(prompt_text, model_dir_name, temperature=temperature)
output_path.write_text(response)
elapsed = time.monotonic() - start
logger.info("Completed: %s / %s in %.2fs", repo, filename, elapsed)
except Exception:
elapsed = time.monotonic() - start
error_path = model_output / f"{filename}.error"
logger.exception("Failed: %s / %s after %.2fs", repo, filename, elapsed)
error_path.write_text(f"Error processing {filename}")
return False, elapsed
return True, elapsed
def benchmark_model(
client: VLLMClient,
prompts: list[Path],
*,
repo: str,
model_dir_name: str,
model_output: Path,
temperature: float,
concurrency: int,
) -> tuple[int, int]:
"""Run all prompts against a single model in parallel.
vLLM batches concurrent requests internally, so submitting many at once is
significantly faster than running them serially.
"""
pending = [prompt for prompt in prompts if not (model_output / prompt.name).exists()]
skipped = len(prompts) - len(pending)
if skipped:
logger.info("Skipping %d prompts with existing output for %s", skipped, repo)
if not pending:
logger.info("Nothing to do for %s", repo)
return 0, 0
completed = 0
failed = 0
latencies: list[float] = []
wall_start = time.monotonic()
with ThreadPoolExecutor(max_workers=concurrency) as executor:
futures = [
executor.submit(
_run_prompt,
client,
prompt_path,
repo=repo,
model_dir_name=model_dir_name,
model_output=model_output,
temperature=temperature,
)
for prompt_path in pending
]
for future in as_completed(futures):
success, elapsed = future.result()
latencies.append(elapsed)
if success:
completed += 1
else:
failed += 1
wall_elapsed = time.monotonic() - wall_start
attempted = completed + failed
avg_latency = sum(latencies) / attempted
throughput = attempted / wall_elapsed if wall_elapsed > 0 else 0.0
timing = {
"repo": repo,
"wall_seconds": wall_elapsed,
"attempted": attempted,
"completed": completed,
"failed": failed,
"avg_latency_seconds": avg_latency,
"throughput_prompts_per_second": throughput,
"concurrency": concurrency,
}
timing_path = model_output / "_timing.json"
timing_path.write_text(json.dumps(timing, indent=2))
return completed, failed
def run_benchmark(
config: BenchmarkConfig,
input_dir: Path,
output_dir: Path,
) -> None:
"""Execute the benchmark across all models and prompts."""
prompts = discover_prompts(input_dir)
logger.info("Found %d prompts in %s", len(prompts), input_dir)
check_gpu_free()
total_completed = 0
total_failed = 0
for repo in config.models:
if not is_model_present(repo, config.model_dir):
logger.warning("Skipping (not downloaded): %s", repo)
continue
model_output = output_dir / repo
model_output.mkdir(parents=True, exist_ok=True)
logger.info("=== Benchmarking model: %s ===", repo)
stop_vllm()
try:
start_vllm(
model=repo,
port=config.port,
model_dir=config.model_dir,
gpu_memory_utilization=config.gpu_memory_utilization,
)
except RuntimeError:
logger.exception("Failed to start vLLM for %s, skipping", repo)
continue
logger.info("vLLM started for %s", repo)
try:
with VLLMClient(port=config.port, timeout=config.timeout) as client:
client.wait_ready(max_wait=config.vllm_startup_timeout)
completed, failed = benchmark_model(
client,
prompts,
repo=repo,
model_dir_name=repo,
model_output=model_output,
temperature=config.temperature,
concurrency=config.concurrency,
)
total_completed += completed
total_failed += failed
finally:
stop_vllm()
logger.info("=== Benchmark complete ===")
logger.info("Completed: %d | Failed: %d", total_completed, total_failed)
def main(
input_dir: Annotated[Path, typer.Argument(help="Directory containing input .txt prompt files")],
config: Annotated[Path, typer.Option(help="Path to TOML config file")] = Path("bench.toml"),
output_dir: Annotated[Path, typer.Option(help="Output directory for results")] = Path("output"),
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
) -> None:
"""Run prompts through multiple LLMs via vLLM and save results."""
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
if not input_dir.is_dir():
message = f"Input directory does not exist: {input_dir}"
raise typer.BadParameter(message)
if not config.is_file():
message = f"Config file does not exist: {config}"
raise typer.BadParameter(message)
benchmark_config = BenchmarkConfig.from_toml(config)
output_dir.mkdir(parents=True, exist_ok=True)
run_benchmark(benchmark_config, input_dir, output_dir)
def cli() -> None:
"""Typer entry point."""
typer.run(main)
if __name__ == "__main__":
cli()
+30
View File
@@ -0,0 +1,30 @@
"""Pydantic models for benchmark configuration."""
from __future__ import annotations
import tomllib
from typing import TYPE_CHECKING
from pydantic import BaseModel
if TYPE_CHECKING:
from pathlib import Path
class BenchmarkConfig(BaseModel):
"""Top-level benchmark configuration loaded from TOML."""
models: list[str]
model_dir: str = "/zfs/models/hf"
port: int = 8000
gpu_memory_utilization: float = 0.90
temperature: float = 0.0
timeout: int = 300
concurrency: int = 4
vllm_startup_timeout: int = 900
@classmethod
def from_toml(cls, config_path: Path) -> BenchmarkConfig:
"""Load benchmark config from a TOML file."""
raw = tomllib.loads(config_path.read_text())["bench"]
return cls(**raw)
@@ -0,0 +1,34 @@
SUMMARIZATION_SYSTEM_PROMPT = """You are a legislative analyst extracting policy substance from Congressional bill text.
Your job is to compress a bill into a dense, neutral structured summary that captures every distinct policy action — including secondary effects that might be buried in subsections.
EXTRACTION RULES:
- IGNORE: whereas clauses, congressional findings that are purely political statements, recitals, preambles, citations of existing law by number alone, and procedural boilerplate.
- FOCUS ON: operative verbs — what the bill SHALL do, PROHIBIT, REQUIRE, AUTHORIZE, AMEND, APPROPRIATE, or ESTABLISH.
- SURFACE ALL THREADS: If the bill touches multiple policy areas, list each thread separately. Do not collapse them.
- BE CONCRETE: Name the affected population, the mechanism, and the direction (expands/restricts/maintains).
- STAY NEUTRAL: No political framing. Describe what the text does, not what its sponsors claim it does.
OUTPUT FORMAT — plain structured text, not JSON:
OPERATIVE ACTIONS:
[Numbered list of what the bill actually does, one action per line, max 20 words each]
AFFECTED POPULATIONS:
[Who gains something, who loses something, or whose behavior is regulated]
MECHANISMS:
[How it works: new funding, mandate, prohibition, amendment to existing statute, grant program, study commission, etc.]
POLICY THREADS:
[List each distinct policy domain this bill touches, even minor ones. Use plain language, not domain codes.]
SYMBOLIC/PROCEDURAL ONLY:
[Yes or No — is this bill primarily a resolution, designation, or awareness declaration with no operative effect?]
LENGTH TARGET: 150-250 words total. Be ruthless about cutting. Density over completeness."""
SUMMARIZATION_USER_TEMPLATE = """Summarize the following Congressional bill according to your instructions.
BILL TEXT:
{text_content}"""
@@ -0,0 +1,114 @@
"""Build a fine-tuning JSONL dataset from batch request + output files.
Joins the original request JSONL (system + user messages) with the batch
output JSONL (assistant completions) by custom_id to produce a ChatML-style
messages JSONL suitable for fine-tuning.
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import Annotated
import typer
logger = logging.getLogger(__name__)
HTTP_OK = 200
def load_requests(path: Path) -> dict[str, list[dict]]:
"""Parse request JSONL into {custom_id: messages}."""
results: dict[str, list[dict]] = {}
with path.open(encoding="utf-8") as handle:
for raw_line in handle:
stripped = raw_line.strip()
if not stripped:
continue
record = json.loads(stripped)
custom_id = record["custom_id"]
messages = record["body"]["messages"]
results[custom_id] = messages
return results
def load_completions(path: Path) -> dict[str, str]:
"""Parse batch output JSONL into {custom_id: assistant_content}."""
results: dict[str, str] = {}
with path.open(encoding="utf-8") as handle:
for line_number, raw_line in enumerate(handle, 1):
stripped = raw_line.strip()
if not stripped:
continue
record = json.loads(stripped)
custom_id = record["custom_id"]
response = record.get("response", {})
if response.get("status_code") != HTTP_OK:
logger.warning("Skipping %s (line %d): status %s", custom_id, line_number, response.get("status_code"))
continue
body = response.get("body", {})
choices = body.get("choices", [])
if not choices:
logger.warning("Skipping %s (line %d): no choices", custom_id, line_number)
continue
content = choices[0].get("message", {}).get("content", "")
if not content:
logger.warning("Skipping %s (line %d): empty content", custom_id, line_number)
continue
results[custom_id] = content
return results
def main(
requests_path: Annotated[Path, typer.Option("--requests", help="Batch request JSONL")] = Path(
"output/openai_batch/requests.jsonl",
),
batch_output: Annotated[Path, typer.Option("--batch-output", help="Batch output JSONL")] = Path(
"batch_69d84558d91c819091d53f08d78f9fd6_output.jsonl",
),
output_path: Annotated[Path, typer.Option("--output", help="Fine-tuning JSONL output")] = Path(
"output/finetune_dataset.jsonl",
),
log_level: Annotated[str, typer.Option(help="Log level")] = "INFO",
) -> None:
"""Build fine-tuning dataset by joining request and output JSONL files."""
logging.basicConfig(level=log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
logger.info("Loading requests from %s", requests_path)
requests = load_requests(requests_path)
logger.info("Loaded %d requests", len(requests))
logger.info("Loading completions from %s", batch_output)
completions = load_completions(batch_output)
logger.info("Loaded %d completions", len(completions))
output_path.parent.mkdir(parents=True, exist_ok=True)
matched = 0
skipped = 0
with output_path.open("w", encoding="utf-8") as handle:
for custom_id, messages in requests.items():
assistant_content = completions.get(custom_id)
if assistant_content is None:
skipped += 1
continue
example = {
"messages": [*messages, {"role": "assistant", "content": assistant_content}],
}
handle.write(json.dumps(example, ensure_ascii=False))
handle.write("\n")
matched += 1
logger.info("Wrote %d examples to %s (skipped %d unmatched)", matched, output_path, skipped)
def cli() -> None:
"""Typer entry point."""
typer.run(main)
if __name__ == "__main__":
cli()
+97
View File
@@ -0,0 +1,97 @@
"""Sum token usage across compressed and uncompressed run directories."""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Annotated
import typer
logger = logging.getLogger(__name__)
@dataclass
class UsageTotals:
"""Aggregate usage counters for a directory of run records."""
files: int = 0
errors: int = 0
prompt_tokens: int = 0
cached_tokens: int = 0
completion_tokens: int = 0
reasoning_tokens: int = 0
total_tokens: int = 0
per_file: list[tuple[str, int, int, int]] = field(default_factory=list)
def tally_directory(directory: Path) -> UsageTotals:
"""Return aggregated usage stats for every JSON record in a directory."""
totals = UsageTotals()
decoder = json.JSONDecoder()
for path in sorted(directory.glob("*.json")):
text = path.read_text().lstrip()
record, _ = decoder.raw_decode(text)
totals.files += 1
usage = record.get("usage")
if not usage:
totals.errors += 1
continue
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
total_tokens = usage.get("total_tokens", 0)
cached_tokens = (usage.get("prompt_tokens_details") or {}).get("cached_tokens", 0)
reasoning_tokens = (usage.get("completion_tokens_details") or {}).get("reasoning_tokens", 0)
totals.prompt_tokens += prompt_tokens
totals.completion_tokens += completion_tokens
totals.total_tokens += total_tokens
totals.cached_tokens += cached_tokens
totals.reasoning_tokens += reasoning_tokens
totals.per_file.append((path.name, prompt_tokens, completion_tokens, total_tokens))
return totals
def log_totals(label: str, totals: UsageTotals) -> None:
"""Log a one-block summary for a directory."""
counted = totals.files - totals.errors
average_total = totals.total_tokens / counted if counted else 0
logger.info("[%s]", label)
logger.info(" files : %d (with usage: %d, errors: %d)", totals.files, counted, totals.errors)
logger.info(" prompt tokens : %d", totals.prompt_tokens)
logger.info(" cached tokens : %d", totals.cached_tokens)
logger.info(" completion tok : %d", totals.completion_tokens)
logger.info(" reasoning tok : %d", totals.reasoning_tokens)
logger.info(" total tokens : %d", totals.total_tokens)
logger.info(" avg total/file : %.1f", average_total)
def main(
runs_dir: Annotated[Path, typer.Option("--runs-dir")] = Path("output/openai_runs_temp_1"),
log_level: Annotated[str, typer.Option("--log-level")] = "INFO",
) -> None:
"""Print token usage totals for the compressed and uncompressed run directories."""
logging.basicConfig(level=log_level, format="%(message)s")
grand = UsageTotals()
for label in ("compressed", "uncompressed"):
directory = runs_dir / label
if not directory.is_dir():
logger.warning("%s: directory not found at %s", label, directory)
continue
totals = tally_directory(directory)
log_totals(label, totals)
grand.files += totals.files
grand.errors += totals.errors
grand.prompt_tokens += totals.prompt_tokens
grand.cached_tokens += totals.cached_tokens
grand.completion_tokens += totals.completion_tokens
grand.reasoning_tokens += totals.reasoning_tokens
grand.total_tokens += totals.total_tokens
log_totals("grand total", grand)
if __name__ == "__main__":
typer.run(main)
+68
View File
@@ -0,0 +1,68 @@
"""OpenAI-compatible client for vLLM's API."""
from __future__ import annotations
import logging
import time
from typing import Self
import httpx
logger = logging.getLogger(__name__)
READY_POLL_INTERVAL = 2.0
class VLLMClient:
"""Talk to a vLLM server via its OpenAI-compatible API.
Args:
host: vLLM host.
port: vLLM port.
timeout: Per-request timeout in seconds.
"""
def __init__(self, *, host: str = "localhost", port: int = 8000, timeout: int = 300) -> None:
"""Create a client connected to a vLLM server."""
self._client = httpx.Client(base_url=f"http://{host}:{port}", timeout=timeout)
def wait_ready(self, max_wait: int) -> None:
"""Poll /v1/models until the server is ready or timeout."""
deadline = time.monotonic() + max_wait
while time.monotonic() < deadline:
try:
response = self._client.get("/v1/models")
if response.is_success:
logger.info("vLLM server is ready")
return
except httpx.TransportError:
pass
time.sleep(READY_POLL_INTERVAL)
msg = f"vLLM server not ready after {max_wait}s"
raise TimeoutError(msg)
def complete(self, prompt: str, model: str, *, temperature: float = 0.0, max_tokens: int = 4096) -> str:
"""Send a prompt to /v1/completions and return the response text."""
payload = {
"model": model,
"prompt": prompt,
"temperature": temperature,
"max_tokens": max_tokens,
}
logger.info("Sending prompt to %s (%d chars)", model, len(prompt))
response = self._client.post("/v1/completions", json=payload)
response.raise_for_status()
data = response.json()
return data["choices"][0]["text"]
def close(self) -> None:
"""Close the HTTP client."""
self._client.close()
def __enter__(self) -> Self:
"""Enter the context manager."""
return self
def __exit__(self, *args: object) -> None:
"""Close the HTTP client on exit."""
self.close()
-1
View File
@@ -1 +0,0 @@
"""Audiobook tools."""
-471
View File
@@ -1,471 +0,0 @@
"""Convert Audible AAX downloads into Audiobookshelf-friendly M4B files."""
from __future__ import annotations
import json
import logging
import re
import shutil
import subprocess
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict, dataclass
from os import getenv
from pathlib import Path # noqa: TC003 This is required for the typer CLI
from typing import TYPE_CHECKING, Annotated, Any
from uuid import uuid7
import typer
from python.common import configure_logger
from python.orm.common import get_postgres_engine
from python.tools.audiobook.metadata_agent import (
AgentConfig,
StandardBookMetadata,
standard_book_metadata,
write_agent_log,
)
if TYPE_CHECKING:
from sqlalchemy.engine import Engine
logger = logging.getLogger(__name__)
SENSITIVE_COMMAND_ARGUMENTS = {"-activation_bytes"}
BOOK_RANGE_PATTERN = re.compile(r"(?:^|-)books?-(?P<start>[1-9]\d*)-(?P<end>[1-9]\d*)(?:-|$)")
@dataclass(frozen=True)
class ConversionConfig:
"""Runtime settings for one conversion command."""
resolved_output: Path
ollama_api_key: str
agent_config: AgentConfig
engine: Engine
activation_bytes: str | None
dry_run: bool
overwrite: bool
work_directory_name: str = ".audible_convert"
dry_run_directory_name: str = "dry-run"
temp_directory_name: str = "tmp"
log_directory_name: str = "logs"
review_directory_name: str = "review"
@dataclass(frozen=True)
class ConcurrentConversionResult:
"""Result from running ffmpeg and metadata resolution together."""
metadata: StandardBookMetadata | None
conversion_error: Exception | None
metadata_error: Exception | None
class CommandExecutionError(RuntimeError):
"""Command failed without exposing sensitive arguments."""
def __init__(self, arguments: list[str], returncode: int) -> None:
"""Create a redacted command failure."""
self.arguments = tuple(arguments)
self.returncode = returncode
command = " ".join(redact_command_arguments(arguments))
super().__init__(f"Command failed with exit code {returncode}: {command}")
def main(
input_directory: Annotated[Path, typer.Argument(help="Directory audible-cli downloads AAX files into.")],
output_directory: Annotated[Path, typer.Argument(help="Audiobook output directory.")],
*,
dry_run: Annotated[
bool,
typer.Option("--dry-run", help="Print planned output files and write marker files without converting."),
] = False,
overwrite: Annotated[bool, typer.Option("--overwrite", help="Overwrite existing M4B files.")] = False,
) -> None:
"""Convert AAX files from a download directory into M4B files."""
configure_logger()
resolved_input = input_directory.resolve(strict=True)
resolved_output = output_directory.resolve()
if not dry_run:
resolved_output.mkdir(parents=True, exist_ok=True)
ollama_api_key = getenv("OLLAMA_API_KEY")
if not ollama_api_key:
msg = "OLLAMA_API_KEY is required for audiobook metadata resolution"
raise RuntimeError(msg)
config = ConversionConfig(
resolved_output=resolved_output,
ollama_api_key=ollama_api_key,
agent_config=AgentConfig(),
engine=get_postgres_engine(name="RICHIE"),
activation_bytes=getenv("AUDIBLE_ACTIVATION_BYTES"),
dry_run=dry_run,
overwrite=overwrite,
)
aax_files = sorted(resolved_input.glob("*.aax"))
if not aax_files:
logger.info("No AAX files found in %s", resolved_input)
return
for aax_file in aax_files:
logger.info("Converting %s", aax_file)
convert_aax_file_with_agent(aax_file, config)
def run_command(arguments: list[str], *, capture: bool = False) -> subprocess.CompletedProcess[str]:
"""Run a command and return the completed process.
Args:
arguments: Command and arguments to run.
capture: Whether to capture stdout and stderr.
Returns:
The completed process.
"""
logger.debug("%s", " ".join(redact_command_arguments(arguments)))
try:
return subprocess.run(arguments, check=True, capture_output=capture, text=True)
except subprocess.CalledProcessError as error:
raise CommandExecutionError(arguments, error.returncode) from error
def redact_command_arguments(arguments: list[str]) -> list[str]:
"""Return command arguments with sensitive values redacted."""
redacted = []
redact_next = False
for argument in arguments:
if redact_next:
redacted.append("<redacted>")
redact_next = False
continue
redacted.append(argument)
redact_next = argument in SENSITIVE_COMMAND_ARGUMENTS
return redacted
def read_metadata(aax_file: Path) -> dict[str, str]:
"""Read ffprobe format tags from an AAX file.
Args:
aax_file: AAX file to inspect.
Returns:
Lower-cased metadata tag names mapped to their values.
"""
completed = run_command(
[
"ffprobe",
"-v",
"quiet",
"-print_format",
"json",
"-show_format",
str(aax_file),
],
capture=True,
)
ffprobe_data: dict[str, Any] = json.loads(completed.stdout)
tags = ffprobe_data.get("format", {}).get("tags", {})
return {str(key).lower(): str(value) for key, value in tags.items()}
def output_stem(metadata: StandardBookMetadata) -> str:
"""Build the output stem for a book.
Args:
metadata: Book metadata.
Returns:
Output stem in author-series_01-title form.
"""
index_slug = series_index_slug(metadata.series_index, metadata.title)
return f"{metadata.author}-{metadata.series}_{index_slug}-{metadata.title}"
def series_index_slug(series_index: float, title: str = "") -> str:
"""Return a filename-safe series index."""
if title_range := title_series_range_slug(series_index, title):
return title_range
index = float(series_index)
if index.is_integer():
return f"{int(index):02}"
return f"{int(index):02}.5"
def title_series_range_slug(series_index: float, title: str) -> str | None:
"""Return a series range slug found in an omnibus title."""
index = float(series_index)
if not index.is_integer():
return None
first_index = int(index)
for match in BOOK_RANGE_PATTERN.finditer(title):
start = int(match.group("start"))
end = int(match.group("end"))
if start == first_index and end > start:
return f"{start:02}-{end:02}"
return None
def metadata_output_path(output_directory: Path, metadata: StandardBookMetadata) -> Path:
"""Build the final M4B path from resolved metadata."""
stem = output_stem(metadata)
return output_directory / stem / f"{stem}.m4b"
def convert_aax_file(
aax_file: Path,
destination: Path,
activation_bytes: str | None,
*,
overwrite: bool,
) -> None:
"""Convert an AAX file into an M4B file.
Args:
aax_file: Source AAX file.
destination: Destination M4B file.
activation_bytes: Optional Audible activation bytes for ffmpeg.
overwrite: Whether to overwrite an existing M4B.
"""
if destination.exists() and not overwrite:
logger.info("Skipping existing file %s", destination)
return
destination.parent.mkdir(parents=True, exist_ok=True)
arguments = ["ffmpeg", "-hide_banner", "-y" if overwrite else "-n"]
if activation_bytes:
arguments.extend(["-activation_bytes", activation_bytes])
arguments.extend(["-i", str(aax_file), "-map_metadata", "0", "-c", "copy", str(destination)])
run_command(arguments)
def write_review_file(
*,
destination: Path | None,
ffprobe_metadata: dict[str, str],
log_file: Path,
metadata: StandardBookMetadata | None,
reason: str,
review_file: Path,
source: Path,
temp_file: Path | None,
) -> None:
"""Write a manual review file for an unresolved conversion."""
review_file.parent.mkdir(parents=True, exist_ok=True)
payload = {
"destination": str(destination) if destination else None,
"ffprobe_metadata": ffprobe_metadata,
"metadata": asdict(metadata) if metadata else None,
"reason": reason,
"source": str(source),
"temp_file": str(temp_file) if temp_file else None,
}
review_file.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
write_agent_log(log_file, "review_written", path=str(review_file), reason=reason)
def cleanup_temp_output(temp_file: Path) -> None:
"""Remove a run's temporary output directory."""
shutil.rmtree(temp_file.parent, ignore_errors=True)
def dry_run_aax_file_with_agent(
aax_file: Path,
ffprobe_metadata: dict[str, str],
engine: Engine,
config: ConversionConfig,
log_file: Path,
review_file: Path,
) -> None:
"""Resolve and print the planned output path without converting."""
metadata = standard_book_metadata(
aax_file.name,
ffprobe_metadata,
engine,
log_file,
config.ollama_api_key,
config.agent_config,
)
destination = None if metadata.needs_review else metadata_output_path(config.resolved_output, metadata)
if metadata.needs_review:
write_review_file(
destination=destination,
ffprobe_metadata=ffprobe_metadata,
log_file=log_file,
metadata=metadata,
reason="metadata_needs_review",
review_file=review_file,
source=aax_file,
temp_file=None,
)
typer.echo(f"{aax_file} -> REVIEW {review_file}")
else:
stem = output_stem(metadata)
dry_run_file = (
config.resolved_output / config.work_directory_name / config.dry_run_directory_name / stem / f"{stem}.m4b"
)
dry_run_file.parent.mkdir(parents=True, exist_ok=True)
dry_run_file.write_text(f"{destination}\n", encoding="utf-8")
write_agent_log(
log_file,
"dry_run_file_written",
destination=str(destination),
path=str(dry_run_file),
)
typer.echo(f"{aax_file} -> {destination}")
def convert_temp_file_and_resolve_metadata(
aax_file: Path,
temp_file: Path,
ffprobe_metadata: dict[str, str],
config: ConversionConfig,
log_file: Path,
) -> ConcurrentConversionResult:
"""Run ffmpeg and metadata resolution in parallel."""
conversion_error: Exception | None = None
metadata_error: Exception | None = None
metadata: StandardBookMetadata | None = None
with ThreadPoolExecutor(max_workers=2) as executor:
conversion_future = executor.submit(
convert_aax_file,
aax_file,
temp_file,
config.activation_bytes,
overwrite=True,
)
metadata_future = executor.submit(
standard_book_metadata,
aax_file.name,
ffprobe_metadata,
config.engine,
log_file,
config.ollama_api_key,
config.agent_config,
)
conversion_error = conversion_future.exception()
if conversion_error is None:
conversion_future.result()
metadata_error = metadata_future.exception()
if metadata_error is None:
metadata = metadata_future.result()
return ConcurrentConversionResult(
metadata=metadata,
conversion_error=conversion_error,
metadata_error=metadata_error,
)
def convert_aax_file_with_agent(aax_file: Path, config: ConversionConfig) -> None:
"""Convert one AAX file using the metadata agent for the final path."""
run_id = uuid7().hex
log_file = config.resolved_output / config.work_directory_name / config.log_directory_name / f"{run_id}.jsonl"
review_file = config.resolved_output / config.work_directory_name / config.review_directory_name / f"{run_id}.json"
write_agent_log(log_file, "conversion_start", source=str(aax_file), dry_run=config.dry_run)
try:
ffprobe_metadata = read_metadata(aax_file)
except Exception as error:
logger.exception("ffprobe failed")
write_review_file(
destination=None,
ffprobe_metadata={},
log_file=log_file,
metadata=None,
reason=f"ffprobe_failed: {error}",
review_file=review_file,
source=aax_file,
temp_file=None,
)
return
if config.dry_run:
dry_run_aax_file_with_agent(
aax_file,
ffprobe_metadata,
config.engine,
config,
log_file,
review_file,
)
return
temp_file = (
config.resolved_output / config.work_directory_name / config.temp_directory_name / run_id / "converted.m4b"
)
temp_file.parent.mkdir(parents=True, exist_ok=True)
result = convert_temp_file_and_resolve_metadata(aax_file, temp_file, ffprobe_metadata, config, log_file)
if result.conversion_error:
reason = f"ffmpeg_failed: {result.conversion_error}"
write_review_file(
destination=None,
ffprobe_metadata=ffprobe_metadata,
log_file=log_file,
metadata=result.metadata,
reason=reason,
review_file=review_file,
source=aax_file,
temp_file=temp_file if temp_file.exists() else None,
)
return
if result.metadata_error:
write_review_file(
destination=None,
ffprobe_metadata=ffprobe_metadata,
log_file=log_file,
metadata=None,
reason=f"metadata_failed: {result.metadata_error}",
review_file=review_file,
source=aax_file,
temp_file=temp_file,
)
return
if result.metadata is None or result.metadata.needs_review:
write_review_file(
destination=None,
ffprobe_metadata=ffprobe_metadata,
log_file=log_file,
metadata=result.metadata,
reason="metadata_needs_review",
review_file=review_file,
source=aax_file,
temp_file=temp_file,
)
return
destination = metadata_output_path(config.resolved_output, result.metadata)
if destination.exists() and not config.overwrite:
write_agent_log(log_file, "destination_exists", destination=str(destination))
cleanup_temp_output(temp_file)
return
destination.parent.mkdir(parents=True, exist_ok=True)
try:
temp_file.replace(destination)
except Exception as error: # noqa: BLE001
write_review_file(
destination=destination,
ffprobe_metadata=ffprobe_metadata,
log_file=log_file,
metadata=result.metadata,
reason=f"rename_failed: {error}",
review_file=review_file,
source=aax_file,
temp_file=temp_file if temp_file.exists() else None,
)
else:
cleanup_temp_output(temp_file)
write_agent_log(log_file, "conversion_complete", destination=str(destination))
if __name__ == "__main__":
typer.run(main)
-176
View File
@@ -1,176 +0,0 @@
"""Import audiobook catalog authors and series from CSV files."""
from __future__ import annotations
import csv
import logging
from pathlib import Path # noqa: TC003 This is required for the typer CLI
from typing import Annotated
import typer
from sqlalchemy import select
from sqlalchemy.orm import Session
from python.common import configure_logger
from python.orm.common import get_postgres_engine
from python.orm.richie import AudiobookAuthor, AudiobookSeries
logger = logging.getLogger(__name__)
AUTHOR_NAME_COLUMN = "author_name"
ID_COLUMN = "id"
NAME_COLUMN = "name"
class CatalogImportError(ValueError):
"""CSV catalog import failed validation."""
def main(
authors_csv: Annotated[Path, typer.Argument(help="CSV with name and optional id.")],
series_csv: Annotated[Path, typer.Argument(help="CSV with name, author_name, and optional id.")],
) -> None:
"""Upsert audiobook authors and series from CSV files."""
configure_logger()
try:
engine = get_postgres_engine(name="RICHIE")
with Session(engine) as session:
author_count = upsert_authors_from_csv(session, authors_csv)
series_count = upsert_series_from_csv(session, series_csv)
session.commit()
except CatalogImportError as error:
typer.echo(str(error), err=True)
raise typer.Exit(code=1) from error
logger.info("Upserted %s authors and %s series", author_count, series_count)
def upsert_authors_from_csv(session: Session, authors_csv: Path) -> int:
"""Upsert authors from a CSV file."""
count = 0
for row_number, row in csv_rows(authors_csv):
name = required_csv_value(row, authors_csv, row_number, NAME_COLUMN)
upsert_author(session, name, csv_id(row, authors_csv, row_number))
count += 1
return count
def upsert_series_from_csv(session: Session, series_csv: Path) -> int:
"""Upsert series from a CSV file."""
count = 0
for row_number, row in csv_rows(series_csv):
series_name = required_csv_value(row, series_csv, row_number, NAME_COLUMN)
author_name = required_csv_value(row, series_csv, row_number, AUTHOR_NAME_COLUMN)
author = find_author_by_name(session, author_name)
if author is None:
msg = f"{series_csv}:{row_number}: author not found: {author_name}"
raise CatalogImportError(msg)
upsert_series(session, series_name, author, csv_id(row, series_csv, row_number))
count += 1
return count
def upsert_author(session: Session, name: str, author_id: int | None) -> AudiobookAuthor:
"""Upsert one author by id or exact name."""
if author_id is not None:
author = session.get(AudiobookAuthor, author_id)
if author is None:
author = AudiobookAuthor(id=author_id, name=name)
session.add(author)
else:
author.name = name
session.flush()
return author
author = find_author_by_name(session, name)
if author is None:
author = AudiobookAuthor(name=name)
session.add(author)
session.flush()
return author
def upsert_series(
session: Session,
name: str,
author: AudiobookAuthor,
series_id: int | None,
) -> AudiobookSeries:
"""Upsert one series by id or exact author/name match."""
if series_id is not None:
series = session.get(AudiobookSeries, series_id)
if series is None:
series = AudiobookSeries(id=series_id, name=name, author=author)
session.add(series)
else:
series.name = name
series.author = author
session.flush()
return series
series = find_series_by_name_and_author(session, name, author.id)
if series is None:
series = AudiobookSeries(name=name, author=author)
session.add(series)
session.flush()
return series
def find_author_by_name(session: Session, name: str) -> AudiobookAuthor | None:
"""Find one author by exact name."""
return session.scalar(select(AudiobookAuthor).where(AudiobookAuthor.name == name))
def find_series_by_name_and_author(
session: Session,
name: str,
author_id: int,
) -> AudiobookSeries | None:
"""Find one series by exact name and author."""
return session.scalar(
select(AudiobookSeries).where(
AudiobookSeries.name == name,
AudiobookSeries.author_id == author_id,
),
)
def csv_rows(csv_path: Path) -> list[tuple[int, dict[str, str | None]]]:
"""Read a CSV file as numbered rows."""
with csv_path.open(newline="", encoding="utf-8") as file:
reader = csv.DictReader(file)
if reader.fieldnames is None:
msg = f"{csv_path}: missing CSV header"
raise CatalogImportError(msg)
return [(row_number, row) for row_number, row in enumerate(reader, start=2)]
def required_csv_value(
row: dict[str, str | None],
csv_path: Path,
row_number: int,
column: str,
) -> str:
"""Read a required CSV value."""
value = row.get(column)
if value and value.strip():
return value.strip()
msg = f"{csv_path}:{row_number}: missing required column value: {column}"
raise CatalogImportError(msg)
def csv_id(row: dict[str, str | None], csv_path: Path, row_number: int) -> int | None:
"""Read an optional id field from a CSV row."""
value = row.get(ID_COLUMN)
if value is None or not value.strip():
return None
try:
return int(value)
except ValueError as error:
msg = f"{csv_path}:{row_number}: id must be an integer: {value}"
raise CatalogImportError(msg) from error
return None
if __name__ == "__main__":
typer.run(main)
-599
View File
@@ -1,599 +0,0 @@
"""LLM tool calling support for audiobook metadata resolution."""
from __future__ import annotations
import json
import re
import time
from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING
from sqlalchemy import or_, select
from python.orm.richie import Audiobook, AudiobookAuthor, AudiobookSeries
if TYPE_CHECKING:
from pathlib import Path
from sqlalchemy.orm import Session
from python.tools.audiobook.metadata_agent import AgentConfig
CATALOG_SLUG_PATTERN = re.compile(r"^[a-z0-9]+(?:_[a-z0-9]+)*$")
TITLE_SLUG_PATTERN = re.compile(r"^[a-z0-9]+(?:-[a-z0-9]+)*$")
LogWriter = Callable[..., None]
class MetadataResolutionError(ValueError):
"""Metadata resolution failed validation."""
@dataclass(frozen=True)
class EnsuredBook:
"""Book row plus whether it was created."""
book: Audiobook
action: str
class CatalogToolRegistry:
"""Controlled catalog tools exposed to the metadata model."""
def __init__(
self,
session: Session,
log_path: Path,
config: AgentConfig,
write_log: LogWriter,
) -> None:
"""Create a registry bound to one database session and audit log."""
self.session = session
self.log_path = log_path
self.config = config
self.write_log = write_log
self.seen_author_ids: set[int] = set()
self.seen_series_ids: set[int] = set()
self.seen_book_ids: set[int] = set()
self.created_author_ids: set[int] = set()
self.created_series_ids: set[int] = set()
self.created_book_ids: set[int] = set()
def tool_schemas(self) -> list[dict[str, object]]:
"""Return Ollama tool schemas."""
schemas = [
{
"type": "function",
"function": {
"name": "search_authors",
"description": "Search canonical audiobook authors by slug or noisy source text.",
"parameters": {
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"],
},
},
},
{
"type": "function",
"function": {
"name": "search_series",
"description": "Search canonical audiobook series by slug or noisy source text.",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string"},
"author_id": {"type": ["integer", "null"]},
},
"required": ["query"],
},
},
},
{
"type": "function",
"function": {
"name": "search_books",
"description": "Search canonical audiobook titles with optional author and series filters.",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string"},
"author_id": {"type": ["integer", "null"]},
"series_id": {"type": ["integer", "null"]},
},
"required": ["query"],
},
},
},
{
"type": "function",
"function": {
"name": "ensure_author",
"description": "Normalize an author name to a catalog slug, then return or create that author.",
"parameters": {
"type": "object",
"properties": {"name": {"type": "string"}},
"required": ["name"],
},
},
},
{
"type": "function",
"function": {
"name": "ensure_series",
"description": "Normalize a series name to a catalog slug, then return or create it for an author.",
"parameters": {
"type": "object",
"properties": {
"name": {"type": "string"},
"author_id": {"type": "integer"},
},
"required": ["name", "author_id"],
},
},
},
{
"type": "function",
"function": {
"name": "ensure_book",
"description": "Normalize a title to a book slug, then return or create it for an author/series.",
"parameters": {
"type": "object",
"properties": {
"title": {"type": "string"},
"author_id": {"type": "integer"},
"series_id": {"type": ["integer", "null"]},
"series_index": {"type": "number", "multipleOf": 0.5},
},
"required": ["title", "author_id", "series_id", "series_index"],
},
},
},
]
enabled_tool_names = set(self.config.tool_names)
return [schema for schema in schemas if schema["function"]["name"] in enabled_tool_names]
def run(self, name: str, arguments: dict[str, object]) -> list[dict[str, object]]:
"""Run one catalog tool and audit the call."""
handlers = {
"search_authors": self.run_search_authors,
"search_series": self.run_search_series,
"search_books": self.run_search_books,
"ensure_author": self.run_ensure_author,
"ensure_series": self.run_ensure_series,
"ensure_book": self.run_ensure_book,
}
handler = handlers.get(name)
if handler is None:
self.write_log(self.log_path, "tool_error", tool=name, arguments=arguments, error="unknown_tool")
msg = f"Unknown audiobook metadata tool: {name}"
raise MetadataResolutionError(msg)
if name not in self.config.tool_names:
self.write_log(self.log_path, "tool_error", tool=name, arguments=arguments, error="tool_not_enabled")
msg = f"Audiobook metadata tool is not enabled: {name}"
raise MetadataResolutionError(msg)
started = time.perf_counter()
self.write_log(self.log_path, "tool_call", tool=name, arguments=arguments)
result = handler(arguments)
duration_ms = round((time.perf_counter() - started) * 1000, 3)
self.write_log(
self.log_path,
"tool_result",
tool=name,
duration_ms=duration_ms,
result_count=len(result),
preview=result[:3],
)
return result
def get_author(self, author_id: int) -> AudiobookAuthor | None:
"""Return an author by id."""
return self.session.get(AudiobookAuthor, author_id)
def get_book(self, book_id: int) -> Audiobook | None:
"""Return a book by id."""
return self.session.get(Audiobook, book_id)
def get_series(self, series_id: int) -> AudiobookSeries | None:
"""Return a series by id."""
return self.session.get(AudiobookSeries, series_id)
def prune_unused_created_rows(self, *, author_id: int, book_id: int | None, series_id: int | None) -> None:
"""Remove catalog rows created during this run but not used by final metadata."""
used_book_ids = {book_id} if book_id is not None else set()
for created_book_id in self.created_book_ids - used_book_ids:
if book := self.get_book(created_book_id):
self.session.delete(book)
self.session.flush()
used_series_ids = {series_id} if series_id is not None else set()
for created_series_id in self.created_series_ids - used_series_ids:
series = self.get_series(created_series_id)
if series and not series.books:
self.session.delete(series)
self.session.flush()
for created_author_id in self.created_author_ids - {author_id}:
author = self.get_author(created_author_id)
if author and not author.books and not author.series:
self.session.delete(author)
def run_search_authors(self, arguments: dict[str, object]) -> list[dict[str, object]]:
"""Search authors from tool arguments and remember returned ids."""
query = required_string(arguments, "query")
statement = select(AudiobookAuthor).order_by(AudiobookAuthor.name).limit(self.config.max_tool_results)
if terms := query_terms(query):
statement = statement.where(or_(*(AudiobookAuthor.name.ilike(f"%{term}%") for term in terms)))
authors = self.session.scalars(statement).all()
self.seen_author_ids.update(author.id for author in authors)
return [{"id": author.id, "name": author.name} for author in authors]
def run_search_series(self, arguments: dict[str, object]) -> list[dict[str, object]]:
"""Search series from tool arguments and remember returned ids."""
query = required_string(arguments, "query")
author_id = optional_int(arguments.get("author_id"), "author_id")
statement = select(AudiobookSeries).order_by(AudiobookSeries.name).limit(self.config.max_tool_results)
if terms := query_terms(query):
statement = statement.where(or_(*(AudiobookSeries.name.ilike(f"%{term}%") for term in terms)))
if author_id is not None:
statement = statement.where(AudiobookSeries.author_id == author_id)
series_rows = self.session.scalars(statement).all()
self.seen_series_ids.update(series.id for series in series_rows)
self.seen_author_ids.update(series.author_id for series in series_rows)
return [
{
"id": series.id,
"name": series.name,
"author_id": series.author_id,
"author": series.author.name,
}
for series in series_rows
]
def run_search_books(self, arguments: dict[str, object]) -> list[dict[str, object]]:
"""Search books from tool arguments and remember returned ids."""
query = required_string(arguments, "query")
author_id = optional_int(arguments.get("author_id"), "author_id")
series_id = optional_int(arguments.get("series_id"), "series_id")
statement = select(Audiobook).order_by(Audiobook.title).limit(self.config.max_tool_results)
if terms := query_terms(query):
statement = statement.where(or_(*(Audiobook.title.ilike(f"%{term}%") for term in terms)))
if author_id is not None:
statement = statement.where(Audiobook.author_id == author_id)
if series_id is not None:
statement = statement.where(Audiobook.series_id == series_id)
books = self.session.scalars(statement).all()
self.seen_book_ids.update(book.id for book in books)
self.seen_author_ids.update(book.author_id for book in books)
self.seen_series_ids.update(book.series_id for book in books if book.series_id is not None)
return [
{
"id": book.id,
"title": book.title,
"author_id": book.author_id,
"author": book.author.name,
"series_id": book.series_id,
"series": book.series.name if book.series else self.config.standalone_series,
"series_index": book.series_index,
}
for book in books
]
def run_ensure_author(self, arguments: dict[str, object]) -> list[dict[str, object]]:
"""Ensure an author from tool arguments and return a tool result."""
name = normalize_catalog_slug(required_string(arguments, "name"))
validate_catalog_slug(name, "author")
author = self.session.scalar(select(AudiobookAuthor).where(AudiobookAuthor.name == name))
action = "existing"
if author is None:
author = AudiobookAuthor(name=name)
self.session.add(author)
self.session.flush()
self.created_author_ids.add(author.id)
action = "created"
self.seen_author_ids.add(author.id)
return [{"id": author.id, "name": author.name, "action": action}]
def run_ensure_series(self, arguments: dict[str, object]) -> list[dict[str, object]]:
"""Ensure a series from tool arguments and return a tool result."""
name = normalize_catalog_slug(required_string(arguments, "name"))
author_id = required_int(arguments, "author_id")
validate_catalog_slug(name, "series")
author = self.required_author(author_id)
series = self.find_series_by_catalog_slug(name, author.id)
action = "existing"
if series is None:
series = AudiobookSeries(name=name, author=author)
self.session.add(series)
self.session.flush()
self.created_series_ids.add(series.id)
action = "created"
self.seen_author_ids.add(author.id)
self.seen_series_ids.add(series.id)
return [self.series_result(series, action)]
def run_ensure_book(self, arguments: dict[str, object]) -> list[dict[str, object]]:
"""Ensure a book from tool arguments and return a tool result."""
title = required_string(arguments, "title")
author_id = required_int(arguments, "author_id")
series_id = optional_int(arguments.get("series_id"), "series_id")
series_index = required_series_index(arguments, "series_index")
ensured = self.ensure_book(title, author_id, series_id, series_index)
return [self.book_result(ensured.book, ensured.action)]
def ensure_book(
self,
title: str,
author_id: int,
series_id: int | None,
series_index: float,
) -> EnsuredBook:
"""Return an existing book row, or create it after validating ownership."""
title = normalize_title_slug(title)
validate_title_slug(title)
author = self.required_author(author_id)
series = None
if series_id is None:
if series_index != 0:
msg = "standalone books must use series_index 0"
raise MetadataResolutionError(msg)
else:
series = self.required_series(series_id)
if series.author_id != author.id:
msg = f"series_id {series_id} does not belong to author_id {author_id}"
raise MetadataResolutionError(msg)
if series_index <= 0:
msg = "series books must use a positive series_index"
raise MetadataResolutionError(msg)
statement = select(Audiobook).where(
Audiobook.title == title,
Audiobook.author_id == author.id,
)
if series is None:
statement = statement.where(Audiobook.series_id.is_(None))
else:
statement = statement.where(Audiobook.series_id == series.id)
book = self.session.scalar(statement)
if book is None:
book = Audiobook(title=title, author=author, series=series, series_index=series_index)
self.session.add(book)
self.session.flush()
self.created_book_ids.add(book.id)
action = "created"
else:
action = "existing"
self.seen_book_ids.add(book.id)
self.seen_author_ids.add(author.id)
if book.series_id is not None:
self.seen_series_ids.add(book.series_id)
return EnsuredBook(book=book, action=action)
def required_author(self, author_id: int) -> AudiobookAuthor:
"""Return an author or fail metadata resolution."""
author = self.get_author(author_id)
if author is None:
msg = f"author_id {author_id} does not exist"
raise MetadataResolutionError(msg)
return author
def required_series(self, series_id: int) -> AudiobookSeries:
"""Return a series or fail metadata resolution."""
series = self.get_series(series_id)
if series is None:
msg = f"series_id {series_id} does not exist"
raise MetadataResolutionError(msg)
return series
def find_series_by_catalog_slug(self, name: str, author_id: int) -> AudiobookSeries | None:
"""Return a series by exact slug or underscore-insensitive slug."""
exact = self.session.scalar(
select(AudiobookSeries).where(
AudiobookSeries.name == name,
AudiobookSeries.author_id == author_id,
),
)
if exact is not None:
return exact
compact_name = compact_catalog_slug(name)
series_rows = self.session.scalars(
select(AudiobookSeries).where(AudiobookSeries.author_id == author_id).order_by(AudiobookSeries.name),
).all()
for series in series_rows:
if compact_catalog_slug(series.name) == compact_name:
return series
return None
def series_result(self, series: AudiobookSeries, action: str) -> dict[str, object]:
"""Build a normalized series tool result."""
return {
"id": series.id,
"name": series.name,
"author_id": series.author_id,
"author": series.author.name,
"action": action,
}
def book_result(self, book: Audiobook, action: str) -> dict[str, object]:
"""Build a normalized book tool result."""
return {
"id": book.id,
"title": book.title,
"author_id": book.author_id,
"author": book.author.name,
"series_id": book.series_id,
"series": book.series.name if book.series else self.config.standalone_series,
"series_index": book.series_index,
"action": action,
}
def run_tool_calls(
messages: list[dict[str, object]],
message: dict[str, object],
tool_calls: list[tuple[str, dict[str, object]]],
registry: CatalogToolRegistry,
log_path: Path,
write_log: LogWriter,
) -> str | None:
"""Run tool calls, append tool messages, and return fatal error text when stopped."""
messages.append(message)
for tool_name, arguments in tool_calls:
try:
tool_result = registry.run(tool_name, arguments)
except MetadataResolutionError as error:
if is_fatal_tool_error(error):
return str(error)
write_log(log_path, "tool_error", tool=tool_name, arguments=arguments, error=str(error))
messages.append(
{
"role": "tool",
"tool_name": tool_name,
"content": json.dumps({"error": str(error)}, sort_keys=True),
},
)
continue
messages.append(
{
"role": "tool",
"tool_name": tool_name,
"content": json.dumps(tool_result, sort_keys=True),
},
)
return None
def parse_tool_calls(message: dict[str, object]) -> list[tuple[str, dict[str, object]]]:
"""Parse Ollama tool calls from a response message."""
raw_tool_calls = message.get("tool_calls") or []
if not isinstance(raw_tool_calls, list):
msg = "tool_calls must be a list"
raise MetadataResolutionError(msg)
tool_calls = []
for raw_call in raw_tool_calls:
if not isinstance(raw_call, dict):
msg = "tool call must be an object"
raise MetadataResolutionError(msg)
function = raw_call.get("function")
if not isinstance(function, dict):
msg = "tool call is missing function"
raise MetadataResolutionError(msg)
name = function.get("name")
if not isinstance(name, str) or not name:
msg = "tool call is missing function name"
raise MetadataResolutionError(msg)
arguments = parse_tool_arguments(function.get("arguments", {}))
tool_calls.append((name, arguments))
return tool_calls
def parse_tool_arguments(raw_arguments: object) -> dict[str, object]:
"""Parse tool call arguments returned by Ollama."""
if isinstance(raw_arguments, dict):
return {str(key): value for key, value in raw_arguments.items()}
if isinstance(raw_arguments, str):
parsed = json.loads(raw_arguments) if raw_arguments else {}
if isinstance(parsed, dict):
return {str(key): value for key, value in parsed.items()}
msg = "tool arguments must be an object"
raise MetadataResolutionError(msg)
def validate_title_slug(title: str) -> None:
"""Validate a canonical book title slug."""
if not TITLE_SLUG_PATTERN.fullmatch(title):
msg = f"title slug is invalid: {title}"
raise MetadataResolutionError(msg)
def validate_catalog_slug(value: str, label: str) -> None:
"""Validate a canonical catalog slug."""
if not CATALOG_SLUG_PATTERN.fullmatch(value):
msg = f"{label} slug is invalid: {value}"
raise MetadataResolutionError(msg)
def normalize_catalog_slug(value: str) -> str:
"""Normalize noisy catalog names into lower snake-case slugs."""
return re.sub(r"[^a-z0-9]+", "_", value.strip().casefold()).strip("_")
def compact_catalog_slug(value: str) -> str:
"""Return a catalog slug comparison key that ignores underscores."""
return normalize_catalog_slug(value).replace("_", "")
def normalize_title_slug(value: str) -> str:
"""Normalize noisy book titles into lower kebab-case slugs."""
return re.sub(r"[^a-z0-9]+", "-", value.strip().casefold()).strip("-")
def is_fatal_tool_error(error: MetadataResolutionError) -> bool:
"""Return whether a tool error should stop the agent immediately."""
message = str(error)
return message.startswith(
(
"Unknown audiobook metadata tool",
"Audiobook metadata tool is not enabled",
),
)
def query_terms(query: str) -> tuple[str, ...]:
"""Return text variants useful for matching noisy audiobook metadata."""
normalized = query.strip().casefold()
underscore_slug = normalize_catalog_slug(normalized)
compact_slug = compact_catalog_slug(normalized)
hyphen_slug = normalize_title_slug(normalized)
return tuple(dict.fromkeys(term for term in (normalized, underscore_slug, compact_slug, hyphen_slug) if term))
def required_string(data: dict[str, object], key: str) -> str:
"""Read a required string field."""
value = data.get(key)
if not isinstance(value, str) or not value.strip():
msg = f"{key} must be a non-empty string"
raise MetadataResolutionError(msg)
return value.strip()
def required_int(data: dict[str, object], key: str) -> int:
"""Read a required integer field."""
value = data.get(key)
if isinstance(value, bool) or not isinstance(value, int):
msg = f"{key} must be an integer"
raise MetadataResolutionError(msg)
return value
def required_series_index(data: dict[str, object], key: str) -> float:
"""Read a required whole-number or half-number series index."""
value = data.get(key)
if isinstance(value, bool) or not isinstance(value, int | float):
msg = f"{key} must be a number"
raise MetadataResolutionError(msg)
series_index = float(value)
if not (series_index * 2).is_integer():
msg = f"{key} must be a whole number or .5 increment"
raise MetadataResolutionError(msg)
return series_index
def optional_int(value: object, key: str) -> int | None:
"""Read an optional integer field."""
if value is None:
return None
if isinstance(value, bool) or not isinstance(value, int):
msg = f"{key} must be an integer or null"
raise MetadataResolutionError(msg)
return value
-575
View File
@@ -1,575 +0,0 @@
"""Resolve audiobook metadata with a controlled Ollama tool loop."""
from __future__ import annotations
import json
import re
from dataclasses import asdict, dataclass, is_dataclass, replace
from os import PathLike
from typing import TYPE_CHECKING
import httpx
from sqlalchemy.orm import Session
from python.common import utcnow
from python.tools.audiobook.llm_tool_calling import (
CatalogToolRegistry,
MetadataResolutionError,
normalize_title_slug,
optional_int,
parse_tool_calls,
required_int,
required_series_index,
required_string,
run_tool_calls,
validate_catalog_slug,
validate_title_slug,
)
if TYPE_CHECKING:
from pathlib import Path
from sqlalchemy.engine import Engine
from python.orm.richie import AudiobookAuthor
FENCED_JSON_PATTERN = re.compile(r"^```(?:json)?\s*(?P<json>.*?)\s*```$", re.IGNORECASE | re.DOTALL)
@dataclass(frozen=True)
class AgentConfig:
"""Runtime settings for the audiobook metadata agent."""
model: str = "deepseek-v4-flash:cloud"
ollama_chat_url: str = "https://ollama.com/api/chat"
http_timeout_seconds: int = 300
max_agent_turns: int = 8
max_tool_results: int = 10
min_confidence: float = 0.85
invalid_final_retries: int = 1
standalone_series: str = "standalone"
tool_names: tuple[str, ...] = (
"search_authors",
"search_series",
"search_books",
"ensure_author",
"ensure_series",
"ensure_book",
)
@dataclass(frozen=True)
class StandardBookMetadata:
"""Canonical metadata for the final audiobook path."""
author_id: int
author: str
book_id: int | None
title: str
series_id: int | None
series: str
series_index: float
confidence: float
needs_review: bool
evidence: list[str]
@dataclass(frozen=True)
class FinalMetadataFields:
"""Raw model fields after schema validation."""
author_id: int
book_id: int | None
title: str
series_id: int | None
series_index: float
confidence: float
evidence: list[str]
@dataclass(frozen=True)
class ResolvedBookFields:
"""Book fields after optional catalog book resolution."""
book_id: int | None
title: str
series_id: int | None
series_index: float
@dataclass(frozen=True)
class AgentStepResult:
"""Outcome from one model response."""
metadata: StandardBookMetadata | None
invalid_final_count: int
should_continue: bool
def standard_book_metadata(
aax_file_name: str,
aax_metadata_from_ffprobe: dict[str, str],
engine: Engine,
log_path: Path,
ollama_api_key: str,
config: AgentConfig,
) -> StandardBookMetadata:
"""Resolve canonical audiobook metadata with the configured Ollama Cloud model."""
with Session(engine) as session:
registry = CatalogToolRegistry(session, log_path, config, write_agent_log)
agent = AudiobookMetadataAgent(
registry=registry, log_path=log_path, ollama_api_key=ollama_api_key, config=config
)
metadata = agent.run(aax_file_name, aax_metadata_from_ffprobe)
if metadata.needs_review:
session.rollback()
else:
registry.prune_unused_created_rows(
author_id=metadata.author_id,
book_id=metadata.book_id,
series_id=metadata.series_id,
)
session.commit()
return metadata
class AudiobookMetadataAgent:
"""Ollama-backed metadata resolver with a fixed local tool registry."""
def __init__(
self,
*,
registry: CatalogToolRegistry,
log_path: Path,
ollama_api_key: str,
config: AgentConfig,
) -> None:
"""Create an Ollama metadata agent."""
self._registry = registry
self._log_path = log_path
self._ollama_api_key = ollama_api_key
self._config = config
def run(self, aax_file_name: str, aax_metadata_from_ffprobe: dict[str, str]) -> StandardBookMetadata:
"""Resolve metadata for one AAX file."""
messages = [
{"role": "system", "content": system_prompt()},
{"role": "user", "content": user_prompt(aax_file_name, aax_metadata_from_ffprobe)},
]
invalid_final_count = 0
result: StandardBookMetadata | None = None
for turn in range(1, self._config.max_agent_turns + 1):
step = self.run_step(messages, turn, invalid_final_count)
invalid_final_count = step.invalid_final_count
if step.should_continue:
continue
result = step.metadata
break
if result is None:
return self.force_final_response(messages)
return result
def run_step(
self,
messages: list[dict[str, object]],
turn: int,
invalid_final_count: int,
) -> AgentStepResult:
"""Run one model turn and return the next agent-loop action."""
data = self.chat(messages, turn)
message = data.get("message")
if not isinstance(message, dict):
return AgentStepResult(
metadata=review_metadata("Ollama response did not include a message", self._config),
invalid_final_count=invalid_final_count,
should_continue=False,
)
try:
tool_calls = parse_tool_calls(message)
except (json.JSONDecodeError, MetadataResolutionError) as error:
return AgentStepResult(
metadata=review_metadata(str(error), self._config),
invalid_final_count=invalid_final_count,
should_continue=False,
)
if tool_calls:
fatal_error = run_tool_calls(messages, message, tool_calls, self._registry, self._log_path, write_agent_log)
if fatal_error is not None:
return AgentStepResult(
metadata=review_metadata(fatal_error, self._config),
invalid_final_count=invalid_final_count,
should_continue=False,
)
return AgentStepResult(metadata=None, invalid_final_count=invalid_final_count, should_continue=True)
return self.handle_final_message(messages, message, invalid_final_count)
def handle_final_message(
self,
messages: list[dict[str, object]],
message: dict[str, object],
invalid_final_count: int,
) -> AgentStepResult:
"""Validate a final model message or request one retry."""
content = message.get("content")
if not isinstance(content, str):
return AgentStepResult(
metadata=review_metadata("Ollama final response did not include string content", self._config),
invalid_final_count=invalid_final_count,
should_continue=False,
)
try:
resolved = self.validate_final(parse_final_json_content(content))
except (json.JSONDecodeError, MetadataResolutionError) as error:
return self.handle_invalid_final(messages, error, invalid_final_count)
write_agent_log(self._log_path, "final_metadata", metadata=resolved)
return AgentStepResult(metadata=resolved, invalid_final_count=invalid_final_count, should_continue=False)
def handle_invalid_final(
self,
messages: list[dict[str, object]],
error: json.JSONDecodeError | MetadataResolutionError,
invalid_final_count: int,
) -> AgentStepResult:
"""Log invalid final JSON and either retry or return review metadata."""
invalid_final_count += 1
write_agent_log(
self._log_path,
"final_validation_error",
error=str(error),
invalid_final_count=invalid_final_count,
)
if invalid_final_count > self._config.invalid_final_retries:
return AgentStepResult(
metadata=review_metadata(str(error), self._config),
invalid_final_count=invalid_final_count,
should_continue=False,
)
messages.append(
{
"role": "user",
"content": (
"Your previous final answer was invalid. Return only valid JSON matching the required "
f"schema. Validation error: {error}"
),
},
)
return AgentStepResult(metadata=None, invalid_final_count=invalid_final_count, should_continue=True)
def force_final_response(self, messages: list[dict[str, object]]) -> StandardBookMetadata:
"""Request a no-tool final answer after the normal turn limit."""
messages.append({"role": "user", "content": forced_final_prompt()})
write_agent_log(self._log_path, "forced_final_request", reason="max_turns")
data = self.chat(messages, self._config.max_agent_turns + 1, tools_enabled=False)
message = data.get("message")
if not isinstance(message, dict):
return review_metadata("Ollama forced final response did not include a message", self._config)
content = message.get("content")
if not isinstance(content, str):
return review_metadata("Ollama forced final response did not include string content", self._config)
try:
resolved = self.validate_final(parse_final_json_content(content))
except (json.JSONDecodeError, MetadataResolutionError) as error:
return review_metadata(f"Ollama forced final response was invalid: {error}", self._config)
write_agent_log(self._log_path, "final_metadata", metadata=resolved)
return resolved
def chat(self, messages: list[dict[str, object]], turn: int, *, tools_enabled: bool = True) -> dict[str, object]:
"""Send one chat request to Ollama and log the request and response."""
payload = {
"model": self._config.model,
"messages": messages,
"stream": False,
"options": {"temperature": 0.1},
}
tool_names = []
if tools_enabled:
payload["tools"] = self._registry.tool_schemas()
tool_names = self._config.tool_names
write_agent_log(
self._log_path,
"model_request",
model=self._config.model,
turn=turn,
message_count=len(messages),
tool_names=tool_names,
tools_enabled=tools_enabled,
)
write_agent_log(
self._log_path,
"llm_messages_sent",
model=self._config.model,
turn=turn,
messages=messages,
tools_enabled=tools_enabled,
)
response = httpx.post(
self._config.ollama_chat_url,
headers={"Authorization": f"Bearer {self._ollama_api_key}"},
json=payload,
timeout=self._config.http_timeout_seconds,
)
response.raise_for_status()
raw_data = response.json()
if not isinstance(raw_data, dict):
return {}
data = {str(key): value for key, value in raw_data.items()}
message = data.get("message", {})
content = message.get("content") if isinstance(message, dict) else ""
write_agent_log(
self._log_path,
"llm_message_received",
model=self._config.model,
turn=turn,
message=message,
)
write_agent_log(
self._log_path,
"model_response",
model=self._config.model,
turn=turn,
has_tool_calls=bool(isinstance(message, dict) and message.get("tool_calls")),
content_chars=len(content) if isinstance(content, str) else 0,
)
return data
def validate_final(self, raw_metadata: object) -> StandardBookMetadata:
"""Validate final model metadata against catalog rows."""
fields = parse_final_metadata_fields(raw_metadata)
fields = replace(fields, title=normalize_title_slug(fields.title))
author = self.validate_author(fields.author_id)
validate_title_slug(fields.title)
book_fields = self.resolve_book_fields(fields)
series = self.validate_series(fields.author_id, book_fields.series_id, book_fields.series_index)
return StandardBookMetadata(
author_id=fields.author_id,
author=author.name,
book_id=book_fields.book_id,
title=book_fields.title,
series_id=book_fields.series_id,
series=series,
series_index=book_fields.series_index,
confidence=fields.confidence,
needs_review=fields.confidence < self._config.min_confidence,
evidence=fields.evidence,
)
def validate_author(self, author_id: int) -> AudiobookAuthor:
"""Validate that an author id was seen and exists."""
if author_id not in self._registry.seen_author_ids:
msg = f"author_id {author_id} was not returned by search_authors"
raise MetadataResolutionError(msg)
author = self._registry.get_author(author_id)
if author is None:
msg = f"author_id {author_id} does not exist"
raise MetadataResolutionError(msg)
validate_catalog_slug(author.name, "author")
return author
def resolve_book_fields(self, fields: FinalMetadataFields) -> ResolvedBookFields:
"""Resolve final book fields from a seen book id or created book."""
if fields.book_id is None:
ensured = self._registry.ensure_book(
fields.title,
fields.author_id,
fields.series_id,
fields.series_index,
)
return ResolvedBookFields(
book_id=ensured.book.id,
title=ensured.book.title,
series_id=ensured.book.series_id,
series_index=ensured.book.series_index,
)
if fields.book_id not in self._registry.seen_book_ids:
msg = f"book_id {fields.book_id} was not returned by search_books"
raise MetadataResolutionError(msg)
book = self._registry.get_book(fields.book_id)
if book is None:
msg = f"book_id {fields.book_id} does not exist"
raise MetadataResolutionError(msg)
if book.author_id != fields.author_id:
msg = f"book_id {fields.book_id} does not belong to author_id {fields.author_id}"
raise MetadataResolutionError(msg)
return ResolvedBookFields(
book_id=fields.book_id,
title=book.title,
series_id=book.series_id,
series_index=book.series_index,
)
def validate_series(self, author_id: int, series_id: int | None, series_index: float) -> str:
"""Validate final series fields and return the canonical series slug."""
if series_id is None:
if series_index != 0:
msg = "standalone books must use series_index 0"
raise MetadataResolutionError(msg)
return self._config.standalone_series
if series_id not in self._registry.seen_series_ids:
msg = f"series_id {series_id} was not returned by search_series"
raise MetadataResolutionError(msg)
series = self._registry.get_series(series_id)
if series is None:
msg = f"series_id {series_id} does not exist"
raise MetadataResolutionError(msg)
if series.author_id != author_id:
msg = f"series_id {series_id} does not belong to author_id {author_id}"
raise MetadataResolutionError(msg)
if series_index <= 0:
msg = "series books must use a positive series_index"
raise MetadataResolutionError(msg)
validate_catalog_slug(series.name, "series")
return series.name
def write_agent_log(log_path: Path, event: str, **fields: object) -> None:
"""Append one JSONL audit event."""
log_path.parent.mkdir(parents=True, exist_ok=True)
record = {
"created": utcnow().isoformat(),
"event": event,
**{key: json_log_value(value) for key, value in fields.items()},
}
with log_path.open("a", encoding="utf-8") as file:
file.write(json.dumps(record, sort_keys=True))
file.write("\n")
def json_log_value(value: object) -> object:
"""Return a JSON-serializable value for audit logs."""
if is_dataclass(value) and not isinstance(value, type):
return json_log_value(asdict(value))
if isinstance(value, dict):
return {str(key): json_log_value(item) for key, item in value.items()}
if isinstance(value, list | tuple):
return [json_log_value(item) for item in value]
if isinstance(value, set):
return [json_log_value(item) for item in sorted(value, key=str)]
if isinstance(value, PathLike):
return str(value)
return value
def system_prompt() -> str:
"""Return the stable system prompt."""
return """You standardize Audible audiobook metadata against a private catalog.
Rules:
- You must use the provided tools before returning final metadata.
- Only use author_id, series_id, or book_id values returned by tools.
- Return final metadata as JSON only. Do not wrap it in Markdown.
- The final JSON object must contain author_id, book_id, title, series_id, series_index, confidence, and evidence.
- title must be a canonical title slug using lower-case words separated by hyphens.
- Use series_id null and series_index 0 for standalone books.
- If you use a series_id, series_index must be a whole number or .5 value greater than 0.
- Treat series slugs that differ only by underscores as the same series. Prefer the existing catalog row instead of
creating a new series.
- Detect omnibus or box-set editions that contain multiple numbered novels, books, or novellas.
- For an omnibus, make a best-effort range from the filename, tags, and catalog rows. Keep series_index as the
first covered book number and include the range in the title when the source title includes it, for example
books-1-3.
- Be careful with omnibuses of novels or novellas later published as one book: keep the omnibus as the audiobook's
book record unless catalog rows clearly identify a better match.
- Do not create publisher collections or author collections as series unless the book metadata clearly gives a
numbered series.
- Series belong to authors. Use a series_id only when it belongs to the selected author_id.
- Always search for the author before creating one. If no exact author slug exists, call ensure_author.
- Always search for a series with author_id before creating one. If no exact series slug exists, call ensure_series.
- Always search for a book before creating one. If no exact title slug exists, call ensure_book.
- If a tool returns an error, correct your tool arguments or final metadata before continuing.
- confidence must be a number from 0 to 1.
- evidence must be a short list of strings explaining which filename, tags, and catalog rows support the answer."""
def forced_final_prompt() -> str:
"""Return the no-tools finalization prompt."""
return (
"Stop calling tools. Return final metadata as JSON only using the tool results already provided. "
"If search_books returned no matching rows but author and series are known, use book_id null and resolve "
"the title slug from the AAX filename and ffprobe tags. The validator will create the missing book. "
"Use only author_id and series_id values returned by earlier tool results."
)
def user_prompt(aax_file_name: str, metadata: dict[str, str]) -> str:
"""Build the user prompt from source metadata."""
return (
"Resolve this Audible audiobook.\n\n"
f"AAX file name: {aax_file_name}\n\n"
"ffprobe format tags:\n"
f"{json.dumps(metadata, indent=2, sort_keys=True)}"
)
def parse_final_json_content(content: str) -> object:
"""Parse final model content, accepting bare or fenced JSON."""
stripped = content.strip()
if match := FENCED_JSON_PATTERN.fullmatch(stripped):
stripped = match.group("json").strip()
return json.loads(stripped)
def parse_final_metadata_fields(raw_metadata: object) -> FinalMetadataFields:
"""Parse the model's final JSON object into typed fields."""
if not isinstance(raw_metadata, dict):
msg = "Final metadata must be a JSON object"
raise MetadataResolutionError(msg)
data = {str(key): value for key, value in raw_metadata.items()}
return FinalMetadataFields(
author_id=required_int(data, "author_id"),
book_id=optional_int(data.get("book_id"), "book_id"),
title=required_string(data, "title"),
series_id=optional_int(data.get("series_id"), "series_id"),
series_index=required_series_index(data, "series_index"),
confidence=required_float(data, "confidence"),
evidence=required_string_list(data, "evidence"),
)
def review_metadata(reason: str, config: AgentConfig) -> StandardBookMetadata:
"""Return a metadata result that must be reviewed manually."""
return StandardBookMetadata(
author_id=0,
author="unknown_author",
book_id=None,
title="unknown-title",
series_id=None,
series=config.standalone_series,
series_index=0,
confidence=0,
needs_review=True,
evidence=[reason],
)
def required_float(data: dict[str, object], key: str) -> float:
"""Read a required float field."""
value = data.get(key)
if isinstance(value, bool) or not isinstance(value, int | float):
msg = f"{key} must be a number"
raise MetadataResolutionError(msg)
confidence = float(value)
if confidence < 0 or confidence > 1:
msg = f"{key} must be between 0 and 1"
raise MetadataResolutionError(msg)
return confidence
def required_string_list(data: dict[str, object], key: str) -> list[str]:
"""Read a required list of strings."""
value = data.get(key)
if not isinstance(value, list) or not value or not all(isinstance(item, str) for item in value):
msg = f"{key} must be a non-empty list of strings"
raise MetadataResolutionError(msg)
strings = [item.strip() for item in value if item.strip()]
if not strings:
msg = f"{key} must include at least one non-empty string"
raise MetadataResolutionError(msg)
return strings
-17
View File
@@ -1,17 +0,0 @@
FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04
ENV DEBIAN_FRONTEND=noninteractive \
PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1
RUN apt-get update \
&& apt-get install -y --no-install-recommends python3 python3-pip ffmpeg \
&& rm -rf /var/lib/apt/lists/*
RUN pip3 install --no-cache-dir --upgrade pip \
&& pip3 install --no-cache-dir faster-whisper requests
WORKDIR /app
COPY python/tools/whisper/inference.py /app/inference.py
ENTRYPOINT ["python3", "/app/inference.py"]
@@ -1,2 +0,0 @@
*
!python/tools/whisper/inference.py
-1
View File
@@ -1 +0,0 @@
"""Whisper transcription tools (host orchestrator and container entrypoint)."""
-136
View File
@@ -1,136 +0,0 @@
"""Container entrypoint that transcribes a directory of audio files with faster-whisper.
Run inside the whisper-transcribe docker image; segment timestamps are grouped
into one-minute buckets so the output reads as ``[HH:MM:00] text``.
"""
from __future__ import annotations
import argparse
import logging
from pathlib import Path
from faster_whisper import WhisperModel
logger = logging.getLogger(__name__)
AUDIO_EXTENSIONS = {".mp3", ".wav", ".m4a", ".flac", ".ogg", ".opus", ".mp4", ".mkv", ".webm", ".aac"}
BUCKET_SECONDS = 60
BEAM_SIZE = 5
SECONDS_PER_HOUR = 3600
SECONDS_PER_MINUTE = 60
def format_timestamp(total_seconds: float) -> str:
"""Render a whole-minute timestamp as ``HH:MM:00``.
Args:
total_seconds: Offset in seconds from the start of the audio.
Returns:
A zero-padded ``HH:MM:00`` string.
"""
hours = int(total_seconds // SECONDS_PER_HOUR)
minutes = int((total_seconds % SECONDS_PER_HOUR) // SECONDS_PER_MINUTE)
return f"{hours:02d}:{minutes:02d}:00"
def transcribe_file(model: WhisperModel, audio_path: Path, output_path: Path) -> None:
"""Transcribe one audio file and write the bucketed transcript to disk.
Args:
model: Loaded faster-whisper model.
audio_path: Source audio file.
output_path: Destination ``.txt`` path.
"""
logger.info("Transcribing %s", audio_path)
segments, info = model.transcribe(
str(audio_path),
language="en",
beam_size=BEAM_SIZE,
vad_filter=True,
)
logger.info("Duration %.1fs", info.duration)
buckets: dict[int, list[str]] = {}
for segment in segments:
bucket = int(segment.start // BUCKET_SECONDS)
buckets.setdefault(bucket, []).append(segment.text.strip())
lines = [f"[{format_timestamp(bucket * BUCKET_SECONDS)}] {' '.join(buckets[bucket])}" for bucket in sorted(buckets)]
output_path.write_text("\n\n".join(lines) + "\n", encoding="utf-8")
logger.info("Wrote %s", output_path)
def find_audio_files(input_directory: Path) -> list[Path]:
"""Collect every audio file under ``input_directory``.
Args:
input_directory: Directory to walk recursively.
Returns:
Sorted list of audio file paths.
"""
return sorted(
path for path in input_directory.rglob("*") if path.is_file() and path.suffix.lower() in AUDIO_EXTENSIONS
)
def configure_container_logger() -> None:
"""Configure logging for the container (stdout, INFO)."""
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
)
def parse_arguments() -> argparse.Namespace:
"""Parse CLI arguments for the container entrypoint.
Returns:
Parsed argparse namespace.
"""
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--input", type=Path, default=Path("/audio"))
parser.add_argument("--output", type=Path, default=Path("/output"))
parser.add_argument("--model", default="large-v3")
parser.add_argument(
"--download-only",
action="store_true",
help="Download the model into the cache volume and exit without transcribing.",
)
return parser.parse_args()
def main() -> None:
"""Load the model, then either exit (download-only) or transcribe the directory."""
configure_container_logger()
arguments = parse_arguments()
logger.info("Loading model %s on CUDA", arguments.model)
model = WhisperModel(arguments.model, device="cuda", compute_type="float16")
if arguments.download_only:
logger.info("Model ready; exiting (download-only mode)")
return
arguments.output.mkdir(parents=True, exist_ok=True)
audio_files = find_audio_files(arguments.input)
if not audio_files:
logger.warning("No audio files found in %s", arguments.input)
return
logger.info("Found %d audio file(s)", len(audio_files))
for audio_path in audio_files:
relative = audio_path.relative_to(arguments.input)
output_path = arguments.output / relative.with_suffix(".txt")
output_path.parent.mkdir(parents=True, exist_ok=True)
if output_path.exists():
logger.info("Skip %s (already transcribed)", relative)
continue
transcribe_file(model, audio_path, output_path)
if __name__ == "__main__":
main()
-167
View File
@@ -1,167 +0,0 @@
"""Build and run the whisper transcription docker container on demand.
The container is started fresh for each invocation and removed on exit
(``docker run --rm``). The model is cached in a named docker volume so
only the first run pays the download cost.
"""
from __future__ import annotations
import logging
import subprocess
from pathlib import Path
from typing import Annotated
import typer
from python.common import configure_logger
logger = logging.getLogger(__name__)
class Config:
"""Paths and names for the whisper-transcribe Docker workflow."""
image_tag = "whisper-transcribe:latest"
model_volume = "whisper-models"
repo_root = Path(__file__).resolve().parents[3]
dockerfile = Path(__file__).resolve().parent / "Dockerfile"
huggingface_cache = "/root/.cache/huggingface"
def run_docker(arguments: list[str]) -> None:
"""Run a docker subcommand, streaming output and raising on failure.
Args:
arguments: Arguments to pass to the ``docker`` binary.
Raises:
subprocess.CalledProcessError: If docker exits non-zero.
"""
logger.info("docker %s", " ".join(arguments))
subprocess.run(["docker", *arguments], check=True)
def build_image() -> None:
"""Build the whisper-transcribe image using the repo root as build context."""
logger.info("Building image %s", Config.image_tag)
run_docker(
[
"build",
"--tag",
Config.image_tag,
"--file",
str(Config.dockerfile),
str(Config.repo_root),
],
)
def model_cache_present(model: str) -> bool:
"""Check whether the given model is already downloaded in the cache volume.
Args:
model: faster-whisper model name (e.g. ``large-v3``).
Returns:
True if the HuggingFace cache directory for the model exists in the volume.
"""
cache_directory = f"hub/models--Systran--faster-whisper-{model}"
completed = subprocess.run(
[
"docker",
"run",
"--rm",
"--volume",
f"{Config.model_volume}:/cache",
"alpine",
"test",
"-d",
f"/cache/{cache_directory}",
],
check=False,
)
return completed.returncode == 0
def download_model(model: str) -> None:
"""Download the model into the cache volume and exit.
Args:
model: faster-whisper model name.
"""
logger.info("Downloading model %s into volume %s", model, Config.model_volume)
run_docker(
[
"run",
"--rm",
"--device=nvidia.com/gpu=all",
"--ipc=host",
"--volume",
f"{Config.model_volume}:{Config.huggingface_cache}",
Config.image_tag,
"--model",
model,
"--download-only",
],
)
def transcribe(input_directory: Path, output_directory: Path, model: str) -> None:
"""Run transcription on every audio file under ``input_directory``.
Args:
input_directory: Host path containing audio files (mounted read-only).
output_directory: Host path for ``.txt`` transcripts.
model: faster-whisper model name.
"""
logger.info("Transcribing %s -> %s (model=%s)", input_directory, output_directory, model)
run_docker(
[
"run",
"--rm",
"--device=nvidia.com/gpu=all",
"--ipc=host",
"--volume",
f"{input_directory}:/audio:ro",
"--volume",
f"{output_directory}:/output",
"--volume",
f"{Config.model_volume}:{Config.huggingface_cache}",
Config.image_tag,
"--model",
model,
],
)
def main(
input_directory: Annotated[Path, typer.Argument(help="Directory of audio files to transcribe.")],
output_directory: Annotated[Path, typer.Argument(help="Directory to write .txt transcripts to.")],
model: Annotated[str, typer.Option(help="faster-whisper model name.")] = "large-v3",
*,
force_download: Annotated[
bool,
typer.Option("--force-download", help="Re-download the model even if already cached."),
] = False,
) -> None:
"""Build the image, ensure the model is cached, then transcribe and stop."""
configure_logger()
resolved_input = input_directory.resolve(strict=True)
output_directory.mkdir(parents=True, exist_ok=True)
resolved_output = output_directory.resolve()
build_image()
if force_download or not model_cache_present(model):
download_model(model)
else:
logger.info("Model %s already cached in volume %s", model, Config.model_volume)
transcribe(resolved_input, resolved_output, model)
logger.info("Done. Container stopped.")
if __name__ == "__main__":
typer.run(main)
+2 -9
View File
@@ -1,13 +1,11 @@
{ inputs, pkgs, ... }:
{
imports = [
"${inputs.self}/users/math"
"${inputs.self}/users/richie"
"${inputs.self}/users/steve"
"${inputs.self}/users/math"
"${inputs.self}/common/global"
"${inputs.self}/common/optional/docker.nix"
"${inputs.self}/common/optional/scanner.nix"
"${inputs.self}/common/optional/monitoring-agent.nix"
"${inputs.self}/common/optional/steam.nix"
"${inputs.self}/common/optional/syncthing_base.nix"
"${inputs.self}/common/optional/systemd-boot.nix"
@@ -28,12 +26,7 @@
networking = {
hostName = "bob";
hostId = "7c678a41";
firewall = {
enable = true;
allowedTCPPorts = [
8000
];
};
firewall.enable = true;
networkmanager.enable = true;
};
+1 -5
View File
@@ -28,13 +28,9 @@
allowDiscards = true;
keyFileSize = 4096;
keyFile = "/dev/disk/by-id/usb-Samsung_Flash_Drive_FIT_0374620080067131-0:0";
fallbackToPassword = true;
};
};
zfs.extraPools = [
"storage"
];
kernelModules = [ "kvm-amd" ];
extraModulePackages = [ ];
};
+2 -5
View File
@@ -4,7 +4,7 @@
host = "0.0.0.0";
enable = true;
syncModels = false;
syncModels = true;
loadModels = [
"codellama:7b"
"deepscaler:1.5b"
@@ -42,14 +42,11 @@
"qwen3:8b"
"qwen3.5:27b"
"qwen3.5:35b"
"qwen3.6:27b"
"qwen3.6:35b"
"rinex20/translategemma3:12b"
"translategemma:12b"
"translategemma:27b"
"translategemma:4b"
];
models = "/zfs/storage/models";
models = "/zfs/models";
openFirewall = true;
};
}
+11
View File
@@ -0,0 +1,11 @@
#!/bin/bash
# zpools
# storage
sudo zpool create -f -o ashift=12 -O acltype=posixacl -O atime=off -O dnodesize=auto -O xattr=sa -O compression=zstd -m /zfs/storage storage mirror
sudo zpool create -o ashift=12 -O acltype=posixacl -O atime=off -O dnodesize=auto -O xattr=sa -O compression=zstd -m /zfs/storage storage
# storage datasets
sudo zfs create storage/models -o recordsize=1M
+1 -1
View File
@@ -24,6 +24,6 @@ monthly = 0
["root_pool/models"]
15_min = 4
hourly = 24
hourly = 2
daily = 0
monthly = 0
-10
View File
@@ -31,15 +31,5 @@
];
fsWatcherEnabled = true;
};
"recordings" = {
path = "/home/richie/recordings";
devices = [
"jeeves"
"phone"
"rhapsody-in-green"
];
fsWatcherEnabled = true;
};
};
}
+1
View File
@@ -26,6 +26,7 @@
allowDiscards = true;
keyFileSize = 4096;
keyFile = "/dev/disk/by-id/usb-USB_SanDisk_3.2Gen1_03021630090925173333-0:0";
fallbackToPassword = true;
};
};
kernelModules = [ "kvm-intel" ];
+2 -11
View File
@@ -4,21 +4,17 @@ let
in
{
imports = [
"${inputs.self}/users/dov"
"${inputs.self}/users/math"
"${inputs.self}/users/richie"
"${inputs.self}/users/steve"
"${inputs.self}/users/math"
"${inputs.self}/users/dov"
"${inputs.self}/common/global"
"${inputs.self}/common/optional/docker.nix"
"${inputs.self}/common/optional/monitoring-agent.nix"
"${inputs.self}/common/optional/ssh_decrypt.nix"
"${inputs.self}/common/optional/syncthing_base.nix"
"${inputs.self}/common/optional/update.nix"
"${inputs.self}/common/optional/zerotier.nix"
./monitoring
./docker
./services
./web_services
./hardware.nix
./networking.nix
./programs.nix
@@ -39,10 +35,5 @@ in
zerotierone.joinNetworks = [ "a09acf02330d37b9" ];
};
users.groups = {
nornsight = { };
nornsight-admin = { };
};
system.stateVersion = "24.05";
}
+1
View File
@@ -9,6 +9,7 @@ let
inherit device;
keyFileSize = 4096;
keyFile = "/dev/disk/by-id/usb-XIAO_USB_Drive_24587CE29074-0:0";
fallbackToPassword = true;
};
makeLuksSSD =
device:
@@ -1,426 +0,0 @@
{
"annotations": {
"list": [
{
"builtIn": 1,
"datasource": {
"type": "grafana",
"uid": "-- Grafana --"
},
"enable": true,
"hide": true,
"iconColor": "rgba(0, 211, 255, 1)",
"name": "Annotations & Alerts",
"type": "dashboard"
}
]
},
"editable": true,
"fiscalYearStartMonth": 0,
"graphTooltip": 0,
"links": [],
"panels": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"fieldConfig": {
"defaults": {
"unit": "percent"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 6,
"x": 0,
"y": 0
},
"id": 1,
"options": {
"legend": {
"displayMode": "list",
"placement": "bottom"
},
"tooltip": {
"mode": "multi"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"editorMode": "code",
"expr": "100 * (1 - avg by (instance) (rate(node_cpu_seconds_total{mode=\"idle\"}[5m])))",
"legendFormat": "{{instance}}",
"range": true,
"refId": "A"
}
],
"title": "CPU Used",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"fieldConfig": {
"defaults": {
"unit": "percent"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 6,
"x": 6,
"y": 0
},
"id": 2,
"options": {
"legend": {
"displayMode": "list",
"placement": "bottom"
},
"tooltip": {
"mode": "multi"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"editorMode": "code",
"expr": "100 * (1 - (node_memory_MemAvailable_bytes / node_memory_MemTotal_bytes))",
"legendFormat": "{{instance}}",
"range": true,
"refId": "A"
}
],
"title": "RAM Used",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"fieldConfig": {
"defaults": {
"unit": "percent"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 6,
"x": 12,
"y": 0
},
"id": 3,
"options": {
"legend": {
"displayMode": "list",
"placement": "bottom"
},
"tooltip": {
"mode": "multi"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"editorMode": "code",
"expr": "100 * (1 - (node_memory_SwapFree_bytes / node_memory_SwapTotal_bytes))",
"legendFormat": "{{instance}}",
"range": true,
"refId": "A"
}
],
"title": "Swap Used",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"fieldConfig": {
"defaults": {
"unit": "short"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 6,
"x": 18,
"y": 0
},
"id": 4,
"options": {
"legend": {
"displayMode": "list",
"placement": "bottom"
},
"tooltip": {
"mode": "multi"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"editorMode": "code",
"expr": "node_load1",
"legendFormat": "{{instance}} load1",
"range": true,
"refId": "A"
},
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"editorMode": "code",
"expr": "node_load5",
"legendFormat": "{{instance}} load5",
"range": true,
"refId": "B"
},
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"editorMode": "code",
"expr": "node_load15",
"legendFormat": "{{instance}} load15",
"range": true,
"refId": "C"
}
],
"title": "Load",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"fieldConfig": {
"defaults": {
"unit": "Bps"
},
"overrides": []
},
"gridPos": {
"h": 9,
"w": 12,
"x": 0,
"y": 8
},
"id": 5,
"options": {
"legend": {
"displayMode": "list",
"placement": "bottom"
},
"tooltip": {
"mode": "multi"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"editorMode": "code",
"expr": "sum by (instance) (rate(node_disk_read_bytes_total[5m]))",
"legendFormat": "{{instance}} read",
"range": true,
"refId": "A"
},
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"editorMode": "code",
"expr": "sum by (instance) (rate(node_disk_written_bytes_total[5m]))",
"legendFormat": "{{instance}} write",
"range": true,
"refId": "B"
}
],
"title": "Disk Throughput",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"fieldConfig": {
"defaults": {
"unit": "percent"
},
"overrides": []
},
"gridPos": {
"h": 9,
"w": 12,
"x": 12,
"y": 8
},
"id": 6,
"options": {
"cellHeight": "sm",
"showHeader": true,
"sortBy": [
{
"desc": true,
"displayName": "Value"
}
]
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"editorMode": "code",
"expr": "100 * (1 - (node_filesystem_avail_bytes{mountpoint=~\"(/|/home|/var|/zfs.*)\",fstype!=\"\"} / node_filesystem_size_bytes{mountpoint=~\"(/|/home|/var|/zfs.*)\",fstype!=\"\"}))",
"format": "table",
"instant": true,
"legendFormat": "{{instance}} {{mountpoint}}",
"refId": "A"
}
],
"title": "Filesystem Usage",
"type": "table"
},
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"fieldConfig": {
"defaults": {
"unit": "percentunit"
},
"overrides": []
},
"gridPos": {
"h": 10,
"w": 12,
"x": 0,
"y": 17
},
"id": 7,
"options": {
"cellHeight": "sm",
"showHeader": true,
"sortBy": [
{
"desc": true,
"displayName": "Value"
}
]
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"editorMode": "code",
"expr": "topk(10, rate(namedprocess_namegroup_cpu_seconds_total[5m]))",
"format": "table",
"instant": true,
"legendFormat": "{{instance}} {{groupname}}",
"refId": "A"
}
],
"title": "Top Grouped CPU",
"type": "table"
},
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"fieldConfig": {
"defaults": {
"unit": "bytes"
},
"overrides": []
},
"gridPos": {
"h": 10,
"w": 12,
"x": 12,
"y": 17
},
"id": 8,
"options": {
"cellHeight": "sm",
"showHeader": true,
"sortBy": [
{
"desc": true,
"displayName": "Value"
}
]
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"editorMode": "code",
"expr": "topk(10, namedprocess_namegroup_memory_bytes{memtype=\"resident\"})",
"format": "table",
"instant": true,
"legendFormat": "{{instance}} {{groupname}}",
"refId": "A"
}
],
"title": "Top Grouped Memory",
"type": "table"
}
],
"refresh": "30s",
"schemaVersion": 39,
"style": "dark",
"tags": [
"monitoring"
],
"templating": {
"list": []
},
"time": {
"from": "now-24h",
"to": "now"
},
"timepicker": {},
"timezone": "",
"title": "Overview",
"uid": "monitor-overview",
"version": 1,
"weekStart": ""
}
@@ -1,216 +0,0 @@
{
"annotations": {
"list": [
{
"builtIn": 1,
"datasource": {
"type": "grafana",
"uid": "-- Grafana --"
},
"enable": true,
"hide": true,
"iconColor": "rgba(0, 211, 255, 1)",
"name": "Annotations & Alerts",
"type": "dashboard"
}
]
},
"editable": true,
"fiscalYearStartMonth": 0,
"graphTooltip": 0,
"links": [],
"panels": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"fieldConfig": {
"defaults": {
"unit": "percentunit"
},
"overrides": []
},
"gridPos": {
"h": 10,
"w": 12,
"x": 0,
"y": 0
},
"id": 1,
"options": {
"legend": {
"displayMode": "list",
"placement": "bottom"
},
"tooltip": {
"mode": "multi"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"editorMode": "code",
"expr": "topk(10, rate(namedprocess_namegroup_cpu_seconds_total[5m]))",
"legendFormat": "{{instance}} {{groupname}}",
"range": true,
"refId": "A"
}
],
"title": "Grouped CPU",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"fieldConfig": {
"defaults": {
"unit": "bytes"
},
"overrides": []
},
"gridPos": {
"h": 10,
"w": 12,
"x": 12,
"y": 0
},
"id": 2,
"options": {
"legend": {
"displayMode": "list",
"placement": "bottom"
},
"tooltip": {
"mode": "multi"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"editorMode": "code",
"expr": "topk(10, namedprocess_namegroup_memory_bytes{memtype=\"resident\"})",
"legendFormat": "{{instance}} {{groupname}}",
"range": true,
"refId": "A"
}
],
"title": "Grouped Resident Memory",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"fieldConfig": {
"defaults": {
"unit": "Bps"
},
"overrides": []
},
"gridPos": {
"h": 10,
"w": 12,
"x": 0,
"y": 10
},
"id": 3,
"options": {
"legend": {
"displayMode": "list",
"placement": "bottom"
},
"tooltip": {
"mode": "multi"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"editorMode": "code",
"expr": "topk(10, rate(namedprocess_namegroup_read_bytes_total[5m]))",
"legendFormat": "{{instance}} {{groupname}}",
"range": true,
"refId": "A"
}
],
"title": "Grouped Read I/O",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"fieldConfig": {
"defaults": {
"unit": "Bps"
},
"overrides": []
},
"gridPos": {
"h": 10,
"w": 12,
"x": 12,
"y": 10
},
"id": 4,
"options": {
"legend": {
"displayMode": "list",
"placement": "bottom"
},
"tooltip": {
"mode": "multi"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"editorMode": "code",
"expr": "topk(10, rate(namedprocess_namegroup_write_bytes_total[5m]))",
"legendFormat": "{{instance}} {{groupname}}",
"range": true,
"refId": "A"
}
],
"title": "Grouped Write I/O",
"type": "timeseries"
}
],
"refresh": "30s",
"schemaVersion": 39,
"style": "dark",
"tags": [
"monitoring",
"process"
],
"templating": {
"list": []
},
"time": {
"from": "now-7d",
"to": "now"
},
"timepicker": {},
"timezone": "",
"title": "Process History Grouped",
"uid": "monitor-process-history",
"version": 1,
"weekStart": ""
}
@@ -1,224 +0,0 @@
{
"annotations": {
"list": [
{
"builtIn": 1,
"datasource": {
"type": "grafana",
"uid": "-- Grafana --"
},
"enable": true,
"hide": true,
"iconColor": "rgba(0, 211, 255, 1)",
"name": "Annotations & Alerts",
"type": "dashboard"
}
]
},
"editable": true,
"fiscalYearStartMonth": 0,
"graphTooltip": 0,
"links": [],
"panels": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-pid-short"
},
"fieldConfig": {
"defaults": {
"unit": "percentunit"
},
"overrides": []
},
"gridPos": {
"h": 10,
"w": 12,
"x": 0,
"y": 0
},
"id": 1,
"options": {
"cellHeight": "sm",
"showHeader": true,
"sortBy": [
{
"desc": true,
"displayName": "Value"
}
]
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-pid-short"
},
"editorMode": "code",
"expr": "topk(20, rate(namedprocess_namegroup_cpu_seconds_total[2m]))",
"format": "table",
"instant": true,
"legendFormat": "{{instance}} {{groupname}}",
"refId": "A"
}
],
"title": "Top PID CPU",
"type": "table"
},
{
"datasource": {
"type": "prometheus",
"uid": "prom-pid-short"
},
"fieldConfig": {
"defaults": {
"unit": "bytes"
},
"overrides": []
},
"gridPos": {
"h": 10,
"w": 12,
"x": 12,
"y": 0
},
"id": 2,
"options": {
"cellHeight": "sm",
"showHeader": true,
"sortBy": [
{
"desc": true,
"displayName": "Value"
}
]
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-pid-short"
},
"editorMode": "code",
"expr": "topk(20, namedprocess_namegroup_memory_bytes{memtype=\"resident\"})",
"format": "table",
"instant": true,
"legendFormat": "{{instance}} {{groupname}}",
"refId": "A"
}
],
"title": "Top PID RSS",
"type": "table"
},
{
"datasource": {
"type": "prometheus",
"uid": "prom-pid-short"
},
"fieldConfig": {
"defaults": {
"unit": "Bps"
},
"overrides": []
},
"gridPos": {
"h": 10,
"w": 12,
"x": 0,
"y": 10
},
"id": 3,
"options": {
"cellHeight": "sm",
"showHeader": true,
"sortBy": [
{
"desc": true,
"displayName": "Value"
}
]
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-pid-short"
},
"editorMode": "code",
"expr": "topk(20, rate(namedprocess_namegroup_read_bytes_total[2m]))",
"format": "table",
"instant": true,
"legendFormat": "{{instance}} {{groupname}}",
"refId": "A"
}
],
"title": "Top PID Read I/O",
"type": "table"
},
{
"datasource": {
"type": "prometheus",
"uid": "prom-pid-short"
},
"fieldConfig": {
"defaults": {
"unit": "Bps"
},
"overrides": []
},
"gridPos": {
"h": 10,
"w": 12,
"x": 12,
"y": 10
},
"id": 4,
"options": {
"cellHeight": "sm",
"showHeader": true,
"sortBy": [
{
"desc": true,
"displayName": "Value"
}
]
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-pid-short"
},
"editorMode": "code",
"expr": "topk(20, rate(namedprocess_namegroup_write_bytes_total[2m]))",
"format": "table",
"instant": true,
"legendFormat": "{{instance}} {{groupname}}",
"refId": "A"
}
],
"title": "Top PID Write I/O",
"type": "table"
}
],
"refresh": "15s",
"schemaVersion": 39,
"style": "dark",
"tags": [
"monitoring",
"process"
],
"templating": {
"list": []
},
"time": {
"from": "now-10m",
"to": "now"
},
"timepicker": {},
"timezone": "",
"title": "Process Live PID",
"uid": "monitor-process-pid",
"version": 1,
"weekStart": ""
}
@@ -1,351 +0,0 @@
{
"annotations": {
"list": [
{
"builtIn": 1,
"datasource": {
"type": "grafana",
"uid": "-- Grafana --"
},
"enable": true,
"hide": true,
"iconColor": "rgba(0, 211, 255, 1)",
"name": "Annotations & Alerts",
"type": "dashboard"
}
]
},
"editable": true,
"fiscalYearStartMonth": 0,
"graphTooltip": 0,
"links": [],
"panels": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"fieldConfig": {
"defaults": {
"unit": "percent"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 8,
"x": 0,
"y": 0
},
"id": 1,
"options": {
"legend": {
"displayMode": "list",
"placement": "bottom"
},
"tooltip": {
"mode": "multi"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"editorMode": "code",
"expr": "100 * (zfs_pool_allocated_bytes / zfs_pool_size_bytes)",
"legendFormat": "{{instance}} {{pool}}",
"range": true,
"refId": "A"
}
],
"title": "Pool Usage",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"fieldConfig": {
"defaults": {
"unit": "bytes"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 8,
"x": 8,
"y": 0
},
"id": 2,
"options": {
"legend": {
"displayMode": "list",
"placement": "bottom"
},
"tooltip": {
"mode": "multi"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"editorMode": "code",
"expr": "zfs_pool_free_bytes",
"legendFormat": "{{instance}} {{pool}}",
"range": true,
"refId": "A"
}
],
"title": "Pool Free Bytes",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"fieldConfig": {
"defaults": {
"unit": "bytes"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 8,
"x": 16,
"y": 0
},
"id": 3,
"options": {
"cellHeight": "sm",
"showHeader": true,
"sortBy": [
{
"desc": true,
"displayName": "Value"
}
]
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"editorMode": "code",
"expr": "topk(20, zfs_dataset_used_bytes{type=\"filesystem\"})",
"format": "table",
"instant": true,
"legendFormat": "{{instance}} {{name}}",
"refId": "A"
}
],
"title": "Top Filesystems by Used Bytes",
"type": "table"
},
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"fieldConfig": {
"defaults": {
"unit": "ns"
},
"overrides": []
},
"gridPos": {
"h": 9,
"w": 12,
"x": 0,
"y": 8
},
"id": 4,
"options": {
"legend": {
"displayMode": "list",
"placement": "bottom"
},
"tooltip": {
"mode": "multi"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"editorMode": "code",
"expr": "topk(20, zpool_iostat_total_wait_read_ns{vdev!=\"_pool\"})",
"legendFormat": "{{host}} {{pool}} {{vdev}}",
"range": true,
"refId": "A"
}
],
"title": "ZFS Read Wait",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"fieldConfig": {
"defaults": {
"unit": "ns"
},
"overrides": []
},
"gridPos": {
"h": 9,
"w": 12,
"x": 12,
"y": 8
},
"id": 5,
"options": {
"legend": {
"displayMode": "list",
"placement": "bottom"
},
"tooltip": {
"mode": "multi"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"editorMode": "code",
"expr": "topk(20, zpool_iostat_total_wait_write_ns{vdev!=\"_pool\"})",
"legendFormat": "{{host}} {{pool}} {{vdev}}",
"range": true,
"refId": "A"
}
],
"title": "ZFS Write Wait",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"fieldConfig": {
"defaults": {
"unit": "celsius"
},
"overrides": []
},
"gridPos": {
"h": 9,
"w": 12,
"x": 0,
"y": 17
},
"id": 6,
"options": {
"cellHeight": "sm",
"showHeader": true,
"sortBy": [
{
"desc": true,
"displayName": "Value"
}
]
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"editorMode": "code",
"expr": "smartctl_device_temperature{temperature_type=\"current\"}",
"format": "table",
"instant": true,
"legendFormat": "{{instance}} {{device}}",
"refId": "A"
}
],
"title": "Disk Temperature",
"type": "table"
},
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"fieldConfig": {
"defaults": {
"unit": "short"
},
"overrides": []
},
"gridPos": {
"h": 9,
"w": 12,
"x": 12,
"y": 17
},
"id": 7,
"options": {
"cellHeight": "sm",
"showHeader": true,
"sortBy": [
{
"desc": false,
"displayName": "Value"
}
]
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prom-main"
},
"editorMode": "code",
"expr": "smartctl_device_smart_status",
"format": "table",
"instant": true,
"legendFormat": "{{instance}} {{device}}",
"refId": "A"
}
],
"title": "SMART Health",
"type": "table"
}
],
"refresh": "30s",
"schemaVersion": 39,
"style": "dark",
"tags": [
"monitoring",
"zfs"
],
"templating": {
"list": []
},
"time": {
"from": "now-24h",
"to": "now"
},
"timepicker": {},
"timezone": "",
"title": "Storage and ZFS",
"uid": "monitor-storage",
"version": 1,
"weekStart": ""
}
-186
View File
@@ -1,186 +0,0 @@
{
lib,
pkgs,
...
}:
let
vars = import ../vars.nix;
prometheusDataRoot = "${vars.database}/prometheus";
mainPrometheusDataDir = "${prometheusDataRoot}/main";
pidPrometheusDataDir = "${prometheusDataRoot}/pid-short";
prometheusYaml = pkgs.formats.yaml { };
mkPrometheusConfig =
name: cfg:
let
configFile = prometheusYaml.generate "${name}.yaml" cfg;
in
pkgs.runCommand "${name}-checked.yaml"
{
nativeBuildInputs = [ pkgs.prometheus.cli ];
}
''
promtool check config ${configFile}
cp ${configFile} $out
'';
mkTarget = host: address: {
targets = [ address ];
labels.instance = host;
};
mainPrometheusConfig = mkPrometheusConfig "prometheus-main" {
global = {
scrape_interval = "30s";
scrape_timeout = "10s";
evaluation_interval = "30s";
};
scrape_configs = [
{
job_name = "node";
static_configs = [
(mkTarget "jeeves" "192.168.90.40:9100")
(mkTarget "bob" "192.168.90.25:9100")
];
}
{
job_name = "process_grouped";
static_configs = [
(mkTarget "jeeves" "192.168.90.40:9256")
(mkTarget "bob" "192.168.90.25:9256")
];
}
{
job_name = "smartctl";
static_configs = [
(mkTarget "jeeves" "192.168.90.40:9633")
(mkTarget "bob" "192.168.90.25:9633")
];
}
{
job_name = "zfs";
static_configs = [
(mkTarget "jeeves" "192.168.90.40:9134")
(mkTarget "bob" "192.168.90.25:9134")
];
}
];
};
pidPrometheusConfig = mkPrometheusConfig "prometheus-pid-short" {
global = {
scrape_interval = "15s";
scrape_timeout = "10s";
evaluation_interval = "15s";
};
scrape_configs = [
{
job_name = "process_pid";
static_configs = [
(mkTarget "jeeves" "192.168.90.40:9257")
(mkTarget "bob" "192.168.90.25:9257")
];
}
];
};
mkPrometheusService =
{
dataDir,
configFile,
port,
retention,
}:
{
after = [
"zfs-media-database-prometheus.mount"
"network.target"
];
requires = [ "zfs-media-database-prometheus.mount" ];
wantedBy = [ "multi-user.target" ];
unitConfig.RequiresMountsFor = [ dataDir ];
serviceConfig = {
ExecStart = "${lib.getExe pkgs.prometheus} ${
lib.escapeShellArgs [
"--config.file=${configFile}"
"--storage.tsdb.path=${dataDir}"
"--storage.tsdb.retention.time=${retention}"
"--web.listen-address=127.0.0.1:${toString port}"
]
}";
User = "prometheus";
Group = "prometheus";
Restart = "always";
RestartSec = "5s";
WorkingDirectory = dataDir;
ReadWritePaths = [ dataDir ];
CapabilityBoundingSet = [ "" ];
DeviceAllow = [ "/dev/null rw" ];
DevicePolicy = "strict";
LockPersonality = true;
MemoryDenyWriteExecute = true;
NoNewPrivileges = true;
PrivateDevices = true;
PrivateTmp = true;
ProtectClock = true;
ProtectControlGroups = true;
ProtectHome = true;
ProtectHostname = true;
ProtectKernelLogs = true;
ProtectKernelModules = true;
ProtectKernelTunables = true;
ProtectProc = "invisible";
ProtectSystem = "strict";
RemoveIPC = true;
RestrictAddressFamilies = [
"AF_INET"
"AF_INET6"
"AF_UNIX"
];
RestrictNamespaces = true;
RestrictRealtime = true;
RestrictSUIDSGID = true;
SystemCallArchitectures = "native";
SystemCallFilter = [
"@system-service"
"~@privileged"
];
};
};
in
{
users = {
groups.prometheus = { };
users.prometheus = {
isSystemUser = true;
group = "prometheus";
description = "Prometheus daemon user";
};
};
systemd = {
services = {
prometheus-main = mkPrometheusService {
configFile = mainPrometheusConfig;
dataDir = mainPrometheusDataDir;
port = 9090;
retention = "90d";
};
prometheus-pid-short = mkPrometheusService {
configFile = pidPrometheusConfig;
dataDir = pidPrometheusDataDir;
port = 9092;
retention = "10m";
};
};
tmpfiles.rules = [
"d ${prometheusDataRoot} 0755 root root - -"
"d ${mainPrometheusDataDir} 0750 prometheus prometheus - -"
"d ${pidPrometheusDataDir} 0750 prometheus prometheus - -"
];
};
}

Some files were not shown because too many files have changed in this diff Show More