#!/usr/bin/env python3
"""
Workflow proxy server — serves static files and proxies API calls.

Routes:
  /api/fal/*        → https://queue.fal.run/*
  /api/openrouter/* → https://openrouter.ai/api/*
  everything else   → static files from this directory
"""

import http.server
import json
import os
import ssl
import sys
import urllib.request
import urllib.error
from socketserver import ThreadingMixIn

# SSL context for outbound HTTPS requests (fixes macOS/Linux cert issues)
ssl_ctx = ssl.create_default_context()
try:
    import certifi
    ssl_ctx.load_verify_locations(certifi.where())
except ImportError:
    ssl_ctx = ssl._create_unverified_context()

PORT = int(os.environ.get('PORT', 8080))
DIR = os.path.dirname(os.path.abspath(__file__))

# fal.ai endpoint
FAL_QUEUE = 'https://queue.fal.run'


class Handler(http.server.SimpleHTTPRequestHandler):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, directory=DIR, **kwargs)

    def end_headers(self):
        self.send_header('Access-Control-Allow-Origin', '*')
        self.send_header('Access-Control-Allow-Methods', 'GET, POST, OPTIONS')
        self.send_header('Access-Control-Allow-Headers', '*')
        super().end_headers()

    def do_OPTIONS(self):
        self.send_response(200)
        self.end_headers()

    def do_GET(self):
        if self.path.startswith('/api/fal/'):
            self._proxy_fal('GET')
        else:
            super().do_GET()

    def do_POST(self):
        if self.path.startswith('/api/fal/'):
            self._proxy_fal('POST')
        elif self.path.startswith('/api/openrouter/'):
            self._proxy_openrouter('POST')
        else:
            self.send_response(405)
            self.send_header('Content-Type', 'application/json')
            self.end_headers()
            self.wfile.write(b'{"error":"Method Not Allowed"}')

    # ── Proxy: queue.fal.run ──

    def _proxy_fal(self, method):
        path = self.path[len('/api/fal/'):]
        url = f'{FAL_QUEUE}/{path}'
        self._forward(method, url)

    # ── Proxy: openrouter.ai ──

    def _proxy_openrouter(self, method):
        path = self.path[len('/api/openrouter/'):]
        url = f'https://openrouter.ai/api/{path}'
        self._forward(method, url)

    # ── Generic forward ──

    def _forward(self, method, url):
        headers = {}
        for key in ['Authorization', 'Content-Type', 'Accept']:
            val = self.headers.get(key)
            if val:
                headers[key] = val

        body = None
        if method in ('POST', 'PUT'):
            length = int(self.headers.get('Content-Length', 0))
            if length > 0:
                body = self.rfile.read(length)

        try:
            req = urllib.request.Request(url, data=body, headers=headers, method=method)
            with urllib.request.urlopen(req, context=ssl_ctx) as resp:
                resp_body = resp.read()
                self.send_response(resp.status)
                ct = resp.headers.get('Content-Type')
                if ct:
                    self.send_header('Content-Type', ct)
                self.end_headers()
                self.wfile.write(resp_body)

        except urllib.error.HTTPError as e:
            resp_body = e.read()
            self.send_response(e.code)
            self.send_header('Content-Type', e.headers.get('Content-Type', 'application/json'))
            self.end_headers()
            self.wfile.write(resp_body)

        except Exception as e:
            self.send_response(502)
            self.send_header('Content-Type', 'application/json')
            self.end_headers()
            self.wfile.write(json.dumps({'error': str(e)}).encode())

    def log_message(self, fmt, *args):
        msg = str(args[0]) if args else ''
        if '/api/' in msg:
            print(f'  PROXY  {fmt % args}')
        else:
            super().log_message(fmt, *args)


if __name__ == '__main__':
    print(f'\n  Workflow Server')
    print(f'  http://localhost:{PORT}/step2-refine.html\n')

    class ThreadedServer(ThreadingMixIn, http.server.HTTPServer):
        daemon_threads = True

    server = ThreadedServer(('', PORT), Handler)
    try:
        server.serve_forever()
    except KeyboardInterrupt:
        print('\n  Stopped.')
