| | |
| | |
| | |
| | |
| | |
| |
|
| | import os, time, json, random |
| | import requests |
| | import gradio as gr |
| |
|
| | |
| | PROVIDER = os.getenv("PROVIDER", "hf_model").strip() |
| | MODEL_ID = os.getenv("MODEL_ID", "LLM360/K2-Think").strip() |
| | ALT_MODEL_ID = os.getenv("ALT_MODEL_ID", "Qwen/Qwen2.5-7B-Instruct").strip() |
| | HF_TOKEN = os.getenv("HF_TOKEN", "").strip() |
| |
|
| | |
| | def _get(url, params=None, headers=None, timeout=12, retries=2, backoff=1.6): |
| | for i in range(retries + 1): |
| | try: |
| | r = requests.get(url, params=params, headers=headers, timeout=timeout) |
| | r.raise_for_status() |
| | return r |
| | except Exception: |
| | if i == retries: |
| | raise |
| | time.sleep((backoff ** i) + random.random() * 0.25) |
| |
|
| | |
| | def geocode_city(city:str): |
| | r = _get("https://nominatim.openstreetmap.org/search", |
| | params={"q": city, "format": "json", "limit": 1}, |
| | headers={"User-Agent": "climamind-space"}) |
| | j = r.json() |
| | if not j: |
| | raise RuntimeError("City not found") |
| | return {"lat": float(j[0]["lat"]), "lon": float(j[0]["lon"]), "name": j[0]["display_name"]} |
| |
|
| | |
| | def fetch_open_meteo(lat, lon): |
| | r = _get("https://api.open-meteo.com/v1/forecast", params={ |
| | "latitude": lat, "longitude": lon, |
| | "current": "temperature_2m,relative_humidity_2m,wind_speed_10m,precipitation,uv_index", |
| | "hourly": "temperature_2m,relative_humidity_2m,wind_speed_10m,precipitation_probability,uv_index", |
| | "timezone": "auto" |
| | }) |
| | return r.json() |
| |
|
| | |
| | def fetch_pm25(lat, lon): |
| | try: |
| | r = _get("https://air-quality-api.open-meteo.com/v1/air-quality", params={ |
| | "latitude": lat, "longitude": lon, "hourly": "pm2_5", "timezone": "auto" |
| | }, headers={"User-Agent": "climamind-space"}) |
| | j = r.json() |
| | hourly = j.get("hourly", {}) |
| | values = hourly.get("pm2_5") or [] |
| | if values: |
| | return values[-1] |
| | except Exception: |
| | pass |
| | return None |
| |
|
| | def fetch_factors(lat, lon): |
| | wx = fetch_open_meteo(lat, lon) |
| | cur = wx.get("current", {}) or {} |
| | factors = { |
| | "temp_c": cur.get("temperature_2m"), |
| | "rh": cur.get("relative_humidity_2m"), |
| | "wind_kmh": cur.get("wind_speed_10m"), |
| | "precip_mm": cur.get("precipitation"), |
| | "uv": cur.get("uv_index"), |
| | "pm25": fetch_pm25(lat, lon), |
| | } |
| | return {"factors": factors, "raw": wx} |
| |
|
| | |
| | def drying_index(temp_c, rh, wind_kmh, cloud_frac=None): |
| | base = (temp_c or 0) * 1.2 + (wind_kmh or 0) * 0.8 - (rh or 0) * 0.9 |
| | if cloud_frac is not None: |
| | base -= 20 * cloud_frac |
| | return max(0, min(100, round(base))) |
| |
|
| | def heat_stress_index(temp_c, rh, wind_kmh): |
| | hs = (temp_c or 0) * 1.1 + (rh or 0) * 0.3 - (wind_kmh or 0) * 0.2 |
| | return max(0, min(100, round(hs))) |
| |
|
| | |
| | PROMPT = """You are ClimaMind, a climate reasoning assistant. Use ONLY the observations provided and return STRICT JSON. |
| | |
| | Location: {loc} (lat={lat}, lon={lon}), local time: {t_local} |
| | Observations: temp={temp_c}°C, rh={rh}%, wind={wind_kmh} km/h, precip={precip_mm} mm, uv={uv}, pm25={pm25} |
| | Derived: drying_index={d_idx}, heat_stress_index={hs_idx} |
| | |
| | Task: Answer the user’s query: "{query}" for the next 24 hours. |
| | Steps: |
| | 1) Identify the relevant factors. |
| | 2) Reason causally (2–3 steps). |
| | 3) Give a concise recommendation with time window(s) and a confidence. |
| | 4) Output a short WHY-TRACE (3 bullets). |
| | Return JSON ONLY: |
| | {{ |
| | "answer": "...", |
| | "why_trace": ["...", "...", "..."], |
| | "risk_badge": "Low"|"Moderate"|"High" |
| | }} |
| | """ |
| |
|
| | |
| | def call_stub(_prompt:str)->str: |
| | return json.dumps({ |
| | "answer": "Based on 32°C, 50% RH and 12 km/h wind, cotton dries in ~2–3h (faster after 2pm).", |
| | "why_trace": [ |
| | "Higher temperature and wind increase evaporation rate", |
| | "Moderate humidity slightly slows drying", |
| | "Lower afternoon cloud cover speeds it up" |
| | ], |
| | "risk_badge": "Low" |
| | }) |
| |
|
| | |
| | def call_hf_model(prompt:str) -> tuple[str, str]: |
| | from huggingface_hub import InferenceClient |
| | attempts = [m for m in [MODEL_ID, ALT_MODEL_ID] if m] |
| | for mid in attempts: |
| | try: |
| | client = InferenceClient(model=mid, token=(HF_TOKEN or None)) |
| | out = client.text_generation( |
| | prompt, |
| | max_new_tokens=200, |
| | temperature=0.1, |
| | repetition_penalty=1.05, |
| | do_sample=False, |
| | ) |
| | return str(out), mid |
| | except Exception as e: |
| | print(f"[HF_MODEL] Failed on {mid}: {repr(e)}") |
| | continue |
| | |
| | raise RuntimeError(f"No serverless provider available. Tried: {attempts}") |
| |
|
| | _local_loaded = False |
| | def _ensure_local_loaded(): |
| | |
| | global _local_loaded, tokenizer, model |
| | if _local_loaded: |
| | return |
| | from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
| | bnb_cfg = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_compute_dtype="bfloat16", |
| | bnb_4bit_use_double_quant=True, |
| | bnb_4bit_quant_type="nf4", |
| | ) |
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) |
| | model = AutoModelForCausalLM.from_pretrained( |
| | MODEL_ID, |
| | trust_remote_code=True, |
| | device_map="auto", |
| | quantization_config=bnb_cfg, |
| | low_cpu_mem_usage=True, |
| | ) |
| | _local_loaded = True |
| |
|
| | def call_local(prompt:str)->tuple[str, str]: |
| | _ensure_local_loaded() |
| | import torch |
| | if hasattr(tokenizer, "apply_chat_template"): |
| | messages = [{"role":"user","content":prompt}] |
| | inputs = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt").to(model.device) |
| | else: |
| | inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| | with torch.no_grad(): |
| | out = model.generate( |
| | **inputs, |
| | max_new_tokens=200, |
| | temperature=0.1, |
| | do_sample=False, |
| | repetition_penalty=1.05, |
| | eos_token_id=tokenizer.eos_token_id, |
| | ) |
| | return tokenizer.decode(out[0], skip_special_tokens=True), MODEL_ID |
| |
|
| | def reason_answer(loc, coords, factors, query): |
| | d_idx = drying_index(factors.get("temp_c"), factors.get("rh"), factors.get("wind_kmh")) |
| | hs_idx = heat_stress_index(factors.get("temp_c"), factors.get("rh"), factors.get("wind_kmh")) |
| | t_local = time.strftime("%Y-%m-%d %H:%M") |
| | prompt = PROMPT.format( |
| | loc=loc, lat=coords["lat"], lon=coords["lon"], t_local=t_local, |
| | temp_c=factors.get("temp_c"), rh=factors.get("rh"), wind_kmh=factors.get("wind_kmh"), |
| | precip_mm=factors.get("precip_mm"), uv=factors.get("uv"), pm25=factors.get("pm25"), |
| | d_idx=d_idx, hs_idx=hs_idx, query=query |
| | ) |
| |
|
| | if PROVIDER == "hf_model": |
| | try: |
| | raw, model_used = call_hf_model(prompt) |
| | except Exception as e: |
| | print("[HF_MODEL] Falling back to stub:", repr(e)) |
| | raw, model_used = call_stub(prompt), "stub" |
| | elif PROVIDER == "local": |
| | raw, model_used = call_local(prompt) |
| | else: |
| | raw, model_used = call_stub(prompt), "stub" |
| |
|
| | |
| | start, end = raw.find("{"), raw.rfind("}") |
| | if start == -1 or end == -1: |
| | parsed = { |
| | "answer": "The reasoning service returned non-JSON text. Please try again.", |
| | "why_trace": ["Response formatting issue", "Keep temperature low", "Retry once"], |
| | "risk_badge": "Low" |
| | } |
| | else: |
| | try: |
| | parsed = json.loads(raw[start:end+1]) |
| | except Exception: |
| | parsed = { |
| | "answer": "Failed to parse JSON from model output.", |
| | "why_trace": ["JSON parsing error", "Reduce tokens/temperature", "Retry once"], |
| | "risk_badge": "Low" |
| | } |
| | parsed["_model_used"] = model_used |
| | return parsed |
| |
|
| | |
| | def app(city, question): |
| | geo = geocode_city(city) |
| | data = fetch_factors(geo["lat"], geo["lon"]) |
| | ans = reason_answer( |
| | geo["name"], {"lat": geo["lat"], "lon": geo["lon"]}, |
| | data["factors"], question |
| | ) |
| | fx = ", ".join([f"{k}={v}" for k, v in data["factors"].items()]) |
| | why_list = ans.get("why_trace") or [] |
| | why = "\n• " + "\n• ".join(why_list) if why_list else "\n• (no trace returned)" |
| | model_used = ans.pop("_model_used", "unknown") |
| | md = ( |
| | f"**Answer:** {ans.get('answer','(no answer)')}\n\n" |
| | f"**Why-trace:**{why}\n\n" |
| | f"**Risk:** {ans.get('risk_badge','N/A')}\n\n" |
| | f"**Factors:** {fx}\n\n" |
| | f"<sub>Provider: {PROVIDER} • Model: `{model_used}`</sub>" |
| | ) |
| | return md |
| |
|
| | demo = gr.Interface( |
| | fn=app, |
| | inputs=[ |
| | gr.Textbox(label="City", value="New Delhi"), |
| | gr.Dropdown( |
| | choices=[ |
| | "If I wash clothes now, when will they dry?", |
| | "Should I water my plants today or wait?", |
| | "What is the heat/wildfire risk today? Explain briefly." |
| | ], |
| | label="Question", |
| | value="If I wash clothes now, when will they dry?" |
| | ) |
| | ], |
| | outputs=gr.Markdown(label="ClimaMind"), |
| | title="ClimaMind — K2-Think + Live Climate Data", |
| | description="Serverless tries K2, falls back to Qwen if needed; or run locally on GPU Space. Stub as last resort.", |
| | flagging_mode="never", |
| | concurrency_limit=2, |
| | ) |
| |
|
| | demo.queue(max_size=8) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |
| |
|