Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 2,431 Bytes
8b5be8c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import json
import argparse
from pathlib import Path
import chromadb
from chromadb.config import Settings
def export_collection(collection, output_dir: Path, include_embeddings=False):
"""Export one ChromaDB collection to a JSON file."""
# Pull everything (large collections may need pagination)
include_fields = ["documents", "metadatas"]
if include_embeddings:
include_fields.append("embeddings")
items = collection.get(include=include_fields)
data = []
for idx, _id in enumerate(items["ids"]):
record = {
"id": _id,
"document": items["documents"][idx] if items.get("documents") else None,
"metadata": items["metadatas"][idx] if items.get("metadatas") else None,
}
if include_embeddings:
record["embedding"] = (
items["embeddings"][idx] if items.get("embeddings") else None
)
data.append(record)
# Write to <collection>.json
out_path = output_dir / f"{collection.name}.json"
out_path.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8")
print(f"✔ Exported {collection.name} → {out_path}")
def main():
parser = argparse.ArgumentParser(description="Export ChromaDB collections to JSON.")
parser.add_argument(
"--db-path",
type=str,
required=True,
help="Path to the chromadb_store folder (where the DB is persisted)",
)
parser.add_argument(
"--output",
type=str,
default="chroma_exports",
help="Output folder for json files",
)
parser.add_argument(
"--include-embeddings",
action="store_true",
help="Include embeddings in the export (off by default)",
)
args = parser.parse_args()
db_path = Path(args.db_path).expanduser().resolve()
output_dir = Path(args.output).expanduser().resolve()
output_dir.mkdir(parents=True, exist_ok=True)
# Connect to the persistent ChromaDB store
client = chromadb.PersistentClient(
path=str(db_path),
settings=Settings(anonymized_telemetry=False)
)
# Iterate collections
for cname in client.list_collections():
collection = client.get_collection(cname.name)
export_collection(collection, output_dir, args.include_embeddings)
print("\n🎉 All collections exported!")
if __name__ == "__main__":
main()
|