Files
html/top-ia/speculative_parallel.sh

51 lines
2.4 KiB
Bash
Executable File

#!/bin/bash
Q="$*"
[ -z "$Q" ] && { echo '{"error":"need question"}'; exit 1; }
source /etc/weval/secrets.env 2>/dev/null
export Q CEREBRAS_API_KEY GROQ_KEY NVIDIA_NIM_KEY
python3 <<'PY'
import os, json, urllib.request, time
from concurrent.futures import ThreadPoolExecutor, as_completed
q = os.environ['Q']
def call(provider, url, key, model, prompt, max_tok=300):
t0 = time.time()
try:
body = json.dumps({"model":model,"messages":[{"role":"user","content":prompt}],"max_tokens":max_tok}).encode()
req = urllib.request.Request(url, data=body, headers={"Authorization":"Bearer "+key,"Content-Type":"application/json"})
d = json.loads(urllib.request.urlopen(req, timeout=15).read())
return {"provider":provider,"text":d.get('choices',[{}])[0].get('message',{}).get('content','')[:600],"ms":int((time.time()-t0)*1000)}
except Exception as e:
return {"provider":provider,"text":f"ERR:{str(e)[:60]}","ms":int((time.time()-t0)*1000)}
# Phase 1: 3 drafts en PARALLELE (Cerebras+Groq+NVIDIA)
t_start = time.time()
with ThreadPoolExecutor(max_workers=3) as ex:
futures = [
ex.submit(call, "cerebras", "https://api.cerebras.ai/v1/chat/completions", os.environ.get('CEREBRAS_API_KEY',''), "llama3.1-8b", q, 250),
ex.submit(call, "groq", "https://api.groq.com/openai/v1/chat/completions", os.environ.get('GROQ_KEY',''), "llama-3.1-8b-instant", q, 250),
ex.submit(call, "nvidia", "https://integrate.api.nvidia.com/v1/chat/completions", os.environ.get('NVIDIA_NIM_KEY',''), "meta/llama-3.1-8b-instruct", q, 250)
]
drafts = [f.result() for f in as_completed(futures)]
parallel_ms = int((time.time() - t_start) * 1000)
# Phase 2: pick fastest valid + use it as ground truth (no re-verify needed if 2/3 agree)
valid = [d for d in drafts if not d["text"].startswith("ERR")]
fastest = min(valid, key=lambda x: x["ms"]) if valid else None
# Score: serial would be sum(ms), parallel = max(ms)
serial_estimate = sum(d["ms"] for d in drafts)
speedup = round(serial_estimate / parallel_ms, 2) if parallel_ms > 0 else 0
print(json.dumps({
"question": q,
"drafts": drafts,
"fastest": fastest["provider"] if fastest else None,
"answer": fastest["text"] if fastest else "all failed",
"valid_drafts": len(valid),
"parallel_ms": parallel_ms,
"serial_estimate_ms": serial_estimate,
"speedup": f"{speedup}x"
}, ensure_ascii=False))
PY