Salesforce CodeGen Tutorial: Generate, Validate, and Rerank Python Functions With Unit Tests and Safety Checks
重點摘要
In this tutorial, we implement an end-to-end workflow for Salesforce CodeGen. We load a CodeGen model from Hugging Face, prepare it for code generation, and use it to generate Python functions from natural-language prompts. We then move beyond basic inference by adding function extraction, syntax checking, static safety checks, unit-test-based validation, best-of-N candidate reranking, multi-step program synthesis, prompt-style experimentation, benchmark visualization, and artifact export. Through this workflow, we learn how CodeGen can be used not only as a code completion model but also as part of a structured code-generation pipeline that evaluates, filters, and organizes generated solutions. Loading the Salesforce CodeGen Model from Hugging Face Copy CodeCopiedUse a different Browserim
In this tutorial, we implement an end-to-end workflow for Salesforce CodeGen. We load a CodeGen model from Hugging Face, prepare it for code generation, and use it to generate Python functions from natural-language prompts. We then move beyond basic inference by adding function extraction, syntax checking, static safety checks, unit-test-based validation, best-of-N candidate reranking, multi-step program synthesis, prompt-style experimentation, benchmark visualization, and artifact export. Through this workflow, we learn how CodeGen can be used not only as a code completion model but also as part of a structured code-generation pipeline that evaluates, filters, and organizes generated solutions. Loading the Salesforce CodeGen Model from Hugging Face Copy CodeCopiedUse a different Browserimport os, sys, subprocess, textwrap, json, re, time, math, ast, tempfile, multiprocessing as mp from pathlib import Path def sh(cmd): print(f"\n$ {cmd}") subprocess.run(cmd, shell=True, check=True) sh(f"{sys.executable} -m pip install -q -U transformers accelerate safetensors einops datasets evaluate pandas matplotlib tqdm rich radon tiktoken") import torch import pandas as pd import matplotlib.pyplot as plt from tqdm.auto import tqdm from rich import print from rich.panel import Panel from rich.syntax import Syntax from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed from radon.complexity import cc_visit OUT_DIR = Path("/content/codegen_advanced_tutorial") OUT_DIR.mkdir(parents=True, exist_ok=True) set_seed(42) print(Panel.fit("Salesforce CodeGen Advanced Tutorial", style="bold green")) print("\nRuntime information") print("Python:", sys.version.split()[0]) print("Torch:", torch.__version__) print("CUDA available:", torch.cuda.is_available()) if torch.cuda.is_available(): print("GPU:", torch.cuda.get_device_name(0)) print("CUDA memory GB:", round(torch.cuda.get_device_properties(0).total_memory / 1e9, 2)) MODEL_ID = os.environ.get("CODEGEN_MODEL_ID", "Salesforce/codegen-350M-mono") MODEL_OPTIONS = { "easy_colab_default": "Salesforce/codegen-350M-mono", "larger_codegen1": "Salesforce/codegen-2B-mono", "codegen2_1b": "Salesforce/codegen2-1B_P", "codegen25_7b_mono": "Salesforce/codegen25-7b-mono_P", } print("\nSelected model:", MODEL_ID) print("Available model examples:", MODEL_OPTIONS) trust_remote_code = any(x in MODEL_ID.lower() for x in ["codegen2", "codegen25"]) device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if torch.cuda.is_available() else torch.float32 print("\nLoading tokenizer...") tokenizer = AutoTokenizer.from_pretrained( MODEL_ID, trust_remote_code=trust_remote_code ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("Loading model...") load_kwargs = { "trust_remote_code": trust_remote_code, "low_cpu_mem_usage": True, } if torch.cuda.is_available(): load_kwargs["torch_dtype"] = dtype load_kwargs["device_map"] = "auto" else: load_kwargs["torch_dtype"] = torch.float32 model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs) if not torch.cuda.is_available(): model.to(device) model.eval() def count_parameters(model): return sum(p.numel() for p in model.parameters()) print(f"Loaded {MODEL_ID}") print(f"Parameter count: {count_parameters(model)/1e6:.1f}M") def generate_text( prompt, max_new_tokens=180, temperature=0.35, top_p=0.92, top_k=50, do_sample=True, num_return_sequences=1, repetition_penalty=1.05, ): inputs = tokenizer(prompt, return_tensors="pt") inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k, num_return_sequences=num_return_sequences, repetition_penalty=repetition_penalty, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True) return decoded def print_code(title, code): print(Panel.fit(title, style="bold cyan")) print(Syntax(code, "python", theme="monokai", line_numbers=True)) We install all required libraries and prepare the environment for running Salesforce CodeGen. We check the runtime, detect GPU availability, select the CodeGen model, and load both the tokenizer and model from Hugging Face. We also define helper functions for text generation and for displaying formatted code so that the rest of the tutorial is easier to follow. Building Extraction, Safety, and Unit-Test Validation Utilities Copy CodeCopiedUse a different Browserdef extract_function_source(full_text, function_name): text = full_text.replace("\r\n", "\n") fence = re.search(r"```(?:python)?\n(.*?)```", text, flags=re.S | re.I) if fence: text = fence.group(1) pattern = rf"^def\s+{re.escape(function_name)}\s*\(" match = re.search(pattern, text, flags=re.M) if not match: return "" chunk = text[match.start():] lines = chunk.splitlines() collected = [] for i, line in enumerate(lines): if i > 0: if line.startswith("def ") or line.startswith("class "): break if line.startswith("if __name__"): break if line and not line.startswith((" ", "\t", "#")) and re.match(r"^[A-Za-z_][A-Za-z0-9_]*\s*=", line): break collected.append(line) source = "\n".join(collected).rstrip() try: ast.parse(source) return source except SyntaxError: fixed_lines = [] for line in collected: fixed_lines.append(line) candidate = "\n".join(fixed_lines).rstrip() try: ast.parse(candidate) source = candidate except SyntaxError: pass return source if source.strip().startswith("def ") else "" def syntax_ok(source): try: ast.parse(source) return True, "" except SyntaxError as e: return False, str(e) FORBIDDEN_NAMES = { "eval", "exec", "compile", "open", "input", "__import__", "globals", "locals", "vars", "dir", "getattr", "setattr", "delattr", "help", "breakpoint", "exit", "quit" } FORBIDDEN_NODES = ( ast.Import, ast.ImportFrom, ast.Global, ast.Nonlocal, ast.With, ast.AsyncWith, ast.AsyncFunctionDef, ast.ClassDef, ast.Delete, ast.Raise, ) ALLOWED_BUILTINS = { "abs": abs, "all": all, "any": any, "bool": bool, "dict": dict, "enumerate": enumerate, "float": float, "int": int, "isinstance": isinstance, "len": len, "list": list, "map": map, "max": max, "min": min, "pow": pow, "range": range, "reversed": reversed, "round": round, "set": set, "sorted": sorted, "str": str, "sum": sum, "tuple": tuple, "zip": zip, } def static_safety_check(source): try: tree = ast.parse(source) except SyntaxError as e: return False, f"SyntaxError: {e}" for node in ast.walk(tree): if isinstance(node, FORBIDDEN_NODES): return False, f"Forbidden AST node: {type(node).__name__}" if isinstance(node, ast.Name): if node.id in FORBIDDEN_NAMES or node.id.startswith("__"): return False, f"Forbidden name: {node.id}" if isinstance(node, ast.Attribute): if node.attr.startswith("__"): return False, f"Forbidden attribute: {node.attr}" if isinstance(node, ast.Call): if isinstance(node.func, ast.Name) and node.func.id in FORBIDDEN_NAMES: return False, f"Forbidden call: {node.func.id}" return True, "passed" def _worker_run_tests(source, function_name, tests, queue): try: safe_globals = {"__builtins__": ALLOWED_BUILTINS} safe_locals = {} compiled = compile(source, "<generated_code>", "exec") exec(compiled, safe_globals, safe_locals) fn = safe_locals.get(function_name) or safe_globals.get(function_name) if fn is None: queue.put({"ok": False, "error": f"{function_name} not found", "passed": 0, "total": len(tests)}) return passed = 0 details = [] for test in tests: args = test.get("args", []) kwargs = test.get("kwargs", {}) expected = test["expected"] result = fn(*args, **kwargs) ok = result == expected passed += int(ok) details.append({ "args": args, "kwargs": kwargs, "expected": expected, "result": result, "ok": ok, }) queue.put({"ok": passed == len(tests), "error": "", "passed": passed, "total": len(tests), "details": details}) except Exception as e: queue.put({"ok": False, "error": repr(e), "passed": 0, "total": len(tests)}
Related
相關文章

GPT發AI原創新成果了
這篇消息聚焦「GPT發AI原創新成果了」。原始導語提到:AI實現藥物全自動研發,還遠嗎? 從 AI 情報角度來看,這類內容值得關注其背後的技術進展、產品落地、產業競爭與後續市場影響。

AI越強,越要“殺死”過去的自己
這篇消息聚焦「AI越強,越要“殺死”過去的自己」。原始導語提到:人類需要實現思維模式的轉變。 從 AI 情報角度來看,這類內容值得關注其背後的技術進展、產品落地、產業競爭與後續市場影響。

Transformer之父離開谷歌,奧特曼等了他十年
這篇消息聚焦「Transformer之父離開谷歌,奧特曼等了他十年」。原始導語提到:27億美元也沒能留住,Noam Shazeer追尋下一代架構。 從 AI 情報角度來看,這類內容值得關注其背後的技術進展、產品落地、產業競爭與後續市場影響。

Dario訪談首曝:Mythos被稱為“超級武器”
這篇消息聚焦「Dario訪談首曝:Mythos被稱為“超級武器”」。原始導語提到:在這場69分鐘完整訪談裡,Dario Amodei 說人類真正面對的不是某個突然降臨的奇點,而是一條已經開始垂直起飛的指數曲線。 從 AI 情報角度來看,這類內容值得關注其背後的技術進展、產品落地、產業競爭與後續市場影響。

用結構替代數據,因果世界模型如何重塑具身智能大腦
這篇消息聚焦「用結構替代數據,因果世界模型如何重塑具身智能大腦」。原始導語提到:因果世界模型需要一個標誌性的時刻來證明自己。 從 AI 情報角度來看,這類內容值得關注其背後的技術進展、產品落地、產業競爭與後續市場影響。

英偉達機器人自學“裝顯卡”:把 AI 帶到現實世界直接放養,還要開源
該團隊為 8 個 Codex 智能體配備了多個機器人、GPU 分配以及充足的 Token 預算,並設定了一個簡單目標:儘可能快速地完成任務,讓機器人保持忙碌但確保安全,不要浪費寶貴的計算資源。