File size: 6,682 Bytes
e3b05fc
6d7fe95
eabf997
c0c14f5
 
 
b46f5fa
eabf997
54da26c
c0c14f5
eabf997
b46f5fa
54da26c
 
c0c14f5
 
54da26c
c0c14f5
 
 
54da26c
 
 
c0c14f5
b46f5fa
 
 
 
 
 
 
 
54da26c
b46f5fa
 
 
 
 
54da26c
b46f5fa
 
54da26c
 
b46f5fa
 
 
54da26c
 
 
 
 
 
 
eabf997
b46f5fa
 
 
 
 
 
 
 
 
e3b05fc
 
 
 
54da26c
e3b05fc
 
 
 
b46f5fa
 
e3b05fc
 
54da26c
 
 
 
 
 
 
c0c14f5
54da26c
c0c14f5
b46f5fa
c0c14f5
 
 
 
54da26c
 
 
 
 
 
 
c0c14f5
54da26c
 
 
 
c0c14f5
 
54da26c
c0c14f5
 
54da26c
b46f5fa
 
 
c0c14f5
09e65dd
 
 
b46f5fa
09e65dd
c0c14f5
 
37c3c12
54da26c
c0c14f5
54da26c
 
 
b46f5fa
 
54da26c
 
 
e3b05fc
54da26c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b46f5fa
54da26c
 
b46f5fa
54da26c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b46f5fa
54da26c
 
eabf997
54da26c
eabf997
54da26c
 
 
 
 
 
6d7fe95
eabf997
b46f5fa
 
54da26c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import spaces
import gradio as gr
import torch
from transformers import pipeline
from PIL import Image
import time
import traceback

# Global model storage for Zero GPU compatibility
models = {}

@spaces.GPU(duration=300)
def load_model_on_gpu(model_choice):
    """Load GLM model on GPU - separated for clarity."""
    model_map = {
        "GLM-4.5V-AWQ": "QuantTrio/GLM-4.5V-AWQ",
        "GLM-4.5V-FP8": "zai-org/GLM-4.5V-FP8", 
        "GLM-4.5V": "zai-org/GLM-4.5V"
    }
    
    model_name = model_map.get(model_choice)
    if not model_name:
        return False, f"Unknown model: {model_choice}"
    
    if model_name in models:
        return True, f"βœ… {model_choice} already loaded"
    
    try:
        pipe = pipeline(
            "image-text-to-text",
            model=model_name,
            device_map="auto",
            torch_dtype=torch.float16,
            trust_remote_code=True
        )
        models[model_name] = pipe
        return True, f"βœ… {model_choice} loaded successfully"
    except Exception as e:
        return False, f"❌ Failed to load {model_choice}: {str(e)[:200]}"

@spaces.GPU(duration=120)
def generate_code(image, model_choice, prompt_style):
    """Generate CADQuery code - main GPU function."""
    if image is None:
        return "❌ Please upload an image first."
    
    # Create prompts
    prompts = {
        "Simple": "Generate CADQuery Python code for this 3D model:",
        "Detailed": "Analyze this 3D CAD model and generate Python CADQuery code.\n\nRequirements:\n- Import cadquery as cq\n- Store result in 'result' variable\n- Use proper CADQuery syntax\n\nCode:",
        "Chain-of-Thought": "Analyze this 3D CAD model step by step:\n\nStep 1: Identify the basic geometry\nStep 2: Note any features\nStep 3: Generate clean CADQuery Python code\n\n```python\nimport cadquery as cq\n\n# Generated code:"
    }
    
    try:
        # Load model if needed
        model_map = {
            "GLM-4.5V-AWQ": "QuantTrio/GLM-4.5V-AWQ",
            "GLM-4.5V-FP8": "zai-org/GLM-4.5V-FP8",
            "GLM-4.5V": "zai-org/GLM-4.5V"
        }
        
        model_name = model_map[model_choice]
        
        if model_name not in models:
            pipe = pipeline(
                "image-text-to-text",
                model=model_name,
                device_map="auto", 
                torch_dtype=torch.float16,
                trust_remote_code=True
            )
            models[model_name] = pipe
        else:
            pipe = models[model_name]
        
        # Generate
        messages = [{
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompts[prompt_style]}
            ]
        }]
        
        result = pipe(messages, max_new_tokens=512, temperature=0.7)
        
        if isinstance(result, list) and len(result) > 0:
            generated_text = result[0].get("generated_text", str(result))
        else:
            generated_text = str(result)
        
        # Simple code extraction
        code = generated_text.strip()
        if "```python" in code:
            start = code.find("```python") + 9
            end = code.find("```", start)
            if end > start:
                code = code[start:end].strip()
        
        if "import cadquery" not in code:
            code = "import cadquery as cq\n\n" + code
        
        return f"""## 🎯 Generated CADQuery Code

```python
{code}
```

## πŸ“Š Info
- **Model**: {model_choice}
- **Prompt**: {prompt_style}
- **Device**: {"GPU" if torch.cuda.is_available() else "CPU"}

## πŸ”§ Usage
```bash
pip install cadquery
python your_script.py
```
"""
        
    except Exception as e:
        return f"❌ **Generation Failed**: {str(e)[:500]}"

def test_model(model_choice):
    """Test model loading."""
    success, message = load_model_on_gpu(model_choice)
    return f"## Test Result\n\n{message}"

def system_info():
    """Get system info."""
    info = f"""## πŸ–₯️ System Information

- **CUDA Available**: {torch.cuda.is_available()}
- **CUDA Devices**: {torch.cuda.device_count() if torch.cuda.is_available() else 0}
- **PyTorch Version**: {torch.__version__}
- **Device**: {"GPU" if torch.cuda.is_available() else "CPU"}
"""
    return info

# Create interface
with gr.Blocks(title="GLM-4.5V CAD Generator", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # πŸ”§ GLM-4.5V CAD Generator
    
    Generate CADQuery Python code from 3D CAD model images using GLM-4.5V models!
    
    **Models**: GLM-4.5V-AWQ (fastest) | GLM-4.5V-FP8 (balanced) | GLM-4.5V (best quality)
    """)
    
    with gr.Tab("πŸš€ Generate"):
        with gr.Row():
            with gr.Column():
                image_input = gr.Image(type="pil", label="Upload CAD Model Image")
                model_choice = gr.Dropdown(
                    choices=["GLM-4.5V-AWQ", "GLM-4.5V-FP8", "GLM-4.5V"],
                    value="GLM-4.5V-AWQ",
                    label="Select Model"
                )
                prompt_style = gr.Dropdown(
                    choices=["Simple", "Detailed", "Chain-of-Thought"],
                    value="Chain-of-Thought", 
                    label="Prompt Style"
                )
                generate_btn = gr.Button("πŸš€ Generate CADQuery Code", variant="primary")
            
            with gr.Column():
                output = gr.Markdown("Upload an image and click Generate!")
        
        generate_btn.click(
            fn=generate_code,
            inputs=[image_input, model_choice, prompt_style],
            outputs=output
        )
    
    with gr.Tab("πŸ§ͺ Test"):
        with gr.Row():
            with gr.Column():
                test_model_choice = gr.Dropdown(
                    choices=["GLM-4.5V-AWQ", "GLM-4.5V-FP8", "GLM-4.5V"],
                    value="GLM-4.5V-AWQ",
                    label="Model to Test"
                )
                test_btn = gr.Button("πŸ§ͺ Test Model")
            
            with gr.Column():
                test_output = gr.Markdown("Click Test Model to check loading.")
        
        test_btn.click(fn=test_model, inputs=test_model_choice, outputs=test_output)
    
    with gr.Tab("βš™οΈ System"):
        info_display = gr.Markdown()
        refresh_btn = gr.Button("πŸ”„ Refresh")
        
        demo.load(fn=system_info, outputs=info_display)
        refresh_btn.click(fn=system_info, outputs=info_display)

if __name__ == "__main__":
    print("πŸš€ Starting GLM-4.5V CAD Generator...")
    print(f"CUDA available: {torch.cuda.is_available()}")
    demo.launch(share=True, show_error=True)